Commit 3b2c17d9 authored by peastman's avatar peastman
Browse files

Improvements to computing losses

parent 6fefb05f
Loading
Loading
Loading
Loading
+4 −4
Original line number Diff line number Diff line
@@ -138,10 +138,10 @@ class SigmoidCrossEntropy(Loss):

  def _create_pytorch_loss(self):
    import torch
    bce = torch.nn.BCELoss(reduction='none')
    bce = torch.nn.BCEWithLogitsLoss(reduction='none')

    def loss(output, labels):
      return bce(torch.sigmoid(output), labels)
      return bce(output, labels)

    return loss

@@ -163,10 +163,10 @@ class SoftmaxCrossEntropy(Loss):

  def _create_pytorch_loss(self):
    import torch
    ls = torch.nn.LogSoftmax(dim=1)

    def loss(output, labels):
      return -torch.sum(
          labels * torch.log(torch.nn.functional.softmax(output, 1)), dim=-1)
      return -torch.sum(labels * ls(output), dim=-1)

    return loss