Commit 0efa3f60 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #652 from peastman/rnn

A3C works with recurrent layers
parents 9569295e 95e69566
Loading
Loading
Loading
Loading
+28 −6
Original line number Diff line number Diff line
@@ -27,6 +27,9 @@ class Layer(object):
    self.in_layers = in_layers
    self.op_type = "gpu"
    self.variable_scope = ''
    self.rnn_initial_states = []
    self.rnn_final_states = []
    self.rnn_zero_states = []

  def _get_layer_number(self):
    class_name = self.__class__.__name__
@@ -395,16 +398,35 @@ class GRU(Layer):
      raise ValueError("Must have one parent")
    parent_tensor = inputs[0]
    gru_cell = tf.contrib.rnn.GRUCell(self.n_hidden)
    initial_gru_state = gru_cell.zero_state(self.batch_size, tf.float32)
    out_tensor, rnn_states = tf.nn.dynamic_rnn(
        gru_cell,
        parent_tensor,
        initial_state=initial_gru_state,
        scope=self.name)
    zero_state = gru_cell.zero_state(self.batch_size, tf.float32)
    if set_tensors:
      initial_state = tf.placeholder(tf.float32, zero_state.get_shape())
    else:
      initial_state = zero_state
    out_tensor, final_state = tf.nn.dynamic_rnn(
        gru_cell, parent_tensor, initial_state=initial_state, scope=self.name)
    if set_tensors:
      self._record_variable_scope(self.name)
      self.out_tensor = out_tensor
      self.rnn_initial_states.append(initial_state)
      self.rnn_final_states.append(final_state)
      self.rnn_zero_states.append(np.zeros(zero_state.get_shape(), np.float32))
    return out_tensor

  def none_tensors(self):
    saved_tensors = [
        self.out_tensor, self.rnn_initial_states, self.rnn_final_states,
        self.rnn_zero_states
    ]
    self.out_tensor = None
    self.rnn_initial_states = []
    self.rnn_final_states = []
    self.rnn_zero_states = []
    return saved_tensors

  def set_tensors(self, tensor):
    self.out_tensor, self.rnn_initial_states, self.rnn_final_states, self.rnn_zero_states = tensor


class TimeSeriesDense(Layer):

+19 −0
Original line number Diff line number Diff line
@@ -94,6 +94,10 @@ class TensorGraph(Model):
    self.save_file = "%s/%s" % (self.model_dir, "model")
    self.model_class = None

    self.rnn_initial_states = []
    self.rnn_final_states = []
    self.rnn_zero_states = []

  def _add_layer(self, layer):
    if layer.name is None:
      layer.name = "%s_%s" % (layer.__class__.__name__, len(self.layers) + 1)
@@ -226,6 +230,9 @@ class TensorGraph(Model):
          feed_dict[self.features[0]] = X_b
        if len(self.task_weights) == 1 and w_b is not None and not predict:
          feed_dict[self.task_weights[0]] = w_b
        for (inital_state, zero_state) in zip(self.rnn_initial_states,
                                              self.rnn_zero_states):
          feed_dict[initial_state] = zero_state
        yield feed_dict

  def predict_on_generator(self, generator, transformers=[]):
@@ -328,6 +335,9 @@ class TensorGraph(Model):
        with tf.name_scope(node):
          node_layer = self.layers[node]
          node_layer.create_tensor(training=self._training_placeholder)
          self.rnn_initial_states += node_layer.rnn_initial_states
          self.rnn_final_states += node_layer.rnn_final_states
          self.rnn_zero_states += node_layer.rnn_zero_states
      self.built = True

    for layer in self.layers.values():
@@ -412,7 +422,13 @@ class TensorGraph(Model):
    # Remove out_tensor from the object to be pickled
    must_restore = False
    tensor_objects = self.tensor_objects
    rnn_initial_states = self.rnn_initial_states
    rnn_final_states = self.rnn_final_states
    rnn_zero_states = self.rnn_zero_states
    self.tensor_objects = {}
    self.rnn_initial_states = []
    self.rnn_final_states = []
    self.rnn_zero_states = []
    out_tensors = []
    if self.built:
      must_restore = True
