Unverified Commit a1d6fd27 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1778 from deepchem/latesttf

Fixing tests for graph convolutions, DAG models with TF 2.2RC
parents e0185ee5 cbd90aca
Loading
Loading
Loading
Loading
+79 −1
Original line number Original line Diff line number Diff line
@@ -70,11 +70,20 @@ reference_lists = [
]
]


intervals = get_intervals(reference_lists)
intervals = get_intervals(reference_lists)
# We use E-Z notation for stereochemistry
# https://en.wikipedia.org/wiki/E%E2%80%93Z_notation
possible_bond_stereo = ["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"]
possible_bond_stereo = ["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"]
bond_fdim_base = 6
bond_fdim_base = 6




def get_feature_list(atom):
def get_feature_list(atom):
  """Returns a list of possible features for this atom.

  Parameters
  ----------
  atom: RDKit.rdchem.Atom
    Atom to get features for 
  """
  features = 6 * [0]
  features = 6 * [0]
  features[0] = safe_index(possible_atom_list, atom.GetSymbol())
  features[0] = safe_index(possible_atom_list, atom.GetSymbol())
  features[1] = safe_index(possible_numH_list, atom.GetTotalNumHs())
  features[1] = safe_index(possible_numH_list, atom.GetTotalNumHs())
@@ -113,7 +122,13 @@ def id_to_features(id, intervals):




def atom_to_id(atom):
def atom_to_id(atom):
  """Return a unique id corresponding to the atom type"""
  """Return a unique id corresponding to the atom type

  Parameters
  ----------
  atom: RDKit.rdchem.Atom
    Atom to convert to ids.
  """
  features = get_feature_list(atom)
  features = get_feature_list(atom)
  return features_to_id(features, intervals)
  return features_to_id(features, intervals)


@@ -122,6 +137,19 @@ def atom_features(atom,
                  bool_id_feat=False,
                  bool_id_feat=False,
                  explicit_H=False,
                  explicit_H=False,
                  use_chirality=False):
                  use_chirality=False):
  """Helper method used to compute per-atom feature vectors.

  Many different featurization methods compute per-atom features such as ConvMolFeaturizer, WeaveFeaturizer. This method computes such features.

  Parameters
  ----------
  bool_id_feat: bool, optional
    Return an array of unique identifiers corresponding to atom type.
  explicit_H: bool, optional
    If true, model hydrogens explicitly
  use_chirality: bool, optional
    If true, use chirality information.
  """
  if bool_id_feat:
  if bool_id_feat:
    return np.array([atom_to_id(atom)])
    return np.array([atom_to_id(atom)])
  else:
  else:
