Commit 032ac8b9 authored by hsjang001205's avatar hsjang001205
Browse files

DAG_reload

parent 95d32dfc
Loading
Loading
Loading
Loading
+158 −1
Original line number Diff line number Diff line
@@ -9,7 +9,9 @@ from tensorflow.keras.layers import Dropout, BatchNormalization

class InteratomicL2Distances(tf.keras.layers.Layer):
  """Compute (squared) L2 Distances between atoms given neighbors.

  This class computes pairwise distances between its inputs.

  Examples
  --------
  >>> import numpy as np
@@ -22,10 +24,12 @@ class InteratomicL2Distances(tf.keras.layers.Layer):
  >>> result = np.array(layer([coords, neighbor_list]))
  >>> result.shape
  (5, 2)

  """

  def __init__(self, N_atoms: int, M_nbrs: int, ndim: int, **kwargs):
    """Constructor for this layer.

    Parameters
    ----------
    N_atoms: int
@@ -50,11 +54,13 @@ class InteratomicL2Distances(tf.keras.layers.Layer):

  def call(self, inputs):
    """Invokes this layer.

    Parameters
    ----------
    inputs: list
      Should be of form `inputs=[coords, nbr_list]` where `coords` is a
      tensor of shape `(None, N, 3)` and `nbr_list` is a list.

    Returns
    -------
    Tensor of shape `(N_atoms, M_nbrs)` with interatomic distances.
@@ -79,6 +85,7 @@ class GraphConv(tf.keras.layers.Layer):
  convolution combines per-node feature vectures in a nonlinear fashion with
  the feature vectors for neighboring nodes.  This "blends" information in
  local neighborhoods of a graph.

  References
  ----------
  .. [1] Duvenaud, David K., et al. "Convolutional networks on graphs for learning molecular fingerprints." Advances in neural information processing systems. 2015. https://arxiv.org/abs/1509.09292
@@ -92,6 +99,7 @@ class GraphConv(tf.keras.layers.Layer):
               activation_fn: Callable = None,
               **kwargs):
    """Initialize a graph convolutional layer.

    Parameters
    ----------
    out_channel: int
@@ -207,10 +215,12 @@ class GraphConv(tf.keras.layers.Layer):

class GraphPool(tf.keras.layers.Layer):
  """A GraphPool gathers data from local neighborhoods of a graph.

  This layer does a max-pooling over the feature vectors of atoms in a
  neighborhood. You can think of this layer as analogous to a max-pooling
  layer for 2D convolutions but which operates on graphs instead. This
  technique is described in [1]_.

  References
  ----------
  .. [1] Duvenaud, David K., et al. "Convolutional networks on graphs for
@@ -221,6 +231,7 @@ class GraphPool(tf.keras.layers.Layer):

  def __init__(self, min_degree=0, max_degree=10, **kwargs):
    """Initialize this layer

    Parameters
    ----------
    min_deg: int, optional (default 0)
@@ -283,6 +294,7 @@ class GraphPool(tf.keras.layers.Layer):

class GraphGather(tf.keras.layers.Layer):
  """A GraphGather layer pools node-level feature vectors to create a graph feature vector.

  Many graph convolutional networks manipulate feature vectors per
  graph-node. For a molecule for example, each node might represent an
  atom, and the network would manipulate atomic feature vectors that
@@ -290,11 +302,13 @@ class GraphGather(tf.keras.layers.Layer):
  the application, we will likely want to work with a molecule level
  feature representation. The `GraphGather` layer creates a graph level
  feature vector by combining all the node-level feature vectors.

  One subtlety about this layer is that it depends on the
  `batch_size`. This is done for internal implementation reasons. The
  `GraphConv`, and `GraphPool` layers pool all nodes from all graphs
  in a batch that's being processed. The `GraphGather` reassembles
  these jumbled node feature vectors into per-graph feature vectors.

  References
  ----------
  .. [1] Duvenaud, David K., et al. "Convolutional networks on graphs for
