Commit 5bd1a058 authored by Vignesh's avatar Vignesh
Browse files

Added RMSProp Optimizer and tests

parent 7e63235c
Loading
Loading
Loading
Loading
+41 −3
Original line number Diff line number Diff line
@@ -80,6 +80,44 @@ class Adam(Optimizer):
        epsilon=self.epsilon)


class RMSProp(Optimizer):
  """RMSProp Optimization algorithm."""

  def __init__(self,
               learning_rate=0.001,
               momentum=0.0,
               decay=0.9,
               epsilon=1e-10):
    """Construct an RMSProp Optimizer.

        Parameters
        ----------
        learning_rate: float or LearningRateSchedule
            the learning_rate used for optimization
        momentum: float, default 0.0
            a parameter of the RMSProp algorithm
        decay: float, default 0.9
            a parameter of the RMSProp algorithm
        epsilon: float, default 1e-10
            a parameter of the RMSProp algorithm
        """
    self.learning_rate = learning_rate
    self.momentum = momentum
    self.decay = decay
    self.epsilon = epsilon

  def _create_optimizer(self, global_step):
    if isinstance(self.learning_rate, LearningRateSchedule):
      learning_rate = self.learning_rate._create_tensor(global_step)
    else:
      learning_rate = self.learning_rate
    return tf.train.RMSPropOptimizer(
        learning_rate=learning_rate,
        momentum=self.momentum,
        decay=self.decay,
        epsilon=self.epsilon)


class GradientDescent(Optimizer):
  """The gradient descent optimization algorithm."""

+8 −0
Original line number Diff line number Diff line
@@ -14,6 +14,14 @@ class TestLayers(test_util.TensorFlowTestCase):
      tfopt = opt._create_optimizer(global_step)
      assert isinstance(tfopt, tf.train.AdamOptimizer)

  def test_rmsprop(self):
    """Test creating an RMSProp Optimizer."""
    opt = optimizers.RMSProp(learning_rate=0.01)
    with self.session() as sess:
      global_step = tf.Variable(0)
      tfopt = opt._create_optimizer(global_step)
      assert isinstance(tfopt, tf.train.RMSPropOptimizer)

  def test_gradient_descent(self):
    """Test creating a Gradient Descent optimizer."""
    opt = optimizers.GradientDescent(learning_rate=0.01)