Commit 1505edb6 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

changes

parent 73e00d79
Loading
Loading
Loading
Loading
+29 −16
Original line number Diff line number Diff line
@@ -4,6 +4,8 @@ import deepchem as dc
import numpy as np
import tensorflow as tf

from typing import List
from deepchem.utils.typing import OneOrMany 
from deepchem.data import NumpyDataset, pad_features
from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.mol_graphs import ConvMol
@@ -114,25 +116,36 @@ class WeaveModel(KerasModel):

    # Build the model.

    atom_features = Input(shape=(self.n_atom_feat,))
    pair_features = Input(shape=(self.n_pair_feat,))
    atom_features = Input(shape=(self.n_atom_feat[0],))
    pair_features = Input(shape=(self.n_pair_feat[0],))
    pair_split = Input(shape=tuple(), dtype=tf.int32)
    atom_split = Input(shape=tuple(), dtype=tf.int32)
    atom_to_pair = Input(shape=(2,), dtype=tf.int32)
    weave_layer1A, weave_layer1P = layers.WeaveLayer(
        n_atom_input_feat=self.n_atom_feat,
        n_pair_input_feat=self.n_pair_feat,
        n_atom_output_feat=self.n_hidden,
        n_pair_output_feat=self.n_hidden)(
            [atom_features, pair_features, pair_split, atom_to_pair])
    weave_layer2A, weave_layer2P = layers.WeaveLayer(
        n_atom_input_feat=self.n_hidden,
        n_pair_input_feat=self.n_hidden,
        n_atom_output_feat=self.n_hidden,
        n_pair_output_feat=self.n_hidden,
        update_pair=False)(
            [weave_layer1A, weave_layer1P, pair_split, atom_to_pair])
    dense1 = Dense(self.n_graph_feat, activation=tf.nn.tanh)(weave_layer2A)
    inputs = [atom_features, pair_features, pair_split, atom_to_pair]
    for ind in range(n_weave):
      n_atom = self.n_atom_feat[ind]
      n_pair = self.n_pair_feat[ind]
      if ind < n_weave - 1:
        n_atom_next = self.n_atom_feat[ind+1]
        n_pair_next = self.n_pair_feat[ind+1]
      else:
        n_atom_next = n_hidden
        n_pair_next = n_hidden
      weave_layer_ind_A, weave_layer_ind_P = layers.WeaveLayer(
          n_atom_input_feat=n_atom,
          n_pair_input_feat=n_pair,
          n_atom_output_feat=n_atom_next,
          n_pair_output_feat=n_pair_next)(inputs)
      inputs = [weave_layer_ind_A, weave_layer_ind_P, pair_split, atom_to_pair]
    #weave_layer2A, weave_layer2P = layers.WeaveLayer(
    #    n_atom_input_feat=self.n_hidden,
    #    n_pair_input_feat=self.n_hidden,
    #    n_atom_output_feat=self.n_hidden,
    #    n_pair_output_feat=self.n_hidden,
    #    update_pair=False)(
    #        [weave_layer1A, weave_layer1P, pair_split, atom_to_pair])
    #dense1 = Dense(self.n_graph_feat, activation=tf.nn.tanh)(weave_layer2A)
    dense1 = Dense(self.n_graph_feat, activation=tf.nn.tanh)(weave_layer_ind_A)
    # Batch normalization causes issues, spitting out NaNs if
    # allowed to train
    batch_norm1 = BatchNormalization(epsilon=1e-5, trainable=False)(dense1)
+3 −1
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@
import tensorflow as tf
import numpy as np
import collections
from typing import Callable, Dict, List
from tensorflow.keras import activations, initializers, backend
from tensorflow.keras.layers import Dropout

@@ -2121,7 +2122,8 @@ class WeaveLayer(tf.keras.layers.Layer):
    self.n_pair_output_feat = n_pair_output_feat
    self.W_AP, self.b_AP, self.W_PP, self.b_PP, self.W_P, self.b_P = None, None, None, None, None, None

  def get_config(self):
  def get_config(self) -> Dict:
    """Returns config dictionary for this layer."""
    config = super(WeaveLayer, self).get_config()
    config['n_atom_input_feat'] = self.n_atom_input_feat
    config['n_pair_input_feat'] = self.n_pair_input_feat