Commit ea879e41 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #496 from miaecle/molnet

New weave implementation and update for dc.molnet
parents fa9c0b8b 6f84277d
Loading
Loading
Loading
Loading
+75 −8

File changed.

Preview size limit exceeded, changes collapsed.

+33 −1
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ __license__ = "MIT"

import tensorflow as tf
from deepchem.nn.layers import GraphGather
from deepchem.models.tf_new_models.graph_topology import GraphTopology, DTNNGraphTopology, DAGGraphTopology, WeaveGraphTopology
from deepchem.models.tf_new_models.graph_topology import GraphTopology, DTNNGraphTopology, DAGGraphTopology, WeaveGraphTopology, AlternateWeaveGraphTopology


class SequentialGraph(object):
@@ -196,6 +196,38 @@ class SequentialWeaveGraph(SequentialGraph):
      self.layers.append(layer)


class AlternateSequentialWeaveGraph(SequentialGraph):
  """Alternate implementation of SequentialGraph for Weave models
  """

  def __init__(self, batch_size, max_atoms=50, n_atom_feat=75, n_pair_feat=14):
    self.graph = tf.Graph()
    self.batch_size = batch_size
    self.max_atoms = max_atoms
    self.n_atom_feat = n_atom_feat
    self.n_pair_feat = n_pair_feat
    with self.graph.as_default():
      self.graph_topology = AlternateWeaveGraphTopology(
          self.batch_size, self.max_atoms, self.n_atom_feat, self.n_pair_feat)
      self.output = self.graph_topology.get_atom_features_placeholder()
      self.output_P = self.graph_topology.get_pair_features_placeholder()
    self.layers = []

  def add(self, layer):
    """Adds a new layer to model."""
    with self.graph.as_default():
      if type(layer).__name__ in ['AlternateWeaveLayer']:
        self.output, self.output_P = layer([
            self.output, self.output_P
        ] + self.graph_topology.get_topology_placeholders())
      elif type(layer).__name__ in ['AlternateWeaveGather']:
        self.output = layer(
            [self.output, self.graph_topology.atom_split_placeholder])
      else:
        self.output = layer(self.output)
      self.layers.append(layer)


class SequentialSupportGraph(object):
  """An analog of Keras Sequential model for test/support models."""

+111 −0
Original line number Diff line number Diff line
@@ -492,3 +492,114 @@ class WeaveGraphTopology(GraphTopology):
        self.membership_placeholder: membership
    }
    return dict_DTNN


