Commit 0c241128 authored by peastman's avatar peastman
Browse files

Continue converting graph models to KerasModel

parent 1533d1db
Loading
Loading
Loading
Loading
+119 −146
Original line number Diff line number Diff line
@@ -9,17 +9,10 @@ from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.mol_graphs import ConvMol
from deepchem.metrics import to_one_hot
from deepchem.models import KerasModel
from deepchem.models.losses import L2Loss, SoftmaxCrossEntropy
import deepchem.models.layers as layers
from deepchem.models.tensorgraph.graph_layers import WeaveGather, \
    DTNNEmbedding, DTNNStep, DTNNGather, DAGLayer, \
    DAGGather, DTNNExtract, MessagePassing, SetGather
from deepchem.models.tensorgraph.graph_layers import WeaveLayerFactory
from deepchem.models.tensorgraph.layers import Layer, Dense, SoftMax, Reshape, \
    SoftMaxCrossEntropy, GraphConv, BatchNorm, Exp, ReduceMean, ReduceSum, \
    GraphPool, GraphGather, WeightedError, Dropout, BatchNorm, Stack, Flatten, GraphCNN, GraphCNNPool
from deepchem.models.tensorgraph.layers import L2Loss, Label, Weights, Feature
from deepchem.models.tensorgraph.tensor_graph import TensorGraph
from deepchem.trans import undo_transforms
from tensorflow.keras.layers import Input, Dense, Reshape, Softmax, Dropout, Activation, BatchNormalization


class TrimGraphOutput(tf.keras.layers.Layer):
@@ -79,11 +72,11 @@ class WeaveModel(KerasModel):

    # Build the model.

    self.atom_features = tf.keras.Input(shape=(self.n_atom_feat,))
    self.pair_features = tf.keras.Input(shape=(self.n_pair_feat,))
    self.pair_split = tf.keras.Input(shape=tuple(), dtype=tf.int32)
    self.atom_split = tf.keras.Input(shape=tuple(), dtype=tf.int32)
    self.atom_to_pair = tf.keras.Input(shape=(2,), dtype=tf.int32)
    self.atom_features = Input(shape=(self.n_atom_feat,))
    self.pair_features = Input(shape=(self.n_pair_feat,))
    self.pair_split = Input(shape=tuple(), dtype=tf.int32)
    self.atom_split = Input(shape=tuple(), dtype=tf.int32)
    self.atom_to_pair = Input(shape=(2,), dtype=tf.int32)
    weave_layer1A, weave_layer1P = layers.WeaveLayer(
        n_atom_input_feat=self.n_atom_feat,
        n_pair_input_feat=self.n_pair_feat,
@@ -99,9 +92,8 @@ class WeaveModel(KerasModel):
        n_pair_output_feat=self.n_hidden,
        update_pair=False)(
            [weave_layer1A, weave_layer1P, self.pair_split, self.atom_to_pair])
    dense1 = tf.keras.layers.Dense(
        self.n_graph_feat, activation=tf.nn.tanh)(weave_layer2A)
    batch_norm1 = tf.keras.layers.BatchNormalization(epsilon=1e-5)(dense1)
    dense1 = Dense(self.n_graph_feat, activation=tf.nn.tanh)(weave_layer2A)
    batch_norm1 = BatchNormalization(epsilon=1e-5)(dense1)
    weave_gather = layers.WeaveGather(
        batch_size, n_input=self.n_graph_feat,
        gaussian_expand=True)([batch_norm1, self.atom_split])
