Unverified Commit ea0fe592 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2182 from hsjang001205/add_losses

Add losses for VAE/semi-supervised CVAE
parents bca869d8 ba4f401d
Loading
Loading
Loading
Loading
+171 −0
Original line number Diff line number Diff line
@@ -205,6 +205,177 @@ class SparseSoftmaxCrossEntropy(Loss):
    return loss


class VAE_ELBO(Loss):
  """The Variational AutoEncoder loss, KL Divergence Regularize + marginal log-likelihood.
  
  This losses based on _[1].
  ELBO(Evidence lower bound) lexically replaced Variational lower bound. 
  BCE means marginal log-likelihood, and KLD means KL divergence with normal distribution.
  Added hyper parameter 'kl_scale' for KLD.
  
  The logvar and mu should have shape (batch_size, hidden_space).
  The x and reconstruction_x should have (batch_size, attribute). 
  The kl_scale should be float.
  
  Examples
  --------
  Examples for calculating loss using constant tensor.
  
  batch_size = 2,
  hidden_space = 2,
  num of original attribute = 3
  >>> import numpy as np
  >>> import torch
  >>> import tensorflow as tf
  >>> logvar = np.array([[1.0,1.3],[0.6,1.2]])
  >>> mu = np.array([[0.2,0.7],[1.2,0.4]])
  >>> x = np.array([[0.9,0.4,0.8],[0.3,0,1]])
  >>> reconstruction_x = np.array([[0.8,0.3,0.7],[0.2,0,0.9]])
  
  Case tensorflow
  >>> VAE_ELBO()._compute_tf_loss(tf.constant(logvar), tf.constant(mu), tf.constant(x), tf.constant(reconstruction_x))
  <tf.Tensor: shape=(2,), dtype=float64, numpy=array([0.70165154, 0.76238271])>
  
  Case pytorch
  >>> (VAE_ELBO()._create_pytorch_loss())(torch.tensor(logvar), torch.tensor(mu), torch.tensor(x), torch.tensor(reconstruction_x))
  tensor([0.7017, 0.7624], dtype=torch.float64)
  
  
  References
  ----------
  .. [1] Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes." arXiv preprint arXiv:1312.6114 (2013).
  
  """

  def _compute_tf_loss(self, logvar, mu, x, reconstruction_x, kl_scale=1):
    import tensorflow as tf
    x, reconstruction_x = _make_tf_shapes_consistent(x, reconstruction_x)
    x, reconstruction_x = _ensure_float(x, reconstruction_x)
    BCE = tf.keras.losses.binary_crossentropy(x, reconstruction_x)
    KLD = VAE_KLDivergence()._compute_tf_loss(logvar, mu)
    return BCE + kl_scale * KLD

  def _create_pytorch_loss(self):
    import torch
    bce = torch.nn.BCELoss(reduction='none')

    def loss(logvar, mu, x, reconstruction_x, kl_scale=1):
      x, reconstruction_x = _make_pytorch_shapes_consistent(x, reconstruction_x)
      BCE = torch.mean(bce(reconstruction_x, x), dim=-1)
      KLD = (VAE_KLDivergence()._create_pytorch_loss())(logvar, mu)
      return BCE + kl_scale * KLD

    return loss


class VAE_KLDivergence(Loss):
  """The KL_divergence between hidden distribution and normal distribution.
  
  This loss represents KL divergence losses between normal distribution(using parameter of distribution)
  based on  _[1].
  
  The logvar should have shape (batch_size, hidden_space) and each term represents
  standard deviation of hidden distribution. The mean shuold have 
  (batch_size, hidden_space) and each term represents mean of hidden distribtuon.
  
  Examples
  --------
  Examples for calculating loss using constant tensor.
  
  batch_size = 2,
  hidden_space = 2,
  >>> import numpy as np
  >>> import torch
  >>> import tensorflow as tf
  >>> logvar = np.array([[1.0,1.3],[0.6,1.2]])
  >>> mu = np.array([[0.2,0.7],[1.2,0.4]])
  
  Case tensorflow
  >>> VAE_KLDivergence()._compute_tf_loss(tf.constant(logvar), tf.constant(mu))
  <tf.Tensor: shape=(2,), dtype=float64, numpy=array([0.17381787, 0.51425203])>
  
  Case pytorch
  >>> (VAE_KLDivergence()._create_pytorch_loss())(torch.tensor(logvar), torch.tensor(mu))
  tensor([0.1738, 0.5143], dtype=torch.float64)
  
  References
  ----------
  .. [1] Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes." arXiv preprint arXiv:1312.6114 (2013).
  
  """

  def _compute_tf_loss(self, logvar, mu):
    import tensorflow as tf
    logvar, mu = _make_tf_shapes_consistent(logvar, mu)
    logvar, mu = _ensure_float(logvar, mu)
    return 0.5 * tf.reduce_mean(
        tf.square(mu) + tf.square(logvar) -
        tf.math.log(1e-20 + tf.square(logvar)) - 1, -1)

  def _create_pytorch_loss(self):
    import torch

    def loss(logvar, mu):
      logvar, mu = _make_pytorch_shapes_consistent(logvar, mu)
      return 0.5 * torch.mean(
          torch.square(mu) + torch.square(logvar) -
          torch.log(1e-20 + torch.square(logvar)) - 1, -1)

    return loss