@@ -304,6 +318,7 @@ class GraphGather(tf.keras.layers.Layer):

  def __init__(self, batch_size, activation_fn=None, **kwargs):
    """Initialize this layer.

    Parameters
    ---------
    batch_size: int
@@ -326,6 +341,7 @@ class GraphGather(tf.keras.layers.Layer):

  def call(self, inputs):
    """Invoking this layer.

    Parameters
    ----------
    inputs: list
@@ -353,6 +369,7 @@ class GraphGather(tf.keras.layers.Layer):

class LSTMStep(tf.keras.layers.Layer):
  """Layer that performs a single step LSTM update.

  This layer performs a single step LSTM update. Note that it is *not*
  a full LSTM recurrent network. The LSTMStep layer is useful as a
  primitive for designing layers such as the AttnLSTMEmbedding or the
@@ -423,10 +440,12 @@ class LSTMStep(tf.keras.layers.Layer):

  def call(self, inputs):
    """Execute this layer on input tensors.

    Parameters
    ----------
    inputs: list
      List of three tensors (x, h_tm1, c_tm1). h_tm1 means "h, t-1".

    Returns
    -------
    list
@@ -464,11 +483,13 @@ def cosine_dist(x, y):
  input tensors would be different test vectors or sentences. The input tensors
  themselves could be different batches. Using vectors or tensors of all 0s
  should be avoided.

  Methods
  -------
  The vectors in the input tensors are first l2-normalized such that each vector
  has length or magnitude of 1. The inner product (dot product) is then taken 
  between corresponding pairs of row vectors in the input tensors and returned.

  Examples
  --------
  The cosine similarity between two equivalent vectors will be 1. The cosine
@@ -482,18 +503,22 @@ def cosine_dist(x, y):
  >>> x = tf.ones((6, 4), dtype=tf.dtypes.float32, name=None)
  >>> y_same = tf.ones((6, 4), dtype=tf.dtypes.float32, name=None)
  >>> cos_sim_same = layers.cosine_dist(x,y_same)

  `x` and `y_same` are the same tensor (equivalent at every element, in this 
  case 1). As such, the pairwise inner product of the rows in `x` and `y` will
  always be 1. The output tensor will be of shape (6,6).

  >>> diff = cos_sim_same - tf.ones((6, 6), dtype=tf.dtypes.float32, name=None)
  >>> tf.reduce_sum(diff) == 0 # True
  <tf.Tensor: shape=(), dtype=bool, numpy=True>
  >>> cos_sim_same.shape
  TensorShape([6, 6])

  The cosine similarity between two orthogonal vectors will be 0 (by definition).
  If every row in `x` is orthogonal to every row in `y`, then the output will be a
  tensor of 0s. In the following example, each row in the tensor `x1` is orthogonal
  to each row in `x2` because they are halves of an identity matrix.

  >>> identity_tensor = tf.eye(512, dtype=tf.dtypes.float32)
  >>> x1 = identity_tensor[0:256,:]
  >>> x2 = identity_tensor[256:512,:]
@@ -508,6 +533,7 @@ def cosine_dist(x, y):
  <tf.Tensor: shape=(), dtype=bool, numpy=True>
  >>> cos_sim_orth.shape
  TensorShape([256, 256])

  Parameters
  ----------
  x: tf.Tensor
@@ -518,6 +544,7 @@ def cosine_dist(x, y):
    Input Tensor of shape `(m, p)`
    The shape of this input tensor should be `m` rows by `p` columns.
    Note that `m` need not equal `n` (the number of rows in `x`).

  Returns
  -------
  tf.Tensor
@@ -533,6 +560,7 @@ def cosine_dist(x, y):

