Commit 06b2d7d5 authored by nd-02110114's avatar nd-02110114
Browse files

💚 fix test error

parent c24b7cbf
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -192,7 +192,7 @@ class SparseSoftmaxCrossEntropy(Loss):

  def _create_pytorch_loss(self):
    import torch
    ce_loss = torch.nn.CrossEntropyLoss(reduction='mean')
    ce_loss = torch.nn.CrossEntropyLoss(reduction='none')

    def loss(output, labels):
      # Convert (batch_size, tasks, classes) to (batch_size, classes, tasks)