Commit 263384d7 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #587 from peastman/checkpoint

Improved A3C checkpointing
parents 4c18f113 5a17008c
Loading
Loading
Loading
Loading
+16 −4
Original line number Diff line number Diff line
@@ -123,8 +123,11 @@ class A3C(object):
        graph.build()
    return graph, features, rewards, actions, action_prob, value

  def fit(self, total_steps, max_checkpoints_to_keep=5,
          checkpoint_interval=600):
  def fit(self,
          total_steps,
          max_checkpoints_to_keep=5,
          checkpoint_interval=600,
          restore=False):
    """Train the policy.

    Parameters
@@ -137,6 +140,9 @@ class A3C(object):
      files are deleted.
    checkpoint_interval: float
      the time interval at which to save checkpoints, measured in seconds
    restore: bool
      if True, restore the model from the most recent checkpoint and continue training
      from there.  If False, retrain the model from scratch.
    """
    with self._graph._get_tf("Graph").as_default():
      step_count = [0]
@@ -145,13 +151,17 @@ class A3C(object):
      for i in range(multiprocessing.cpu_count()):
        workers.append(_Worker(self, i))
      self._session.run(tf.global_variables_initializer())
      if restore:
        self.restore()
      for worker in workers:
        thread = threading.Thread(
            name=worker.scope,
            target=lambda: worker.run(step_count, total_steps))
        threads.append(thread)
        thread.start()
      saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
      variables = tf.get_collection(
          tf.GraphKeys.GLOBAL_VARIABLES, scope='global')
      saver = tf.train.Saver(variables, max_to_keep=max_checkpoints_to_keep)
      checkpoint_index = 0
      while True:
        threads = [t for t in threads if t.isAlive()]
@@ -212,7 +222,9 @@ class A3C(object):
    if last_checkpoint is None:
      raise ValueError('No checkpoint found')
    with self._graph._get_tf("Graph").as_default():
      saver = tf.train.Saver()
      variables = tf.get_collection(
          tf.GraphKeys.GLOBAL_VARIABLES, scope='global')
      saver = tf.train.Saver(variables)
      saver.restore(self._session, last_checkpoint)


+7 −0
Original line number Diff line number Diff line
@@ -73,3 +73,10 @@ class TestA3C(unittest.TestCase):
    new_a3c.restore()
    action_prob2, value2 = new_a3c.predict([[0]])
    assert value2 == value

    # Do the same thing, only using the "restore" argument to fit().

    new_a3c = dc.rl.A3C(env, TestPolicy(), model_dir=a3c._graph.model_dir)
    new_a3c.fit(0, restore=True)
    action_prob2, value2 = new_a3c.predict([[0]])
    assert value2 == value