class AttnLSTMEmbedding(tf.keras.layers.Layer):
  """Implements AttnLSTM as in matching networks paper.

  The AttnLSTM embedding adjusts two sets of vectors, the "test" and
  "support" sets. The "support" consists of a set of evidence vectors.
  Think of these as the small training set for low-data machine
@@ -542,7 +570,9 @@ class AttnLSTMEmbedding(tf.keras.layers.Layer):
  the "support".  The AttnLSTMEmbedding is thus a type of learnable
  metric that allows a network to modify its internal notion of
  distance.

  See references [1]_ [2]_ for more details.

  References
  ----------
  .. [1] Vinyals, Oriol, et al. "Matching networks for one shot learning." 
@@ -588,6 +618,7 @@ class AttnLSTMEmbedding(tf.keras.layers.Layer):

  def call(self, inputs):
    """Execute this layer on input tensors.

    Parameters
    ----------
    inputs: list
@@ -595,6 +626,7 @@ class AttnLSTMEmbedding(tf.keras.layers.Layer):
      n_feat) and Xp should be of shape (n_support, n_feat) where
      n_test is the size of the test set, n_support that of the support
      set, and n_feat is the number of per-atom features.

    Returns
    -------
    list
@@ -625,6 +657,7 @@ class AttnLSTMEmbedding(tf.keras.layers.Layer):

class IterRefLSTMEmbedding(tf.keras.layers.Layer):
  """Implements the Iterative Refinement LSTM.

  Much like AttnLSTMEmbedding, the IterRefLSTMEmbedding is another type
  of learnable metric which adjusts "test" and "support." Recall that
  "support" is the small amount of data available in a low data machine
@@ -641,6 +674,7 @@ class IterRefLSTMEmbedding(tf.keras.layers.Layer):
    additively, this model allows for an additive update to be
    performed to both test and support using information from each
    other.

    Parameters
    ----------
    n_support: int
@@ -684,6 +718,7 @@ class IterRefLSTMEmbedding(tf.keras.layers.Layer):

  def call(self, inputs):
    """Execute this layer on input tensors.

    Parameters
    ----------
    inputs: list
@@ -691,6 +726,7 @@ class IterRefLSTMEmbedding(tf.keras.layers.Layer):
      n_feat) and Xp should be of shape (n_support, n_feat) where
      n_test is the size of the test set, n_support that of the
      support set, and n_feat is the number of per-atom features.

    Returns
    -------
    Returns two tensors of same shape as input. Namely the output
@@ -737,6 +773,7 @@ class IterRefLSTMEmbedding(tf.keras.layers.Layer):

class SwitchedDropout(tf.keras.layers.Layer):
  """Apply dropout based on an input.

  This is required for uncertainty prediction.  The standard Keras
  Dropout layer only performs dropout during training, but we
  sometimes need to do it during prediction.  The second input to this
@@ -763,6 +800,7 @@ class WeightedLinearCombo(tf.keras.layers.Layer):

  def __init__(self, std=0.3, **kwargs):
    """Initialize this layer.

    Parameters
    ----------
    std: float, optional (default 0.3)
@@ -800,11 +838,13 @@ class CombineMeanStd(tf.keras.layers.Layer):

  def __init__(self, training_only=False, noise_epsilon=1.0, **kwargs):
    """Create a CombineMeanStd layer.

    This layer should have two inputs with the same shape, and its
    output also has the same shape.  Each element of the output is a
    Gaussian distributed random number whose mean is the corresponding
    element of the first input, and whose standard deviation is the
    corresponding element of the second input.

    Parameters
    ----------
    training_only: bool
@@ -853,6 +893,7 @@ class Stack(tf.keras.layers.Layer):

class Variable(tf.keras.layers.Layer):
  """Output a trainable value.

  Due to a quirk of Keras, you must pass an input value when invoking
  this layer.  It doesn't matter what value you pass.  Keras assumes
  every layer that is not an Input will have at least one parent, and
