Commit 5b9ac4c1 authored by hsjang001205's avatar hsjang001205
Browse files

Add losses for VAE

parent a81fe7f3
Loading
Loading
Loading
Loading
+14 −10
Original line number Diff line number Diff line
@@ -291,14 +291,18 @@ class VAE_KLDivergence(Loss):
    import tensorflow as tf
    logvar, mu = _make_tf_shapes_consistent(logvar, mu)
    logvar, mu = _ensure_float(logvar, mu)
    return 0.5 * tf.reduce_mean(tf.square(mu) + tf.square(logvar) - tf.math.log(1e-20 + tf.square(logvar)) - 1,-1)
    return 0.5 * tf.reduce_mean(
        tf.square(mu) + tf.square(logvar) -
        tf.math.log(1e-20 + tf.square(logvar)) - 1, -1)

  def _create_pytorch_loss(self):
    import torch

    def loss(logvar, mu):
      logvar, mu = _make_pytorch_shapes_consistent(logvar, mu)
      return 0.5 * torch.mean(torch.square(mu) + torch.square(logvar) - torch.log(1e-20 + torch.square(logvar)) - 1,-1)
      return 0.5 * torch.mean(
          torch.square(mu) + torch.square(logvar) -
          torch.log(1e-20 + torch.square(logvar)) - 1, -1)

    return loss

+73 −34
Original line number Diff line number Diff line
@@ -207,10 +207,20 @@ class TestLosses(unittest.TestCase):
    x = tf.constant([[0.9, 0.4, 0.8], [0.3, 0, 1]])
    reconstruction_x = tf.constant([[0.8, 0.3, 0.7], [0.2, 0, 0.9]])
    result = loss._compute_tf_loss(logvar, mu, x, reconstruction_x).numpy()
    expected = [0.5 * np.mean([0.04+1.0-np.log(1e-20+1.0)-1, 0.49+1.69 - np.log(1e-20 +1.69) - 1])
                -np.mean(np.array([0.9,0.4,0.8])*np.log([0.8,0.3,0.7])+np.array([0.1,0.6,0.2])*np.log([0.2,0.7,0.3])),
                0.5 * np.mean([1.44+0.36-np.log(1e-20+0.36)-1, 0.16+1.44 - np.log(1e-20 +1.44) - 1])
                -np.mean(np.array([0.3,0,1])*np.log([0.2,1e-20,0.9])+np.array([0.7,1,0])*np.log([0.8,1,0.1]))]
    expected = [
        0.5 * np.mean([
            0.04 + 1.0 - np.log(1e-20 + 1.0) - 1,
            0.49 + 1.69 - np.log(1e-20 + 1.69) - 1
        ]) - np.mean(
            np.array([0.9, 0.4, 0.8]) * np.log([0.8, 0.3, 0.7]) +
            np.array([0.1, 0.6, 0.2]) * np.log([0.2, 0.7, 0.3])),
        0.5 * np.mean([
            1.44 + 0.36 - np.log(1e-20 + 0.36) - 1,
            0.16 + 1.44 - np.log(1e-20 + 1.44) - 1
        ]) - np.mean(
            np.array([0.3, 0, 1]) * np.log([0.2, 1e-20, 0.9]) +
            np.array([0.7, 1, 0]) * np.log([0.8, 1, 0.1]))
    ]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
@@ -221,11 +231,22 @@ class TestLosses(unittest.TestCase):
    mu = torch.tensor([[0.2, 0.7], [1.2, 0.4]])
    x = torch.tensor([[0.9, 0.4, 0.8], [0.3, 0, 1]])
    reconstruction_x = torch.tensor([[0.8, 0.3, 0.7], [0.2, 0, 0.9]])
    result = loss._create_pytorch_loss()(logvar, mu, x, reconstruction_x).numpy()
    expected = [0.5 * np.mean([0.04+1.0-np.log(1e-20+1.0)-1, 0.49+1.69 - np.log(1e-20 +1.69) - 1])
                -np.mean(np.array([0.9,0.4,0.8])*np.log([0.8,0.3,0.7])+np.array([0.1,0.6,0.2])*np.log([0.2,0.7,0.3])),
                0.5 * np.mean([1.44+0.36-np.log(1e-20+0.36)-1, 0.16+1.44 - np.log(1e-20 +1.44) - 1])
                -np.mean(np.array([0.3,0,1])*np.log([0.2,1e-20,0.9])+np.array([0.7,1,0])*np.log([0.8,1,0.1]))]
    result = loss._create_pytorch_loss()(logvar, mu, x,
                                         reconstruction_x).numpy()
    expected = [
        0.5 * np.mean([
            0.04 + 1.0 - np.log(1e-20 + 1.0) - 1,
            0.49 + 1.69 - np.log(1e-20 + 1.69) - 1
        ]) - np.mean(
            np.array([0.9, 0.4, 0.8]) * np.log([0.8, 0.3, 0.7]) +
            np.array([0.1, 0.6, 0.2]) * np.log([0.2, 0.7, 0.3])),
        0.5 * np.mean([
            1.44 + 0.36 - np.log(1e-20 + 0.36) - 1,
            0.16 + 1.44 - np.log(1e-20 + 1.44) - 1
        ]) - np.mean(
            np.array([0.3, 0, 1]) * np.log([0.2, 1e-20, 0.9]) +
            np.array([0.7, 1, 0]) * np.log([0.8, 1, 0.1]))
    ]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
