Commit aa4add51 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Fixing some tweak in Weave models

parent 1cbafde0
Loading
Loading
Loading
Loading
+52 −35
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ import deepchem as dc
import numpy as np
import tensorflow as tf

from typing import List, Union, Tuple, Iterable, Dict
from typing import List, Union, Tuple, Iterable, Dict, Optional
from deepchem.utils.typing import OneOrMany, LossFn, KerasActivationFn
from deepchem.data import Dataset, NumpyDataset, pad_features
from deepchem.feat.graph_features import ConvMolFeaturizer
@@ -78,7 +78,8 @@ class WeaveModel(KerasModel):

  """

  def __init__(self,
  def __init__(
      self,
      n_tasks: int,
      n_atom_feat: OneOrMany[int] = 75,
      n_pair_feat: OneOrMany[int] = 14,
@@ -86,11 +87,13 @@ class WeaveModel(KerasModel):
      n_graph_feat: int = 128,
      n_weave: int = 2,
      fully_connected_layer_sizes: List[int] = [2000, 100],
               weight_init_stddevs: OneOrMany[float] = [0.01, 0.04],
               bias_init_consts: OneOrMany[float] = [0.5, 3.0],
      conv_weight_init_stddevs: OneOrMany[float] = 0.03,
      weight_init_stddevs: OneOrMany[float] = 0.01,
      bias_init_consts: OneOrMany[float] = 0.0,
      weight_decay_penalty: float = 0.0,
      weight_decay_penalty_type: str = "l2",
      dropouts: OneOrMany[float] = 0.25,
      final_conv_activation_fn: Optional[KerasActivationFn] = tf.nn.tanh,
      activation_fns: OneOrMany[KerasActivationFn] = tf.nn.relu,
      batch_normalize: bool = True,
      batch_normalize_kwargs: Dict = {
@@ -121,14 +124,18 @@ class WeaveModel(KerasModel):
    fully_connected_layer_sizes: list
      The size of each dense layer in the network.  The length of
      this list determines the number of layers.
    conv_weight_init_stddevs: list or float
      The standard deviation of the distribution to use for weight
      initialization of each convolutional layer. The length of this lisst
      should equal `n_weave`. Alternatively, this may be a single value instead
      of a list, in which case the same value is used for each layer.
    weight_init_stddevs: list or float
      The standard deviation of the distribution to use for weight
      initialization of each layer.  The length of this list should
      equal len(layer_sizes).  Alternatively this may be a single
      value instead of a list, in which case the same value is used
      for every layer.
      initialization of each fully connected layer.  The length of this list
      should equal len(layer_sizes).  Alternatively this may be a single value
      instead of a list, in which case the same value is used for every layer.
    bias_init_consts: list or float
      The value to initialize the biases in each layer to.  The
      The value to initialize the biases in each fully connected layer.  The
      length of this list should equal len(layer_sizes).
      Alternatively this may be a single value instead of a list, in
      which case the same value is used for every layer.
@@ -137,11 +144,15 @@ class WeaveModel(KerasModel):
    weight_decay_penalty_type: str
      The type of penalty to use for weight decay, either 'l1' or 'l2'
    dropouts: list or float
      The dropout probablity to use for each layer.  The length of this list
      The dropout probablity to use for each fully connected layer.  The length of this list
      should equal len(layer_sizes).  Alternatively this may be a single value
      instead of a list, in which case the same value is used for every layer.
    final_conv_activation_fn: Optional[KerasActivationFn]
      The Tensorflow activation funcntion to apply to the final
      convolution at the end of the weave convolutions. If `None`, then no
      convolution (linear) is applied.
    activation_fns: list or object
      The Tensorflow activation function to apply to each layer.  The length
      The Tensorflow activation function to apply to each fully connected layer.  The length
      of this list should equal len(layer_sizes).  Alternatively this may be a
      single value instead of a list, in which case the same value is used for
      every layer.
@@ -172,6 +183,8 @@ class WeaveModel(KerasModel):
    if not isinstance(n_pair_feat, collections.Sequence):
      n_pair_feat = [n_pair_feat] * n_weave
    n_layers = len(fully_connected_layer_sizes)
    if not isinstance(conv_weight_init_stddevs, collections.Sequence):
      conv_weight_init_stddevs = [conv_weight_init_stddevs] * n_weave
    if not isinstance(weight_init_stddevs, collections.Sequence):
      weight_init_stddevs = [weight_init_stddevs] * n_layers
    if not isinstance(bias_init_consts, collections.Sequence):
@@ -217,12 +230,16 @@ class WeaveModel(KerasModel):
          n_pair_input_feat=n_pair,
          n_atom_output_feat=n_atom_next,
          n_pair_output_feat=n_pair_next,
          init=tf.keras.initializers.TruncatedNormal(
              stddev=conv_weight_init_stddevs[ind]),
          batch_normalize=batch_normalize)(inputs)
      inputs = [weave_layer_ind_A, weave_layer_ind_P, pair_split, atom_to_pair]
    # Final atom-layer convolution. Note this differs slightly from the paper
    # since we use a tanh activation. This seems necessary for numerical
    # since we use a tanh activation as default. This seems necessary for numerical
    # stability.
    dense1 = Dense(self.n_graph_feat, activation=tf.nn.tanh)(weave_layer_ind_A)
    dense1 = Dense(
        self.n_graph_feat,
        activation=final_conv_activation_fn)(weave_layer_ind_A)
    if batch_normalize:
      dense1 = BatchNormalization(**batch_normalize_kwargs)(dense1)
    weave_gather = layers.WeaveGather(
+2 −1
Original line number Diff line number Diff line
@@ -108,7 +108,8 @@ def test_weave_layer():
  mols = [Chem.MolFromSmiles(s) for s in raw_smiles]
  featurizer = dc.feat.WeaveFeaturizer()
  mols = featurizer.featurize(mols)
  weave = layers.WeaveLayer()
  weave = layers.WeaveLayer(
      init=tf.keras.initializers.TruncatedNormal(stddev=0.03))
  atom_feat = []
  pair_feat = []
  atom_to_pair = []
+15 −18
Original line number Diff line number Diff line
@@ -13,8 +13,10 @@ from deepchem.feat import ConvMolFeaturizer
from flaky import flaky


def get_dataset(mode='classification', featurizer='GraphConv', num_tasks=2):
  data_points = 20
def get_dataset(mode='classification',
                featurizer='GraphConv',
                num_tasks=2,
                data_points=20):
  if mode == 'classification':
    tasks, all_dataset, transformers = load_bace_classification(
        featurizer, reload=False)
@@ -121,22 +123,18 @@ def test_compute_features_on_distance_1():
@flaky
@pytest.mark.slow
def test_weave_model():
  tasks, dataset, transformers, metric = get_dataset('classification', 'Weave')
  tasks, dataset, transformers, metric = get_dataset(
      'classification', 'Weave', data_points=10)

  batch_size = 20
  batch_size = 10
  model = WeaveModel(
      len(tasks),
      batch_size=batch_size,
      mode='classification',
      fully_connected_layer_sizes=[2000, 1000],
      batch_normalize=True,
      batch_normalize_kwargs={
          "fused": False,
          "trainable": True,
          "renorm": True
      },
      learning_rage=0.0005)
  model.fit(dataset, nb_epoch=200)
      final_conv_activation_fn=None,
      dropouts=0,
      learning_rage=0.0003)
  model.fit(dataset, nb_epoch=100)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.9

@@ -147,18 +145,17 @@ def test_weave_regression_model():
  import tensorflow as tf
  tf.random.set_seed(123)
  np.random.seed(123)
  tasks, dataset, transformers, metric = get_dataset('regression', 'Weave')
  tasks, dataset, transformers, metric = get_dataset(
      'regression', 'Weave', data_points=10)

  batch_size = 10
  model = WeaveModel(
      len(tasks),
      batch_size=batch_size,
      mode='regression',
      batch_normalize=False,
      fully_connected_layer_sizes=[],
      dropouts=0,
      learning_rate=0.0005)
  model.fit(dataset, nb_epoch=200)
      learning_rate=0.00003)
  model.fit(dataset, nb_epoch=400)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.1