@@ -861,6 +902,7 @@ class Variable(tf.keras.layers.Layer):

  def __init__(self, initial_value, **kwargs):
    """Construct a variable layer.

    Parameters
    ----------
    initial_value: array or Tensor
@@ -884,6 +926,7 @@ class Variable(tf.keras.layers.Layer):

class VinaFreeEnergy(tf.keras.layers.Layer):
  """Computes free-energy as defined by Autodock Vina.

  TODO(rbharath): Make this layer support batching.
  """

@@ -972,6 +1015,7 @@ class VinaFreeEnergy(tf.keras.layers.Layer):
      Coordinates/features.
    Z: tf.Tensor of shape (N)
      Atomic numbers of neighbor atoms.

    Returns
    -------
    layer: tf.Tensor of shape (B)
@@ -1008,11 +1052,13 @@ class VinaFreeEnergy(tf.keras.layers.Layer):

class NeighborList(tf.keras.layers.Layer):
  """Computes a neighbor-list in Tensorflow.

  Neighbor-lists (also called Verlet Lists) are a tool for grouping
  atoms which are close to each other spatially. This layer computes a
  Neighbor List from a provided tensor of atomic coordinates. You can
  think of this as a general "k-means" layer, but optimized for the
  case `k==3`.

  TODO(rbharath): Make this layer support batching.
  """

@@ -1063,11 +1109,14 @@ class NeighborList(tf.keras.layers.Layer):

  def compute_nbr_list(self, coords):
    """Get closest neighbors for atoms.

    Needs to handle padding for atoms with no neighbors.

    Parameters
    ----------
    coords: tf.Tensor
      Shape (N_atoms, ndim)

    Returns
    -------
    nbr_list: tf.Tensor
@@ -1122,6 +1171,7 @@ class NeighborList(tf.keras.layers.Layer):

  def get_atoms_in_nbrs(self, coords, cells):
    """Get the atoms in neighboring cells for each cells.

    Returns
    -------
    atoms_in_nbrs = (N_atoms, n_nbr_cells, M_nbrs)
@@ -1162,13 +1212,16 @@ class NeighborList(tf.keras.layers.Layer):

  def get_closest_atoms(self, coords, cells):
    """For each cell, find M_nbrs closest atoms.

    Let N_atoms be the number of atoms.

    Parameters
    ----------
    coords: tf.Tensor
      (N_atoms, ndim) shape.
    cells: tf.Tensor
      (n_cells, ndim) shape.

    Returns
    -------
    closest_inds: tf.Tensor
@@ -1197,6 +1250,7 @@ class NeighborList(tf.keras.layers.Layer):

  def get_cells_for_atoms(self, coords, cells):
    """Compute the cells each atom belongs to.

    Parameters
    ----------
    coords: tf.Tensor
@@ -1238,11 +1292,13 @@ class NeighborList(tf.keras.layers.Layer):

  def get_neighbor_cells(self, cells):
    """Compute neighbors of cells in grid.

    # TODO(rbharath): Do we need to handle periodic boundary conditions
    properly here?
    # TODO(rbharath): This doesn't handle boundaries well. We hard-code
    # looking for n_nbr_cells neighbors, which isn't right for boundary cells in
    # the cube.

    Parameters
    ----------
    cells: tf.Tensor
@@ -1270,9 +1326,11 @@ class NeighborList(tf.keras.layers.Layer):

  def get_cells(self):
    """Returns the locations of all grid points in box.

    Suppose start is -10 Angstrom, stop is 10 Angstrom, nbr_cutoff is 1.
    Then would return a list of length 20^3 whose entries would be
    [(-10, -10, -10), (-10, -10, -9), ..., (9, 9, 9)]

    Returns
    -------
    cells: tf.Tensor
@@ -1288,9 +1346,11 @@ class NeighborList(tf.keras.layers.Layer):

class AtomicConvolution(tf.keras.layers.Layer):
  """Implements the atomic convolutional transform introduced in

  Gomes, Joseph, et al. "Atomic convolutional networks for predicting
  protein-ligand binding affinity." arXiv preprint arXiv:1703.10603
  (2017).

  At a high level, this transform performs a graph convolution
  on the nearest neighbors graph in 3D space.
  """
@@ -1301,8 +1361,10 @@ class AtomicConvolution(tf.keras.layers.Layer):
               boxsize=None,
               **kwargs):
    """Atomic convolution layer

    N = max_num_atoms, M = max_num_neighbors, B = batch_size, d = num_features
    l = num_radial_filters * num_atom_types

    Parameters
    ----------
    atom_types: list or None
