Unverified Commit b3bc14ac authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2026 from deepchem/weave_fix

Fix Weave Models
parents 40604945 6f87f962
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -19,6 +19,6 @@ class TestCSVLoader(TestCase):
    loader = dc.data.CSVLoader(
        tasks=tasks, smiles_field="smiles", featurizer=featurizer)

    X = loader.featurize(fin.name)
    X = loader.create_dataset(fin.name)
    self.assertEqual(1, len(X))
    os.remove(fin.name)
+240 −69
Original line number Diff line number Diff line
@@ -4,12 +4,14 @@ import deepchem as dc
import numpy as np
import tensorflow as tf

from deepchem.data import NumpyDataset, pad_features
from typing import List, Union, Tuple, Iterable, Dict
from deepchem.utils.typing import OneOrMany, KerasLossFn, KerasActivationFn
from deepchem.data import Dataset, NumpyDataset, pad_features
from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.mol_graphs import ConvMol
from deepchem.metrics import to_one_hot
from deepchem.models import KerasModel, layers
from deepchem.models.losses import L2Loss, SoftmaxCrossEntropy
from deepchem.models.losses import L2Loss, SoftmaxCrossEntropy, Loss
from deepchem.trans import undo_transforms
from tensorflow.keras.layers import Input, Dense, Reshape, Softmax, Dropout, Activation, BatchNormalization

