Commit cbf3580f authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Saving changes

parent 8961cfff
Loading
Loading
Loading
Loading
+19 −16
Original line number Diff line number Diff line
@@ -26,7 +26,9 @@ class SequentialGraph(object):
    n_feat: int
      Number of features per atom.
    """
    #self.graph_topology = GraphTopology(n_atoms, n_feat)
    self.graph = tf.Graph()
    self.session = tf.Session(graph=self.graph, config=config)
    with self.graph.as_default():
      self.graph_topology = GraphTopology(n_feat)
      self.output = self.graph_topology.get_atom_features_placeholder()
    # Keep track of the layers
@@ -34,6 +36,7 @@ class SequentialGraph(object):

  def add(self, layer):
    """Adds a new layer to model."""
    with self.graph.as_default():
      # For graphical layers, add connectivity placeholders 
      if type(layer).__name__ in ['GraphConv', 'GraphGather', 'GraphPool']:
        if (len(self.layers) > 0 and hasattr(self.layers[-1], "__name__")):
+8 −12
Original line number Diff line number Diff line
@@ -52,22 +52,18 @@ class GraphTopology(object):
    self.max_deg = max_deg
    self.min_deg = min_deg

    self.atom_features_placeholder = Input(
        tensor=tf.placeholder(
    self.atom_features_placeholder = tensor=tf.placeholder(
            dtype='float32', shape=(None, self.n_feat),
            name=self.name+'_atom_features'))
            name=self.name+'_atom_features')
    self.deg_adj_lists_placeholders = [
        Input(tensor=tf.placeholder(
          dtype='int32', shape=(None, deg), name=self.name+'_deg_adj'+str(deg)))
        tf.placeholder(
          dtype='int32', shape=(None, deg), name=self.name+'_deg_adj'+str(deg))
        for deg in range(1, self.max_deg+1)]
    self.deg_slice_placeholder = Input(
        tensor=tf.placeholder(
    self.deg_slice_placeholder = tf.placeholder(
            dtype='int32', shape=(self.max_deg-self.min_deg+1,2),
            name="deg_slice",),
            name=self.name+'_deg_slice')
    self.membership_placeholder = Input(
          tensor=tf.placeholder(dtype='int32', shape=(None,), name="membership"),
          name=self.name+'_membership')
    self.membership_placeholder = tf.placeholder(
        dtype='int32', shape=(None,), name=self.name+'_membership')

    # Define the list of tensors to be used as topology
    self.topology = [self.deg_slice_placeholder, self.membership_placeholder]