Commit a4c83e15 authored by galenxing's avatar galenxing
Browse files

alphashare test

parent 5a475a6a
Loading
Loading
Loading
Loading
+24 −0
Original line number Diff line number Diff line
@@ -46,6 +46,9 @@ from deepchem.models.tensorgraph.layers import TensorWrapper
from deepchem.models.tensorgraph.layers import LSTMStep
from deepchem.models.tensorgraph.layers import AttnLSTMEmbedding
from deepchem.models.tensorgraph.layers import IterRefLSTMEmbedding
from deepchem.models.tensorgraph.layers import AlphaShare
#from deepchem.models.tensorgraph.layers import BetaShare
#from deepchem.models.tensorgraph.layers import SluiceLoss

import deepchem as dc

@@ -394,6 +397,7 @@ class TestLayers(test_util.TensorFlowTestCase):
      for deg_adj in deg_adjs:
        deg_adjs_tf.append(tf.convert_to_tensor(deg_adj, dtype=tf.int32))
      args = [atom_features, degree_slice, membership] + deg_adjs_tf
      test_1 = tf.convert_to_tensor(test_1, dtype=tf.float32)
      out_tensor = GraphConv(out_channels)(*args)
      sess.run(tf.global_variables_initializer())
      out_tensor = out_tensor.eval()
@@ -464,6 +468,24 @@ class TestLayers(test_util.TensorFlowTestCase):
      assert test_out.shape == (n_test, n_feat)
      assert support_out.shape == (n_support, n_feat)

  def test_alpha_share(self):
    """test that alpha share works correctly"""
    batch_size = 50
    length = 10
    test_1 = np.random.rand(batch_size, length)
    test_2 = np.random.rand(batch_size, length)

    with self.test_session() as sess:
      test_1 = tf.convert_to_tensor(test_1, dtype=tf.float32)
      test_2 = tf.convert_to_tensor(test_2, dtype=tf.float32)

      out_tensor = AlphaShare(in_layers=[test_1, test_2])
      sess.run(tf.global_variables_initializer())
      test_1_out_tensor = out_tensor[0].eval
      test_2_out_tensor = out_tensor[1].eval
      assert test_1.shape == test_1_out_tensor.shape
      assert test_2.shape == test_2_out_tensor.shape

  # TODO(rbharath): This test should pass. Fix it!
  #def test_graph_pool(self):
  #  """Test that GraphPool can be invoked."""
@@ -609,3 +631,5 @@ class TestLayers(test_util.TensorFlowTestCase):
      assert result == 1.5
      result = sess.run(tf.gradients(v, v))
      assert result[0] == 1.0

TestLayers.test_alpha_share()