Commit 3e7f114b authored by hsjang001205's avatar hsjang001205
Browse files

Add losses for VAE

parent 3dd5a2bd
Loading
Loading
Loading
Loading
+19 −4
Original line number Diff line number Diff line
@@ -208,7 +208,7 @@ class SparseSoftmaxCrossEntropy(Loss):
class VAE_ELBO(Loss):
  """The Variational AutoEncoder loss, KL Divergence Regularize + marginal log-likelihood.
  
  This losses basesd on "Auto-Encoding Variational Bayes" (https://arxiv.org/abs/1312.6114).
  This losses based on _[1].
  ELBO(Evidence lower bound) lexically replaced Variational lower bound. 
  BCE means marginal log-likelihood, and KLD means KL divergence with normal distribution.
  Added hyper parameter 'kl_scale' for KLD.
@@ -239,6 +239,12 @@ class VAE_ELBO(Loss):
  Case pytorch
  >>> (VAE_ELBO()._create_pytorch_loss())(torch.tensor(logvar), torch.tensor(mu), torch.tensor(x), torch.tensor(reconstruction_x))
  tensor([0.7017, 0.7624], dtype=torch.float64)
  
  
  References
  ----------
  .. [1] Diederik P Kingma., and Max Welling. "Auto-encoding variational bayes." arXiv preprint arXiv:1312.6114 (2013).
  
  """

  def _compute_tf_loss(self, logvar, mu, x, reconstruction_x, kl_scale=1):
@@ -266,7 +272,7 @@ class VAE_KLDivergence(Loss):
  """The KL_divergence between hidden distribution and normal distribution.
  
  This loss represents KL divergence losses between normal distribution(using parameter of distribution)
  based on "Auto-Encoding Variational Bayes" (https://arxiv.org/abs/1312.6114).
  based on  _[1].
  
  The logvar should have shape (batch_size, hidden_space) and each term represents
  standard deviation of hidden distribution. The mean shuold have 
@@ -291,6 +297,11 @@ class VAE_KLDivergence(Loss):
  Case pytorch
  >>> (VAE_KLDivergence()._create_pytorch_loss())(torch.tensor(logvar), torch.tensor(mu))
  tensor([0.1738, 0.5143], dtype=torch.float64)
  
  References
  ----------
  .. [1] Diederik P Kingma., and Max Welling. "Auto-encoding variational bayes." arXiv preprint arXiv:1312.6114 (2013).
  
  """

  def _compute_tf_loss(self, logvar, mu):
@@ -316,8 +327,7 @@ class VAE_KLDivergence(Loss):
class ShannonEntropy(Loss):
  """The ShannonEntropy of discrete-distribution.
  
  This loss represents shannon entropy based on
  "A Brief Introduction to Shannon's Information Theory" (https://arxiv.org/abs/1612.09316).
  This loss represents shannon entropy based on _[1].
  
  The inputs should have shape (batch size, num of variable) and represents
  probabilites distribution.
@@ -340,6 +350,11 @@ class ShannonEntropy(Loss):
  Case pytorch
  >>> (ShannonEntropy()._create_pytorch_loss())(torch.tensor(inputs))
  tensor([0.3054, 0.1625], dtype=torch.float64)
  
  References
  ----------
  .. [1] Ricky Xiaofeng Chen. "A Brief Introduction to Shannon’s Information Theory." arXiv preprint arXiv:1612.09316 (2016).
  
  """

  def _compute_tf_loss(self, inputs):
+9 −0
Original line number Diff line number Diff line
@@ -199,6 +199,15 @@ Losses
.. autoclass:: deepchem.models.losses.SparseSoftmaxCrossEntropy
  :members:

.. autoclass:: deepchem.models.losses.VAE_ELBO
  :members:

.. autoclass:: deepchem.models.losses.VAE_KLDivergence
  :members:

.. autoclass:: deepchem.models.losses.ShannonEntropy
  :members:

Optimizers
----------