@@ -33,9 +35,7 @@ class WeaveModel(KerasModel):
  """Implements Google-style Weave Graph Convolutions

  This model implements the Weave style graph convolutions
  from the following paper.

  Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond fingerprints." Journal of computer-aided molecular design 30.8 (2016): 595-608.
  from [1]_.

  The biggest difference between WeaveModel style convolutions
  and GraphConvModel style convolutions is that Weave
@@ -44,17 +44,64 @@ class WeaveModel(KerasModel):
  explicitly to model bond interactions. This may cause
  scaling issues, but may possibly allow for better modeling
  of subtle bond effects.

  Note that [1]_ introduces a whole variety of different architectures for
  Weave models. The default settings in this class correspond to the W2N2
  variant from [1]_ which is the most commonly used variant..

  Examples
  --------

  Here's an example of how to fit a `WeaveModel` on a tiny sample dataset.
  
  >>> import numpy as np
  >>> import deepchem as dc
  >>> featurizer = dc.feat.WeaveFeaturizer()
  >>> X = featurizer(["C", "CC"])
  >>> y = np.array([1, 0])
  >>> dataset = dc.data.NumpyDataset(X, y)
  >>> model = dc.models.WeaveModel(n_tasks=1, n_weave=2, fully_connected_layer_sizes=[2000, 1000], mode="classification")
  >>> loss = model.fit(dataset)

  Note
  ----
  In general, the use of batch normalization can cause issues with NaNs. If
  you're having trouble with NaNs while using this model, consider setting
  `batch_normalize_kwargs={"trainable": False}` or turning off batch
  normalization entirely with `batch_normalize=False`.

  References
  ----------
  .. [1] Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond
  fingerprints." Journal of computer-aided molecular design 30.8 (2016):
  595-608.

  """

  def __init__(self,
               n_tasks,
               n_atom_feat=75,
               n_pair_feat=14,
               n_hidden=50,
               n_graph_feat=128,
               mode="classification",
               n_classes=2,
               batch_size=100,
               n_tasks: int,
               n_atom_feat: OneOrMany[int] = 75,
               n_pair_feat: OneOrMany[int] = 14,
               n_hidden: int = 50,
               n_graph_feat: int = 128,
               n_weave: int = 2,
               fully_connected_layer_sizes: List[int] = [2000, 100],
               weight_init_stddevs: OneOrMany[float] = [0.01, 0.04],
               bias_init_consts: OneOrMany[float] = [0.5, 3.0],
               weight_decay_penalty: float = 0.0,
               weight_decay_penalty_type: str = "l2",
               dropouts: OneOrMany[float] = 0.25,
               activation_fns: OneOrMany[KerasActivationFn] = tf.nn.relu,
               batch_normalize: bool = True,
               batch_normalize_kwargs: Dict = {
                   "renorm": True,
                   "fused": False
               },
               gaussian_expand: bool = True,
               compress_post_gaussian_expansion: bool = False,
               mode: str = "classification",
               n_classes: int = 2,
               batch_size: int = 100,
               **kwargs):
    """
    Parameters
@@ -69,6 +116,49 @@ class WeaveModel(KerasModel):
      Number of units(convolution depths) in corresponding hidden layer
    n_graph_feat: int, optional
      Number of output features for each molecule(graph)
    n_weave: int, optional
      The number of weave layers in this model.
    fully_connected_layer_sizes: list
      The size of each dense layer in the network.  The length of
      this list determines the number of layers.
    weight_init_stddevs: list or float
      The standard deviation of the distribution to use for weight
      initialization of each layer.  The length of this list should
      equal len(layer_sizes).  Alternatively this may be a single
      value instead of a list, in which case the same value is used
      for every layer.
    bias_init_consts: list or float
      The value to initialize the biases in each layer to.  The
      length of this list should equal len(layer_sizes).
      Alternatively this may be a single value instead of a list, in
      which case the same value is used for every layer.
    weight_decay_penalty: float
      The magnitude of the weight decay penalty to use
    weight_decay_penalty_type: str
      The type of penalty to use for weight decay, either 'l1' or 'l2'
    dropouts: list or float
      The dropout probablity to use for each layer.  The length of this list
      should equal len(layer_sizes).  Alternatively this may be a single value
      instead of a list, in which case the same value is used for every layer.
    activation_fns: list or object
      The Tensorflow activation function to apply to each layer.  The length
      of this list should equal len(layer_sizes).  Alternatively this may be a
      single value instead of a list, in which case the same value is used for
      every layer.
    batch_normalize: bool, optional (default True)
      If this is turned on, apply batch normalization before applying
      activation functions on convolutional and fully connected layers.
    batch_normalize_kwargs: Dict, optional (default `{"renorm"=True, "fused": False}`)
      Batch normalization is a complex layer which has many potential
      argumentswhich change behavior. This layer accepts user-defined
      parameters which are passed to all `BatchNormalization` layers in
      `WeaveModel`, `WeaveLayer`, and `WeaveGather`.
    gaussian_expand: boolean, optional (default True)
      Whether to expand each dimension of atomic features by gaussian
      histogram
    compress_post_gaussian_expansion: bool, optional (default False)
      If True, compress the results of the Gaussian expansion back to the
      original dimensions of the input.
    mode: str
      Either "classification" or "regression" for type of model.
    n_classes: int
@@ -76,6 +166,28 @@ class WeaveModel(KerasModel):
    """
    if mode not in ['classification', 'regression']:
      raise ValueError("mode must be either 'classification' or 'regression'")

    if not isinstance(n_atom_feat, collections.Sequence):
      n_atom_feat = [n_atom_feat] * n_weave
    if not isinstance(n_pair_feat, collections.Sequence):
      n_pair_feat = [n_pair_feat] * n_weave
    n_layers = len(fully_connected_layer_sizes)
    if not isinstance(weight_init_stddevs, collections.Sequence):
      weight_init_stddevs = [weight_init_stddevs] * n_layers
    if not isinstance(bias_init_consts, collections.Sequence):
      bias_init_consts = [bias_init_consts] * n_layers
    if not isinstance(dropouts, collections.Sequence):
      dropouts = [dropouts] * n_layers
    if not isinstance(activation_fns, collections.Sequence):
      activation_fns = [activation_fns] * n_layers
    if weight_decay_penalty != 0.0:
      if weight_decay_penalty_type == 'l1':
        regularizer = tf.keras.regularizers.l1(weight_decay_penalty)
      else:
        regularizer = tf.keras.regularizers.l2(weight_decay_penalty)
    else:
      regularizer = None

    self.n_tasks = n_tasks
    self.n_atom_feat = n_atom_feat
    self.n_pair_feat = n_pair_feat
