Commit d634d11a authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

Added tests

parent 96b0e2f2
Loading
Loading
Loading
Loading
+14 −12
Original line number Diff line number Diff line
@@ -612,28 +612,30 @@ def test_scale_norm():
  """Test invoking ScaleNorm."""
  input_ar = torch.tensor([[1., 99., 10000.], [0.003, 999.37, 23.]])
  layer = torch_layers.ScaleNorm(0.35)
  result1 = layer.forward(input_ar)
  result1 = layer(input_ar)
  output_ar = np.array([[5.9157897e-05, 5.8566318e-03, 5.9157896e-01],
                        [1.7754727e-06, 5.9145141e-01, 1.3611957e-02]])
  assert np.allclose(result1, output_ar)


@pytest.mark.torch
def test_mat_embedding():
  """Test invoking MATEmbedding."""
  torch.manual_seed(0)
  input_ar = torch.tensor([1., 2., 3.])
  layer = torch_layers.MATEmbedding(3, 1, 0.0)
  result1 = layer(input_ar)
  output_ar = np.array([[5.9157897e-05, 5.8566318e-03, 5.9157896e-01],
                        [1.7754727e-06, 5.9145141e-01, 1.3611957e-02]])
  assert np.allclose(result1, output_ar)
  result = layer(input_ar).detach()
  output_ar = torch.tensor([-1.2353])
  assert torch.allclose(result, output_ar, rtol=1e-4)


@pytest.mark.torch
def test_mat_generator():
  """Test invoking MATGenerator."""
  input_ar = torch.tensor([[1., 99., 10000.], [0.003, 999.37, 23.]])
  layer = torch_layers.ScaleNorm(0.35)
  result1 = layer.forward(input_ar)
  output_ar = np.array([[5.9157897e-05, 5.8566318e-03, 5.9157896e-01],
                        [1.7754727e-06, 5.9145141e-01, 1.3611957e-02]])
  assert np.allclose(result1, output_ar)
  torch.manual_seed(0)
  input_ar = torch.tensor([1., 2., 3.])
  layer = torch_layers.MATGenerator(3, 'mean', 1, 1, 0.0)
  mask = torch.tensor([1., 1., 1.])
  result = layer(input_ar, mask)
  output_ar = torch.tensor([-1.4436])
  assert torch.allclose(result, output_ar, rtol=1e-4)