Commit 767ff35a authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

First commit

parent 1b7d83bd
Loading
Loading
Loading
Loading
+109 −10
Original line number Diff line number Diff line
@@ -2856,11 +2856,7 @@ class DAGLayer(tf.keras.layers.Layer):
  """DAG computation layer.

  This layer generates a directed acyclic graph for each atom
  in a molecule. This layer is based on the algorithm from the
  following paper:

  Lusci, Alessandro, Gianluca Pollastri, and Pierre Baldi. "Deep architectures and deep learning in chemoinformatics: the prediction of aqueous solubility for drug-like molecules." Journal of chemical information and modeling 53.7 (2013): 1563-1575.

  in a molecule. This layer is based on the algorithm from [1]_

  This layer performs a sort of inward sweep. Recall that for
  each atom, a DAG is generated that "points inward" to that
@@ -2870,6 +2866,13 @@ class DAGLayer(tf.keras.layers.Layer):
  inwards" from the leaf nodes of the DAG upwards to the
  atom. This is batched so the transformation is done for
  each atom.

  References
  ----------
  ..[1] Lusci, Alessandro, Gianluca Pollastri, and Pierre Baldi. "Deep
  architectures and deep learning in chemoinformatics: the prediction of
  aqueous solubility for drug-like molecules." Journal of chemical information
  and modeling 53.7 (2013): 1563-1575.
  """

  def __init__(self,
@@ -3104,8 +3107,77 @@ class DAGGather(tf.keras.layers.Layer):


class MessagePassing(tf.keras.layers.Layer):
  """ General class for MPNN
  default structures built according to https://arxiv.org/abs/1511.06391 """
  """ General class for MPNN message passing.

  Default structures are built according to [1]_. This class performs `T`
  steps of message passing.

  Examples
  --------
  >>> import numpy as np
  >>> import deepchem as dc

  Suppose you have a batch of molecules

  >>> smiles = ["CCC", "C"]

  Note that there are 4 atoms in total in this system. This layer expects its
  input molecules to be batched together.

  >>> total_n_atoms = 4

  Let's suppose that we have a featurizer that computes `n_atom_feat` features
  per atom.

  >>> n_atom_feat = 75

  Then conceptually, `atom_feat` is the array of shape `(total_n_atoms,
  n_atom_feat)` of atomic features. For simplicity, let's just go with a
  random such matrix.

  >>> atom_feat = np.random.rand(total_n_atoms, n_atom_feat)

  Let's suppose we have `n_pair_feat` pairwise features

  >>> n_pair_feat = 14

  For each molecule, we compute a matrix of shape `(n_atoms*n_atoms,
  n_pair_feat)` of pairwise features for each pair of atoms in the molecule.
  Let's construct this conceptually for our example.

  >>> pair_feat = [np.random.rand(3*3, n_pair_feat), np.random.rand(1*1, n_pair_feat)]
  >>> pair_feat = np.concatenate(pair_feat, axis=0)
  >>> pair_feat.shape
  (10, 14)

  `atom_to_pair` tells us the precise pair each pair feature belongs to. In
  our case

  >>> atom_to_pair = np.array([[0, 0],
  ...                          [0, 1],
  ...                          [0, 2],
  ...                          [1, 0],
  ...                          [1, 1],
  ...                          [1, 2],
  ...                          [2, 0],
  ...                          [2, 1],
  ...                          [2, 2],
  ...                          [3, 3]])

  Let's now define the actual layer with `T=2` layers of message passing.

  >>> layer = MessagePassing(T=2)

  And invoke it

  >>> out = layer([atom_feat, pair_feat, atom_to_pair])


  References
  ----------
  .. [1] Vinyals, Oriol, Samy Bengio, and Manjunath Kudlur. "Order matters:
  Sequence to sequence for sets." arXiv preprint arXiv:1511.06391 (2015).
  """

  def __init__(self,
               T,
@@ -3150,7 +3222,17 @@ class MessagePassing(tf.keras.layers.Layer):
    self.built = True

  def call(self, inputs):
    """ Perform T steps of message passing """
    """ Perform T steps of message passing 

    Parameters
    ----------
    inputs: List
      Should contain 3 tensors [atom_features, pair_features, atom_to_pair]

    Returns
    -------
    tf.Tensor
    """
    atom_features, pair_features, atom_to_pair = inputs
    n_atom_features = atom_features.get_shape().as_list()[-1]
    if n_atom_features < self.n_hidden:
@@ -3167,7 +3249,15 @@ class MessagePassing(tf.keras.layers.Layer):


class EdgeNetwork(tf.keras.layers.Layer):
  """ Submodule for Message Passing """
  """ Submodule for Message Passing 

  Implements an edge network as described in [1]_

  References
  ----------
  .. [1] Vinyals, Oriol, Samy Bengio, and Manjunath Kudlur. "Order matters:
  Sequence to sequence for sets." arXiv preprint arXiv:1511.06391 (2015).
  """

  def __init__(self,
               n_pair_features=8,
@@ -3204,7 +3294,16 @@ class EdgeNetwork(tf.keras.layers.Layer):


class GatedRecurrentUnit(tf.keras.layers.Layer):
  """ Submodule for Message Passing """
  """ Submodule for Message Passing

  Implements a GRU for message passing as described in [1]_

  References
  ----------
  .. [1] Vinyals, Oriol, Samy Bengio, and Manjunath Kudlur. "Order matters:
  Sequence to sequence for sets." arXiv preprint arXiv:1511.06391 (2015).

  """

  def __init__(self, n_hidden=100, init='glorot_uniform', **kwargs):
    super(GatedRecurrentUnit, self).__init__(**kwargs)
+40 −0
Original line number Diff line number Diff line
@@ -99,6 +99,46 @@ def test_interatomic_l2_distances():
      assert np.allclose(dist2, result[atom, neighbor])


def test_message_passing_layer():
  """Test invoking WeaveLayer."""
  out_channels = 2
  n_atoms = 4  # In CCC and C, there are 4 atoms
  raw_smiles = ['CCC', 'C']
  from rdkit import Chem
  mols = [Chem.MolFromSmiles(s) for s in raw_smiles]
  featurizer = dc.feat.WeaveFeaturizer()
  mols = featurizer.featurize(mols)
  mp = layers.MessagePassing(T=2)
  atom_feat = []
  pair_feat = []
  atom_to_pair = []
  start = 0
  n_pair_feat = 14
  for im, mol in enumerate(mols):
    n_atoms = mol.get_num_atoms()
    # index of pair features
    C0, C1 = np.meshgrid(np.arange(n_atoms), np.arange(n_atoms))
    atom_to_pair.append(
        np.transpose(np.array([C1.flatten() + start,
                               C0.flatten() + start])))
    # number of pairs for each atom
    start = start + n_atoms

    # atom features
    atom_feat.append(mol.get_atom_features())
    # pair features
    pair_feat.append(
        np.reshape(mol.get_pair_features(), (n_atoms * n_atoms, n_pair_feat)))
  inputs = [
      np.array(np.concatenate(atom_feat, axis=0), dtype=np.float32),
      np.concatenate(pair_feat, axis=0),
      np.concatenate(atom_to_pair, axis=0)
  ]
  # Outputs should be what?
  outputs = mp(inputs)
  assert len(outputs) == 4


def test_weave_layer():
  """Test invoking WeaveLayer."""
  out_channels = 2