Commit a07508ff authored by peastman's avatar peastman
Browse files

Improved A3C checkpointing

parent 98dec473
Loading
Loading
Loading
Loading
+10 −3
Original line number Diff line number Diff line
@@ -124,7 +124,7 @@ class A3C(object):
    return graph, features, rewards, actions, action_prob, value

  def fit(self, total_steps, max_checkpoints_to_keep=5,
          checkpoint_interval=600):
          checkpoint_interval=600, restore=False):
    """Train the policy.

    Parameters
@@ -137,6 +137,9 @@ class A3C(object):
      files are deleted.
    checkpoint_interval: float
      the time interval at which to save checkpoints, measured in seconds
    restore: bool
      if True, restore the model from the most recent checkpoint and continue training
      from there.  If False, retrain the model from scratch.
    """
    with self._graph._get_tf("Graph").as_default():
      step_count = [0]
@@ -145,13 +148,16 @@ class A3C(object):
      for i in range(multiprocessing.cpu_count()):
        workers.append(_Worker(self, i))
      self._session.run(tf.global_variables_initializer())
      if restore:
        self.restore()
      for worker in workers:
        thread = threading.Thread(
            name=worker.scope,
            target=lambda: worker.run(step_count, total_steps))
        threads.append(thread)
        thread.start()
      saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
      variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global')
      saver = tf.train.Saver(variables, max_to_keep=max_checkpoints_to_keep)
      checkpoint_index = 0
      while True:
        threads = [t for t in threads if t.isAlive()]
@@ -212,7 +218,8 @@ class A3C(object):
    if last_checkpoint is None:
      raise ValueError('No checkpoint found')
    with self._graph._get_tf("Graph").as_default():
      saver = tf.train.Saver()
      variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global')
      saver = tf.train.Saver(variables)
      saver.restore(self._session, last_checkpoint)