@@ -109,17 +101,17 @@ class WeaveModel(KerasModel):
    n_tasks = self.n_tasks
    if self.mode == 'classification':
      n_classes = self.n_classes
      logits = tf.keras.layers.Reshape((n_tasks, n_classes))(
          tf.keras.layers.Dense(n_tasks * n_classes)(weave_gather))
      output = tf.keras.layers.Softmax()(logits)
      logits = Reshape((n_tasks,
                        n_classes))(Dense(n_tasks * n_classes)(weave_gather))
      output = Softmax()(logits)
      outputs = [output, logits]
      output_types = ['prediction', 'loss']
      loss = dc.models.losses.SoftmaxCrossEntropy()
      loss = SoftmaxCrossEntropy()
    else:
      output = tf.keras.layers.Dense(n_tasks)(weave_gather)
      output = Dense(n_tasks)(weave_gather)
      outputs = [output]
      output_types = ['prediction']
      loss = dc.models.losses.L2Loss()
      loss = L2Loss()
    model = tf.keras.Model(
        inputs=[
            self.atom_features, self.pair_features, self.pair_split,
@@ -233,31 +225,30 @@ class DTNNModel(KerasModel):

    # Build the model.

    self.atom_number = tf.keras.Input(shape=tuple(), dtype=tf.int32)
    self.distance = tf.keras.Input(shape=(self.n_distance,))
    self.atom_membership = tf.keras.Input(shape=tuple(), dtype=tf.int32)
    self.distance_membership_i = tf.keras.Input(shape=tuple(), dtype=tf.int32)
    self.distance_membership_j = tf.keras.Input(shape=tuple(), dtype=tf.int32)
    self.atom_number = Input(shape=tuple(), dtype=tf.int32)
    self.distance = Input(shape=(self.n_distance,))
    self.atom_membership = Input(shape=tuple(), dtype=tf.int32)
    self.distance_membership_i = Input(shape=tuple(), dtype=tf.int32)
    self.distance_membership_j = Input(shape=tuple(), dtype=tf.int32)

    dtnn_embedding = layers.DTNNEmbedding(n_embedding=self.n_embedding)(
        self.atom_number)
    if self.dropout > 0.0:
      dtnn_embedding = tf.keras.layers.Dropout(
          rate=self.dropout)(dtnn_embedding)
      dtnn_embedding = Dropout(rate=self.dropout)(dtnn_embedding)
    dtnn_layer1 = layers.DTNNStep(
        n_embedding=self.n_embedding, n_distance=self.n_distance)([
            dtnn_embedding, self.distance, self.distance_membership_i,
            self.distance_membership_j
        ])
    if self.dropout > 0.0:
      dtnn_layer1 = tf.keras.layers.Dropout(rate=self.dropout)(dtnn_layer1)
      dtnn_layer1 = Dropout(rate=self.dropout)(dtnn_layer1)
    dtnn_layer2 = layers.DTNNStep(
        n_embedding=self.n_embedding, n_distance=self.n_distance)([
            dtnn_layer1, self.distance, self.distance_membership_i,
            self.distance_membership_j
        ])
    if self.dropout > 0.0:
      dtnn_layer2 = tf.keras.layers.Dropout(rate=self.dropout)(dtnn_layer2)
      dtnn_layer2 = Dropout(rate=self.dropout)(dtnn_layer2)
    dtnn_gather = layers.DTNNGather(
        n_embedding=self.n_embedding,
        layer_sizes=[self.n_hidden],
@@ -265,17 +256,17 @@ class DTNNModel(KerasModel):
        output_activation=self.output_activation)(
            [dtnn_layer2, self.atom_membership])
    if self.dropout > 0.0:
      dtnn_gather = tf.keras.layers.Dropout(rate=self.dropout)(dtnn_gather)
      dtnn_gather = Dropout(rate=self.dropout)(dtnn_gather)

    n_tasks = self.n_tasks
    output = tf.keras.layers.Dense(n_tasks)(dtnn_gather)
    output = Dense(n_tasks)(dtnn_gather)
    model = tf.keras.Model(
        inputs=[
            self.atom_number, self.distance, self.atom_membership,
            self.distance_membership_i, self.distance_membership_j
        ],
        outputs=[output])
    super(DTNNModel, self).__init__(model, dc.models.losses.L2Loss(), **kwargs)
    super(DTNNModel, self).__init__(model, L2Loss(), **kwargs)

  def compute_features_on_batch(self, X_b):
    """Computes the values for different Feature Layers on given batch
@@ -401,15 +392,12 @@ class DAGModel(KerasModel):

    # Build the model.

    self.atom_features = tf.keras.Input(shape=(self.n_atom_feat,))
    self.parents = tf.keras.Input(
        shape=(self.max_atoms, self.max_atoms), dtype=tf.int32)
    self.calculation_orders = tf.keras.Input(
        shape=(self.max_atoms,), dtype=tf.int32)
    self.calculation_masks = tf.keras.Input(
        shape=(self.max_atoms,), dtype=tf.bool)
    self.membership = tf.keras.Input(shape=tuple(), dtype=tf.int32)
    self.n_atoms = tf.keras.Input(shape=tuple(), dtype=tf.int32)
    self.atom_features = Input(shape=(self.n_atom_feat,))
    self.parents = Input(shape=(self.max_atoms, self.max_atoms), dtype=tf.int32)
    self.calculation_orders = Input(shape=(self.max_atoms,), dtype=tf.int32)
    self.calculation_masks = Input(shape=(self.max_atoms,), dtype=tf.bool)
    self.membership = Input(shape=tuple(), dtype=tf.int32)
    self.n_atoms = Input(shape=tuple(), dtype=tf.int32)
    dag_layer1 = layers.DAGLayer(
        n_graph_feat=self.n_graph_feat,
        n_atom_feat=self.n_atom_feat,
@@ -429,17 +417,17 @@ class DAGModel(KerasModel):
    n_tasks = self.n_tasks
    if self.mode == 'classification':
      n_classes = self.n_classes
      logits = tf.keras.layers.Reshape((n_tasks, n_classes))(
          tf.keras.layers.Dense(n_tasks * n_classes)(dag_gather))
      output = tf.keras.layers.Softmax()(logits)
      logits = Reshape((n_tasks,
                        n_classes))(Dense(n_tasks * n_classes)(dag_gather))
      output = Softmax()(logits)
      outputs = [output, logits]
      output_types = ['prediction', 'loss']
      loss = dc.models.losses.SoftmaxCrossEntropy()
      loss = SoftmaxCrossEntropy()
    else:
      output = tf.keras.layers.Dense(n_tasks)(dag_gather)
      output = Dense(n_tasks)(dag_gather)
      if self.uncertainty:
        log_var = Dense(n_tasks)(dag_gather)
        var = tf.keras.layers.Activation(tf.exp)(log_var)
        var = Activation(tf.exp)(log_var)
        outputs = [output, var, output, log_var]
        output_types = ['prediction', 'variance', 'loss', 'loss']

@@ -449,7 +437,7 @@ class DAGModel(KerasModel):
      else:
        outputs = [output]
        output_types = ['prediction']
        loss = dc.models.losses.L2Loss()
        loss = L2Loss()
    model = tf.keras.Model(
        inputs=[
            self.atom_features, self.parents, self.calculation_orders,
@@ -569,29 +557,28 @@ class GraphConvModel(KerasModel):

    # Build the model.

    self.atom_features = tf.keras.Input(shape=(self.number_atom_features,))
    self.degree_slice = tf.keras.Input(shape=(2,), dtype=tf.int32)
    self.membership = tf.keras.Input(shape=tuple(), dtype=tf.int32)
    self.n_samples = tf.keras.Input(shape=tuple(), dtype=tf.int32)
    self.atom_features = Input(shape=(self.number_atom_features,))
    self.degree_slice = Input(shape=(2,), dtype=tf.int32)
    self.membership = Input(shape=tuple(), dtype=tf.int32)
    self.n_samples = Input(shape=tuple(), dtype=tf.int32)

    self.deg_adjs = []
    for i in range(0, 10 + 1):
      deg_adj = tf.keras.Input(shape=(i + 1,), dtype=tf.int32)
      deg_adj = Input(shape=(i + 1,), dtype=tf.int32)
      self.deg_adjs.append(deg_adj)
    in_layer = self.atom_features
    for layer_size, dropout in zip(self.graph_conv_layers, self.dropout):
      gc1_in = [in_layer, self.degree_slice, self.membership] + self.deg_adjs
      gc1 = layers.GraphConv(layer_size, activation_fn=tf.nn.relu)(gc1_in)
      batch_norm1 = tf.keras.layers.BatchNormalization(fused=False)(gc1)
      batch_norm1 = BatchNormalization(fused=False)(gc1)
      if dropout > 0.0:
        batch_norm1 = tf.keras.layers.Dropout(rate=dropout)(batch_norm1)
        batch_norm1 = Dropout(rate=dropout)(batch_norm1)
      gp_in = [batch_norm1, self.degree_slice, self.membership] + self.deg_adjs
      in_layer = layers.GraphPool()(gp_in)
    dense = tf.keras.layers.Dense(
        self.dense_layer_size, activation=tf.nn.relu)(in_layer)
    batch_norm3 = tf.keras.layers.BatchNormalization(fused=False)(dense)
    dense = Dense(self.dense_layer_size, activation=tf.nn.relu)(in_layer)
    batch_norm3 = BatchNormalization(fused=False)(dense)
    if self.dropout[-1] > 0.0:
      batch_norm3 = tf.keras.layers.Dropout(rate=self.dropout[-1])(batch_norm3)
      batch_norm3 = Dropout(rate=self.dropout[-1])(batch_norm3)
    self.neural_fingerprint = layers.GraphGather(
        batch_size=batch_size, activation_fn=tf.nn.tanh)(
            [batch_norm3, self.degree_slice, self.membership] + self.deg_adjs)
@@ -599,20 +586,20 @@ class GraphConvModel(KerasModel):
    n_tasks = self.n_tasks
    if self.mode == 'classification':
      n_classes = self.n_classes
      logits = tf.keras.layers.Reshape((n_tasks, n_classes))(
          tf.keras.layers.Dense(n_tasks * n_classes)(self.neural_fingerprint))
      logits = Reshape((n_tasks, n_classes))(Dense(n_tasks * n_classes)(
          self.neural_fingerprint))
      logits = TrimGraphOutput()([logits, self.n_samples])
      output = tf.keras.layers.Softmax()(logits)
      output = Softmax()(logits)
      outputs = [output, logits]
      output_types = ['prediction', 'loss']
      loss = dc.models.losses.SoftmaxCrossEntropy()
      loss = SoftmaxCrossEntropy()
    else:
      output = tf.keras.layers.Dense(n_tasks)(self.neural_fingerprint)
      output = Dense(n_tasks)(self.neural_fingerprint)
      output = TrimGraphOutput()([output, self.n_samples])
      if self.uncertainty:
        log_var = tf.keras.layers.Dense(n_tasks)(self.neural_fingerprint)
        log_var = Dense(n_tasks)(self.neural_fingerprint)
        log_var = TrimGraphOutput()([log_var, self.n_samples])
        var = tf.keras.layers.Activation(tf.exp)(log_var)
        var = Activation(tf.exp)(log_var)
        outputs = [output, var, output, log_var]
        output_types = ['prediction', 'variance', 'loss', 'loss']

@@ -622,7 +609,7 @@ class GraphConvModel(KerasModel):
      else:
        outputs = [output]
        output_types = ['prediction']
        loss = dc.models.losses.L2Loss()
        loss = L2Loss()
    model = tf.keras.Model(
        inputs=[
            self.atom_features, self.degree_slice, self.membership,
@@ -657,7 +644,7 @@ class GraphConvModel(KerasModel):
        yield (inputs, [y_b], [w_b])


class MPNNModel(TensorGraph):
class MPNNModel(KerasModel):
  """ Message Passing Neural Network,
      default structures built according to https://arxiv.org/abs/1511.06391 """

@@ -672,6 +659,7 @@ class MPNNModel(TensorGraph):
               dropout=0.0,
               n_classes=2,
               uncertainty=False,
               batch_size=100,
               **kwargs):
    """
    Parameters
@@ -710,70 +698,61 @@ class MPNNModel(TensorGraph):
        raise ValueError("Uncertainty is only supported in regression mode")
      if dropout == 0.0:
        raise ValueError('Dropout must be included to predict uncertainty')
    super(MPNNModel, self).__init__(**kwargs)
    self.build_graph()

  def build_graph(self):
    # Build placeholders
    self.atom_features = Feature(shape=(None, self.n_atom_feat))
    self.pair_features = Feature(shape=(None, self.n_pair_feat))
    self.atom_split = Feature(shape=(None,), dtype=tf.int32)
    self.atom_to_pair = Feature(shape=(None, 2), dtype=tf.int32)

    message_passing = MessagePassing(
        self.T,
        message_fn='enn',
        update_fn='gru',
        n_hidden=self.n_hidden,
        in_layers=[self.atom_features, self.pair_features, self.atom_to_pair])

    atom_embeddings = Dense(self.n_hidden, in_layers=[message_passing])

    mol_embeddings = SetGather(
        self.M,
        self.batch_size,
        n_hidden=self.n_hidden,
        in_layers=[atom_embeddings, self.atom_split])

    dense1 = Dense(
        out_channels=2 * self.n_hidden,
        activation_fn=tf.nn.relu,
        in_layers=[mol_embeddings])

    # Build the model.

    self.atom_features = Input(shape=(self.n_atom_feat,))
    self.pair_features = Input(shape=(self.n_pair_feat,))
    self.atom_split = Input(shape=tuple(), dtype=tf.int32)
    self.atom_to_pair = Input(shape=(2,), dtype=tf.int32)
    self.n_samples = Input(shape=tuple(), dtype=tf.int32)

    message_passing = layers.MessagePassing(
        self.T, message_fn='enn', update_fn='gru', n_hidden=self.n_hidden)(
            [self.atom_features, self.pair_features, self.atom_to_pair])

    atom_embeddings = Dense(self.n_hidden)(message_passing)

    mol_embeddings = layers.SetGather(
        self.M, batch_size,
        n_hidden=self.n_hidden)([atom_embeddings, self.atom_split])

    dense1 = Dense(2 * self.n_hidden, activation=tf.nn.relu)(mol_embeddings)

    n_tasks = self.n_tasks
    weights = Weights(shape=(None, n_tasks))
    if self.mode == 'classification':
      n_classes = self.n_classes
      labels = Label(shape=(None, n_tasks, n_classes))
      logits = Reshape(
          shape=(None, n_tasks, n_classes),
          in_layers=[Dense(in_layers=dense1, out_channels=n_tasks * n_classes)])
      logits = TrimGraphOutput([logits, weights])
      output = SoftMax(logits)
      self.add_output(output)
      loss = SoftMaxCrossEntropy(in_layers=[labels, logits])
      weighted_loss = WeightedError(in_layers=[loss, weights])
      self.set_loss(weighted_loss)
      logits = Reshape((n_tasks, n_classes))(Dense(n_tasks * n_classes)(dense1))
      logits = TrimGraphOutput()([logits, self.n_samples])
      output = Softmax()(logits)
      outputs = [output, logits]
      output_types = ['prediction', 'loss']
      loss = SoftmaxCrossEntropy()
    else:
      labels = Label(shape=(None, n_tasks))
      output = Reshape(
          shape=(None, n_tasks),
          in_layers=[Dense(in_layers=dense1, out_channels=n_tasks)])
      output = TrimGraphOutput([output, weights])
      self.add_output(output)
      output = Dense(n_tasks)(dense1)
      output = TrimGraphOutput()([output, self.n_samples])
      if self.uncertainty:
        log_var = Reshape(
            shape=(None, n_tasks),
            in_layers=[Dense(in_layers=dense1, out_channels=n_tasks)])
        log_var = TrimGraphOutput([log_var, weights])
        var = Exp(log_var)
        self.add_variance(var)
        diff = labels - output
        weighted_loss = weights * (diff * diff / var + log_var)
        weighted_loss = ReduceSum(ReduceMean(weighted_loss, axis=[1]))
        log_var = Dense(n_tasks)(dense1)
        log_var = TrimGraphOutput()([log_var, self.n_samples])
        var = Activation(tf.exp)(log_var)
        outputs = [output, var, output, log_var]
        output_types = ['prediction', 'variance', 'loss', 'loss']

        def loss(outputs, labels, weights):
          diff = labels[0] - outputs[0]
          return tf.reduce_mean(diff * diff / tf.exp(outputs[1]) + outputs[1])
      else:
        weighted_loss = ReduceSum(L2Loss(in_layers=[labels, output, weights]))
      self.set_loss(weighted_loss)
        outputs = [output]
        output_types = ['prediction']
        loss = L2Loss()
    model = tf.keras.Model(
        inputs=[
            self.atom_features, self.pair_features, self.atom_split,
            self.atom_to_pair, self.n_samples
        ],
        outputs=outputs)
    super(MPNNModel, self).__init__(
        model, loss, output_types=output_types, batch_size=batch_size, **kwargs)

  def default_generator(self,
                        dataset,
@@ -781,25 +760,17 @@ class MPNNModel(TensorGraph):
                        predict=False,
                        deterministic=True,
                        pad_batches=True):
    """ Same generator as Weave models """
    for epoch in range(epochs):
      for (X_b, y_b, w_b, ids_b) in dataset.iterbatches(
          batch_size=self.batch_size,
          deterministic=deterministic,
          pad_batches=False):
          pad_batches=pad_batches):

        n_samples = np.array(X_b.shape[0])
        X_b = pad_features(self.batch_size, X_b)
        feed_dict = dict()
        if y_b is not None:
          if self.mode == 'classification':
            feed_dict[self.labels[0]] = to_one_hot(y_b.flatten(),
                                                   self.n_classes).reshape(
                                                       -1, self.n_tasks,
                                                       self.n_classes)
          else:
            feed_dict[self.labels[0]] = y_b
        if w_b is not None:
          feed_dict[self.task_weights[0]] = w_b
        if y_b is not None and self.mode == 'classification':
          y_b = to_one_hot(y_b.flatten(), self.n_classes).reshape(
              -1, self.n_tasks, self.n_classes)

        atom_feat = []
        pair_feat = []
@@ -828,11 +799,13 @@ class MPNNModel(TensorGraph):
              np.reshape(mol.get_pair_features(),
                         (n_atoms * n_atoms, self.n_pair_feat)))

        feed_dict[self.atom_features] = np.concatenate(atom_feat, axis=0)
        feed_dict[self.pair_features] = np.concatenate(pair_feat, axis=0)
        feed_dict[self.atom_split] = np.array(atom_split)
        feed_dict[self.atom_to_pair] = np.concatenate(atom_to_pair, axis=0)
        yield feed_dict
        inputs = [
            np.concatenate(atom_feat, axis=0),
            np.concatenate(pair_feat, axis=0),
            np.array(atom_split),
            np.concatenate(atom_to_pair, axis=0), n_samples
        ]
        yield (inputs, [y_b], [w_b])


#################### Deprecation warnings for renamed TensorGraph models ####################
+0 −12
Original line number Diff line number Diff line
@@ -238,12 +238,6 @@ class TestGraphModels(unittest.TestCase):
    scores = model.evaluate(dataset, [metric], transformers)
    assert scores['mean-roc_auc_score'] >= 0.9

    model.save()
    model = TensorGraph.load_from_dir(model.model_dir)
    scores2 = model.evaluate(dataset, [metric], transformers)
    assert np.allclose(scores['mean-roc_auc_score'],
                       scores2['mean-roc_auc_score'])

  @attr("slow")
  def test_mpnn_regression_model(self):
    tasks, dataset, transformers, metric = self.get_dataset(
@@ -262,12 +256,6 @@ class TestGraphModels(unittest.TestCase):
    scores = model.evaluate(dataset, [metric], transformers)
    assert all(s < 0.1 for s in scores['mean_absolute_error'])

    model.save()
    model = TensorGraph.load_from_dir(model.model_dir)
    scores2 = model.evaluate(dataset, [metric], transformers)
    assert np.allclose(scores['mean_absolute_error'],
                       scores2['mean_absolute_error'])

  @attr("slow")
  def test_mpnn_regression_uncertainty(self):
    tasks, dataset, transformers, metric = self.get_dataset(