Commit a739f244 authored by mufeili's avatar mufeili
Browse files

Update

parent 9ed04203
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -187,7 +187,12 @@ class SparseSoftmaxCrossEntropy(Loss):

  def _compute_tf_loss(self, output, labels):
    import tensorflow as tf

    if labels.shape[-1] == 1:
      labels = tf.squeeze(labels, axis=-1)

    labels = tf.cast(labels, tf.int32)

    return tf.nn.sparse_softmax_cross_entropy_with_logits(labels, output)

  def _create_pytorch_loss(self):