Commit 3dd5a2bd authored by hsjang001205's avatar hsjang001205
Browse files

Add losses for VAE

parent 5b9ac4c1
Loading
Loading
Loading
Loading
+13 −4
Original line number Diff line number Diff line
@@ -224,10 +224,13 @@ class VAE_ELBO(Loss):
  batch_size = 2,
  hidden_space = 2,
  num of original attribute = 3
  >>> import numpy as np
  >>> import torch
  >>> import tensorflow as tf
  >>> logvar = np.array([[1.0,1.3],[0.6,1.2]])
  >>> mu = np.array([[0.2,0.7],[1.2,0.4]])
  >>> x = np.array([[0.9,0.4,0.8],[0.3,0,1]])
  >>> reconstruction = np.array([[0.8,0.3,0.7],[0.2,0,0.9]])
  >>> reconstruction_x = np.array([[0.8,0.3,0.7],[0.2,0,0.9]])
  
  Case tensorflow
  >>> VAE_ELBO()._compute_tf_loss(tf.constant(logvar), tf.constant(mu), tf.constant(x), tf.constant(reconstruction_x))
@@ -275,12 +278,15 @@ class VAE_KLDivergence(Loss):
  
  batch_size = 2,
  hidden_space = 2,
  >>> import numpy as np
  >>> import torch
  >>> import tensorflow as tf
  >>> logvar = np.array([[1.0,1.3],[0.6,1.2]])
  >>> mu = np.array([[0.2,0.7],[1.2,0.4]])
  
  Case tensorflow
  >>> VAE_KLDivergence()._compute_tf_loss(tf.constant(logvar), tf.constant(mu))
  tf.Tensor([0.52783368 0.24813068], shape=(2,), dtype=float64)
  <tf.Tensor: shape=(2,), dtype=float64, numpy=array([0.17381787, 0.51425203])>
  
  Case pytorch
  >>> (VAE_KLDivergence()._create_pytorch_loss())(torch.tensor(logvar), torch.tensor(mu))
@@ -322,15 +328,18 @@ class ShannonEntropy(Loss):
  
  batch_size = 2,
  num_of variable = variable,
  >>> import numpy as np
  >>> import torch
  >>> import tensorflow as tf
  >>> inputs = np.array([[0.7,0.3],[0.9,0.1]])
  
  Case tensorflow
  >>> ShannonEntropy()._compute_tf_loss(tf.constant(inputs))
  tf.Tensor([0.52783368 0.24813068], shape=(2,), dtype=float64)
  <tf.Tensor: shape=(2,), dtype=float64, numpy=array([0.30543215, 0.16254149])>
  
  Case pytorch
  >>> (ShannonEntropy()._create_pytorch_loss())(torch.tensor(inputs))
  tensor([0.1738, 0.5143], dtype=torch.float64)
  tensor([0.3054, 0.1625], dtype=torch.float64)
  """

  def _compute_tf_loss(self, inputs):