Commit 6d02180d authored by BuildTools's avatar BuildTools
Browse files

Add losses for VAE

parent f930464f
Loading
Loading
Loading
Loading
+6 −5
Original line number Diff line number Diff line
@@ -267,7 +267,7 @@ class KLDivergence(Loss):
    import tensorflow as tf
    P, Q = _make_tf_shapes_consistent(P, Q)
    P, Q = _ensure_float(P, Q)
    #extended 1-dimensional inputs two binary distribution
    #extended one of probabilites to to binary distribution
    if P.shape[-1] == 1:
      P = tf.concat([P,1-P], axis = -1)
      Q = tf.concat([Q,1-Q], axis = -1)
@@ -278,7 +278,7 @@ class KLDivergence(Loss):

    def loss(P, Q):
      P, Q = _make_pytorch_shapes_consistent(P, Q)
      #extended 1-dimensional inputs two binary distribution
      #extended one of probabilites to binary distribution
      if P.shape[-1] == 1:
        P = torch.cat((P,1-P), dim = -1)
        Q = torch.cat((Q,1-Q), dim = -1)
@@ -290,12 +290,13 @@ class KLDivergence(Loss):
class ShannonEntropy(Loss):
  """The ShannonEntropy of discrete-distribution.
  
  The inputs last dimension should be num of variable.
  The inputs should have shape (batch size, num of variable) and represents
  probabilites distribution.
  """

  def _compute_tf_loss(self, inputs):
    import tensorflow as tf
    #extended 1-dimensional inputs to binary distribution
    #extended one of probabilites to binary distribution
    if inputs.shape[-1] == 1:
      inputs = tf.concat([inputs,1-inputs], axis = -1)
    return tf.reduce_mean(-inputs*tf.math.log(1e-20+inputs), -1)
@@ -304,7 +305,7 @@ class ShannonEntropy(Loss):
    import torch

    def loss(inputs):
      #extended 1-dimensional inputs to binary distribution
      #extended one of probabilites to binary distribution
      if inputs.shape[-1] == 1:
        inputs = torch.cat((inputs,1-inputs), dim = -1)
      return torch.mean(-inputs*torch.log(1e-20+inputs), -1)