Commit c4203fe4 authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

Added poisson loss, tests and documentation

parent dd8665e9
Loading
Loading
Loading
Loading
+20 −0
Original line number Diff line number Diff line
@@ -93,6 +93,26 @@ class HingeLoss(Loss):

    return loss

class PoissonLoss(Loss):
  """The Poisson loss function is defined as the mean of the elements of y_pred - (y_true * log(y_pred) for an input of (y_true, y_pred).
  Poisson loss is generally used for regression tasks where the data follows the poisson

  """

  def _compute_tf_loss(self, output, labels):
    import tensorflow as tf
    output, labels = _make_tf_shapes_consistent(output, labels)
    loss = tf.keras.losses.Poisson(reduction='auto')
    return loss(labels, output)

  def _create_pytorch_loss(self):
    import torch

    def loss(output, labels):
      output, labels = _make_pytorch_shapes_consistent(output, labels)
      return torch.mean(output - labels * torch.log(output))

    return loss

class BinaryCrossEntropy(Loss):
  """The cross entropy between pairs of probabilities.
+20 −0
Original line number Diff line number Diff line
@@ -98,6 +98,26 @@ class TestLosses(unittest.TestCase):
    expected = [np.mean([0.9, 1.8]), np.mean([1.4, 0.4])]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
  def test_poisson_loss_tf(self):
    """Test PoissonLoss."""
    loss = losses.PoissonLoss()
    outputs = tf.constant([[0.1, 0.8], [0.4, 0.6]])
    labels = tf.constant([[0.0, 1.0], [1.0, 0.0]])
    result = loss._compute_tf_loss(outputs, labels).numpy()
    expected = 0.75986
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
  def test_poisson_loss_pytorch(self):
    """Test PoissonLoss."""
    loss = losses.PoissonLoss()
    outputs = torch.tensor([[0.1, 0.8], [0.4, 0.6]])
    labels = torch.tensor([[0.0, 1.0], [1.0, 0.0]])
    result = loss._create_pytorch_loss()(outputs, labels).numpy()
    expected = 0.75986
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
  def test_binary_cross_entropy_tf(self):
    """Test BinaryCrossEntropy."""
+3 −0
Original line number Diff line number Diff line
@@ -193,6 +193,9 @@ Losses
.. autoclass:: deepchem.models.losses.HingeLoss
  :members:

.. autoclass:: deepchem.models.losses.PoissonLoss
  :members:

.. autoclass:: deepchem.models.losses.BinaryCrossEntropy
  :members: