Commit 7eca4ad6 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

first

parent d975b876
Loading
Loading
Loading
Loading
+14 −2
Original line number Diff line number Diff line
"""Interface for reinforcement learning."""

from deepchem.utils import TensorFlowStub

try:
  import tensorflow as tf
  import tensorflow_probability as tfp
  from deepchem.rl.a2c import A2C
  from deepchem.rl.ppo import PPO
except ModuleNotFoundError:

  class A2C(TensorFlowStub):
    pass

  class PPO(TensorFlowStub):
    pass


class Environment(object):
+24 −7
Original line number Diff line number Diff line
"""Advantage Actor-Critic (A2C) algorithm for reinforcement learning."""

from deepchem.models import KerasModel
from deepchem.models.optimizers import Adam
import numpy as np
import tensorflow as tf
import collections
import copy
import multiprocessing
@@ -11,19 +8,32 @@ import os
import re
import threading
import time
from deepchem.models import KerasModel
from deepchem.models.optimizers import Adam


class A2CLossDiscrete(object):
  """This class computes the loss function for A2C with discrete action spaces."""
  """This class computes the loss function for A2C with discrete action spaces.

  Note
  ----
  This class requires tensorflow to be installed.
  """

  def __init__(self, value_weight, entropy_weight, action_prob_index,
               value_index):

    try:
      import tensorflow as tf
    except ModuleNotFoundError:
      raise ValueError("This class requires tensorflow to be installed.")
    self.value_weight = value_weight
    self.entropy_weight = entropy_weight
    self.action_prob_index = action_prob_index
    self.value_index = value_index

  def __call__(self, outputs, labels, weights):
    import tensorflow as tf
    prob = outputs[self.action_prob_index]
    value = outputs[self.value_index]
    reward, advantage = weights
@@ -43,16 +53,18 @@ class A2CLossContinuous(object):

  Note
  ----
  This class requires tensorflow-probability to be installed.
  This class requires tensorflow, tensorflow-probability to be installed.
  """

  def __init__(self, value_weight, entropy_weight, mean_index, std_index,
               value_index):
    try:
      import tensorflow as tf
      import tensorflow_probability as tfp
    except ModuleNotFoundError:
      raise ValueError(
          "This class requires tensorflow-probability to be installed.")
          "This class requires tensorflow, tensorflow-probability to be installed."
      )
    self.value_weight = value_weight
    self.entropy_weight = entropy_weight
    self.mean_index = mean_index
@@ -126,7 +138,8 @@ class A2C(object):

  Note
  ----
  Using this class on continuous action spaces requires that `tensorflow_probability` be installed.
  This class requires tensorflow to be installed.  Using this class on
  continuous action spaces requires that `tensorflow_probability` be installed.
  """

  def __init__(self,
@@ -166,6 +179,10 @@ class A2C(object):
      the directory in which the model will be saved.  If None, a temporary directory will be created.
    use_hindsight: bool
      if True, use Hindsight Experience Replay

    Note
    ----
    This class requires tensorflow to be installed.
    """
    self._env = env
    self._policy = policy
+16 −2
Original line number Diff line number Diff line
@@ -14,7 +14,12 @@ import time


class PPOLoss(object):
  """This class computes the loss function for PPO."""
  """This class computes the loss function for PPO.

  Note
  ----
  This class requires tensorflow to be installed.
  """

  def __init__(self, value_weight, entropy_weight, clipping_width,
               action_prob_index, value_index):
@@ -138,6 +143,10 @@ class PPO(object):
      the directory in which the model will be saved.  If None, a temporary directory will be created.
    use_hindsight: bool
      if True, use Hindsight Experience Replay

    Note
    ----
    This class requires tensorflow to be installed.
    """
    self._env = env
    self._policy = policy
@@ -405,7 +414,12 @@ class PPO(object):


class _Worker(object):
  """A Worker object is created for each training thread."""
  """A Worker object is created for each training thread.

  Note
  ----
  This class requires tensorflow to be installed.
  """

  def __init__(self, ppo, index):
    self.ppo = ppo
+9 −0
Original line number Diff line number Diff line
@@ -21,6 +21,15 @@ except:
  from urllib import urlretrieve  # Python 2


class TensorFlowStub(object):
  """This class provides a stub to handle TensorFlow import failures."""

  def __init__(self, *args, **kwargs):
    raise ModuleNotFoundError(
        "The class '%s' cannot be used because TensorFlow is not installed" %
        type(self).__name__)


def pad_array(x, shape, fill=0, both=False):
  """
  Pad an array with a fill value.