Unverified Commit eac8e476 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2493 from atreyamaj/sparse_adam

Added Sparse/Lazy Adam optimizer
parents bb4364ab a3df3181
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -40,6 +40,7 @@ DeepChem currently supports Python 3.6 through 3.7 and requires these packages o
- [TensorFlow](https://www.tensorflow.org/)
  - `deepchem>=2.4.0` depends on TensorFlow v2
  - `deepchem<2.4.0` depends on TensorFlow v1
- [Tensorflow Addons](https://www.tensorflow.org/addons) for Tensorflow v2 if you want to use advanced optimizers such as AdamW and Sparse Adam. (Optional)

### Soft Requirements

+52 −0
Original line number Diff line number Diff line
@@ -190,6 +190,58 @@ class Adam(Optimizer):
    return torch.optim.Adam(params, lr, (self.beta1, self.beta2), self.epsilon)


class SparseAdam(Optimizer):
  """The Sparse Adam optimization algorithm, also known as Lazy Adam.
  Sparse Adam is suitable for sparse tensors. It handles sparse updates more efficiently. 
  It only updates moving-average accumulators for sparse variable indices that appear in the current batch, rather than updating the accumulators for all indices.
  """

  def __init__(self,
               learning_rate: Union[float, LearningRateSchedule] = 0.001,
               beta1: float = 0.9,
               beta2: float = 0.999,
               epsilon: float = 1e-08):
    """Construct an Adam optimizer.

    Parameters
    ----------
    learning_rate: float or LearningRateSchedule
      the learning rate to use for optimization
    beta1: float
      a parameter of the SparseAdam algorithm
    beta2: float
      a parameter of the SparseAdam algorithm
    epsilon: float
      a parameter of the SparseAdam algorithm
    """
    super(SparseAdam, self).__init__(learning_rate)
    self.beta1 = beta1
    self.beta2 = beta2
    self.epsilon = epsilon

  def _create_tf_optimizer(self, global_step):
    import tensorflow as tf
    import tensorflow_addons as tfa
    if isinstance(self.learning_rate, LearningRateSchedule):
      learning_rate = self.learning_rate._create_tf_tensor(global_step)
    else:
      learning_rate = self.learning_rate
    return tfa.optimizers.LazyAdam(
        learning_rate=learning_rate,
        beta_1=self.beta1,
        beta_2=self.beta2,
        epsilon=self.epsilon)

  def _create_pytorch_optimizer(self, params):
    import torch
    if isinstance(self.learning_rate, LearningRateSchedule):
      lr = self.learning_rate.initial_rate
    else:
      lr = self.learning_rate
    return torch.optim.SparseAdam(params, lr, (self.beta1, self.beta2),
                                  self.epsilon)


class RMSProp(Optimizer):
  """RMSProp Optimization algorithm."""

+23 −0
Original line number Diff line number Diff line
@@ -7,6 +7,12 @@ try:
except:
  has_tensorflow = False

try:
  import tensorflow_addons as tfa
  has_tensorflow_addons = True
except:
  has_tensorflow_addons = False

try:
  import torch
  has_pytorch = True
@@ -33,6 +39,23 @@ class TestOptimizers(unittest.TestCase):
    torchopt = opt._create_pytorch_optimizer(params)
    assert isinstance(torchopt, torch.optim.Adam)

  @unittest.skipIf(not has_tensorflow_addons,
                   'TensorFlow Addons is not installed')
  def test_sparseadam_tf(self):
    """Test creating a SparseAdam optimizer."""
    opt = optimizers.SparseAdam(learning_rate=0.01)
    global_step = tf.Variable(0)
    tfopt = opt._create_tf_optimizer(global_step)
    assert isinstance(tfopt, tfa.optimizers.LazyAdam)

  @unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
  def test_sparseadam_pytorch(self):
    """Test creating a SparseAdam optimizer."""
    opt = optimizers.SparseAdam(learning_rate=0.01)
    params = [torch.nn.Parameter(torch.Tensor([1.0]))]
    torchopt = opt._create_pytorch_optimizer(params)
    assert isinstance(torchopt, torch.optim.SparseAdam)

  @unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
  def test_adagrad_tf(self):
    """Test creating an AdaGrad optimizer."""
+3 −0
Original line number Diff line number Diff line
@@ -232,6 +232,9 @@ Optimizers
.. autoclass:: deepchem.models.optimizers.Adam
  :members:

.. autoclass:: deepchem.models.optimizers.SparseAdam
  :members:

.. autoclass:: deepchem.models.optimizers.RMSProp
  :members:

+5 −0
Original line number Diff line number Diff line
@@ -122,6 +122,10 @@ DeepChem has a number of "soft" requirements.
|                                |               |                                                   |
|                                |               |                                                   |
+--------------------------------+---------------+---------------------------------------------------+
| `Tensorflow Addons`_           | latest        | :code:`dc.models.optimizers`                      |
|                                |               |                                                   |
|                                |               |                                                   |
+--------------------------------+---------------+---------------------------------------------------+
          
.. _`joblib`: https://pypi.python.org/pypi/joblib
.. _`NumPy`: https://numpy.org/
@@ -153,3 +157,4 @@ DeepChem has a number of "soft" requirements.
.. _`Tensorflow Probability`: https://www.tensorflow.org/probability
.. _`Weights & Biases`: https://docs.wandb.com/
.. _`XGBoost`: https://xgboost.readthedocs.io/en/latest/
.. _`Tensorflow Addons`: https://www.tensorflow.org/addons/overview
 No newline at end of file
Loading