Commit a7318f9e authored by peastman's avatar peastman
Browse files

Converted PPO to TF2

parent 563207ae
Loading
Loading
Loading
Loading
+107 −115
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ class PPOLoss(object):
    policy_loss = -tf.reduce_mean(
        tf.minimum(ratio * advantage, clipped_ratio * advantage))
    value_loss = tf.reduce_mean(tf.square(reward - value))
    entropy = -tf.reduce_mean(tf.reduce_sum(prob * tf.log(prob), axis=1))
    entropy = -tf.reduce_mean(tf.reduce_sum(prob * tf.math.log(prob), axis=1))
    return policy_loss + self.value_weight * value_loss - self.entropy_weight * entropy


@@ -155,57 +155,34 @@ class PPO(object):
      self._optimizer = Adam(learning_rate=0.001, beta1=0.9, beta2=0.999)
    else:
      self._optimizer = optimizer
    self._model = self._build_model(model_dir)
    output_names = policy.output_names
    output_tensors = self._model._output_tensors
    self._value = output_tensors[output_names.index('value')]
    self._action_prob = output_tensors[output_names.index('action_prob')]
    rnn_outputs = [i for i, n in enumerate(output_names) if n == 'rnn_state']
    self._rnn_final_states = [output_tensors[i] for i in rnn_outputs]
    self._session = tf.Session()
    self._train_op = self._model._tf_optimizer.minimize(
        self._model._loss_tensor)
    self._value_index = output_names.index('value')
    self._action_prob_index = output_names.index('action_prob')
    self._rnn_final_state_indices = [
        i for i, n in enumerate(output_names) if n == 'rnn_state'
    ]
    self._rnn_states = policy.rnn_initial_states
    if len(self._rnn_states) > 0 and batch_size != 0:
      raise ValueError(
          'Cannot batch rollouts when the policy contains a recurrent layer.  Set batch_size to 0.'
      )
    self._model = self._build_model(model_dir)
    self._checkpoint = tf.train.Checkpoint()
    self._checkpoint.save_counter  # Ensure the variable has been created
    self._checkpoint.listed = self._model.model.trainable_variables
    self._session.run(self._checkpoint.save_counter.initializer)

  def _build_model(self, model_dir):
    """Construct a KerasModel containing the policy and loss calculations."""
    state_shape = self._env.state_shape
    state_dtype = self._env.state_dtype
    if not self._state_is_list:
      state_shape = [state_shape]
      state_dtype = [state_dtype]
    features = []
    for s, d in zip(state_shape, state_dtype):
      features.append(
          tf.keras.layers.Input(shape=list(s), dtype=tf.as_dtype(d)))
    policy_model = self._policy.create_model()
    output_names = self._policy.output_names
    loss = PPOLoss(self.value_weight, self.entropy_weight, self.clipping_width,
                   output_names.index('action_prob'),
                   output_names.index('value'))
                   self._action_prob_index, self._value_index)
    model = KerasModel(
        policy_model,
        loss,
        batch_size=self.max_rollout_length,
        model_dir=model_dir,
        optimize=self._optimizer)
    env = self._env
    example_inputs = [
        np.zeros([model.batch_size] + list(shape), dtype)
        for shape, dtype in zip(state_shape, state_dtype)
    ]
    example_labels = [np.zeros((model.batch_size, env.n_actions))]
    example_weights = [np.zeros(model.batch_size)] * 3
    model._create_training_ops((example_inputs, example_labels,
                                example_weights))
    model._ensure_built()
    return model

  def fit(self,
@@ -234,7 +211,6 @@ class PPO(object):
    threads = []
    for i in range(self.optimization_rollouts):
      workers.append(_Worker(self, i))
    self._session.run(tf.global_variables_initializer())
    if restore:
      self.restore()
    pool = Pool()
@@ -257,20 +233,14 @@ class PPO(object):
        for batch in batches:
          initial_rnn_states, state_arrays, discounted_rewards, actions_matrix, action_prob, advantages = batch

          # Build the feed dict and run the optimizer.
          # Build the inputs and run the optimizer.

          feed_dict = {}
          for f, s in zip(self._model._input_placeholders, state_arrays):
            feed_dict[f] = s
          for f, s in zip(self._model._input_placeholders[len(state_arrays):],
                          initial_rnn_states):
            feed_dict[f] = np.expand_dims(s, axis=0)
          feed_dict[self._model._weights_placeholders[0]] = discounted_rewards
          feed_dict[self._model._label_placeholders[0]] = actions_matrix
          feed_dict[self._model._weights_placeholders[1]] = advantages
          feed_dict[self._model._weights_placeholders[2]] = action_prob
          feed_dict[self._model._global_step] = step_count
          self._session.run(self._train_op, feed_dict=feed_dict)
          state_arrays = [np.stack(s) for s in state_arrays]
          inputs = state_arrays + [
              np.expand_dims(s, axis=0) for s in initial_rnn_states
          ]
          self._apply_gradients(inputs, actions_matrix, discounted_rewards,
                                advantages, action_prob)

      # Update the number of steps taken so far and perform checkpointing.

@@ -280,10 +250,21 @@ class PPO(object):
      step_count += new_steps
      if step_count >= total_steps or time.time(
      ) >= checkpoint_time + checkpoint_interval:
        with self._session.as_default():
        manager.save()
        checkpoint_time = time.time()

  @tf.function(experimental_relax_shapes=True)
  def _apply_gradients(self, inputs, actions_matrix, discounted_rewards,
                       advantages, action_prob):
    """Compute the gradient of the loss function for a batch and update the model."""
    vars = self._model.model.trainable_variables
    with tf.GradientTape() as tape:
      outputs = self._model.model(inputs)
      loss = self._model._loss_fn(outputs, [actions_matrix],
                                  [discounted_rewards, advantages, action_prob])
    gradients = tape.gradient(loss, vars)
    self._model._tf_optimizer.apply_gradients(zip(gradients, vars))

  def _iter_batches(self, rollouts):
    """Given a set of rollouts, merge them into batches for optimization."""

@@ -337,16 +318,8 @@ class PPO(object):
    -------
    the array of action probabilities, and the estimated value function
    """
    if not self._state_is_list:
      state = [state]
    feed_dict = self._create_feed_dict(state, use_saved_states)
    tensors = [self._action_prob, self._value]
    if save_states:
      tensors += self._rnn_final_states
    results = self._session.run(tensors, feed_dict=feed_dict)
    if save_states:
      self._rnn_states = [np.squeeze(r, 0) for r in results[2:]]
    return results[:2]
    results = self._predict_outputs(state, use_saved_states, save_states)
    return [results[i] for i in (self._action_prob_index, self._value_index)]

  def select_action(self,
                    state,
@@ -380,28 +353,45 @@ class PPO(object):
    -------
    the index of the selected action
    """
    if not self._state_is_list:
      state = [state]
    feed_dict = self._create_feed_dict(state, use_saved_states)
    tensors = [self._action_prob]
    if save_states:
      tensors += self._rnn_final_states
    results = self._session.run(tensors, feed_dict=feed_dict)
    probabilities = results[0]
    if save_states:
      self._rnn_states = [np.squeeze(r, 0) for r in results[1:]]
    if deterministic:
      return probabilities.argmax()
    else:
      return np.random.choice(
          np.arange(self._env.n_actions), p=probabilities[0])
    outputs = self._predict_outputs(state, use_saved_states, save_states)
    return self._select_action_from_outputs(outputs, deterministic)

  def restore(self):
    """Reload the model parameters from the most recent checkpoint file."""
    last_checkpoint = tf.train.latest_checkpoint(self._model.model_dir)
    if last_checkpoint is None:
      raise ValueError('No checkpoint found')
    self._checkpoint.restore(last_checkpoint).run_restore_ops(self._session)
    self._checkpoint.restore(last_checkpoint)

  def _predict_outputs(self, state, use_saved_states, save_states):
    """Compute a set of outputs for a state. """
    if not self._state_is_list:
      state = [state]
    if use_saved_states:
      state = state + list(self._rnn_states)
    else:
      state = state + list(self._policy.rnn_initial_states)
    inputs = [np.expand_dims(s, axis=0) for s in state]
    results = self._compute_model(inputs)
    results = [r.numpy() for r in results]
    if save_states:
      self._rnn_states = [
          np.squeeze(results[i], 0) for i in self._rnn_final_state_indices
      ]
    return results

  @tf.function(experimental_relax_shapes=True)
  def _compute_model(self, inputs):
    return self._model.model(inputs)

  def _select_action_from_outputs(self, outputs, deterministic):
    """Given the policy outputs, select an action to perform."""
    action_prob = outputs[self._action_prob_index]
    if deterministic:
      return action_prob.argmax()
    else:
      action_prob = action_prob.flatten()
      return np.random.choice(np.arange(len(action_prob)), p=action_prob)

  def _create_feed_dict(self, state, use_saved_states):
    """Create a feed dict for use by predict() or select_action()."""
@@ -423,21 +413,14 @@ class _Worker(object):
    self.env = copy.deepcopy(ppo._env)
    self.env.reset()
    self.model = ppo._build_model(None)
    output_names = ppo._policy.output_names
    output_tensors = self.model._output_tensors
    self.value = output_tensors[output_names.index('value')]
    self.action_prob = output_tensors[output_names.index('action_prob')]
    rnn_outputs = [i for i, n in enumerate(output_names) if n == 'rnn_state']
    self.rnn_final_states = [output_tensors[i] for i in rnn_outputs]
    self.rnn_states = ppo._policy.rnn_initial_states
    local_vars = self.model.model.trainable_variables
    global_vars = ppo._model.model.trainable_variables
    self.update_local_variables = tf.group(
        *[tf.assign(v1, v2) for v1, v2 in zip(local_vars, global_vars)])

  def run(self):
    rollouts = []
    self.ppo._session.run(self.update_local_variables)
    local_vars = self.model.model.trainable_variables
    global_vars = self.ppo._model.model.trainable_variables
    for v1, v2 in zip(local_vars, global_vars):
      v1.assign(v2)
    initial_rnn_states = self.rnn_states
    states, actions, action_prob, rewards, values = self.create_rollout()
    rollouts.append(
@@ -452,7 +435,6 @@ class _Worker(object):
  def create_rollout(self):
    """Generate a rollout."""
    n_actions = self.env.n_actions
    session = self.ppo._session
    states = []
    action_prob = []
    actions = []
@@ -466,14 +448,15 @@ class _Worker(object):
        break
      state = self.env.state
      states.append(state)
      feed_dict = self.create_feed_dict(state)
      results = session.run(
          [self.action_prob, self.value] + self.rnn_final_states,
          feed_dict=feed_dict)
      probabilities, value = results[:2]
      probabilities = np.squeeze(probabilities)
      self.rnn_states = [np.squeeze(r, 0) for r in results[2:]]
      action = np.random.choice(np.arange(n_actions), p=probabilities)
      results = self._compute_model(
          self._create_model_inputs(state, self.rnn_states))
      results = [r.numpy() for r in results]
      value = results[self.ppo._value_index]
      probabilities = np.squeeze(results[self.ppo._action_prob_index])
      self.rnn_states = [
          np.squeeze(results[i], 0) for i in self.ppo._rnn_final_state_indices
      ]
      action = self.ppo._select_action_from_outputs(results, False)
      actions.append(action)
      action_prob.append(probabilities[action])
      values.append(float(value))
@@ -482,9 +465,10 @@ class _Worker(object):
    # Compute an estimate of the reward for the rest of the episode.

    if not self.env.terminated:
      feed_dict = self.create_feed_dict(self.env.state)
      final_value = self.ppo.discount_factor * float(
          session.run(self.value, feed_dict))
      results = self._compute_model(
          self._create_model_inputs(self.env.state, self.rnn_states))
      final_value = self.ppo.discount_factor * results[self.ppo.
                                                       _value_index].numpy()[0]
    else:
      final_value = 0.0
    values.append(final_value)
@@ -492,8 +476,10 @@ class _Worker(object):
      self.env.reset()
      self.rnn_states = self.ppo._policy.rnn_initial_states
    return states, np.array(
        actions, dtype=np.int32), np.array(action_prob), np.array(
            rewards), np.array(values)
        actions, dtype=np.int32), np.array(
            action_prob, dtype=np.float32), np.array(
                rewards, dtype=np.float32), np.array(
                    values, dtype=np.float32)

  def process_rollout(self, states, actions, action_prob, rewards, values,
                      initial_rnn_states):
@@ -517,14 +503,15 @@ class _Worker(object):
    n_actions = self.env.n_actions
    actions_matrix = []
    for action in actions:
      a = np.zeros(n_actions)
      a = np.zeros(n_actions, np.float32)
      a[action] = 1.0
      actions_matrix.append(a)
    actions_matrix = np.array(actions_matrix, dtype=np.float32)

    # Rearrange the states into the proper set of arrays.

    if self.ppo._state_is_list:
      state_arrays = [[] for i in range(len(self.model._input_shapes))]
      state_arrays = [[] for i in range(len(self.env.state_shape))]
      for state in states:
        for j in range(len(state)):
          state_arrays[j].append(state[j])
@@ -541,29 +528,34 @@ class _Worker(object):
    hindsight_states, rewards = self.env.apply_hindsight(
        states, actions, states[-1])
    if self.ppo._state_is_list:
      state_arrays = [[] for i in range(len(self.model._input_shapes))]
      state_arrays = [[] for i in range(len(self.env.state_shape))]
      for state in hindsight_states:
        for j in range(len(state)):
          state_arrays[j].append(state[j])
    else:
      state_arrays = [hindsight_states]
    state_arrays += initial_rnn_states
    feed_dict = {}
    for f, s in zip(self.model._input_placeholders, state_arrays):
      feed_dict[f] = s
    values, probabilities = self.ppo._session.run(
        [self.value, self.action_prob], feed_dict=feed_dict)
    state_arrays = [np.stack(s) for s in state_arrays]
    inputs = state_arrays + [
        np.expand_dims(s, axis=0) for s in initial_rnn_states
    ]
    outputs = self._compute_model(inputs)
    values = outputs[self.ppo._value_index].numpy()
    values = np.append(values.flatten(), 0.0)
    probabilities = outputs[self.ppo._action_prob_index].numpy()
    action_prob = probabilities[np.arange(len(actions)), actions]
    return self.process_rollout(hindsight_states, actions, action_prob,
                                np.array(rewards), np.array(values),
                                np.array(rewards, dtype=np.float32),
                                np.array(values, dtype=np.float32),
                                initial_rnn_states)

  def create_feed_dict(self, state):
    """Create a feed dict for use during a rollout."""
  def _create_model_inputs(self, state, rnn_states):
    """Create the inputs to the model for use during a rollout."""
    if not self.ppo._state_is_list:
      state = [state]
    state = state + self.rnn_states
    feed_dict = dict((f, np.expand_dims(s, axis=0))
                     for f, s in zip(self.model._input_placeholders, state))
    return feed_dict
    state = state + rnn_states
    return [np.expand_dims(s, axis=0) for s in state]

  @tf.function(experimental_relax_shapes=True)
  def _compute_model(self, inputs):
    return self.model.model(inputs)
+6 −6
Original line number Diff line number Diff line
@@ -107,10 +107,10 @@ class TestPPO(unittest.TestCase):

      def __init__(self):
        super(TestEnvironment, self).__init__((10,), 10)
        self._state = np.random.random(10)
        self._state = np.random.random(10).astype(np.float32)

      def step(self, action):
        self._state = np.random.random(10)
        self._state = np.random.random(10).astype(np.float32)
        return 0.0

      def reset(self):
@@ -129,10 +129,10 @@ class TestPPO(unittest.TestCase):
        rnn_state = Input(shape=(10,))
        reshaped = Reshape((1, 10))(state)
        gru, rnn_final_state = GRU(
            10, return_state=True, return_sequences=True)(
            10, return_state=True, return_sequences=True, time_major=True)(
                reshaped, initial_state=rnn_state)
        output = Softmax()(Reshape((10,))(gru))
        value = dc.models.layers.Variable([0.0])([])
        value = dc.models.layers.Variable([0.0])([state])
        return tf.keras.Model(
            inputs=[state, rnn_state], outputs=[output, value, rnn_final_state])

@@ -219,8 +219,8 @@ class TestPPO(unittest.TestCase):

      def create_model(self, **kwargs):
        state = Input(shape=(4,))
        dense1 = Dense(6, activation=tf.nn.relu)(state)
        dense2 = Dense(6, activation=tf.nn.relu)(dense1)
        dense1 = Dense(8, activation=tf.nn.relu)(state)
        dense2 = Dense(8, activation=tf.nn.relu)(dense1)
        output = Dense(4, activation=tf.nn.softmax, use_bias=False)(dense2)
        value = Dense(1)(dense2)
        return tf.keras.Model(inputs=state, outputs=[output, value])