Unverified Commit dc88b8bd authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2488 from atreyamaj/atreya_adamw

Adding the AdamW optimizer
parents eac8e476 7ac347b2
Loading
Loading
Loading
Loading
+21 −0
Original line number Diff line number Diff line
@@ -94,6 +94,27 @@ class HingeLoss(Loss):
    return loss


class PoissonLoss(Loss):
  """The Poisson loss function is defined as the mean of the elements of y_pred - (y_true * log(y_pred) for an input of (y_true, y_pred).
  Poisson loss is generally used for regression tasks where the data follows the poisson
  """

  def _compute_tf_loss(self, output, labels):
    import tensorflow as tf
    output, labels = _make_tf_shapes_consistent(output, labels)
    loss = tf.keras.losses.Poisson(reduction='auto')
    return loss(labels, output)

  def _create_pytorch_loss(self):
    import torch

    def loss(output, labels):
      output, labels = _make_pytorch_shapes_consistent(output, labels)
      return torch.mean(output - labels * torch.log(output))

    return loss


class BinaryCrossEntropy(Loss):
  """The cross entropy between pairs of probabilities.

+63 −1
Original line number Diff line number Diff line
@@ -242,6 +242,68 @@ class SparseAdam(Optimizer):
                                  self.epsilon)


class AdamW(Optimizer):
  """The AdamW optimization algorithm.
  AdamW is a variant of Adam, with improved weight decay.
  In Adam, weight decay is implemented as: weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
  In AdamW, weight decay is implemented as: weight_decay (float, optional) – weight decay coefficient (default: 1e-2)
  """

  def __init__(self,
               learning_rate: Union[float, LearningRateSchedule] = 0.001,
               weight_decay: Union[float, LearningRateSchedule] = 0.01,
               beta1: float = 0.9,
               beta2: float = 0.999,
               epsilon: float = 1e-08,
               amsgrad: bool = False):
    """Construct an AdamW optimizer.
    Parameters
    ----------
    learning_rate: float or LearningRateSchedule
      the learning rate to use for optimization
    weight_decay: float or LearningRateSchedule
      weight decay coefficient for AdamW
    beta1: float
      a parameter of the Adam algorithm
    beta2: float
      a parameter of the Adam algorithm
    epsilon: float
      a parameter of the Adam algorithm
    amsgrad: bool
      If True, will use the AMSGrad variant of AdamW (from "On the Convergence of Adam and Beyond"), else will use the original algorithm.
    """
    super(AdamW, self).__init__(learning_rate)
    self.weight_decay = weight_decay
    self.beta1 = beta1
    self.beta2 = beta2
    self.epsilon = epsilon
    self.amsgrad = amsgrad

  def _create_tf_optimizer(self, global_step):
    import tensorflow as tf
    import tensorflow_addons as tfa
    if isinstance(self.learning_rate, LearningRateSchedule):
      learning_rate = self.learning_rate._create_tf_tensor(global_step)
    else:
      learning_rate = self.learning_rate
    return tfa.optimizers.AdamW(
        weight_decay=self.weight_decay,
        learning_rate=learning_rate,
        beta_1=self.beta1,
        beta_2=self.beta2,
        epsilon=self.epsilon,
        amsgrad=self.amsgrad)

  def _create_pytorch_optimizer(self, params):
    import torch
    if isinstance(self.learning_rate, LearningRateSchedule):
      lr = self.learning_rate.initial_rate
    else:
      lr = self.learning_rate
    return torch.optim.AdamW(params, lr, (self.beta1, self.beta2), self.epsilon,
                             self.weight_decay, self.amsgrad)


class RMSProp(Optimizer):
  """RMSProp Optimization algorithm."""

+20 −0
Original line number Diff line number Diff line
@@ -98,6 +98,26 @@ class TestLosses(unittest.TestCase):
    expected = [np.mean([0.9, 1.8]), np.mean([1.4, 0.4])]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
  def test_poisson_loss_tf(self):
    """Test PoissonLoss."""
    loss = losses.PoissonLoss()
    outputs = tf.constant([[0.1, 0.8], [0.4, 0.6]])
    labels = tf.constant([[0.0, 1.0], [1.0, 0.0]])
    result = loss._compute_tf_loss(outputs, labels).numpy()
    expected = 0.75986
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
  def test_poisson_loss_pytorch(self):
    """Test PoissonLoss."""
    loss = losses.PoissonLoss()
    outputs = torch.tensor([[0.1, 0.8], [0.4, 0.6]])
    labels = torch.tensor([[0.0, 1.0], [1.0, 0.0]])
    result = loss._create_pytorch_loss()(outputs, labels).numpy()
    expected = 0.75986
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
  def test_binary_cross_entropy_tf(self):
    """Test BinaryCrossEntropy."""
+17 −0
Original line number Diff line number Diff line
@@ -39,6 +39,23 @@ class TestOptimizers(unittest.TestCase):
    torchopt = opt._create_pytorch_optimizer(params)
    assert isinstance(torchopt, torch.optim.Adam)

  @unittest.skipIf(not has_tensorflow_addons,
                   'TensorFlow Addons is not installed')
  def test_adamw_tf(self):
    """Test creating an AdamW optimizer."""
    opt = optimizers.AdamW(learning_rate=0.01)
    global_step = tf.Variable(0)
    tfopt = opt._create_tf_optimizer(global_step)
    assert isinstance(tfopt, tfa.optimizers.AdamW)

  @unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
  def test_adamw_pytorch(self):
    """Test creating an AdamW optimizer."""
    opt = optimizers.AdamW(learning_rate=0.01)
    params = [torch.nn.Parameter(torch.Tensor([1.0]))]
    torchopt = opt._create_pytorch_optimizer(params)
    assert isinstance(torchopt, torch.optim.AdamW)

  @unittest.skipIf(not has_tensorflow_addons,
                   'TensorFlow Addons is not installed')
  def test_sparseadam_tf(self):
+6 −0
Original line number Diff line number Diff line
@@ -193,6 +193,9 @@ Losses
.. autoclass:: deepchem.models.losses.HingeLoss
  :members:

.. autoclass:: deepchem.models.losses.PoissonLoss
  :members:

.. autoclass:: deepchem.models.losses.BinaryCrossEntropy
  :members:

@@ -232,6 +235,9 @@ Optimizers
.. autoclass:: deepchem.models.optimizers.Adam
  :members:

.. autoclass:: deepchem.models.optimizers.AdamW
  :members:
  
.. autoclass:: deepchem.models.optimizers.SparseAdam
  :members:

Loading