@@ -440,6 +456,9 @@ class TensorGraph(Model):
      self._training_placeholder = training_placeholder
      self.built = True
    self.tensor_objects = tensor_objects
    self.rnn_initial_states = rnn_initial_states
    self.rnn_final_states = rnn_final_states
    self.rnn_zero_states = rnn_zero_states

  def evaluate_generator(self,
                         feed_dict_generator,
+83 −19
Original line number Diff line number Diff line
@@ -33,11 +33,6 @@ class A3CLoss(Layer):
    return self.out_tensor


def _create_feed_dict(features, state):
  return dict((f.out_tensor, np.expand_dims(s, axis=0))
              for f, s in zip(features, state))


class A3C(object):
  """
  Implements the Asynchronous Advantage Actor-Critic (A3C) algorithm for reinforcement learning.
@@ -102,6 +97,7 @@ class A3C(object):
         None, 'global', model_dir)
    with self._graph._get_tf("Graph").as_default():
      self._session = tf.Session()
    self._rnn_states = self._graph.rnn_zero_states

  def _build_graph(self, tf_graph, scope, model_dir):
    """Construct a TensorGraph containing the policy and loss calculations."""
@@ -182,27 +178,52 @@ class A3C(object):
        if len(threads) == 0:
          break

  def predict(self, state):
  def predict(self, state, use_saved_states=True, save_states=True):
    """Compute the policy's output predictions for a state.

    If the policy involves recurrent layers, this method can preserve their internal
    states between calls.  Use the use_saved_states and save_states arguments to specify
    how it should behave.

    Parameters
    ----------
    state: array
      the state of the environment for which to generate predictions
    use_saved_states: bool
      if True, the states most recently saved by a previous call to predict() or select_action()
      will be used as the initial states.  If False, the internal states of all recurrent layers
      will be set to all zeros before computing the predictions.
    save_states: bool
      if True, the internal states of all recurrent layers at the end of the calculation
      will be saved, and any previously saved states will be discarded.  If False, the
      states at the end of the calculation will be discarded, and any previously saved
      states will be kept.

    Returns
    -------
    the array of action probabilities, and the estimated value function
    """
    with self._graph._get_tf("Graph").as_default():
      feed_dict = _create_feed_dict(self._features, state)
      return self._session.run(
          [self._action_prob.out_tensor, self._value.out_tensor],
          feed_dict=feed_dict)
      feed_dict = self._create_feed_dict(state, use_saved_states)
      tensors = [self._action_prob.out_tensor, self._value.out_tensor]
      if save_states:
        tensors += self._graph.rnn_final_states
      results = self._session.run(tensors, feed_dict=feed_dict)
      if save_states:
        self._rnn_states = results[2:]
      return results[:2]

  def select_action(self, state, deterministic=False):
  def select_action(self,
                    state,
                    deterministic=False,
                    use_saved_states=True,
                    save_states=True):
    """Select an action to perform based on the environment's state.

    If the policy involves recurrent layers, this method can preserve their internal
    states between calls.  Use the use_saved_states and save_states arguments to specify
    how it should behave.

    Parameters
    ----------
    state: array
@@ -210,15 +231,29 @@ class A3C(object):
    deterministic: bool
      if True, always return the best action (that is, the one with highest probability).
      If False, randomly select an action based on the computed probabilities.
    use_saved_states: bool
      if True, the states most recently saved by a previous call to predict() or select_action()
      will be used as the initial states.  If False, the internal states of all recurrent layers
      will be set to all zeros before computing the predictions.
    save_states: bool
      if True, the internal states of all recurrent layers at the end of the calculation
      will be saved, and any previously saved states will be discarded.  If False, the
      states at the end of the calculation will be discarded, and any previously saved
      states will be kept.

    Returns
    -------
    the index of the selected action
    """
    with self._graph._get_tf("Graph").as_default():
      feed_dict = _create_feed_dict(self._features, state)
      probabilities = self._session.run(
          self._action_prob.out_tensor, feed_dict=feed_dict)
      feed_dict = self._create_feed_dict(state, use_saved_states)
      tensors = [self._action_prob.out_tensor]
      if save_states:
        tensors += self._graph.rnn_final_states
      results = self._session.run(tensors, feed_dict=feed_dict)
      probabilities = results[0]
      if save_states:
        self._rnn_states = results[1:]
      if deterministic:
        return probabilities.argmax()
      else:
@@ -236,6 +271,18 @@ class A3C(object):
      saver = tf.train.Saver(variables)
      saver.restore(self._session, last_checkpoint)

  def _create_feed_dict(self, state, use_saved_states):
    """Create a feed dict for use by predict() or select_action()."""
    feed_dict = dict((f.out_tensor, np.expand_dims(s, axis=0))
                     for f, s in zip(self._features, state))
    if use_saved_states:
      rnn_states = self._rnn_states
    else:
      rnn_states = self._graph.rnn_zero_states
    for (placeholder, value) in zip(self._graph.rnn_initial_states, rnn_states):
      feed_dict[placeholder] = value
    return feed_dict


class _Worker(object):
  """A Worker object is created for each training thread."""
@@ -248,6 +295,7 @@ class _Worker(object):
    self.env.reset()
    self.graph, self.features, self.rewards, self.actions, self.action_prob, self.value, self.advantages = a3c._build_graph(
        a3c._graph._get_tf('Graph'), self.scope, None)
    self.rnn_states = self.graph.rnn_zero_states
    with a3c._graph._get_tf("Graph").as_default():
      local_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     self.scope)
@@ -265,9 +313,12 @@ class _Worker(object):
      session = self.a3c._session
      while step_count[0] < total_steps:
        session.run(self.update_local_variables)
        feed_dict = {}
        for placeholder, value in zip(self.graph.rnn_initial_states,
                                      self.rnn_states):
          feed_dict[placeholder] = value
        episode_states, episode_actions, episode_rewards, episode_advantages = self.create_rollout(
        )
        feed_dict = {}
        for f, s in zip(self.features, episode_states):
          feed_dict[f.out_tensor] = s
        feed_dict[self.rewards.out_tensor] = episode_rewards
@@ -290,10 +341,13 @@ class _Worker(object):
      state = self.env.state
      for j in range(len(state)):
        states[j].append(state[j])
      feed_dict = _create_feed_dict(self.features, state)
      probabilities, value = session.run(
          [self.action_prob.out_tensor, self.value.out_tensor],
      feed_dict = self.create_feed_dict(state)
      results = session.run(
          [self.action_prob.out_tensor, self.value.out_tensor] +
          self.graph.rnn_final_states,
          feed_dict=feed_dict)
      probabilities, value = results[:2]
      self.rnn_states = results[2:]
      action = np.random.choice(np.arange(n_actions), p=probabilities[0])
      actions.append(np.zeros(n_actions))
      actions[i][action] = 1.0
@@ -301,7 +355,7 @@ class _Worker(object):
      rewards.append(self.env.step(action))
    if not self.env.terminated:
      # Add an estimate of the reward for the rest of the episode.
      feed_dict = _create_feed_dict(self.features, self.env.state)
      feed_dict = self.create_feed_dict(self.env.state)
      rewards[-1] += self.a3c.discount_factor * float(
          session.run(self.value.out_tensor, feed_dict))
    for j in range(len(rewards) - 1, 0, -1):
@@ -310,4 +364,14 @@ class _Worker(object):
    advantages = rewards_array - np.array(values)
    if self.env.terminated:
      self.env.reset()
      self.rnn_states = self.graph.rnn_zero_states
    return np.array(states), np.array(actions), rewards_array, advantages

  def create_feed_dict(self, state):
    """Create a feed dict for use during a rollout."""
    feed_dict = dict((f.out_tensor, np.expand_dims(s, axis=0))
                     for f, s in zip(self.features, state))
    for (placeholder, value) in zip(self.graph.rnn_initial_states,
                                    self.rnn_states):
      feed_dict[placeholder] = value
    return feed_dict
+58 −1
Original line number Diff line number Diff line
from flaky import flaky

import deepchem as dc
from deepchem.models.tensorgraph.layers import Reshape, Variable, SoftMax
from deepchem.models.tensorgraph.layers import Reshape, Variable, SoftMax, GRU
import numpy as np
import tensorflow as tf
import unittest
@@ -85,3 +85,60 @@ class TestA3C(unittest.TestCase):
    new_a3c.fit(0, restore=True)
    action_prob2, value2 = new_a3c.predict([[0]])
    assert value2 == value

  def test_recurrent_states(self):
    """Test a policy that involves recurrent layers."""

    # The environment just has a constant state.

    class TestEnvironment(dc.rl.Environment):

      def __init__(self):
        super(TestEnvironment, self).__init__([(10,)], 10)
        self._state = [np.random.random(10)]

      def step(self, action):
        self._state = [np.random.random(10)]
        return 0.0

      def reset(self):
        pass

    # The policy includes a single recurrent layer.

    class TestPolicy(dc.rl.Policy):

      def create_layers(self, state, **kwargs):

        reshaped = Reshape(shape=(1, -1, 10), in_layers=state)
        gru = GRU(n_hidden=10, batch_size=1, in_layers=reshaped)
        output = SoftMax(
            in_layers=[Reshape(in_layers=[gru], shape=(-1, env.n_actions))])
        value = Variable([0.0])
        return {'action_prob': output, 'value': value}

    # We don't care about actually optimizing it, so just run a few rollouts to make
    # sure fit() doesn't crash, then check the behavior of the GRU state.

    env = TestEnvironment()
    a3c = dc.rl.A3C(env, TestPolicy())
    a3c.fit(100)
    # On the first call, the initial state should be all zeros.
    prob1, value1 = a3c.predict(
        env.state, use_saved_states=True, save_states=False)
    # It should still be zeros since we didn't save it last time.
    prob2, value2 = a3c.predict(
        env.state, use_saved_states=True, save_states=True)
    # It should be different now.
    prob3, value3 = a3c.predict(
        env.state, use_saved_states=True, save_states=False)
    # This should be the same as the previous one.
    prob4, value4 = a3c.predict(
        env.state, use_saved_states=True, save_states=False)
    # Now we reset it, so we should get the same result as initially.
    prob5, value5 = a3c.predict(
        env.state, use_saved_states=False, save_states=True)
    assert np.array_equal(prob1, prob2)
    assert np.array_equal(prob1, prob5)
    assert np.array_equal(prob3, prob4)
    assert not np.array_equal(prob2, prob3)