Commit 84ef3959 authored by peastman's avatar peastman
Browse files

yapf

parent a07508ff
Loading
Loading
Loading
Loading
+9 −4
Original line number Diff line number Diff line
@@ -123,8 +123,11 @@ class A3C(object):
        graph.build()
    return graph, features, rewards, actions, action_prob, value

  def fit(self, total_steps, max_checkpoints_to_keep=5,
          checkpoint_interval=600, restore=False):
  def fit(self,
          total_steps,
          max_checkpoints_to_keep=5,
          checkpoint_interval=600,
          restore=False):
    """Train the policy.

    Parameters
@@ -156,7 +159,8 @@ class A3C(object):
            target=lambda: worker.run(step_count, total_steps))
        threads.append(thread)
        thread.start()
      variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global')
      variables = tf.get_collection(
          tf.GraphKeys.GLOBAL_VARIABLES, scope='global')
      saver = tf.train.Saver(variables, max_to_keep=max_checkpoints_to_keep)
      checkpoint_index = 0
      while True:
@@ -218,7 +222,8 @@ class A3C(object):
    if last_checkpoint is None:
      raise ValueError('No checkpoint found')
    with self._graph._get_tf("Graph").as_default():
      variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='global')
      variables = tf.get_collection(
          tf.GraphKeys.GLOBAL_VARIABLES, scope='global')
      saver = tf.train.Saver(variables)
      saver.restore(self._session, last_checkpoint)