@@ -1344,6 +1406,7 @@ class AtomicConvolution(tf.keras.layers.Layer):
      Neighbor list.
    Nbrs_Z: tf.Tensor of shape (B, N, M)
      Atomic numbers of neighbor atoms.

    Returns
    -------
    layer: tf.Tensor of shape (B, N, l)
@@ -1386,7 +1449,9 @@ class AtomicConvolution(tf.keras.layers.Layer):

  def radial_symmetry_function(self, R, rc, rs, e):
    """Calculates radial symmetry function.

    B = batch_size, N = max_num_atoms, M = max_num_neighbors, d = num_filters

    Parameters
    ----------
    R: tf.Tensor of shape (B, N, M)
@@ -1397,6 +1462,7 @@ class AtomicConvolution(tf.keras.layers.Layer):
      Gaussian distance matrix mean.
    e: float
      Gaussian distance matrix width.

    Returns
    -------
    retval: tf.Tensor of shape (B, N, M)
@@ -1408,13 +1474,16 @@ class AtomicConvolution(tf.keras.layers.Layer):

  def radial_cutoff(self, R, rc):
    """Calculates radial cutoff matrix.

    B = batch_size, N = max_num_atoms, M = max_num_neighbors

    Parameters
    ----------
      R [B, N, M]: tf.Tensor
        Distance matrix.
      rc: tf.Variable
        Interaction cutoff [Angstrom].

    Returns
    -------
    FC [B, N, M]: tf.Tensor
@@ -1428,7 +1497,9 @@ class AtomicConvolution(tf.keras.layers.Layer):

  def gaussian_distance_matrix(self, R, rs, e):
    """Calculates gaussian distance matrix.

    B = batch_size, N = max_num_atoms, M = max_num_neighbors

    Parameters
    ----------
    R [B, N, M]: tf.Tensor
@@ -1437,6 +1508,7 @@ class AtomicConvolution(tf.keras.layers.Layer):
      Gaussian distance matrix mean.
    e: tf.Variable
      Gaussian distance matrix width (e = .5/std**2).

    Returns
    -------
    retval [B, N, M]: tf.Tensor
@@ -1446,7 +1518,9 @@ class AtomicConvolution(tf.keras.layers.Layer):

  def distance_tensor(self, X, Nbrs, boxsize, B, N, M, d):
    """Calculates distance tensor for batch of molecules.

    B = batch_size, N = max_num_atoms, M = max_num_neighbors, d = num_features

    Parameters
    ----------
    X: tf.Tensor of shape (B, N, d)
@@ -1455,6 +1529,7 @@ class AtomicConvolution(tf.keras.layers.Layer):
      Neighbor list tensor.
    boxsize: float or None
      Simulation box length [Angstrom].

    Returns
    -------
    D: tf.Tensor of shape (B, N, M, d)
@@ -1471,11 +1546,14 @@ class AtomicConvolution(tf.keras.layers.Layer):

  def distance_matrix(self, D):
    """Calcuates the distance matrix from the distance tensor

    B = batch_size, N = max_num_atoms, M = max_num_neighbors, d = num_features

    Parameters
    ----------
    D: tf.Tensor of shape (B, N, M, d)
      Distance tensor.

    Returns
    -------
    R: tf.Tensor of shape (B, N, M)
@@ -1490,11 +1568,14 @@ class AlphaShareLayer(tf.keras.layers.Layer):
  """
  Part of a sluice network. Adds alpha parameters to control
  sharing between the main and auxillary tasks

  Factory method AlphaShare should be used for construction

  Parameters
  ----------
  in_layers: list of Layers or tensors
    tensors in list must be the same size and list must include two or more tensors

  Returns
  -------
  out_tensor: a tensor with shape [len(in_layers), x, y] where x, y were the original layer dimensions
