Commit 05ae3a79 authored by mufeili's avatar mufeili
Browse files

Update

parent a739f244
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -188,7 +188,7 @@ class SparseSoftmaxCrossEntropy(Loss):
  def _compute_tf_loss(self, output, labels):
    import tensorflow as tf

    if labels.shape[-1] == 1:
    if len(labels.shape) == len(output.shape):
      labels = tf.squeeze(labels, axis=-1)

    labels = tf.cast(labels, tf.int32)
@@ -205,7 +205,8 @@ class SparseSoftmaxCrossEntropy(Loss):
      # This is for API consistency
      if len(output.shape) == 3:
        output = output.permute(0, 2, 1)
      if labels.shape[-1] == 1:

      if len(labels.shape) == len(output.shape):
        labels = labels.squeeze(-1)
      return ce_loss(output, labels.long())