class AlternateWeaveGraphTopology(GraphTopology):
  """Manages placeholders associated with batch of graphs and their topology"""

  def __init__(self,
               batch_size,
               max_atoms,
               n_atom_feat,
               n_pair_feat,
               name='Weave_topology'):
    """
    Parameters
    ----------
    max_atoms: int
      maximum number of atoms in a molecule
    n_atom_feat: int
      number of basic features of each atom
    n_pair_feat: int
      number of basic features of each pair
    """

    #self.n_atoms = n_atoms
    self.name = name
    self.batch_size = batch_size
    self.max_atoms = max_atoms * batch_size
    self.n_atom_feat = n_atom_feat
    self.n_pair_feat = n_pair_feat

    self.atom_features_placeholder = tf.placeholder(
        dtype='float32',
        shape=(None, self.n_atom_feat),
        name=self.name + '_atom_features')
    self.pair_features_placeholder = tf.placeholder(
        dtype='float32',
        shape=(None, self.n_pair_feat),
        name=self.name + '_pair_features')
    self.pair_split_placeholder = tf.placeholder(
        dtype='int32', shape=(None,), name=self.name + '_pair_split')
    self.atom_split_placeholder = tf.placeholder(
        dtype='int32', shape=(self.batch_size,), name=self.name + '_atom_split')
    self.atom_to_pair_placeholder = tf.placeholder(
        dtype='int32', shape=(None, 2), name=self.name + '_atom_to_pair')

    # Define the list of tensors to be used as topology
    self.topology = [
        self.pair_split_placeholder, self.atom_split_placeholder,
        self.atom_to_pair_placeholder
    ]
    self.inputs = [self.atom_features_placeholder]
    self.inputs += self.topology

  def get_pair_features_placeholder(self):
    return self.pair_features_placeholder

  def batch_to_feed_dict(self, batch):
    """Converts the current batch of WeaveMol into tensorflow feed_dict.

    Assigns the atom features and pair features to the
    placeholders tensors

    params
    ------
    batch : np.ndarray
      Array of WeaveMol

    returns
    -------
    feed_dict : dict
      Can be merged with other feed_dicts for input into tensorflow
    """
    # Extract atom numbers
    atom_feat = []
    pair_feat = []
    atom_split = []
    atom_to_pair = []
    pair_split = []
    max_atoms = self.max_atoms
    start = 0
    for im, mol in enumerate(batch):
      n_atoms = mol.get_num_atoms()
      # number of atoms in each molecule
      atom_split.append(n_atoms)
      # index of pair features
      C0, C1 = np.meshgrid(np.arange(n_atoms), np.arange(n_atoms))
      atom_to_pair.append(
          np.transpose(np.array([C1.flatten() + start, C0.flatten() + start])))
      # number of pairs for each atom
      pair_split.extend(C1.flatten() + start)
      start = start + n_atoms

      # atom features
      atom_feat.append(mol.get_atom_features())
      # pair features
      pair_feat.append(
          np.reshape(mol.get_pair_features(), (n_atoms * n_atoms,
                                               self.n_pair_feat)))

    atom_feat = np.concatenate(atom_feat, axis=0)
    pair_feat = np.concatenate(pair_feat, axis=0)
    atom_to_pair = np.concatenate(atom_to_pair, axis=0)
    atom_split = np.array(atom_split)
    # Generate dicts
    dict_DTNN = {
        self.atom_features_placeholder: atom_feat,
        self.pair_features_placeholder: pair_feat,
        self.pair_split_placeholder: pair_split,
        self.atom_split_placeholder: atom_split,
        self.atom_to_pair_placeholder: atom_to_pair
    }
    return dict_DTNN