class ShannonEntropy(Loss):
  """The ShannonEntropy of discrete-distribution.
  
  This loss represents shannon entropy based on _[1].
  
  The inputs should have shape (batch size, num of variable) and represents
  probabilites distribution.
  
  Examples
  --------
  Examples for calculating loss using constant tensor.
  
  batch_size = 2,
  num_of variable = variable,
  >>> import numpy as np
  >>> import torch
  >>> import tensorflow as tf
  >>> inputs = np.array([[0.7,0.3],[0.9,0.1]])
  
  Case tensorflow
  >>> ShannonEntropy()._compute_tf_loss(tf.constant(inputs))
  <tf.Tensor: shape=(2,), dtype=float64, numpy=array([0.30543215, 0.16254149])>
  
  Case pytorch
  >>> (ShannonEntropy()._create_pytorch_loss())(torch.tensor(inputs))
  tensor([0.3054, 0.1625], dtype=torch.float64)
  
  References
  ----------
  .. [1] Chen, Ricky Xiaofeng. "A Brief Introduction to Shannon’s Information Theory." arXiv preprint arXiv:1612.09316 (2016).
  
  """

  def _compute_tf_loss(self, inputs):
    import tensorflow as tf
    #extended one of probabilites to binary distribution
    if inputs.shape[-1] == 1:
      inputs = tf.concat([inputs, 1 - inputs], axis=-1)
    return tf.reduce_mean(-inputs * tf.math.log(1e-20 + inputs), -1)

  def _create_pytorch_loss(self):
    import torch

    def loss(inputs):
      #extended one of probabilites to binary distribution
      if inputs.shape[-1] == 1:
        inputs = torch.cat((inputs, 1 - inputs), dim=-1)
      return torch.mean(-inputs * torch.log(1e-20 + inputs), -1)

    return loss


def _make_tf_shapes_consistent(output, labels):
  """Try to make inputs have the same shape by adding dimensions of size 1."""
  import tensorflow as tf
