Unverified Commit fbf8b794 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #972 from lilleswing/config-graphconv

Add Model Configuration Params To GraphConvTensorGraph
parents 0d42aef7 e3cc49be
Loading
Loading
Loading
Loading
+143 −128
Original line number Diff line number Diff line
@@ -638,18 +638,33 @@ class PetroskiSuchTensorGraph(TensorGraph):

class GraphConvTensorGraph(TensorGraph):

  def __init__(self, n_tasks, mode="classification", **kwargs):
  def __init__(self,
               n_tasks,
               graph_conv_layers=[64, 64],
               dense_layer_size=128,
               dropout=0.0,
               mode="classification",
               **kwargs):
    """
        Parameters
        ----------
        n_tasks: int
          Number of tasks
        graph_conv_layers: list of int
          Width of channels for the Graph Convolution Layers
        dense_layer_size: int
          Width of channels for Atom Level Dense Layer before GraphPool
        dropout: float
          Droupout dropout probability.  Dropout is applied after the per Atom Level Dense Layer
        mode: str
          Either "classification" or "regression"
        """
    self.n_tasks = n_tasks
    self.mode = mode
    self.error_bars = True if 'error_bars' in kwargs and kwargs['error_bars'] else False
    self.dense_layer_size = dense_layer_size
    self.dropout = dropout
    self.graph_conv_layers = graph_conv_layers
    kwargs['use_queue'] = False
    super(GraphConvTensorGraph, self).__init__(**kwargs)
    self.build_graph()
@@ -666,23 +681,23 @@ class GraphConvTensorGraph(TensorGraph):
    for i in range(0, 10 + 1):
      deg_adj = Feature(shape=(None, i + 1), dtype=tf.int32)
      self.deg_adjs.append(deg_adj)
    in_layer = self.atom_features
    for layer_size in self.graph_conv_layers:
      gc1 = GraphConv(
        64,
          layer_size,
          activation_fn=tf.nn.relu,
        in_layers=[self.atom_features, self.degree_slice, self.membership] +
          in_layers=[in_layer, self.degree_slice, self.membership] +
          self.deg_adjs)
      batch_norm1 = BatchNorm(in_layers=[gc1])
    gp1 = GraphPool(in_layers=[batch_norm1, self.degree_slice, self.membership]
                    + self.deg_adjs)
    gc2 = GraphConv(
        64,
      in_layer = GraphPool(
          in_layers=[batch_norm1, self.degree_slice, self.membership
                    ] + self.deg_adjs)
    dense = Dense(
        out_channels=self.dense_layer_size,
        activation_fn=tf.nn.relu,
        in_layers=[gp1, self.degree_slice, self.membership] + self.deg_adjs)
    batch_norm2 = BatchNorm(in_layers=[gc2])
    gp2 = GraphPool(in_layers=[batch_norm2, self.degree_slice, self.membership]
                    + self.deg_adjs)
    dense = Dense(out_channels=128, activation_fn=tf.nn.relu, in_layers=[gp2])
        in_layers=[in_layer])
    batch_norm3 = BatchNorm(in_layers=[dense])
    batch_norm3 = Dropout(self.dropout, in_layers=[batch_norm3])
    readout = GraphGather(
        batch_size=self.batch_size,
        activation_fn=tf.nn.tanh,