@@ -86,41 +198,72 @@ class WeaveModel(KerasModel):

    # Build the model.

    atom_features = Input(shape=(self.n_atom_feat,))
    pair_features = Input(shape=(self.n_pair_feat,))
    atom_features = Input(shape=(self.n_atom_feat[0],))
    pair_features = Input(shape=(self.n_pair_feat[0],))
    pair_split = Input(shape=tuple(), dtype=tf.int32)
    atom_split = Input(shape=tuple(), dtype=tf.int32)
    atom_to_pair = Input(shape=(2,), dtype=tf.int32)
    weave_layer1A, weave_layer1P = layers.WeaveLayer(
        n_atom_input_feat=self.n_atom_feat,
        n_pair_input_feat=self.n_pair_feat,
        n_atom_output_feat=self.n_hidden,
        n_pair_output_feat=self.n_hidden)(
            [atom_features, pair_features, pair_split, atom_to_pair])
    weave_layer2A, weave_layer2P = layers.WeaveLayer(
        n_atom_input_feat=self.n_hidden,
        n_pair_input_feat=self.n_hidden,
        n_atom_output_feat=self.n_hidden,
        n_pair_output_feat=self.n_hidden,
        update_pair=False)(
            [weave_layer1A, weave_layer1P, pair_split, atom_to_pair])
    dense1 = Dense(self.n_graph_feat, activation=tf.nn.tanh)(weave_layer2A)
    # Batch normalization causes issues, spitting out NaNs if
    # allowed to train
    batch_norm1 = BatchNormalization(epsilon=1e-5, trainable=False)(dense1)
    inputs = [atom_features, pair_features, pair_split, atom_to_pair]
    for ind in range(n_weave):
      n_atom = self.n_atom_feat[ind]
      n_pair = self.n_pair_feat[ind]
      if ind < n_weave - 1:
        n_atom_next = self.n_atom_feat[ind + 1]
        n_pair_next = self.n_pair_feat[ind + 1]
      else:
        n_atom_next = n_hidden
        n_pair_next = n_hidden
      weave_layer_ind_A, weave_layer_ind_P = layers.WeaveLayer(
          n_atom_input_feat=n_atom,
          n_pair_input_feat=n_pair,
          n_atom_output_feat=n_atom_next,
          n_pair_output_feat=n_pair_next,
          batch_normalize=batch_normalize)(inputs)
      inputs = [weave_layer_ind_A, weave_layer_ind_P, pair_split, atom_to_pair]
    # Final atom-layer convolution. Note this differs slightly from the paper
    # since we use a tanh activation. This seems necessary for numerical
    # stability.
    dense1 = Dense(self.n_graph_feat, activation=tf.nn.tanh)(weave_layer_ind_A)
    if batch_normalize:
      dense1 = BatchNormalization(**batch_normalize_kwargs)(dense1)
    weave_gather = layers.WeaveGather(
        batch_size, n_input=self.n_graph_feat,
        gaussian_expand=True)([batch_norm1, atom_split])
        batch_size,
        n_input=self.n_graph_feat,
        gaussian_expand=gaussian_expand,
        compress_post_gaussian_expansion=compress_post_gaussian_expansion)(
            [dense1, atom_split])

    if n_layers > 0:
      # Now fully connected layers
      input_layer = weave_gather
      for layer_size, weight_stddev, bias_const, dropout, activation_fn in zip(
          fully_connected_layer_sizes, weight_init_stddevs, bias_init_consts,
          dropouts, activation_fns):
        layer = Dense(
            layer_size,
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=weight_stddev),
            bias_initializer=tf.constant_initializer(value=bias_const),
            kernel_regularizer=regularizer)(weave_gather)
        if dropout > 0.0:
          layer = Dropout(rate=dropout)(layer)
        if batch_normalize:
          # Should this allow for training?
          layer = BatchNormalization(**batch_normalize_kwargs)(layer)
        layer = Activation(activation_fn)(layer)
        input_layer = layer
      output = input_layer
    else:
      output = weave_gather

    n_tasks = self.n_tasks
    if self.mode == 'classification':
      n_classes = self.n_classes
      logits = Reshape((n_tasks,
                        n_classes))(Dense(n_tasks * n_classes)(weave_gather))
      logits = Reshape((n_tasks, n_classes))(Dense(n_tasks * n_classes)(output))
      output = Softmax()(logits)
      outputs = [output, logits]
      output_types = ['prediction', 'loss']
      loss = SoftmaxCrossEntropy()
      loss: Loss = SoftmaxCrossEntropy()
    else:
      output = Dense(n_tasks)(weave_gather)
      outputs = [output]