@@ -1577,11 +1658,13 @@ class BetaShare(tf.keras.layers.Layer):
  """
  Part of a sluice network. Adds beta params to control which layer
  outputs are used for prediction

  Parameters
  ----------
  in_layers: list of Layers or tensors
    tensors in list must be the same size and list must include two or
    more tensors

  Returns
  -------
  output_layers: list of Layers or tensors with same size as in_layers
@@ -1802,10 +1885,13 @@ class GraphEmbedPoolLayer(tf.keras.layers.Layer):
  r"""
  GraphCNNPool Layer from Robust Spatial Filtering with Graph Convolutional Neural Networks
  https://arxiv.org/abs/1703.00792

  This is a learnable pool operation It constructs a new adjacency
  matrix for a graph of specified number of nodes.

  This differs from our other pool operations which set vertices to a
  function value without altering the adjacency matrix.

  ..math:: V_{emb} = SpatialGraphCNN({V_{in}})
  ..math:: V_{out} = \sigma(V_{emb})^{T} * V_{in}
  ..math:: A_{out} = V_{emb}^{T} * A_{in} * V_{emb}
@@ -1840,10 +1926,13 @@ class GraphEmbedPoolLayer(tf.keras.layers.Layer):
    in_layers: list of Layers or tensors
      [V, A, mask]
      V are the vertex features must be of shape (batch, vertex, channel)

      A are the adjacency matrixes for each graph
        Shape (batch, from_vertex, adj_matrix, to_vertex)

      mask is optional, to be used when not every graph has the
      same number of vertices

    Returns
    -------
    Returns a `tf.tensor` with a graph convolution applied
@@ -1889,14 +1978,18 @@ class GraphCNN(tf.keras.layers.Layer):
  r"""
  GraphCNN Layer from Robust Spatial Filtering with Graph Convolutional Neural Networks
  https://arxiv.org/abs/1703.00792

  Spatial-domain convolutions can be defined as
  H = h_0I + h_1A + h_2A^2 + ... + hkAk, H ∈ R**(N×N)

  We approximate it by
  H ≈ h_0I + h_1A

  We can define a convolution as applying multiple these linear filters
  over edges of different types (think up, down, left, right, diagonal in images)
  Where each edge type has its own adjacency matrix
  H ≈ h_0I + h_1A_1 + h_2A_2 + . . . h_(L−1)A_(L−1)

  V_out = \sum_{c=1}^{C} H^{c} V^{c} + b
  """

@@ -1906,13 +1999,17 @@ class GraphCNN(tf.keras.layers.Layer):
    ----------
    num_filters: int
      Number of filters to have in the output

    in_layers: list of Layers or tensors
      [V, A, mask]
      V are the vertex features must be of shape (batch, vertex, channel)

      A are the adjacency matrixes for each graph
        Shape (batch, from_vertex, adj_matrix, to_vertex)

      mask is optional, to be used when not every graph has the
      same number of vertices

    Returns: tf.tensor
    Returns a tf.tensor with a graph convolution applied
    The shape will be (batch, vertex, self.num_filters)
@@ -1978,10 +2075,15 @@ class GraphCNN(tf.keras.layers.Layer):