+54 −1
Original line number Diff line number Diff line
@@ -4,46 +4,65 @@ CheckFeaturizer = {
    ('bace_c', 'tf_robust'): ['ECFP', 1024],
    ('bace_c', 'rf'): ['ECFP', 1024],
    ('bace_c', 'irv'): ['ECFP', 1024],
    ('bace_c', 'xgb'): ['ECFP', 1024],
    ('bace_c', 'graphconv'): ['GraphConv', 75],
    ('bace_c', 'dag'): ['GraphConv', 75],
    ('bace_c', 'weave'): ['Weave', 75],
    ('bbbp', 'logreg'): ['ECFP', 1024],
    ('bbbp', 'tf'): ['ECFP', 1024],
    ('bbbp', 'tf_robust'): ['ECFP', 1024],
    ('bbbp', 'rf'): ['ECFP', 1024],
    ('bbbp', 'irv'): ['ECFP', 1024],
    ('bbbp', 'xgb'): ['ECFP', 1024],
    ('bbbp', 'graphconv'): ['GraphConv', 75],
    ('bbbp', 'dag'): ['GraphConv', 75],
    ('bbbp', 'weave'): ['Weave', 75],
    ('clintox', 'logreg'): ['ECFP', 1024],
    ('clintox', 'tf'): ['ECFP', 1024],
    ('clintox', 'tf_robust'): ['ECFP', 1024],
    ('clintox', 'rf'): ['ECFP', 1024],
    ('clintox', 'irv'): ['ECFP', 1024],
    ('clintox', 'xgb'): ['ECFP', 1024],
    ('clintox', 'graphconv'): ['GraphConv', 75],
    ('clintox', 'dag'): ['GraphConv', 75],
    ('clintox', 'weave'): ['Weave', 75],
    ('hiv', 'logreg'): ['ECFP', 1024],
    ('hiv', 'tf'): ['ECFP', 1024],
    ('hiv', 'tf_robust'): ['ECFP', 1024],
    ('hiv', 'rf'): ['ECFP', 1024],
    ('hiv', 'irv'): ['ECFP', 1024],
    ('hiv', 'xgb'): ['ECFP', 1024],
    ('hiv', 'graphconv'): ['GraphConv', 75],
    ('hiv', 'dag'): ['GraphConv', 75],
    ('hiv', 'weave'): ['Weave', 75],
    ('muv', 'logreg'): ['ECFP', 1024],
    ('muv', 'tf'): ['ECFP', 1024],
    ('muv', 'tf_robust'): ['ECFP', 1024],
    ('muv', 'rf'): ['ECFP', 1024],
    ('muv', 'irv'): ['ECFP', 1024],
    ('muv', 'xgb'): ['ECFP', 1024],
    ('muv', 'graphconv'): ['GraphConv', 75],
    ('muv', 'siamese'): ['GraphConv', 75],
    ('muv', 'attn'): ['GraphConv', 75],
    ('muv', 'res'): ['GraphConv', 75],
    ('muv', 'weave'): ['Weave', 75],
    ('pcba', 'logreg'): ['ECFP', 1024],
    ('pcba', 'tf'): ['ECFP', 1024],
    ('pcba', 'tf_robust'): ['ECFP', 1024],
    ('pcba', 'rf'): ['ECFP', 1024],
    ('pcba', 'irv'): ['ECFP', 1024],
    ('pcba', 'xgb'): ['ECFP', 1024],
    ('pcba', 'graphconv'): ['GraphConv', 75],
    ('pcba', 'weave'): ['Weave', 75],
    ('sider', 'logreg'): ['ECFP', 1024],
    ('sider', 'tf'): ['ECFP', 1024],
    ('sider', 'tf_robust'): ['ECFP', 1024],
    ('sider', 'rf'): ['ECFP', 1024],
    ('sider', 'irv'): ['ECFP', 1024],
    ('sider', 'xgb'): ['ECFP', 1024],
    ('sider', 'graphconv'): ['GraphConv', 75],
    ('sider', 'dag'): ['GraphConv', 75],
    ('sider', 'weave'): ['Weave', 75],
    ('sider', 'siamese'): ['GraphConv', 75],
    ('sider', 'attn'): ['GraphConv', 75],
    ('sider', 'res'): ['GraphConv', 75],
@@ -52,7 +71,10 @@ CheckFeaturizer = {
    ('tox21', 'tf_robust'): ['ECFP', 1024],
    ('tox21', 'rf'): ['ECFP', 1024],
    ('tox21', 'irv'): ['ECFP', 1024],
    ('tox21', 'xgb'): ['ECFP', 1024],
    ('tox21', 'graphconv'): ['GraphConv', 75],
    ('tox21', 'dag'): ['GraphConv', 75],
    ('tox21', 'weave'): ['Weave', 75],
    ('tox21', 'siamese'): ['GraphConv', 75],
    ('tox21', 'attn'): ['GraphConv', 75],
    ('tox21', 'res'): ['GraphConv', 75],
@@ -61,42 +83,73 @@ CheckFeaturizer = {
    ('toxcast', 'tf_robust'): ['ECFP', 1024],
    ('toxcast', 'rf'): ['ECFP', 1024],
    ('toxcast', 'irv'): ['ECFP', 1024],
    ('toxcast', 'xgb'): ['ECFP', 1024],
    ('toxcast', 'graphconv'): ['GraphConv', 75],
    ('toxcast', 'weave'): ['Weave', 75],
    ('bace_r', 'tf_regression'): ['ECFP', 1024],
    ('bace_r', 'rf_regression'): ['ECFP', 1024],
    ('bace_r', 'xgb_regression'): ['ECFP', 1024],
    ('bace_r', 'graphconvreg'): ['GraphConv', 75],
    ('bace_r', 'dag_regression'): ['GraphConv', 75],
    ('bace_r', 'weave_regression'): ['Weave', 75],
    ('chembl', 'tf_regression'): ['ECFP', 1024],
    ('chembl', 'rf_regression'): ['ECFP', 1024],
    ('chembl', 'xgb_regression'): ['ECFP', 1024],
    ('chembl', 'graphconvreg'): ['GraphConv', 75],
    ('chembl', 'weave_regression'): ['Weave', 75],
    ('clearance', 'tf_regression'): ['ECFP', 1024],
    ('clearance', 'rf_regression'): ['ECFP', 1024],
    ('clearance', 'xgb_regression'): ['ECFP', 1024],
    ('clearance', 'graphconvreg'): ['GraphConv', 75],
    ('clearance', 'dag_regression'): ['GraphConv', 75],
    ('clearance', 'weave_regression'): ['Weave', 75],
    ('delaney', 'tf_regression'): ['ECFP', 1024],
    ('delaney', 'rf_regression'): ['ECFP', 1024],
    ('delaney', 'xgb_regression'): ['ECFP', 1024],
    ('delaney', 'graphconvreg'): ['GraphConv', 75],
    ('delaney', 'dag_regression'): ['GraphConv', 75],
    ('delaney', 'weave_regression'): ['Weave', 75],
    ('hopv', 'tf_regression'): ['ECFP', 1024],
    ('hopv', 'rf_regression'): ['ECFP', 1024],
    ('hopv', 'xgb_regression'): ['ECFP', 1024],
    ('hopv', 'graphconvreg'): ['GraphConv', 75],
    ('hopv', 'dag_regression'): ['GraphConv', 75],
    ('hopv', 'weave_regression'): ['Weave', 75],
    ('lipo', 'tf_regression'): ['ECFP', 1024],
    ('lipo', 'rf_regression'): ['ECFP', 1024],
    ('lipo', 'xgb_regression'): ['ECFP', 1024],
    ('lipo', 'graphconvreg'): ['GraphConv', 75],
    ('lipo', 'dag_regression'): ['GraphConv', 75],
    ('lipo', 'weave_regression'): ['Weave', 75],
    ('nci', 'tf_regression'): ['ECFP', 1024],
    ('nci', 'rf_regression'): ['ECFP', 1024],
    ('nci', 'xgb_regression'): ['ECFP', 1024],
    ('nci', 'graphconvreg'): ['GraphConv', 75],
    ('nci', 'weave_regression'): ['Weave', 75],
    ('ppb', 'tf_regression'): ['ECFP', 1024],
    ('ppb', 'rf_regression'): ['ECFP', 1024],
    ('ppb', 'xgb_regression'): ['ECFP', 1024],
    ('ppb', 'graphconvreg'): ['GraphConv', 75],
    ('ppb', 'dag_regression'): ['GraphConv', 75],
    ('ppb', 'weave_regression'): ['Weave', 75],
    ('sampl', 'tf_regression'): ['ECFP', 1024],
    ('sampl', 'rf_regression'): ['ECFP', 1024],
    ('sampl', 'xgb_regression'): ['ECFP', 1024],
    ('sampl', 'graphconvreg'): ['GraphConv', 75],
    ('sampl', 'dag_regression'): ['GraphConv', 75],
    ('sampl', 'weave_regression'): ['Weave', 75],
    ('kaggle', 'tf_regression'): [None, 14293],
    ('kaggle', 'rf_regression'): [None, 14293],
    ('pdbbind', 'tf_regression'): ['grid', 2052],
    ('pdbbind', 'rf_regression'): ['grid', 2052],
    ('qm7', 'tf_regression_ft'): [None, [23, 23]],
    ('qm7', 'dtnn'): [None, [23, 23]],
    ('qm7b', 'tf_regression_ft'): [None, [23, 23]],
    ('qm7b', 'dtnn'): [None, [23, 23]],
    ('qm8', 'tf_regression_ft'): [None, [26, 26]],
    ('qm9', 'tf_regression_ft'): [None, [29, 29]]
    ('qm8', 'dtnn'): [None, [26, 26]],
    ('qm9', 'tf_regression_ft'): [None, [29, 29]],
    ('qm9', 'dtnn'): [None, [29, 29]]
}

CheckSplit = {
+30 −2
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ import deepchem
from deepchem.molnet.load_function.bace_features import bace_user_specified_features


def load_bace_regression(featurizer=None, split='random'):
def load_bace_regression(featurizer=None, split='random', reload=True):
  """Load bace datasets."""
  # Featurize bace dataset
  print("About to featurize bace dataset.")
@@ -18,6 +18,8 @@ def load_bace_regression(featurizer=None, split='random'):
    data_dir = os.environ["DEEPCHEM_DATA_DIR"]
  else:
    data_dir = "/tmp"
  if reload:
    save_dir = os.path.join(data_dir, "bace_r/" + featurizer + "/" + split)

  dataset_file = os.path.join(data_dir, "bace.csv")

@@ -28,10 +30,18 @@ def load_bace_regression(featurizer=None, split='random'):
    )

  bace_tasks = ["pIC50"]
  if reload:
    loaded, all_dataset, transformers = deepchem.utils.save.load_dataset_from_disk(
        save_dir)
    if loaded:
      return bace_tasks, all_dataset, transformers

  if featurizer == 'ECFP':
    featurizer = deepchem.feat.CircularFingerprint(size=1024)
  elif featurizer == 'GraphConv':
    featurizer = deepchem.feat.ConvMolFeaturizer()
  elif featurizer == 'Weave':
    featurizer = deepchem.feat.WeaveFeaturizer()
  elif featurizer == 'Raw':
    featurizer = deepchem.feat.RawFeaturizer()
  elif featurizer == None:
@@ -59,10 +69,14 @@ def load_bace_regression(featurizer=None, split='random'):
  }
  splitter = splitters[split]
  train, valid, test = splitter.train_valid_test_split(dataset)

  if reload:
    deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
                                             transformers)
  return bace_tasks, (train, valid, test), transformers


