Commit 9ed04203 authored by mufeili's avatar mufeili
Browse files

Update

parent fad88737
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -186,6 +186,12 @@ class TestLosses(unittest.TestCase):
    expected = [-np.log(softmax[0, 1]), -np.log(softmax[1, 0])]
    assert np.allclose(expected, result)

    labels = tf.constant([[1], [0]])
    result = loss._compute_tf_loss(outputs, labels).numpy()
    softmax = np.exp(y) / np.expand_dims(np.sum(np.exp(y), axis=1), 1)
    expected = [-np.log(softmax[0, 1]), -np.log(softmax[1, 0])]
    assert np.allclose(expected, result)

  @unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
  def test_sparse_softmax_cross_entropy_pytorch(self):
    """Test SparseSoftmaxCrossEntropy."""