+111 −0
Original line number Diff line number Diff line
@@ -197,3 +197,114 @@ class TestLosses(unittest.TestCase):
    softmax = np.exp(y) / np.expand_dims(np.sum(np.exp(y), axis=1), 1)
    expected = [-np.log(softmax[0, 1]), -np.log(softmax[1, 0])]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
  def test_VAE_ELBO_tf(self):
    """."""
    loss = losses.VAE_ELBO()
    logvar = tf.constant([[1.0, 1.3], [0.6, 1.2]])
    mu = tf.constant([[0.2, 0.7], [1.2, 0.4]])
    x = tf.constant([[0.9, 0.4, 0.8], [0.3, 0, 1]])
    reconstruction_x = tf.constant([[0.8, 0.3, 0.7], [0.2, 0, 0.9]])
    result = loss._compute_tf_loss(logvar, mu, x, reconstruction_x).numpy()
    expected = [
        0.5 * np.mean([
            0.04 + 1.0 - np.log(1e-20 + 1.0) - 1,
            0.49 + 1.69 - np.log(1e-20 + 1.69) - 1
        ]) - np.mean(
            np.array([0.9, 0.4, 0.8]) * np.log([0.8, 0.3, 0.7]) +
            np.array([0.1, 0.6, 0.2]) * np.log([0.2, 0.7, 0.3])),
        0.5 * np.mean([
            1.44 + 0.36 - np.log(1e-20 + 0.36) - 1,
            0.16 + 1.44 - np.log(1e-20 + 1.44) - 1
        ]) - np.mean(
            np.array([0.3, 0, 1]) * np.log([0.2, 1e-20, 0.9]) +
            np.array([0.7, 1, 0]) * np.log([0.8, 1, 0.1]))
    ]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
  def test_VAE_ELBO_pytorch(self):
    """."""
    loss = losses.VAE_ELBO()
    logvar = torch.tensor([[1.0, 1.3], [0.6, 1.2]])
    mu = torch.tensor([[0.2, 0.7], [1.2, 0.4]])
    x = torch.tensor([[0.9, 0.4, 0.8], [0.3, 0, 1]])
    reconstruction_x = torch.tensor([[0.8, 0.3, 0.7], [0.2, 0, 0.9]])
    result = loss._create_pytorch_loss()(logvar, mu, x,
                                         reconstruction_x).numpy()
    expected = [
        0.5 * np.mean([
            0.04 + 1.0 - np.log(1e-20 + 1.0) - 1,
            0.49 + 1.69 - np.log(1e-20 + 1.69) - 1
        ]) - np.mean(
            np.array([0.9, 0.4, 0.8]) * np.log([0.8, 0.3, 0.7]) +
            np.array([0.1, 0.6, 0.2]) * np.log([0.2, 0.7, 0.3])),
        0.5 * np.mean([
            1.44 + 0.36 - np.log(1e-20 + 0.36) - 1,
            0.16 + 1.44 - np.log(1e-20 + 1.44) - 1
        ]) - np.mean(
            np.array([0.3, 0, 1]) * np.log([0.2, 1e-20, 0.9]) +
            np.array([0.7, 1, 0]) * np.log([0.8, 1, 0.1]))
    ]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
  def test_VAE_KLDivergence_tf(self):
    """."""
    loss = losses.VAE_KLDivergence()
    logvar = tf.constant([[1.0, 1.3], [0.6, 1.2]])
    mu = tf.constant([[0.2, 0.7], [1.2, 0.4]])
    result = loss._compute_tf_loss(logvar, mu).numpy()
    expected = [
        0.5 * np.mean([
            0.04 + 1.0 - np.log(1e-20 + 1.0) - 1,
            0.49 + 1.69 - np.log(1e-20 + 1.69) - 1
        ]), 0.5 * np.mean([
            1.44 + 0.36 - np.log(1e-20 + 0.36) - 1,
            0.16 + 1.44 - np.log(1e-20 + 1.44) - 1
        ])
    ]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
  def test_VAE_KLDivergence_pytorch(self):
    """."""
    loss = losses.VAE_KLDivergence()
    logvar = torch.tensor([[1.0, 1.3], [0.6, 1.2]])
    mu = torch.tensor([[0.2, 0.7], [1.2, 0.4]])
    result = loss._create_pytorch_loss()(logvar, mu).numpy()
    expected = [
        0.5 * np.mean([
            0.04 + 1.0 - np.log(1e-20 + 1.0) - 1,
            0.49 + 1.69 - np.log(1e-20 + 1.69) - 1
        ]), 0.5 * np.mean([
            1.44 + 0.36 - np.log(1e-20 + 0.36) - 1,
            0.16 + 1.44 - np.log(1e-20 + 1.44) - 1
        ])
    ]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
  def test_ShannonEntropy_tf(self):
    """."""
    loss = losses.ShannonEntropy()
    inputs = tf.constant([[0.7, 0.3], [0.9, 0.1]])
    result = loss._compute_tf_loss(inputs).numpy()
    expected = [
        -np.mean([0.7 * np.log(0.7), 0.3 * np.log(0.3)]),
        -np.mean([0.9 * np.log(0.9), 0.1 * np.log(0.1)])
    ]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
  def test_ShannonEntropy_pytorch(self):
    """."""
    loss = losses.ShannonEntropy()
    inputs = torch.tensor([[0.7, 0.3], [0.9, 0.1]])
    result = loss._create_pytorch_loss()(inputs).numpy()
    expected = [
        -np.mean([0.7 * np.log(0.7), 0.3 * np.log(0.3)]),
        -np.mean([0.9 * np.log(0.9), 0.1 * np.log(0.1)])
    ]
    assert np.allclose(expected, result)
+9 −0
Original line number Diff line number Diff line
@@ -199,6 +199,15 @@ Losses
.. autoclass:: deepchem.models.losses.SparseSoftmaxCrossEntropy
  :members:

.. autoclass:: deepchem.models.losses.VAE_ELBO
  :members:

.. autoclass:: deepchem.models.losses.VAE_KLDivergence
  :members:

.. autoclass:: deepchem.models.losses.ShannonEntropy
  :members:

Optimizers
----------