Commit 1d6e7864 authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

Formatting

parent 400fc446
Loading
Loading
Loading
Loading
+8 −6
Original line number Diff line number Diff line
@@ -77,7 +77,10 @@ def test_combine_mean_std():
  assert not np.array_equal(result2, mean)
  assert np.allclose(result2, mean, atol=0.1)

np.array

@pytest.mark.tensorflow
def test_stack():
  """Test invoking Stack."""
  input1 = np.random.rand(5, 4).astype(np.float32)
  input2 = np.random.rand(5, 4).astype(np.float32)
  result = layers.Stack()([input1, input2])
@@ -603,8 +606,10 @@ def test_DAG_gather():
  membership = np.sort(np.random.randint(0, batch_size, size=(batch_size)))
  outputs = layer([atom_features, membership])


@pytest.mark.pytorch
def test_layer_norm():
  assert(1 == 2)
  """Test invoking LayerNorm."""
  input_ar = torch.tensor([[1., 99., 10000.], [0.003, 999.37, 23.]])
  layer = torch_layers.LayerNorm(input_ar.shape)
@@ -612,6 +617,3 @@ def test_layer_norm():
  output_ar = np.array([[-0.58585864, -0.5687999, 1.1546584],
                        [-0.59738946, 1.1544659, -0.55707645]])
  assert np.allclose(result1, output_ar)