Commit c3377b43 authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

Changed torch.maximum to np.maximum

parent 54493d9c
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -109,11 +109,12 @@ class SquaredHingeLoss(Loss):
    import torch

    def loss(output, labels):
      import numpy as np
      output, labels = _make_pytorch_shapes_consistent(output, labels)
      return torch.mean(
          torch.pow(
              torch.maximum(1 - torch.multiply(labels, output),
                            torch.tensor(0)), 2),
              np.maximum(1 - torch.multiply(labels, output), torch.tensor(0)),
              2),
          dim=-1)

    return loss