class Highway(tf.keras.layers.Layer):
  """ Create a highway layer. y = H(x) * T(x) + x * (1 - T(x))

  H(x) = activation_fn(matmul(W_H, x) + b_H) is the non-linear transformed output
  T(x) = sigmoid(matmul(W_T, x) + b_T) is the transform gate

  Implementation based on paper

  Srivastava, Rupesh Kumar, Klaus Greff, and Jürgen Schmidhuber. "Highway networks." arXiv preprint arXiv:1505.00387 (2015).


  This layer expects its input to be a two dimensional tensor
  of shape (batch size, # input features).  Outputs will be in
  the same shape.
@@ -2050,43 +2152,62 @@ class Highway(tf.keras.layers.Layer):
class WeaveLayer(tf.keras.layers.Layer):
  """This class implements the core Weave convolution from the
  Google graph convolution paper [1]_

  This model contains atom features and bond features
  separately.Here, bond features are also called pair features.
  There are 2 types of transformation, atom->atom, atom->pair,
  pair->atom, pair->pair that this model implements.

  Examples
  --------
  This layer expects 4 inputs in a list of the form `[atom_features,
  pair_features, pair_split, atom_to_pair]`. We'll walk through the structure
  of these inputs. Let's start with some basic definitions.

  >>> import deepchem as dc
  >>> import numpy as np

  Suppose you have a batch of molecules

  >>> smiles = ["CCC", "C"]

  Note that there are 4 atoms in total in this system. This layer expects its
  input molecules to be batched together.

  >>> total_n_atoms = 4

  Let's suppose that we have a featurizer that computes `n_atom_feat` features
  per atom.

  >>> n_atom_feat = 75

  Then conceptually, `atom_feat` is the array of shape `(total_n_atoms,
  n_atom_feat)` of atomic features. For simplicity, let's just go with a
  random such matrix.

  >>> atom_feat = np.random.rand(total_n_atoms, n_atom_feat)

  Let's suppose we have `n_pair_feat` pairwise features

  >>> n_pair_feat = 14

  For each molecule, we compute a matrix of shape `(n_atoms*n_atoms,
  n_pair_feat)` of pairwise features for each pair of atoms in the molecule.
  Let's construct this conceptually for our example.

  >>> pair_feat = [np.random.rand(3*3, n_pair_feat), np.random.rand(1*1, n_pair_feat)]
  >>> pair_feat = np.concatenate(pair_feat, axis=0)
  >>> pair_feat.shape
  (10, 14)

  `pair_split` is an index into `pair_feat` which tells us which atom each row belongs to. In our case, we hve

  >>> pair_split = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3])

  That is, the first 9 entries belong to "CCC" and the last entry to "C". The
  final entry `atom_to_pair` goes in a little more in-depth than `pair_split`
  and tells us the precise pair each pair feature belongs to. In our case

  >>> atom_to_pair = np.array([[0, 0],
  ...                          [0, 1],
  ...                          [0, 2],
@@ -2097,25 +2218,34 @@ class WeaveLayer(tf.keras.layers.Layer):
  ...                          [2, 1],
  ...                          [2, 2],
  ...                          [3, 3]])

  Let's now define the actual layer

  >>> layer = WeaveLayer()

  And invoke it

  >>> [A, P] = layer([atom_feat, pair_feat, pair_split, atom_to_pair])

  The weave layer produces new atom/pair features. Let's check their shapes

  >>> A = np.array(A)
  >>> A.shape
  (4, 50)
  >>> P = np.array(P)
  >>> P.shape
  (10, 50)

  The 4 is `total_num_atoms` and the 10 is the total number of pairs. Where
  does `50` come from? It's from the default arguments `n_atom_input_feat` and
  `n_pair_input_feat`.

  References
  ----------
  .. [1] Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond
  fingerprints." Journal of computer-aided molecular design 30.8 (2016):
  595-608.

  """

  def __init__(self,
@@ -2208,6 +2338,7 @@ class WeaveLayer(tf.keras.layers.Layer):

  def build(self, input_shape):
    """ Construct internal trainable weights.

    Parameters
    ----------
    input_shape: tuple
