Commit 9be691bb authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

Fix scalenorm t est

parent 9e397b2b
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -613,9 +613,9 @@ def test_scale_norm():
  input_ar = torch.tensor([[1., 99., 10000.], [0.003, 999.37, 23.]])
  layer = torch_layers.ScaleNorm(0.35)
  result1 = layer(input_ar)
  output_ar = np.array([[5.9157897e-05, 5.8566318e-03, 5.9157896e-01],
  output_ar = torch.tensor([[5.9157897e-05, 5.8566318e-03, 5.9157896e-01],
                        [1.7754727e-06, 5.9145141e-01, 1.3611957e-02]])
  assert np.allclose(result1, output_ar)
  assert torch.allclose(result1, output_ar)


@pytest.mark.torch