Unverified Commit 2531eca8 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1485 from VIGS25/add-tests-for-graph-layers

#1186: Adding Tests for Graph Layers (DTNNEmbedding, DTNNExtract, WeaveGather)
parents 0253c010 7b6d8f5b
Loading
Loading
Loading
Loading
+78 −0
Original line number Diff line number Diff line
@@ -60,6 +60,9 @@ from deepchem.models.tensorgraph.layers import WeightedLinearCombo
from deepchem.models.tensorgraph.IRV import IRVLayer
from deepchem.models.tensorgraph.IRV import IRVRegularize
from deepchem.models.tensorgraph.IRV import Slice
from deepchem.models.tensorgraph.graph_layers import DTNNEmbedding
from deepchem.models.tensorgraph.graph_layers import DTNNExtract
from deepchem.models.tensorgraph.graph_layers import WeaveGather


class TestLayers(test_util.TensorFlowTestCase):
@@ -928,3 +931,78 @@ class TestLayers(test_util.TensorFlowTestCase):
      sess.run(tf.global_variables_initializer())
      cost = WeightDecay(3.0, 'l2')(0.0)
      assert np.allclose(3.0 * np.sum(values * values) / 2, cost.eval())

  def test_dtnn_embedding(self):
    """Test that DTNNEmbedding can be invoked."""
    n_embedding = 10
    periodic_table_length = 20
    test_tensor_input = np.random.permutation(
        np.arange(0, periodic_table_length // 2, dtype=np.int32))
    with self.test_session() as sess:
      test_tensor = tf.convert_to_tensor(test_tensor_input, dtype=tf.int32)
      dtnn_embedding = DTNNEmbedding(
          n_embedding=n_embedding, periodic_table_length=periodic_table_length)
      dtnn_embedding.create_tensor(in_layers=[test_tensor])

      # Layer is wrapper around embedding lookup, tested that then
      sess.run(tf.global_variables_initializer())
      out_tensor = dtnn_embedding.out_tensor.eval()
      embedding_val = dtnn_embedding.embedding_list.eval()
      expected_output = embedding_val[test_tensor_input]
      self.assertAllClose(out_tensor, expected_output)
      self.assertAllClose(out_tensor.shape,
                          (periodic_table_length // 2, n_embedding))

  def test_dtnn_extract(self):
    """Test that DTNNExtract can be invoked."""
    num_samples = 20
    num_features = 30
    task_id = 15
    test_tensor_input = np.random.randn(num_samples, num_features)
    test_output = test_tensor_input[:, task_id:task_id + 1]
    with self.test_session() as sess:
      test_tensor = tf.convert_to_tensor(test_tensor_input)
      dtnn_extract = DTNNExtract(task_id=task_id)
      dtnn_extract.create_tensor(in_layers=[test_tensor])
      sess.run(tf.global_variables_initializer())
      out_tensor = dtnn_extract.out_tensor.eval()
      self.assertAllClose(test_output, out_tensor)
      self.assertEqual(out_tensor.shape, (num_samples, 1))

  def test_weave_gather(self):
    """Test that WeaveGather can be invoked."""
    batch_size = 1
    num_samples = 2
    num_atoms_per_sample = 5
    num_features = 5
    gaussian_expand = False

    atom_split_np = list()
    for i in range(num_samples):
      atom_split_np.extend([i] * num_atoms_per_sample)
    atom_split_np = np.array(atom_split_np)
    tensor_input_np = np.random.randn(num_samples * num_atoms_per_sample,
                                      num_features)

    # Expected output
    expected_output = list()
    for i in range(num_samples):
      expected_output.append(
          np.sum(
              tensor_input_np[i * num_atoms_per_sample:(i + 1) *
                              num_atoms_per_sample],
              axis=0))
    expected_output = np.array(expected_output)

    with self.test_session() as sess:
      tensor_input_tf = tf.convert_to_tensor(tensor_input_np)
      atom_split_tf = tf.convert_to_tensor(atom_split_np)
      weave_gather = WeaveGather(
          batch_size=batch_size, gaussian_expand=gaussian_expand)
      weave_gather.create_tensor(in_layers=[tensor_input_tf, atom_split_tf])

      sess.run(tf.global_variables_initializer())
      out_tensor = weave_gather.out_tensor.eval()

      self.assertAllClose(expected_output, out_tensor)
      self.assertEqual(expected_output.shape, out_tensor.shape)