@@ -2255,6 +2386,7 @@ class WeaveLayer(tf.keras.layers.Layer):

  def call(self, inputs: List) -> List:
    """Creates weave tensors.

    Parameters
    ----------
    inputs: List
@@ -2318,42 +2450,59 @@ class WeaveLayer(tf.keras.layers.Layer):

class WeaveGather(tf.keras.layers.Layer):
  """Implements the weave-gathering section of weave convolutions.

  Implements the gathering layer from [1]_. The weave gathering layer gathers
  per-atom features to create a molecule-level fingerprint in a weave
  convolutional network. This layer can also performs Gaussian histogram
  expansion as detailed in [1]_. Note that the gathering function here is
  simply addition as in [1]_>

  Examples
  --------
  This layer expects 2 inputs in a list of the form `[atom_features,
  pair_features]`. We'll walk through the structure
  of these inputs. Let's start with some basic definitions.

  >>> import deepchem as dc
  >>> import numpy as np

  Suppose you have a batch of molecules

  >>> smiles = ["CCC", "C"]

  Note that there are 4 atoms in total in this system. This layer expects its
  input molecules to be batched together.

  >>> total_n_atoms = 4

  Let's suppose that we have `n_atom_feat` features per atom. 

  >>> n_atom_feat = 75

  Then conceptually, `atom_feat` is the array of shape `(total_n_atoms,
  n_atom_feat)` of atomic features. For simplicity, let's just go with a
  random such matrix.

  >>> atom_feat = np.random.rand(total_n_atoms, n_atom_feat)

  We then need to provide a mapping of indices to the atoms they belong to. In
  ours case this would be

  >>> atom_split = np.array([0, 0, 0, 1])

  Let's now define the actual layer

  >>> gather = WeaveGather(batch_size=2, n_input=n_atom_feat)
  >>> output_molecules = gather([atom_feat, atom_split])
  >>> len(output_molecules)
  2

  References
  ----------
  .. [1] Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond
  fingerprints." Journal of computer-aided molecular design 30.8 (2016):
  595-608.

  Note
  ----
  This class requires `tensorflow_probability` to be installed.
@@ -2424,10 +2573,12 @@ class WeaveGather(tf.keras.layers.Layer):

  def call(self, inputs: List) -> List:
    """Creates weave tensors.

    Parameters
    ----------
    inputs: List
      Should contain 2 tensors [atom_features, atom_split]

    Returns
    -------
    output_molecules: List 
@@ -2450,20 +2601,24 @@ class WeaveGather(tf.keras.layers.Layer):

  def gaussian_histogram(self, x):
    """Expands input into a set of gaussian histogram bins.

    Parameters
    ----------
    x: tf.Tensor
      Of shape `(N, n_feat)`

    Examples
    --------
    This method uses 11 bins spanning portions of a Gaussian with zero mean
    and unit standard deviation.

    >>> gaussian_memberships = [(-1.645, 0.283), (-1.080, 0.170),
    ...                         (-0.739, 0.134), (-0.468, 0.118),
    ...                         (-0.228, 0.114), (0., 0.114),
    ...                         (0.228, 0.114), (0.468, 0.118),
    ...                         (0.739, 0.134), (1.080, 0.170),
    ...                         (1.645, 0.283)]

    We construct a Gaussian at `gaussian_memberships[i][0]` with standard
    deviation `gaussian_memberships[i][1]`. Each feature in `x` is assigned
    the probability of falling in each Gaussian, and probabilities are
@@ -3111,6 +3266,7 @@ class GatedRecurrentUnit(tf.keras.layers.Layer):

class SetGather(tf.keras.layers.Layer):
  """set2set gather layer for graph-based model

  Models using this layer must set `pad_batches=True`.
  """

@@ -3150,6 +3306,7 @@ class SetGather(tf.keras.layers.Layer):

  def call(self, inputs):
    """Perform M steps of set2set gather,

    Detailed descriptions in: https://arxiv.org/abs/1511.06391
    """
    atom_features, atom_split = inputs