Commit 798e6229 authored by Peter Eastman's avatar Peter Eastman
Browse files

A3C supports checkpointing

parent a59b5f57
Loading
Loading
Loading
Loading
+36 −10
Original line number Diff line number Diff line
@@ -7,6 +7,8 @@ import numpy as np
import tensorflow as tf
import copy
import multiprocessing
import os
import re
import threading

class A3CLoss(Layer):
@@ -42,7 +44,7 @@ class A3C(object):
  "action" argument passed to the environment is an integer, giving the index of the action to perform.
  """

  def __init__(self, env, policy, max_rollout_length=20, discount_factor=0.99, value_weight=1.0, entropy_weight=0.01):
  def __init__(self, env, policy, max_rollout_length=20, discount_factor=0.99, value_weight=1.0, entropy_weight=0.01, model_dir=None):
    """Create an object for optimizing a policy.

    Parameters
@@ -60,6 +62,8 @@ class A3C(object):
      a scale factor for the value loss term in the loss function
    entropy_weight: float
      a scale factor for the entropy term in the loss function
    model_dir: str
      the directory in which the model will be saved.  If None, a temporary directory will be created.
    """
    self._env = env
    self._policy = policy
@@ -68,11 +72,11 @@ class A3C(object):
    self.value_weight = value_weight
    self.entropy_weight = entropy_weight
    self.optimizer = TFWrapper(tf.train.AdamOptimizer, learning_rate=0.001, beta1=0.9, beta2=0.999)
    (self._graph, self._features, rewards, actions, self._action_prob, self._value) = self._build_graph(None, 'global')
    (self._graph, self._features, rewards, actions, self._action_prob, self._value) = self._build_graph(None, 'global', model_dir)
    with self._graph._get_tf("Graph").as_default():
      self._session = tf.Session()

  def _build_graph(self, tf_graph, scope):
  def _build_graph(self, tf_graph, scope, model_dir):
    """Construct a TensorGraph containing the policy and loss calculations."""
    features = Feature(shape=[None]+list(self._env.state_shape))
    policy_layers = self._policy.create_layers(features)
@@ -81,7 +85,7 @@ class A3C(object):
    rewards = Weights(shape=(None, 1))
    actions = Label(shape=(None, self._env.n_actions))
    loss = A3CLoss(self.value_weight, self.entropy_weight, in_layers=[rewards, actions, action_prob, value])
    graph = TensorGraph(batch_size=self.max_rollout_length, use_queue=False, graph=tf_graph)
    graph = TensorGraph(batch_size=self.max_rollout_length, use_queue=False, graph=tf_graph, model_dir=model_dir)
    graph.add_output(action_prob)
    graph.add_output(value)
    graph.set_loss(loss)
@@ -91,7 +95,7 @@ class A3C(object):
        graph.build()
    return graph, features, rewards, actions, action_prob, value

  def fit(self, total_steps):
  def fit(self, total_steps, max_checkpoints_to_keep=5, checkpoint_interval=600):
    """Train the policy.

    Parameters
@@ -99,6 +103,11 @@ class A3C(object):
    total_steps: int
      the total number of time steps to perform on the environment, across all rollouts
      on all threads
    max_checkpoints_to_keep: int
      the maximum number of checkpoint files to keep.  When this number is reached, older
      files are deleted.
    checkpoint_interval: float
      the time interval at which to save checkpoints, measured in seconds
    """
    with self._graph._get_tf("Graph").as_default():
      train_op = self._graph._get_tf('train_op')
@@ -107,13 +116,21 @@ class A3C(object):
      workers = []
      threads = []
      for i in range(multiprocessing.cpu_count()):
        workers.append(Worker(self, i))
        workers.append(_Worker(self, i))
      for worker in workers:
        thread = threading.Thread(name=worker.scope, target=lambda: worker.run(step_count, total_steps))
        threads.append(thread)
        thread.start()
      for thread in threads:
        thread.join()
      saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
      checkpoint_index = 0
      while True:
        threads = [t for t in threads if t.isAlive()]
        if len(threads) > 0:
          threads[0].join(checkpoint_interval)
        checkpoint_index += 1
        saver.save(self._session, self._graph.save_file, global_step=checkpoint_index)
        if len(threads) == 0:
          break

  def predict(self, state):
    """Compute the policy's output predictions for a state.
@@ -154,8 +171,17 @@ class A3C(object):
      else:
        return np.random.choice(np.arange(self._env.n_actions), p=probabilities[0])

  def restore(self):
    """Reload the model parameters from the most recent checkpoint file."""
    last_checkpoint = tf.train.latest_checkpoint(self._graph.model_dir)
    if last_checkpoint is None:
      raise ValueError('No checkpoint found')
    with self._graph._get_tf("Graph").as_default():
      saver = tf.train.Saver()
      saver.restore(self._session, last_checkpoint)


class Worker(object):
class _Worker(object):
  """A Worker object is created for each training thread."""

  def __init__(self, a3c, index):
@@ -164,7 +190,7 @@ class Worker(object):
    self.scope = 'worker%d' % index
    self.env = copy.deepcopy(a3c._env)
    self.env.reset()
    self.graph, self.features, self.rewards, self.actions, self.action_prob, self.value = a3c._build_graph(a3c._graph._get_tf('Graph'), self.scope)
    self.graph, self.features, self.rewards, self.actions, self.action_prob, self.value = a3c._build_graph(a3c._graph._get_tf('Graph'), self.scope, None)
    with a3c._graph._get_tf("Graph").as_default():
      local_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)
      global_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'global')
+8 −0
Original line number Diff line number Diff line
@@ -60,3 +60,11 @@ class TestA3C(unittest.TestCase):
    assert -0.5 < value[0] < 0.5
    assert action_prob.argmax() == 37
    assert a3c.select_action([0], deterministic=True) == 37

    # Verify that we can create a new A3C object, reload the parameters from the first one, and
    # get the same result.

    new_a3c = dc.rl.A3C(env, TestPolicy(), model_dir=a3c._graph.model_dir)
    new_a3c.restore()
    action_prob2, value2 = new_a3c.predict([0])
    assert value2 == value