@@ -235,8 +256,15 @@ class TestLosses(unittest.TestCase):
    logvar = tf.constant([[1.0, 1.3], [0.6, 1.2]])
    mu = tf.constant([[0.2, 0.7], [1.2, 0.4]])
    result = loss._compute_tf_loss(logvar, mu).numpy()
    expected = [0.5 * np.mean([0.04+1.0-np.log(1e-20+1.0)-1, 0.49+1.69 - np.log(1e-20 +1.69) - 1]),
                0.5 * np.mean([1.44+0.36-np.log(1e-20+0.36)-1, 0.16+1.44 - np.log(1e-20 +1.44) - 1])]
    expected = [
        0.5 * np.mean([
            0.04 + 1.0 - np.log(1e-20 + 1.0) - 1,
            0.49 + 1.69 - np.log(1e-20 + 1.69) - 1
        ]), 0.5 * np.mean([
            1.44 + 0.36 - np.log(1e-20 + 0.36) - 1,
            0.16 + 1.44 - np.log(1e-20 + 1.44) - 1
        ])
    ]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
@@ -246,8 +274,15 @@ class TestLosses(unittest.TestCase):
    logvar = torch.tensor([[1.0, 1.3], [0.6, 1.2]])
    mu = torch.tensor([[0.2, 0.7], [1.2, 0.4]])
    result = loss._create_pytorch_loss()(logvar, mu).numpy()
    expected = [0.5 * np.mean([0.04+1.0-np.log(1e-20+1.0)-1, 0.49+1.69 - np.log(1e-20 +1.69) - 1]),
                0.5 * np.mean([1.44+0.36-np.log(1e-20+0.36)-1, 0.16+1.44 - np.log(1e-20 +1.44) - 1])]
    expected = [
        0.5 * np.mean([
            0.04 + 1.0 - np.log(1e-20 + 1.0) - 1,
            0.49 + 1.69 - np.log(1e-20 + 1.69) - 1
        ]), 0.5 * np.mean([
            1.44 + 0.36 - np.log(1e-20 + 0.36) - 1,
            0.16 + 1.44 - np.log(1e-20 + 1.44) - 1
        ])
    ]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
@@ -256,8 +291,10 @@ class TestLosses(unittest.TestCase):
    loss = losses.ShannonEntropy()
    inputs = tf.constant([[0.7, 0.3], [0.9, 0.1]])
    result = loss._compute_tf_loss(inputs).numpy()
    expected = [-np.mean([0.7*np.log(0.7),0.3*np.log(0.3)]),
                -np.mean([0.9*np.log(0.9),0.1*np.log(0.1)])]
    expected = [
        -np.mean([0.7 * np.log(0.7), 0.3 * np.log(0.3)]),
        -np.mean([0.9 * np.log(0.9), 0.1 * np.log(0.1)])
    ]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
@@ -266,6 +303,8 @@ class TestLosses(unittest.TestCase):
    loss = losses.ShannonEntropy()
    inputs = torch.tensor([[0.7, 0.3], [0.9, 0.1]])
    result = loss._create_pytorch_loss()(inputs).numpy()
    expected = [-np.mean([0.7*np.log(0.7),0.3*np.log(0.3)]),
                -np.mean([0.9*np.log(0.9),0.1*np.log(0.1)])]
    expected = [
        -np.mean([0.7 * np.log(0.7), 0.3 * np.log(0.3)]),
        -np.mean([0.9 * np.log(0.9), 0.1 * np.log(0.1)])
    ]
    assert np.allclose(expected, result)