@@ -134,12 +277,33 @@ class WeaveModel(KerasModel):
    super(WeaveModel, self).__init__(
        model, loss, output_types=output_types, batch_size=batch_size, **kwargs)

  def default_generator(self,
                        dataset,
                        epochs=1,
                        mode='fit',
                        deterministic=True,
                        pad_batches=True):
  def default_generator(
      self,
      dataset: Dataset,
      epochs: int = 1,
      mode: str = 'fit',
      deterministic: bool = True,
      pad_batches: bool = True) -> Iterable[Tuple[List, List, List]]:
    """Convert a dataset into the tensors needed for learning.

    Parameters
    ----------
    dataset: `dc.data.Dataset`
      Dataset to convert
    epochs: int, optional (Default 1)
      Number of times to walk over `dataset`
    mode: str, optional (Default 'fit')
      Ignored in this implementation.
    deterministic: bool, optional (Default True)
      Whether the dataset should be walked in a deterministic fashion
    pad_batches: bool, optional (Default True)
      If true, each returned batch will have size `self.batch_size`.

    Returns
    -------
    Iterator which walks over the batches
    """

    for epoch in range(epochs):
      for (X_b, y_b, w_b, ids_b) in dataset.iterbatches(
          batch_size=self.batch_size,
@@ -174,7 +338,7 @@ class WeaveModel(KerasModel):
          # pair features
          pair_feat.append(
              np.reshape(mol.get_pair_features(),
                         (n_atoms * n_atoms, self.n_pair_feat)))
                         (n_atoms * n_atoms, self.n_pair_feat[0])))

        inputs = [
            np.concatenate(atom_feat, axis=0),
@@ -189,9 +353,12 @@ class WeaveModel(KerasModel):
class DTNNModel(KerasModel):
  """Deep Tensor Neural Networks

  This class implements deep tensor neural networks as first defined in
  This class implements deep tensor neural networks as first defined in [1]_

  Schütt, Kristof T., et al. "Quantum-chemical insights from deep tensor neural networks." Nature communications 8.1 (2017): 1-8.
  References
  ----------
  .. [1] Schütt, Kristof T., et al. "Quantum-chemical insights from deep
  tensor neural networks." Nature communications 8.1 (2017): 1-8.
  """

  def __init__(self,
@@ -497,7 +664,7 @@ class DAGModel(KerasModel):
                        mode='fit',
                        deterministic=True,
                        pad_batches=True):
    """TensorGraph style implementation"""
    """Convert a dataset into the tensors needed for learning"""
    for epoch in range(epochs):
      for (X_b, y_b, w_b, ids_b) in dataset.iterbatches(
          batch_size=self.batch_size,
@@ -660,6 +827,8 @@ class GraphConvModel(KerasModel):
  following paper [1]_. These graph convolutions start with a per-atom set of
  descriptors for each atom in a molecule, then combine and recombine these
  descriptors over convolutional layers.
  following [1]_.


  References
  ----------
@@ -669,16 +838,16 @@ class GraphConvModel(KerasModel):
  """

  def __init__(self,
               n_tasks,
               graph_conv_layers=[64, 64],
               dense_layer_size=128,
               dropout=0.0,
               mode="classification",
               number_atom_features=75,
               n_classes=2,
               batch_size=100,
               batch_normalize=True,
               uncertainty=False,
               n_tasks: int,
               graph_conv_layers: List[int] = [64, 64],
               dense_layer_size: int = 128,
               dropout: float = 0.0,
               mode: str = "classification",
               number_atom_features: int = 75,
               n_classes: int = 2,
               batch_size: int = 100,
               batch_normalize: bool = True,
               uncertainty: bool = False,
               **kwargs):
    """The wrapper class for graph convolutions.

@@ -695,10 +864,11 @@ class GraphConvModel(KerasModel):
    dense_layer_size: int
      Width of channels for Atom Level Dense Layer before GraphPool
    dropout: list or float
      the dropout probablity to use for each layer.  The length of this list should equal
      len(graph_conv_layers)+1 (one value for each convolution layer, and one for the
      dense layer).  Alternatively this may be a single value instead of a list, in which
      case the same value is used for every layer.
      the dropout probablity to use for each layer.  The length of this list
      should equal len(graph_conv_layers)+1 (one value for each convolution
      layer, and one for the dense layer).  Alternatively this may be a single
      value instead of a list, in which case the same value is used for every
      layer.
    mode: str
      Either "classification" or "regression"
    number_atom_features: int
@@ -731,7 +901,7 @@ class GraphConvModel(KerasModel):
        batch_size=batch_size)
    if mode == "classification":
      output_types = ['prediction', 'loss', 'embedding']
      loss = SoftmaxCrossEntropy()
      loss: Union[Loss, KerasLossFn] = SoftmaxCrossEntropy()
    else:
      if self.uncertainty:
        output_types = ['prediction', 'variance', 'loss', 'loss', 'embedding']
@@ -779,11 +949,12 @@ class MPNNModel(KerasModel):
  nodes in a graph send each other "messages" and update their
  internal state as a consequence of these messages.

  Ordering structures in this model are built according to


Vinyals, Oriol, Samy Bengio, and Manjunath Kudlur. "Order matters: Sequence to sequence for sets." arXiv preprint arXiv:1511.06391 (2015).
  Ordering structures in this model are built according to [1]_

  References
  ----------
  .. [1] Vinyals, Oriol, Samy Bengio, and Manjunath Kudlur. "Order matters:
  Sequence to sequence for sets." arXiv preprint arXiv:1511.06391 (2015).
  """

  def __init__(self,
+355 −68

File changed.

Preview size limit exceeded, changes collapsed.

+45 −0
Original line number Diff line number Diff line
@@ -45,6 +45,51 @@ class LearningRateSchedule(object):
    raise NotImplemented("Subclasses must implement this")


class AdaGrad(Optimizer):
  """The AdaGrad optimization algorithm.

  Adagrad is an optimizer with parameter-specific learning rates, which are
  adapted relative to how frequently a parameter gets updated during training.
  The more updates a parameter receives, the smaller the updates. See [1]_ for
a full reference for the algorithm.

  Returns
  -------
  .. [1] Duchi, John, Elad Hazan, and Yoram Singer. "Adaptive subgradient
methods for online learning and stochastic optimization." Journal of machine
learning research 12.7 (2011).
  """

  def __init__(self,
               learning_rate=0.001,
               initial_accumulator_value=0.1,
               epsilon=1e-07):
    """Construct an AdaGrad optimizer.
    Parameters
    ----------
    learning_rate: float or LearningRateSchedule
      the learning rate to use for optimization
    initial_accumulator_value: float
      a parameter of the AdaGrad algorithm
    epsilon: float
      a parameter of the AdaGrad algorithm

    """
    self.learning_rate = learning_rate
    self.initial_accumulator_value = initial_accumulator_value
    self.epsilon = epsilon

  def _create_optimizer(self, global_step):
    if isinstance(self.learning_rate, LearningRateSchedule):
      learning_rate = self.learning_rate._create_tensor(global_step)
    else:
      learning_rate = self.learning_rate
    return tf.keras.optimizers.Adagrad(
        learning_rate=self.learning_rate,
        initial_accumulator_value=self.initial_accumulator_value,
        epsilon=self.epsilon)


class Adam(Optimizer):
  """The Adam optimization algorithm."""

+30 −7
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@ from flaky import flaky


def get_dataset(mode='classification', featurizer='GraphConv', num_tasks=2):
  data_points = 10
  data_points = 20
  if mode == 'classification':
    tasks, all_dataset, transformers = load_bace_classification(featurizer)
  else:
@@ -141,24 +141,47 @@ def test_graph_conv_atom_features():
  y_pred1 = model.predict(dataset)


@flaky
@pytest.mark.slow
def test_weave_model():
  tasks, dataset, transformers, metric = get_dataset('classification', 'Weave')

  batch_size = 10
  model = WeaveModel(len(tasks), batch_size=batch_size, mode='classification')
  model.fit(dataset, nb_epoch=50)
  batch_size = 20
  model = WeaveModel(
      len(tasks),
      batch_size=batch_size,
      mode='classification',
      fully_connected_layer_sizes=[2000, 1000],
      batch_normalize=True,
      batch_normalize_kwargs={
          "fused": False,
          "trainable": True,
          "renorm": True
      },
      learning_rage=0.0005)
  model.fit(dataset, nb_epoch=200)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.9


@flaky
@pytest.mark.slow
def test_weave_regression_model():
  import numpy as np
  import tensorflow as tf
  tf.random.set_seed(123)
  np.random.seed(123)
  tasks, dataset, transformers, metric = get_dataset('regression', 'Weave')

  batch_size = 10
  model = WeaveModel(len(tasks), batch_size=batch_size, mode='regression')
  model.fit(dataset, nb_epoch=80)
  model = WeaveModel(
      len(tasks),
      batch_size=batch_size,
      mode='regression',
      batch_normalize=False,
      fully_connected_layer_sizes=[],
      dropouts=0,
      learning_rate=0.0005)
  model.fit(dataset, nb_epoch=200)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.1

Loading