@@ -199,6 +227,16 @@ def atom_features(atom,




def bond_features(bond, use_chirality=False):
def bond_features(bond, use_chirality=False):
  """Helper method used to compute bond feature vectors.

  Many different featurization methods compute bond features
  such as WeaveFeaturizer. This method computes such features.

  Parameters
  ----------
  use_chirality: bool, optional
    If true, use chirality information.
  """
  from rdkit import Chem
  from rdkit import Chem
  bt = bond.GetBondType()
  bt = bond.GetBondType()
  bond_feats = [
  bond_feats = [
@@ -215,6 +253,26 @@ def bond_features(bond, use_chirality=False):


def pair_features(mol, edge_list, canon_adj_list, bt_len=6,
def pair_features(mol, edge_list, canon_adj_list, bt_len=6,
                  graph_distance=True):
                  graph_distance=True):
  """Helper method used to compute atom pair feature vectors.

  Many different featurization methods compute atom pair features
  such as WeaveFeaturizer. Note that atom pair features could be
  for pairs of atoms which aren't necessarily bonded to one
  another. 

  Parameters
  ----------
  mol: TODO
    TODO
  edge_list: list
    List of edges t oconsider
  canon_adj_list: list
    TODO
  bt_len: int, optional
    TODO
  graph_distance: bool, optional
    TODO
  """
  if graph_distance:
  if graph_distance:
    max_distance = 7
    max_distance = 7
  else:
  else:
@@ -271,6 +329,10 @@ def find_distance(a1, num_atoms, canon_adj_list, max_distance=7):




class ConvMolFeaturizer(Featurizer):
class ConvMolFeaturizer(Featurizer):
  """This class implements the featurization to implement graph convolutions from the Duvenaud graph convolution paper

Duvenaud, David K., et al. "Convolutional networks on graphs for learning molecular fingerprints." Advances in neural information processing systems. 2015.
  """
  name = ['conv_mol']
  name = ['conv_mol']


  def __init__(self, master_atom=False, use_chirality=False,
  def __init__(self, master_atom=False, use_chirality=False,
@@ -381,10 +443,26 @@ class ConvMolFeaturizer(Featurizer):




class WeaveFeaturizer(Featurizer):
class WeaveFeaturizer(Featurizer):
  """This class implements the featurization to implement Weave convolutions from the Google graph convolution paper.

  Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond fingerprints." Journal of computer-aided molecular design 30.8 (2016): 595-608.
  """

  name = ['weave_mol']
  name = ['weave_mol']


  def __init__(self, graph_distance=True, explicit_H=False,
  def __init__(self, graph_distance=True, explicit_H=False,
               use_chirality=False):
               use_chirality=False):
    """
    Parameters
    ----------
    graph_distance: bool, optional
      If true, use graph distance. Otherwise, use Euclidean
      distance.
    explicit_H: bool, optional
      If true, model hydrogens in the molecule.
    use_chirality: bool, optional
      If true, use chiral information in the featurization
    """
    # Distance is either graph distance(True) or Euclidean distance(False,
    # Distance is either graph distance(True) or Euclidean distance(False,
    # only support datasets providing Cartesian coordinates)
    # only support datasets providing Cartesian coordinates)
    self.graph_distance = graph_distance
    self.graph_distance = graph_distance
+6 −2
Original line number Original line Diff line number Diff line
@@ -386,8 +386,12 @@ class MultiConvMol(object):




class WeaveMol(object):
class WeaveMol(object):
  """Holds information about a molecule
  """Molecular featurization object for weave convolutions.
  Molecule struct used in weave models

  These objects are produced by WeaveFeaturizer, and feed into
  WeaveModel. The underlying implementation is inspired by:

  Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond fingerprints." Journal of computer-aided molecular design 30.8 (2016): 595-608.
  """
  """


  def __init__(self, nodes, pairs):
  def __init__(self, nodes, pairs):
+15 −10
Original line number Original line Diff line number Diff line
@@ -54,6 +54,7 @@ def initializeWeightsBiases(prev_layer_size,




class AtomicConvScore(Layer):
class AtomicConvScore(Layer):
  """The scoring function used by the atomic convolution models."""


  def __init__(self, atom_types, layer_sizes, **kwargs):
  def __init__(self, atom_types, layer_sizes, **kwargs):
    super(AtomicConvScore, self).__init__(**kwargs)
    super(AtomicConvScore, self).__init__(**kwargs)
@@ -145,6 +146,19 @@ class AtomicConvScore(Layer):




class AtomicConvModel(KerasModel):
class AtomicConvModel(KerasModel):
  """Implements an Atomic Convolution Model.

  Implements the atomic convolutional networks as introduced in

  Gomes, Joseph, et al. "Atomic convolutional networks for predicting protein-ligand binding affinity." arXiv preprint arXiv:1703.10603 (2017).

  The atomic convolutional networks function as a variant of
  graph convolutions. The difference is that the "graph" here is
  the nearest neighbors graph in 3D space. The AtomicConvModel
  leverages these connections in 3D space to train models that
  learn to predict energetic state starting from the spatial
  geometry of the model.
  """


  def __init__(self,
  def __init__(self,
               frag1_num_atoms=70,
               frag1_num_atoms=70,
@@ -163,16 +177,7 @@ class AtomicConvModel(KerasModel):
               layer_sizes=[32, 32, 16],
               layer_sizes=[32, 32, 16],
               learning_rate=0.001,
               learning_rate=0.001,
               **kwargs):
               **kwargs):
    """Implements an Atomic Convolution Model.
    """   

    Implements the atomic convolutional networks as introduced in
    https://arxiv.org/abs/1703.10603. The atomic convolutional networks
    function as a variant of graph convolutions. The difference is that the
    "graph" here is the nearest neighbors graph in 3D space. The
    AtomicConvModel leverages these connections in 3D space to train models
    that learn to predict energetic state starting from the spatial
    geometry of the model.

    Params
    Params
    ------
    ------
    frag1_num_atoms: int
    frag1_num_atoms: int
+6 −2
Original line number Original line Diff line number Diff line
@@ -417,7 +417,11 @@ class MultitaskFitTransformRegressor(MultitaskRegressor):
          dropout = np.array(1.0)
          dropout = np.array(1.0)
        yield ([X_b, dropout], [y_b], [w_b])
        yield ([X_b, dropout], [y_b], [w_b])


  def predict_on_generator(self, generator, transformers=[], outputs=None):
  def predict_on_generator(self,
                           generator,
                           transformers=[],
                           outputs=None,
                           output_types=None):


    def transform_generator():
    def transform_generator():
      for inputs, labels, weights in generator:
      for inputs, labels, weights in generator:
@@ -427,4 +431,4 @@ class MultitaskFitTransformRegressor(MultitaskRegressor):
        yield ([X_t] + inputs[1:], labels, weights)
        yield ([X_t] + inputs[1:], labels, weights)


    return super(MultitaskFitTransformRegressor, self).predict_on_generator(
    return super(MultitaskFitTransformRegressor, self).predict_on_generator(
        transform_generator(), transformers, outputs)
        transform_generator(), transformers, outputs, output_types)
+223 −89
Original line number Original line Diff line number Diff line
@@ -30,6 +30,21 @@ class TrimGraphOutput(tf.keras.layers.Layer):




class WeaveModel(KerasModel):
class WeaveModel(KerasModel):
  """Implements Google-style Weave Graph Convolutions

  This model implements the Weave style graph convolutions
  from the following paper.

  Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond fingerprints." Journal of computer-aided molecular design 30.8 (2016): 595-608.

  The biggest difference between WeaveModel style convolutions
  and GraphConvModel style convolutions is that Weave
  convolutions model bond features explicitly. This has the
  side effect that it needs to construct a NxN matrix
  explicitly to model bond interactions. This may cause
  scaling issues, but may possibly allow for better modeling
  of subtle bond effects.
  """


  def __init__(self,
  def __init__(self,
               n_tasks,
               n_tasks,
@@ -90,7 +105,9 @@ class WeaveModel(KerasModel):
        update_pair=False)(
        update_pair=False)(
            [weave_layer1A, weave_layer1P, pair_split, atom_to_pair])
            [weave_layer1A, weave_layer1P, pair_split, atom_to_pair])
    dense1 = Dense(self.n_graph_feat, activation=tf.nn.tanh)(weave_layer2A)
    dense1 = Dense(self.n_graph_feat, activation=tf.nn.tanh)(weave_layer2A)
    batch_norm1 = BatchNormalization(epsilon=1e-5)(dense1)
    # Batch normalization causes issues, spitting out NaNs if
    # allowed to train
    batch_norm1 = BatchNormalization(epsilon=1e-5, trainable=False)(dense1)
    weave_gather = layers.WeaveGather(
    weave_gather = layers.WeaveGather(
        batch_size, n_input=self.n_graph_feat,
        batch_size, n_input=self.n_graph_feat,
        gaussian_expand=True)([batch_norm1, atom_split])
        gaussian_expand=True)([batch_norm1, atom_split])
@@ -170,6 +187,12 @@ class WeaveModel(KerasModel):




class DTNNModel(KerasModel):
class DTNNModel(KerasModel):
  """Deep Tensor Neural Networks

  This class implements deep tensor neural networks as first defined in

  Schütt, Kristof T., et al. "Quantum-chemical insights from deep tensor neural networks." Nature communications 8.1 (2017): 1-8.
  """


  def __init__(self,
  def __init__(self,
               n_tasks,
               n_tasks,
@@ -322,6 +345,28 @@ class DTNNModel(KerasModel):




class DAGModel(KerasModel):
class DAGModel(KerasModel):
  """Directed Acyclic Graph models for molecular property prediction.

    This model is based on the following paper: 

    Lusci, Alessandro, Gianluca Pollastri, and Pierre Baldi. "Deep architectures and deep learning in chemoinformatics: the prediction of aqueous solubility for drug-like molecules." Journal of chemical information and modeling 53.7 (2013): 1563-1575.

   The basic idea for this paper is that a molecule is usually
   viewed as an undirected graph. However, you can convert it to
   a series of directed graphs. The idea is that for each atom,
   you make a DAG using that atom as the vertex of the DAG and
   edges pointing "inwards" to it. This transformation is
   implemented in
   `dc.trans.transformers.DAGTransformer.UG_to_DAG`.

   This model accepts ConvMols as input, just as GraphConvModel
   does, but these ConvMol objects must be transformed by
   dc.trans.DAGTransformer. 

   As a note, performance of this model can be a little
   sensitive to initialization. It might be worth training a few
   different instantiations to get a stable set of parameters.
   """


  def __init__(self,
  def __init__(self,
               n_tasks,
               n_tasks,
@@ -382,9 +427,13 @@ class DAGModel(KerasModel):
    if uncertainty:
    if uncertainty:
      if mode != "regression":
      if mode != "regression":
        raise ValueError("Uncertainty is only supported in regression mode")
        raise ValueError("Uncertainty is only supported in regression mode")
      if dropout == 0.0:
      if dropout is None or dropout == 0.0:
        raise ValueError('Dropout must be included to predict uncertainty')
        raise ValueError('Dropout must be included to predict uncertainty')


    ############################################
    print("self.dropout")
    print(self.dropout)
    ############################################
    # Build the model.
    # Build the model.


    atom_features = Input(shape=(self.n_atom_feat,))
    atom_features = Input(shape=(self.n_atom_feat,))
@@ -393,7 +442,6 @@ class DAGModel(KerasModel):
    calculation_masks = Input(shape=(self.max_atoms,), dtype=tf.bool)
    calculation_masks = Input(shape=(self.max_atoms,), dtype=tf.bool)
    membership = Input(shape=tuple(), dtype=tf.int32)
    membership = Input(shape=tuple(), dtype=tf.int32)
    n_atoms = Input(shape=tuple(), dtype=tf.int32)
    n_atoms = Input(shape=tuple(), dtype=tf.int32)
    dropout_switch = tf.keras.Input(shape=tuple())
    dag_layer1 = layers.DAGLayer(
    dag_layer1 = layers.DAGLayer(
        n_graph_feat=self.n_graph_feat,
        n_graph_feat=self.n_graph_feat,
        n_atom_feat=self.n_atom_feat,
        n_atom_feat=self.n_atom_feat,
@@ -402,14 +450,14 @@ class DAGModel(KerasModel):
        dropout=self.dropout,
        dropout=self.dropout,
        batch_size=batch_size)([
        batch_size=batch_size)([
            atom_features, parents, calculation_orders, calculation_masks,
            atom_features, parents, calculation_orders, calculation_masks,
            n_atoms, dropout_switch
            n_atoms
        ])
        ])
    dag_gather = layers.DAGGather(
    dag_gather = layers.DAGGather(
        n_graph_feat=self.n_graph_feat,
        n_graph_feat=self.n_graph_feat,
        n_outputs=self.n_outputs,
        n_outputs=self.n_outputs,
        max_atoms=self.max_atoms,
        max_atoms=self.max_atoms,
        layer_sizes=self.layer_sizes_gather,
        layer_sizes=self.layer_sizes_gather,
        dropout=self.dropout)([dag_layer1, membership, dropout_switch])
        dropout=self.dropout)([dag_layer1, membership])
    n_tasks = self.n_tasks
    n_tasks = self.n_tasks
    if self.mode == 'classification':
    if self.mode == 'classification':
      n_classes = self.n_classes
      n_classes = self.n_classes
@@ -436,8 +484,12 @@ class DAGModel(KerasModel):
        loss = L2Loss()
        loss = L2Loss()
    model = tf.keras.Model(
    model = tf.keras.Model(
        inputs=[
        inputs=[
            atom_features, parents, calculation_orders, calculation_masks,
            atom_features,
            membership, n_atoms, dropout_switch
            parents,
            calculation_orders,
            calculation_masks,
            membership,
            n_atoms  #, dropout_switch
        ],
        ],
        outputs=outputs)
        outputs=outputs)
    super(DAGModel, self).__init__(
    super(DAGModel, self).__init__(
@@ -495,7 +547,126 @@ class DAGModel(KerasModel):
        ], [y_b], [w_b])
        ], [y_b], [w_b])




class _GraphConvKerasModel(tf.keras.Model):

  def __init__(self,
               n_tasks,
               graph_conv_layers,
               dense_layer_size=128,
               dropout=0.0,
               mode="classification",
               number_atom_features=75,
               n_classes=2,
               batch_normalize=True,
               uncertainty=False,
               batch_size=100):
    """An internal keras model class.

    The graph convolutions use a nonstandard control flow so the
    standard Keras functional API can't support them. We instead
    use the imperative "subclassing" API to implement the graph
    convolutions.

    All arguments have the same meaning as in GraphConvModel.
    """
    super(_GraphConvKerasModel, self).__init__()
    if mode not in ['classification', 'regression']:
      raise ValueError("mode must be either 'classification' or 'regression'")

    self.mode = mode
    self.uncertainty = uncertainty

    if not isinstance(dropout, collections.Sequence):
      dropout = [dropout] * (len(graph_conv_layers) + 1)
    if len(dropout) != len(graph_conv_layers) + 1:
      raise ValueError('Wrong number of dropout probabilities provided')
    if uncertainty:
      if mode != "regression":
        raise ValueError("Uncertainty is only supported in regression mode")
      if any(d == 0.0 for d in dropout):
        raise ValueError(
            'Dropout must be included in every layer to predict uncertainty')

    self.graph_convs = [
        layers.GraphConv(layer_size, activation_fn=tf.nn.relu)
        for layer_size in graph_conv_layers
    ]
    self.batch_norms = [
        BatchNormalization(fused=False) if batch_normalize else None
        for _ in range(len(graph_conv_layers) + 1)
    ]
    self.dropouts = [
        Dropout(rate=rate) if rate > 0.0 else None for rate in dropout
    ]
    self.graph_pools = [layers.GraphPool() for _ in graph_conv_layers]
    self.dense = Dense(dense_layer_size, activation=tf.nn.relu)
    self.graph_gather = layers.GraphGather(
        batch_size=batch_size, activation_fn=tf.nn.tanh)
    self.trim = TrimGraphOutput()
    if self.mode == 'classification':
      self.reshape_dense = Dense(n_tasks * n_classes)
      self.reshape = Reshape((n_tasks, n_classes))
      self.softmax = Softmax()
    else:
      self.regression_dense = Dense(n_tasks)
      if self.uncertainty:
        self.uncertainty_dense = Dense(n_tasks)
        self.uncertainty_trim = TrimGraphOutput()
        self.uncertainty_activation = Activation(tf.exp)

  def call(self, inputs, training=False):
    atom_features = inputs[0]
    degree_slice = tf.cast(inputs[1], dtype=tf.int32)
    membership = tf.cast(inputs[2], dtype=tf.int32)
    n_samples = tf.cast(inputs[3], dtype=tf.int32)
    deg_adjs = [tf.cast(deg_adj, dtype=tf.int32) for deg_adj in inputs[4:]]

    in_layer = atom_features
    for i in range(len(self.graph_convs)):
      gc_in = [in_layer, degree_slice, membership] + deg_adjs
      gc1 = self.graph_convs[i](gc_in)
      if self.batch_norms[i] is not None:
        gc1 = self.batch_norms[i](gc1, training=training)
      if training and self.dropouts[i] is not None:
        gc1 = self.dropouts[i](gc1, training=training)
      gp_in = [gc1, degree_slice, membership] + deg_adjs
      in_layer = self.graph_pools[i](gp_in)
    dense = self.dense(in_layer)
    if self.batch_norms[-1] is not None:
      dense = self.batch_norms[-1](dense, training=training)
    if training and self.dropouts[-1] is not None:
      dense = self.dropouts[1](dense, training=training)
    neural_fingerprint = self.graph_gather([dense, degree_slice, membership] +
                                           deg_adjs)
    if self.mode == 'classification':
      logits = self.reshape(self.reshape_dense(neural_fingerprint))
      logits = self.trim([logits, n_samples])
      output = self.softmax(logits)
      outputs = [output, logits, neural_fingerprint]
    else:
      output = self.regression_dense(neural_fingerprint)
      output = self.trim([output, n_samples])
      if self.uncertainty:
        log_var = self.uncertainty_dense(neural_fingerprint)
        log_var = self.uncertainty_trim([log_var, n_samples])
        var = self.uncertainty_activation(log_var)
        outputs = [output, var, output, log_var, neural_fingerprint]
      else:
        outputs = [output, neural_fingerprint]

    return outputs


class GraphConvModel(KerasModel):
class GraphConvModel(KerasModel):
  """Graph Convolutional Models.

  This class implements the graph convolutional model from the
  following paper:


  Duvenaud, David K., et al. "Convolutional networks on graphs for learning molecular fingerprints." Advances in neural information processing systems. 2015.

  """


  def __init__(self,
  def __init__(self,
               n_tasks,
               n_tasks,
@@ -505,10 +676,16 @@ class GraphConvModel(KerasModel):
               mode="classification",
               mode="classification",
               number_atom_features=75,
               number_atom_features=75,
               n_classes=2,
               n_classes=2,
               uncertainty=False,
               batch_size=100,
               batch_size=100,
               batch_normalize=True,
               uncertainty=False,
               **kwargs):
               **kwargs):
    """
    """The wrapper class for graph convolutions.

    Note that since the underlying _GraphConvKerasModel class is
    specified using imperative subclassing style, this model
    cannout make predictions for arbitrary outputs. 

    Parameters
    Parameters
    ----------
    ----------
    n_tasks: int
    n_tasks: int
@@ -530,98 +707,47 @@ class GraphConvModel(KerasModel):
        function atom_features in graph_features
        function atom_features in graph_features
    n_classes: int
    n_classes: int
      the number of classes to predict (only used in classification mode)
      the number of classes to predict (only used in classification mode)
    batch_normalize: True
      if True, apply batch normalization to model
    uncertainty: bool
    uncertainty: bool
      if True, include extra outputs and loss terms to enable the uncertainty
      if True, include extra outputs and loss terms to enable the uncertainty
      in outputs to be predicted
      in outputs to be predicted
    """
    """
    if mode not in ['classification', 'regression']:
      raise ValueError("mode must be either 'classification' or 'regression'")
    self.n_tasks = n_tasks
    self.mode = mode
    self.mode = mode
    self.dense_layer_size = dense_layer_size
    self.n_tasks = n_tasks
    self.graph_conv_layers = graph_conv_layers
    self.number_atom_features = number_atom_features
    self.n_classes = n_classes
    self.n_classes = n_classes
    self.batch_size = batch_size
    self.uncertainty = uncertainty
    self.uncertainty = uncertainty
    if not isinstance(dropout, collections.Sequence):
    model = _GraphConvKerasModel(
      dropout = [dropout] * (len(graph_conv_layers) + 1)
        n_tasks,
    if len(dropout) != len(graph_conv_layers) + 1:
        graph_conv_layers=graph_conv_layers,
      raise ValueError('Wrong number of dropout probabilities provided')
        dense_layer_size=dense_layer_size,
    self.dropout = dropout
        dropout=dropout,
    if uncertainty:
        mode=mode,
      if mode != "regression":
        number_atom_features=number_atom_features,
        raise ValueError("Uncertainty is only supported in regression mode")
        n_classes=n_classes,
      if any(d == 0.0 for d in dropout):
        batch_normalize=batch_normalize,
        raise ValueError(
        uncertainty=uncertainty,
            'Dropout must be included in every layer to predict uncertainty')
        batch_size=batch_size)

    if mode == "classification":
    # Build the model.
      output_types = ['prediction', 'loss', 'embedding']

    atom_features = Input(shape=(self.number_atom_features,))
    degree_slice = Input(shape=(2,), dtype=tf.int32)
    membership = Input(shape=tuple(), dtype=tf.int32)
    n_samples = Input(shape=tuple(), dtype=tf.int32)
    dropout_switch = tf.keras.Input(shape=tuple())

    self.deg_adjs = []
    for i in range(0, 10 + 1):
      deg_adj = Input(shape=(i + 1,), dtype=tf.int32)
      self.deg_adjs.append(deg_adj)
    in_layer = atom_features
    for layer_size, dropout in zip(self.graph_conv_layers, self.dropout):
      gc1_in = [in_layer, degree_slice, membership] + self.deg_adjs
      gc1 = layers.GraphConv(layer_size, activation_fn=tf.nn.relu)(gc1_in)
      batch_norm1 = BatchNormalization(fused=False)(gc1)
      if dropout > 0.0:
        batch_norm1 = layers.SwitchedDropout(rate=dropout)(
            [batch_norm1, dropout_switch])
      gp_in = [batch_norm1, degree_slice, membership] + self.deg_adjs
      in_layer = layers.GraphPool()(gp_in)
    dense = Dense(self.dense_layer_size, activation=tf.nn.relu)(in_layer)
    batch_norm3 = BatchNormalization(fused=False)(dense)
    if self.dropout[-1] > 0.0:
      batch_norm3 = layers.SwitchedDropout(rate=self.dropout[-1])(
          [batch_norm3, dropout_switch])
    self.neural_fingerprint = layers.GraphGather(
        batch_size=batch_size,
        activation_fn=tf.nn.tanh)([batch_norm3, degree_slice, membership] +
                                  self.deg_adjs)

    n_tasks = self.n_tasks
    if self.mode == 'classification':
      n_classes = self.n_classes
      logits = Reshape((n_tasks, n_classes))(Dense(n_tasks * n_classes)(
          self.neural_fingerprint))
      logits = TrimGraphOutput()([logits, n_samples])
      output = Softmax()(logits)
      outputs = [output, logits]
      output_types = ['prediction', 'loss']
      loss = SoftmaxCrossEntropy()
      loss = SoftmaxCrossEntropy()
    else:
    else:
      output = Dense(n_tasks)(self.neural_fingerprint)
      output = TrimGraphOutput()([output, n_samples])
      if self.uncertainty:
      if self.uncertainty:
        log_var = Dense(n_tasks)(self.neural_fingerprint)
        output_types = ['prediction', 'variance', 'loss', 'loss', 'embedding']
        log_var = TrimGraphOutput()([log_var, n_samples])
        var = Activation(tf.exp)(log_var)
        outputs = [output, var, output, log_var]
        output_types = ['prediction', 'variance', 'loss', 'loss']


        def loss(outputs, labels, weights):
        def loss(outputs, labels, weights):
          diff = labels[0] - outputs[0]
          diff = labels[0] - outputs[0]
          return tf.reduce_mean(diff * diff / tf.exp(outputs[1]) + outputs[1])
          return tf.reduce_mean(diff * diff / tf.exp(outputs[1]) + outputs[1])
      else:
      else:
        outputs = [output]
        output_types = ['prediction', 'embedding']
        output_types = ['prediction']
        loss = L2Loss()
        loss = L2Loss()
    model = tf.keras.Model(
        inputs=[
            atom_features, degree_slice, membership, n_samples, dropout_switch
        ] + self.deg_adjs,
        outputs=outputs)
    super(GraphConvModel, self).__init__(
    super(GraphConvModel, self).__init__(
        model, loss, output_types=output_types, batch_size=batch_size, **kwargs)
        model, loss, output_types=output_types, batch_size=batch_size, **kwargs)


  def fit(self, *args, **kwargs):
    super(GraphConvModel, self).fit(*args, **kwargs)

  def default_generator(self,
  def default_generator(self,
                        dataset,
                        dataset,
                        epochs=1,
                        epochs=1,
@@ -638,13 +764,9 @@ class GraphConvModel(KerasModel):
              -1, self.n_tasks, self.n_classes)
              -1, self.n_tasks, self.n_classes)
        multiConvMol = ConvMol.agglomerate_mols(X_b)
        multiConvMol = ConvMol.agglomerate_mols(X_b)
        n_samples = np.array(X_b.shape[0])
        n_samples = np.array(X_b.shape[0])
        if mode == 'predict':
          dropout = np.array(0.0)
        else:
          dropout = np.array(1.0)
        inputs = [
        inputs = [
            multiConvMol.get_atom_features(), multiConvMol.deg_slice,
            multiConvMol.get_atom_features(), multiConvMol.deg_slice,
            np.array(multiConvMol.membership), n_samples, dropout
            np.array(multiConvMol.membership), n_samples
        ]
        ]
        for i in range(1, len(multiConvMol.get_deg_adjacency_lists())):
        for i in range(1, len(multiConvMol.get_deg_adjacency_lists())):
          inputs.append(multiConvMol.get_deg_adjacency_lists()[i])
          inputs.append(multiConvMol.get_deg_adjacency_lists()[i])
@@ -653,7 +775,19 @@ class GraphConvModel(KerasModel):


class MPNNModel(KerasModel):
class MPNNModel(KerasModel):
  """ Message Passing Neural Network,
  """ Message Passing Neural Network,
      default structures built according to https://arxiv.org/abs/1511.06391 """

  Message Passing Neural Networks treat graph convolutional
  operations as an instantiation of a more general message
  passing schem. Recall that message passing in a graph is when
  nodes in a graph send each other "messages" and update their
  internal state as a consequence of these messages.

  Ordering structures in this model are built according to


Vinyals, Oriol, Samy Bengio, and Manjunath Kudlur. "Order matters: Sequence to sequence for sets." arXiv preprint arXiv:1511.06391 (2015).

  """


  def __init__(self,
  def __init__(self,
               n_tasks,
               n_tasks,
Loading