Commit f87fa51a authored by leswing's avatar leswing
Browse files

Return adjs from Graphcnnpool

parent 829e5157
Loading
Loading
Loading
Loading
+3 −4
Original line number Diff line number Diff line
@@ -3179,11 +3179,10 @@ class GraphCNNPoolLayer(Layer):
    result_A = tf.reshape(result_A, (tf.shape(A)[0], self.num_vertices,
                                     A.get_shape()[2].value, self.num_vertices))
    # We do not need the mask because every graph has self.num_vertices vertices now
    # result = make_bn(result, True, mask=None, name="%s_bn" % self.name)
    if set_tensors:
      self.out_tensor = result
    self.out_tensors = [result, result_A, factors]
    return result
    self.out_tensors = [result, result_A]
    return result, result_A

  def embedding_factors(self, V, no_filters, name="default"):
    no_features = V.get_shape()[-1].value
@@ -3220,7 +3219,7 @@ class GraphCNNPoolLayer(Layer):

def GraphCNNPool(num_vertices, **kwargs):
  gcnnpool_layer = GraphCNNPoolLayer(num_vertices, **kwargs)
  return [PassThroughLayer(x, in_layers=gcnnpool_layer) for x in range(3)]
  return [PassThroughLayer(x, in_layers=gcnnpool_layer) for x in range(2)]


class GraphCNN(Layer):
+3 −3
Original line number Diff line number Diff line
@@ -495,7 +495,7 @@ class PetroskiSuchTensorGraph(TensorGraph):
  def __init__(self,
               n_tasks,
               max_atoms=200,
               dropout=0.2,
               dropout=0.0,
               mode="classification",
               **kwargs):
    """
@@ -528,13 +528,13 @@ class PetroskiSuchTensorGraph(TensorGraph):
    gcnn2 = BatchNorm(
        GraphCNN(num_filters=64, in_layers=[gcnn1, self.adj_matrix, self.mask]))
    gcnn2 = Dropout(self.dropout, in_layers=gcnn2)
    gc_pool, adj_matrix, factors = GraphCNNPool(
    gc_pool, adj_matrix = GraphCNNPool(
        num_vertices=32, in_layers=[gcnn2, self.adj_matrix, self.mask])
    gc_pool = BatchNorm(gc_pool)
    gc_pool = Dropout(self.dropout, in_layers=gc_pool)
    gcnn3 = BatchNorm(GraphCNN(num_filters=32, in_layers=[gc_pool, adj_matrix]))
    gcnn3 = Dropout(self.dropout, in_layers=gcnn3)
    gc_pool2, adj_matrix2, factors = GraphCNNPool(
    gc_pool2, adj_matrix2 = GraphCNNPool(
        num_vertices=8, in_layers=[gcnn3, adj_matrix])
    gc_pool2 = BatchNorm(gc_pool2)
    gc_pool2 = Dropout(self.dropout, in_layers=gc_pool2)
+4 −5
Original line number Diff line number Diff line
@@ -732,7 +732,6 @@ class TestLayers(test_util.TensorFlowTestCase):
      out_tensor = GraphCNN(num_filters=6)(V, adjs)
      sess.run(tf.global_variables_initializer())
      result = out_tensor.eval()
      print(result.shape)
      assert result.shape == (10, 100, 6)

  def test_graphcnnpool(self):
@@ -740,8 +739,8 @@ class TestLayers(test_util.TensorFlowTestCase):
    V = np.random.uniform(size=(10, 100, 50)).astype(np.float32)
    adjs = np.random.uniform(size=(10, 100, 5, 100)).astype(np.float32)
    with self.test_session() as sess:
      out_tensor = GraphCNNPoolLayer(num_vertices=6)(V, adjs)
      vertex_props, adjs = GraphCNNPoolLayer(num_vertices=6)(V, adjs)
      sess.run(tf.global_variables_initializer())
      result = out_tensor.eval()
      print(result.shape)
      assert result.shape == (10, 6, 50)
      vertex_props, adjs = vertex_props.eval(), adjs.eval()
      assert vertex_props.shape == (10, 6, 50)
      assert adjs.shape == (10, 6, 5, 6)