Commit f930464f authored by BuildTools's avatar BuildTools
Browse files

Add losses for VAE

parent eda2bcd4
Loading
Loading
Loading
Loading
+107 −0
Original line number Diff line number Diff line
@@ -205,6 +205,113 @@ class SparseSoftmaxCrossEntropy(Loss):
    return loss


class VAE_ELBO(Loss):
  """The Variational AutoEncoder loss, KL Divergence Regularize + marginal log-likelihood.
  
  The logvar and mu should have shape (batch_size, hidden_space).
  The x and reconstruction_x should have (batch_size, attribute).
  kl_scale: KLD regularized weights.
  """

  def _compute_tf_loss(self, logvar, mu, x, reconstruction_x, kl_scale = 1):
    import tensorflow as tf
    x, reconstruction_x = _make_tf_shapes_consistent(x, reconstruction_x)
    x, reconstruction_x = _ensure_float(x, reconstruction_x)
    BCE = tf.keras.losses.binary_crossentropy(x, reconstruction_x)
    KLD = VAE_KLDivergence()._compute_tf_loss(logvar, mu)
    return BCE + kl_scale*KLD

  def _create_pytorch_loss(self):
    import torch
    bce = torch.nn.BCELoss(reduction='none')

    def loss(logvar, mu, x, reconstruction_x, kl_scale = 1):
      x, reconstruction_x = _make_pytorch_shapes_consistent(x, reconstruction_x)
      BCE = torch.mean(bce(x, reconstruction_x), dim=-1)
      KLD = (VAE_KLDivergence()._create_pytorch_loss())(logvar, mu)
      return BCE + kl_scale*KLD

    return loss


class VAE_KLDivergence(Loss):
  """The KL_divergence between hidden distribution and normal distribution
  The logvar should have shape (batch_size, hidden_space) and each term represents
  standard deviation of hidden distribution. The mean shuold have 
  (batch_size, hidden_space) and each term represents mean of hidden distribtuon.
  """

  def _compute_tf_loss(self, logvar, mu):
    import tensorflow as tf
    logvar, mu = _make_tf_shapes_consistent(logvar, mu)
    logvar, mu = _ensure_float(logvar, mu)
    return 0.5 * tf.reduce_mean(tf.square(mu) + tf.square(logvar) - tf.math.log(1e-20 + tf.square(logvar)) - 1,-1)

  def _create_pytorch_loss(self):
    import torch

    def loss(logvar, mu):
      logvar, mu = _make_pytorch_shapes_consistent(logvar, mu)
      return 0.5 * torch.mean(torch.square(mu) + torch.square(logvar) - torch.log(1e-20 + torch.square(logvar)) - 1,-1)

    return loss


class KLDivergence(Loss):
  """The KL_divergence between two distribution D_KL(P||Q).
  The argument should have shape (batch_size, num of variable) and represents
  probabilites distribution. 
  """

  def _compute_tf_loss(self, P, Q):
    import tensorflow as tf
    P, Q = _make_tf_shapes_consistent(P, Q)
    P, Q = _ensure_float(P, Q)
    #extended 1-dimensional inputs two binary distribution
    if P.shape[-1] == 1:
      P = tf.concat([P,1-P], axis = -1)
      Q = tf.concat([Q,1-Q], axis = -1)
    return tf.reduce_mean(P * tf.math.log((P+1e-20) / (Q+1e-20)), axis=-1)

  def _create_pytorch_loss(self):
    import torch

    def loss(P, Q):
      P, Q = _make_pytorch_shapes_consistent(P, Q)
      #extended 1-dimensional inputs two binary distribution
      if P.shape[-1] == 1:
        P = torch.cat((P,1-P), dim = -1)
        Q = torch.cat((Q,1-Q), dim = -1)
      return torch.mean(P * torch.log((P+1e-20) / (Q+1e-20)),dim = -1)

    return loss


class ShannonEntropy(Loss):
  """The ShannonEntropy of discrete-distribution.
  
  The inputs last dimension should be num of variable.
  """

  def _compute_tf_loss(self, inputs):
    import tensorflow as tf
    #extended 1-dimensional inputs to binary distribution
    if inputs.shape[-1] == 1:
      inputs = tf.concat([inputs,1-inputs], axis = -1)
    return tf.reduce_mean(-inputs*tf.math.log(1e-20+inputs), -1)

  def _create_pytorch_loss(self):
    import torch

    def loss(inputs):
      #extended 1-dimensional inputs to binary distribution
      if inputs.shape[-1] == 1:
        inputs = torch.cat((inputs,1-inputs), dim = -1)
      return torch.mean(-inputs*torch.log(1e-20+inputs), -1)

    return loss


def _make_tf_shapes_consistent(output, labels):
  """Try to make inputs have the same shape by adding dimensions of size 1."""
  import tensorflow as tf