def load_bace_classification(featurizer=None, split='random'):
def load_bace_classification(featurizer=None, split='random', reload=True):
  """Load bace datasets."""
  # Featurize bace dataset
  print("About to featurize bace dataset.")
@@ -70,6 +84,8 @@ def load_bace_classification(featurizer=None, split='random'):
    data_dir = os.environ["DEEPCHEM_DATA_DIR"]
  else:
    data_dir = "/tmp"
  if reload:
    save_dir = os.path.join(data_dir, "bace_c/" + featurizer + "/" + split)

  dataset_file = os.path.join(data_dir, "bace.csv")

@@ -80,10 +96,18 @@ def load_bace_classification(featurizer=None, split='random'):
    )

  bace_tasks = ["Class"]
  if reload:
    loaded, all_dataset, transformers = deepchem.utils.save.load_dataset_from_disk(
        save_dir)
    if loaded:
      return bace_tasks, all_dataset, transformers

  if featurizer == 'ECFP':
    featurizer = deepchem.feat.CircularFingerprint(size=1024)
  elif featurizer == 'GraphConv':
    featurizer = deepchem.feat.ConvMolFeaturizer()
  elif featurizer == 'Weave':
    featurizer = deepchem.feat.WeaveFeaturizer()
  elif featurizer == 'Raw':
    featurizer = deepchem.feat.RawFeaturizer()
  elif featurizer == None:
@@ -110,4 +134,8 @@ def load_bace_classification(featurizer=None, split='random'):
  }
  splitter = splitters[split]
  train, valid, test = splitter.train_valid_test_split(dataset)

  if reload:
    deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
                                             transformers)
  return bace_tasks, (train, valid, test), transformers
Loading