Commit a207b9c5 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #656 from peastman/gae

A3C supports Generalized Advantage Estimation
parents 6dd2f1b3 24d3c96f
Loading
Loading
Loading
Loading
+36 −9
Original line number Diff line number Diff line
@@ -37,9 +37,10 @@ class A3C(object):
  """
  Implements the Asynchronous Advantage Actor-Critic (A3C) algorithm for reinforcement learning.

  This algorithm requires the policy to output two quantities: a vector giving the probability of
  taking each action, and an estimate of the value function for the current state.  It optimizes
  both outputs at once using a loss that is the sum of three terms:
  The algorithm is described in Mnih et al, "Asynchronous Methods for Deep Reinforcement Learning"
  (https://arxiv.org/abs/1602.01783).  This class requires the policy to output two quantities:
  a vector giving the probability of taking each action, and an estimate of the value function for
  the current state.  It optimizes both outputs at once using a loss that is the sum of three terms:

  1. The policy loss, which seeks to maximize the discounted reward for each action.
  2. The value loss, which tries to make the value estimate match the actual discounted reward
@@ -48,6 +49,11 @@ class A3C(object):

  This class only supports environments with discrete action spaces, not continuous ones.  The
  "action" argument passed to the environment is an integer, giving the index of the action to perform.

  This class supports Generalized Advantage Estimation as described in Schulman et al., "High-Dimensional
  Continuous Control Using Generalized Advantage Estimation" (https://arxiv.org/abs/1506.02438).
  This is a method of trading off bias and variance in the advantage estimate, which can sometimes
  improve the rate of convergance.  Use the advantage_lambda parameter to adjust the tradeoff.
  """

  def __init__(self,
@@ -55,6 +61,7 @@ class A3C(object):
               policy,
               max_rollout_length=20,
               discount_factor=0.99,
               advantage_lambda=0.98,
               value_weight=1.0,
               entropy_weight=0.01,
               optimizer=None,
@@ -85,6 +92,7 @@ class A3C(object):
    self._policy = policy
    self.max_rollout_length = max_rollout_length
    self.discount_factor = discount_factor
    self.advantage_lambda = advantage_lambda
    self.value_weight = value_weight
    self.entropy_weight = entropy_weight
    if optimizer is None:
@@ -335,6 +343,9 @@ class _Worker(object):
    actions = []
    rewards = []
    values = []

    # Generate the rollout.

    for i in range(self.a3c.max_rollout_length):
      if self.env.terminated:
        break
@@ -353,19 +364,35 @@ class _Worker(object):
      actions[i][action] = 1.0
      values.append(float(value))
      rewards.append(self.env.step(action))

    # Compute an estimate of the reward for the rest of the episode.

    if not self.env.terminated:
      # Add an estimate of the reward for the rest of the episode.
      feed_dict = self.create_feed_dict(self.env.state)
      rewards[-1] += self.a3c.discount_factor * float(
      final_value = self.a3c.discount_factor * float(
          session.run(self.value.out_tensor, feed_dict))
    for j in range(len(rewards) - 1, 0, -1):
      rewards[j - 1] += self.a3c.discount_factor * rewards[j]
    else:
      final_value = 0.0

    # Compute the output arrays.

    rewards_array = np.array(rewards)
    advantages = rewards_array - np.array(values)
    values_array = np.array(values)
    discounted_rewards = rewards_array.copy()
    discounted_rewards[-1] += final_value
    advantages = rewards_array - values_array + self.a3c.discount_factor * np.array(
        values[1:] + [final_value])
    for j in range(len(rewards) - 1, 0, -1):
      discounted_rewards[j -
                         1] += self.a3c.discount_factor * discounted_rewards[j]
      advantages[
          j -
          1] += self.a3c.discount_factor * self.a3c.advantage_lambda * advantages[
              j]
    if self.env.terminated:
      self.env.reset()
      self.rnn_states = self.graph.rnn_zero_states
    return np.array(states), np.array(actions), rewards_array, advantages
    return np.array(states), np.array(actions), discounted_rewards, advantages

  def create_feed_dict(self, state):
    """Create a feed dict for use during a rollout."""