Commit 915ee622 authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

Added AdamW optimizer + tests, updated docs

parent 9a08f884
Loading
Loading
Loading
Loading
+61 −0
Original line number Diff line number Diff line
@@ -190,6 +190,67 @@ class Adam(Optimizer):
    return torch.optim.Adam(params, lr, (self.beta1, self.beta2), self.epsilon)


class AdamW(Optimizer):
  """The AdamW optimization algorithm.
  AdamW is a variant of Adam, with improved weight decay.
  In Adam, weight decay is implemented as: weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
  In AdamW, weight decay is implemented as: weight_decay (float, optional) – weight decay coefficient (default: 1e-2)
  """

  def __init__(self,
               learning_rate: Union[float, LearningRateSchedule] = 0.001,
               weight_decay: Union[float, LearningRateSchedule] = 0.01,
               beta1: float = 0.9,
               beta2: float = 0.999,
               epsilon: float = 1e-08,
               amsgrad: bool = False):
    """Construct an AdamW optimizer.

    Parameters
    ----------
    learning_rate: float or LearningRateSchedule
      the learning rate to use for optimization
    weight_decay: float or LearningRateSchedule
      weight decay coefficient for AdamW
    beta1: float
      a parameter of the Adam algorithm
    beta2: float
      a parameter of the Adam algorithm
    epsilon: float
      a parameter of the Adam algorithm
    """
    super(AdamW, self).__init__(learning_rate)
    self.weight_decay = weight_decay
    self.beta1 = beta1
    self.beta2 = beta2
    self.epsilon = epsilon
    self.amsgrad = amsgrad

  def _create_tf_optimizer(self, global_step):
    import tensorflow as tf
    import tensorflow_addons as tfa
    if isinstance(self.learning_rate, LearningRateSchedule):
      learning_rate = self.learning_rate._create_tf_tensor(global_step)
    else:
      learning_rate = self.learning_rate
    return tfa.optimizers.AdamW(
        weight_decay=self.weight_decay,
        learning_rate=learning_rate,
        beta_1=self.beta1,
        beta_2=self.beta2,
        epsilon=self.epsilon,
        amsgrad=self.amsgrad)

  def _create_pytorch_optimizer(self, params):
    import torch
    if isinstance(self.learning_rate, LearningRateSchedule):
      lr = self.learning_rate.initial_rate
    else:
      lr = self.learning_rate
    return torch.optim.AdamW(params, lr, (self.beta1, self.beta2), self.epsilon,
                             self.weight_decay, self.amsgrad)


class RMSProp(Optimizer):
  """RMSProp Optimization algorithm."""

+22 −0
Original line number Diff line number Diff line
@@ -7,6 +7,12 @@ try:
except:
  has_tensorflow = False

try:
  import tensorflow_addons as tfa
  has_tensorflow_addons = True
except:
  has_tensorflow_addons = False

try:
  import torch
  has_pytorch = True
@@ -33,6 +39,22 @@ class TestOptimizers(unittest.TestCase):
    torchopt = opt._create_pytorch_optimizer(params)
    assert isinstance(torchopt, torch.optim.Adam)

  @unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
  def test_adamw_tf(self):
    """Test creating an AdamW optimizer."""
    opt = optimizers.AdamW(learning_rate=0.01)
    global_step = tf.Variable(0)
    tfopt = opt._create_tf_optimizer(global_step)
    assert isinstance(tfopt, tfa.optimizers.AdamW)

  @unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
  def test_adamw_pytorch(self):
    """Test creating an AdamW optimizer."""
    opt = optimizers.AdamW(learning_rate=0.01)
    params = [torch.nn.Parameter(torch.Tensor([1.0]))]
    torchopt = opt._create_pytorch_optimizer(params)
    assert isinstance(torchopt, torch.optim.AdamW)

  @unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
  def test_adagrad_tf(self):
    """Test creating an AdaGrad optimizer."""
+3 −0
Original line number Diff line number Diff line
@@ -235,6 +235,9 @@ Optimizers
.. autoclass:: deepchem.models.optimizers.Adam
  :members:

.. autoclass:: deepchem.models.optimizers.AdamW
  :members:

.. autoclass:: deepchem.models.optimizers.RMSProp
  :members: