Commit c81d221b authored by Kevin Shen's avatar Kevin Shen
Browse files

Merge branch 'deepchem:master' into wandb

parents 0efd1f69 875f7f46
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -40,6 +40,7 @@ DeepChem currently supports Python 3.6 through 3.7 and requires these packages o
- [TensorFlow](https://www.tensorflow.org/)
  - `deepchem>=2.4.0` depends on TensorFlow v2
  - `deepchem<2.4.0` depends on TensorFlow v1
- [Tensorflow Addons](https://www.tensorflow.org/addons) for Tensorflow v2 if you want to use advanced optimizers such as AdamW and Sparse Adam. (Optional)

### Soft Requirements

+46 −0
Original line number Diff line number Diff line
@@ -94,6 +94,52 @@ class HingeLoss(Loss):
    return loss


class SquaredHingeLoss(Loss):
  """The Squared Hinge loss function.
  
  Defined as the square of the hinge loss between y_true and y_pred. The Squared Hinge Loss is differentiable.
  """

  def _compute_tf_loss(self, output, labels):
    import tensorflow as tf
    output, labels = _make_tf_shapes_consistent(output, labels)
    return tf.keras.losses.SquaredHinge(reduction='none')(labels, output)

  def _create_pytorch_loss(self):
    import torch

    def loss(output, labels):
      output, labels = _make_pytorch_shapes_consistent(output, labels)
      return torch.mean(
          torch.pow(
              torch.maximum(1 - torch.multiply(labels, output),
                            torch.tensor(0)), 2),
          dim=-1)

    return loss


class PoissonLoss(Loss):
  """The Poisson loss function is defined as the mean of the elements of y_pred - (y_true * log(y_pred) for an input of (y_true, y_pred).
  Poisson loss is generally used for regression tasks where the data follows the poisson
  """

  def _compute_tf_loss(self, output, labels):
    import tensorflow as tf
    output, labels = _make_tf_shapes_consistent(output, labels)
    loss = tf.keras.losses.Poisson(reduction='auto')
    return loss(labels, output)

  def _create_pytorch_loss(self):
    import torch

    def loss(output, labels):
      output, labels = _make_pytorch_shapes_consistent(output, labels)
      return torch.mean(output - labels * torch.log(output))

    return loss


class BinaryCrossEntropy(Loss):
  """The cross entropy between pairs of probabilities.

+114 −0
Original line number Diff line number Diff line
@@ -190,6 +190,120 @@ class Adam(Optimizer):
    return torch.optim.Adam(params, lr, (self.beta1, self.beta2), self.epsilon)


class SparseAdam(Optimizer):
  """The Sparse Adam optimization algorithm, also known as Lazy Adam.
  Sparse Adam is suitable for sparse tensors. It handles sparse updates more efficiently. 
  It only updates moving-average accumulators for sparse variable indices that appear in the current batch, rather than updating the accumulators for all indices.
  """

  def __init__(self,
               learning_rate: Union[float, LearningRateSchedule] = 0.001,
               beta1: float = 0.9,
               beta2: float = 0.999,
               epsilon: float = 1e-08):
    """Construct an Adam optimizer.
    
    Parameters
    ----------
    learning_rate: float or LearningRateSchedule
      the learning rate to use for optimization
    beta1: float
      a parameter of the SparseAdam algorithm
    beta2: float
      a parameter of the SparseAdam algorithm
    epsilon: float
      a parameter of the SparseAdam algorithm
    """
    super(SparseAdam, self).__init__(learning_rate)
    self.beta1 = beta1
    self.beta2 = beta2
    self.epsilon = epsilon

  def _create_tf_optimizer(self, global_step):
    import tensorflow as tf
    import tensorflow_addons as tfa
    if isinstance(self.learning_rate, LearningRateSchedule):
      learning_rate = self.learning_rate._create_tf_tensor(global_step)
    else:
      learning_rate = self.learning_rate
    return tfa.optimizers.LazyAdam(
        learning_rate=learning_rate,
        beta_1=self.beta1,
        beta_2=self.beta2,
        epsilon=self.epsilon)

  def _create_pytorch_optimizer(self, params):
    import torch
    if isinstance(self.learning_rate, LearningRateSchedule):
      lr = self.learning_rate.initial_rate
    else:
      lr = self.learning_rate
    return torch.optim.SparseAdam(params, lr, (self.beta1, self.beta2),
                                  self.epsilon)


class AdamW(Optimizer):
  """The AdamW optimization algorithm.
  AdamW is a variant of Adam, with improved weight decay.
  In Adam, weight decay is implemented as: weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
  In AdamW, weight decay is implemented as: weight_decay (float, optional) – weight decay coefficient (default: 1e-2)
  """

  def __init__(self,
               learning_rate: Union[float, LearningRateSchedule] = 0.001,
               weight_decay: Union[float, LearningRateSchedule] = 0.01,
               beta1: float = 0.9,
               beta2: float = 0.999,
               epsilon: float = 1e-08,
               amsgrad: bool = False):
    """Construct an AdamW optimizer.
    Parameters
    ----------
    learning_rate: float or LearningRateSchedule
      the learning rate to use for optimization
    weight_decay: float or LearningRateSchedule
      weight decay coefficient for AdamW
    beta1: float
      a parameter of the Adam algorithm
    beta2: float
      a parameter of the Adam algorithm
    epsilon: float
      a parameter of the Adam algorithm
    amsgrad: bool
      If True, will use the AMSGrad variant of AdamW (from "On the Convergence of Adam and Beyond"), else will use the original algorithm.
    """
    super(AdamW, self).__init__(learning_rate)
    self.weight_decay = weight_decay
    self.beta1 = beta1
    self.beta2 = beta2
    self.epsilon = epsilon
    self.amsgrad = amsgrad

  def _create_tf_optimizer(self, global_step):
    import tensorflow as tf
    import tensorflow_addons as tfa
    if isinstance(self.learning_rate, LearningRateSchedule):
      learning_rate = self.learning_rate._create_tf_tensor(global_step)
    else:
      learning_rate = self.learning_rate
    return tfa.optimizers.AdamW(
        weight_decay=self.weight_decay,
        learning_rate=learning_rate,
        beta_1=self.beta1,
        beta_2=self.beta2,
        epsilon=self.epsilon,
        amsgrad=self.amsgrad)

  def _create_pytorch_optimizer(self, params):
    import torch
    if isinstance(self.learning_rate, LearningRateSchedule):
      lr = self.learning_rate.initial_rate
    else:
      lr = self.learning_rate
    return torch.optim.AdamW(params, lr, (self.beta1, self.beta2), self.epsilon,
                             self.weight_decay, self.amsgrad)


class RMSProp(Optimizer):
  """RMSProp Optimization algorithm."""

+39 −0
Original line number Diff line number Diff line
@@ -98,6 +98,45 @@ class TestLosses(unittest.TestCase):
    expected = [np.mean([0.9, 1.8]), np.mean([1.4, 0.4])]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
  def test_squared_hinge_loss_tf(self):
    """Test SquaredHingeLoss."""
    loss = losses.SquaredHingeLoss()
    outputs = tf.constant([[0.1, 0.8], [0.4, 0.6]])
    labels = tf.constant([[1.0, -1.0], [-1.0, 1.0]])
    result = loss._compute_tf_loss(outputs, labels).numpy()
    expected = [np.mean([0.8100, 3.2400]), np.mean([1.9600, 0.1600])]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
  def test_squared_hinge_loss_pytorch(self):
    """Test SquaredHingeLoss."""
    loss = losses.SquaredHingeLoss()
    outputs = torch.tensor([[0.1, 0.8], [0.4, 0.6]])
    labels = torch.tensor([[1.0, -1.0], [-1.0, 1.0]])
    result = loss._create_pytorch_loss()(outputs, labels).numpy()
    expected = [np.mean([0.8100, 3.2400]), np.mean([1.9600, 0.1600])]
    assert np.allclose(expected, result)

  def test_poisson_loss_tf(self):
    """Test PoissonLoss."""
    loss = losses.PoissonLoss()
    outputs = tf.constant([[0.1, 0.8], [0.4, 0.6]])
    labels = tf.constant([[0.0, 1.0], [1.0, 0.0]])
    result = loss._compute_tf_loss(outputs, labels).numpy()
    expected = 0.75986
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
  def test_poisson_loss_pytorch(self):
    """Test PoissonLoss."""
    loss = losses.PoissonLoss()
    outputs = torch.tensor([[0.1, 0.8], [0.4, 0.6]])
    labels = torch.tensor([[0.0, 1.0], [1.0, 0.0]])
    result = loss._create_pytorch_loss()(outputs, labels).numpy()
    expected = 0.75986
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
  def test_binary_cross_entropy_tf(self):
    """Test BinaryCrossEntropy."""
+40 −0
Original line number Diff line number Diff line
@@ -7,6 +7,12 @@ try:
except:
  has_tensorflow = False

try:
  import tensorflow_addons as tfa
  has_tensorflow_addons = True
except:
  has_tensorflow_addons = False

try:
  import torch
  has_pytorch = True
@@ -33,6 +39,40 @@ class TestOptimizers(unittest.TestCase):
    torchopt = opt._create_pytorch_optimizer(params)
    assert isinstance(torchopt, torch.optim.Adam)

  @unittest.skipIf(not has_tensorflow_addons,
                   'TensorFlow Addons is not installed')
  def test_adamw_tf(self):
    """Test creating an AdamW optimizer."""
    opt = optimizers.AdamW(learning_rate=0.01)
    global_step = tf.Variable(0)
    tfopt = opt._create_tf_optimizer(global_step)
    assert isinstance(tfopt, tfa.optimizers.AdamW)

  @unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
  def test_adamw_pytorch(self):
    """Test creating an AdamW optimizer."""
    opt = optimizers.AdamW(learning_rate=0.01)
    params = [torch.nn.Parameter(torch.Tensor([1.0]))]
    torchopt = opt._create_pytorch_optimizer(params)
    assert isinstance(torchopt, torch.optim.AdamW)

  @unittest.skipIf(not has_tensorflow_addons,
                   'TensorFlow Addons is not installed')
  def test_sparseadam_tf(self):
    """Test creating a SparseAdam optimizer."""
    opt = optimizers.SparseAdam(learning_rate=0.01)
    global_step = tf.Variable(0)
    tfopt = opt._create_tf_optimizer(global_step)
    assert isinstance(tfopt, tfa.optimizers.LazyAdam)

  @unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
  def test_sparseadam_pytorch(self):
    """Test creating a SparseAdam optimizer."""
    opt = optimizers.SparseAdam(learning_rate=0.01)
    params = [torch.nn.Parameter(torch.Tensor([1.0]))]
    torchopt = opt._create_pytorch_optimizer(params)
    assert isinstance(torchopt, torch.optim.SparseAdam)

  @unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
  def test_adagrad_tf(self):
    """Test creating an AdaGrad optimizer."""
Loading