Unverified Commit 4d32336d authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1872 from deepchem/layers_docs

Improving a few layers docstrings
parents 2d497b1c df3eb7e1
Loading
Loading
Loading
Loading
+155 −46
Original line number Diff line number Diff line
@@ -24,6 +24,13 @@ class InteratomicL2Distances(tf.keras.layers.Layer):
    return config

  def call(self, inputs):
    """Invokes this layer.

    Parameters
    ----------
    inputs: list
      Should be of form `inputs=[coords, nbr_list]` where `coords` is a tensor of shape `(None, N, 3)` and `nbr_list` is a list.
    """
    if len(inputs) != 2:
      raise ValueError("InteratomicDistances requires coords,nbr_list")
    coords, nbr_list = (inputs[0], inputs[1])
@@ -38,6 +45,16 @@ class InteratomicL2Distances(tf.keras.layers.Layer):


class GraphConv(tf.keras.layers.Layer):
  """Graph Convolutional Layers
  
  This layer implements the graph convolution introduced in 

  Duvenaud, David K., et al. "Convolutional networks on graphs for learning molecular fingerprints." Advances in neural information processing systems. 2015. https://arxiv.org/abs/1509.09292
  
  The graph convolution combines per-node feature vectures in a
  nonlinear fashion with the feature vectors for neighboring nodes.
  This "blends" information in local neighborhoods of a graph.
  """

  def __init__(self,
               out_channel,
@@ -45,6 +62,24 @@ class GraphConv(tf.keras.layers.Layer):
               max_deg=10,
               activation_fn=None,
               **kwargs):
    """Initialize a graph convolutional layer.

    Parameters
    ----------
    out_channel: int
      The number of output channels per graph node.
    min_deg: int, optional (default 0)
      The minimum allowed degree for each graph node.
    max_deg: int, optional (default 10)
      The maximum allowed degree for each graph node. Note that this
      is set to 10 to handle complex molecules (some organometallic
      compounds have strange structures). If you're using this for
      non-molecular applications, you may need to set this much higher
      depending on your dataset.
    activation_fn: function
      A nonlinear activation function to apply. If you're not sure,
      `tf.nn.relu` is probably a good default for your application.
    """
    super(GraphConv, self).__init__(**kwargs)
    self.out_channel = out_channel
    self.min_degree = min_deg
@@ -143,8 +178,27 @@ class GraphConv(tf.keras.layers.Layer):


class GraphPool(tf.keras.layers.Layer):
  """A GraphPool gathers data from local neighborhoods of a graph.

  This layer does a max-pooling over the feature vectors of atoms in a
  neighborhood. You can think of this layer as analogous to a max-pooling layer
  for 2D convolutions but which operates on graphs instead.
  """

  def __init__(self, min_degree=0, max_degree=10, **kwargs):
    """Initialize this layer

    Parameters
    ----------
    min_deg: int, optional (default 0)
      The minimum allowed degree for each graph node.
    max_deg: int, optional (default 10)
      The maximum allowed degree for each graph node. Note that this
      is set to 10 to handle complex molecules (some organometallic
      compounds have strange structures). If you're using this for
      non-molecular applications, you may need to set this much higher
      depending on your dataset.
    """
    super(GraphPool, self).__init__(**kwargs)
    self.min_degree = min_degree
    self.max_degree = max_degree
@@ -195,8 +249,36 @@ class GraphPool(tf.keras.layers.Layer):


class GraphGather(tf.keras.layers.Layer):
  """A GraphGather layer pools node-level feature vectors to create a graph feature vector.

  Many graph convolutional networks manipulate feature vectors per
  graph-node. For a molecule for example, each node might represent an
  atom, and the network would manipulate atomic feature vectors that
  summarize the local chemistry of the atom. However, at the end of
  the application, we will likely want to work with a molecule level
  feature representation. The `GraphGather` layer creates a graph level
  feature vector by combining all the node-level feature vectors.

  One subtlety about this layer is that it depends on the
  `batch_size`. This is done for internal implementation reasons. The
  `GraphConv`, and `GraphPool` layers pool all nodes from all graphs
  in a batch that's being processed. The `GraphGather` reassembles
  these jumbled node feature vectors into per-graph feature vectors.
  """

  def __init__(self, batch_size, activation_fn=None, **kwargs):
    """Initialize this layer.

    Parameters
    ---------
    batch_size: int
      The batch size for this layer. Note that the layer's behavior
      changes depending on the batch size.
    activation_fn: function
      A nonlinear activation function to apply. If you're not sure,
      `tf.nn.relu` is probably a good default for your application.
    """

    super(GraphGather, self).__init__(**kwargs)
    self.batch_size = batch_size
    self.activation_fn = activation_fn
@@ -208,7 +290,15 @@ class GraphGather(tf.keras.layers.Layer):
    return config

  def call(self, inputs):
    # x = [atom_features, deg_slice, membership, deg_adj_list placeholders...]
    """Invoking this layer.

    Parameters
    ----------
    inputs: list
      This list should consist of `inputs = [atom_features, deg_slice,
      membership, deg_adj_list placeholders...]`. These are all
      tensors that are created/process by `GraphConv` and `GraphPool`
    """
    atom_features = inputs[0]

    # Extract graph topology
@@ -507,16 +597,15 @@ class IterRefLSTMEmbedding(tf.keras.layers.Layer):
    Parameters
    ----------
    inputs: list
      List of two tensors (X, Xp). X should be of shape (n_test, n_feat) and
      Xp should be of shape (n_support, n_feat) where n_test is the size of
      the test set, n_support that of the support set, and n_feat is the number
      of per-atom features.
      List of two tensors (X, Xp). X should be of shape (n_test,
      n_feat) and Xp should be of shape (n_support, n_feat) where
      n_test is the size of the test set, n_support that of the
      support set, and n_feat is the number of per-atom features.

    Returns
    -------
    list
      Returns two tensors of same shape as input. Namely the output shape will
      be [(n_test, n_feat), (n_support, n_feat)]
    Returns two tensors of same shape as input. Namely the output
    shape will be [(n_test, n_feat), (n_support, n_feat)]
    """
    if len(inputs) != 2:
      raise ValueError(
@@ -560,10 +649,11 @@ class IterRefLSTMEmbedding(tf.keras.layers.Layer):
class SwitchedDropout(tf.keras.layers.Layer):
  """Apply dropout based on an input.

  This is required for uncertainty prediction.  The standard Keras Dropout
  layer only performs dropout during training, but we sometimes need to do it
  during prediction.  The second input to this layer should be a scalar equal to
  0 or 1, indicating whether to perform dropout.
  This is required for uncertainty prediction.  The standard Keras
  Dropout layer only performs dropout during training, but we
  sometimes need to do it during prediction.  The second input to this
  layer should be a scalar equal to 0 or 1, indicating whether to
  perform dropout.
  """

  def __init__(self, rate, **kwargs):
@@ -584,6 +674,13 @@ class WeightedLinearCombo(tf.keras.layers.Layer):
  """Computes a weighted linear combination of input layers, with the weights defined by trainable variables."""

  def __init__(self, std=0.3, **kwargs):
    """Initialize this layer.

    Parameters
    ----------
    std: float, optional (default 0.3)
      The standard deviation to use when randomly initializing weights.
    """
    super(WeightedLinearCombo, self).__init__(**kwargs)
    self.std = std

@@ -617,17 +714,18 @@ class CombineMeanStd(tf.keras.layers.Layer):
  def __init__(self, training_only=False, noise_epsilon=1.0, **kwargs):
    """Create a CombineMeanStd layer.

    This layer should have two inputs with the same shape, and its output also has the
    same shape.  Each element of the output is a Gaussian distributed random number
    whose mean is the corresponding element of the first input, and whose standard
    deviation is the corresponding element of the second input.
    This layer should have two inputs with the same shape, and its
    output also has the same shape.  Each element of the output is a
    Gaussian distributed random number whose mean is the corresponding
    element of the first input, and whose standard deviation is the
    corresponding element of the second input.

    Parameters
    ----------
    training_only: bool
      if True, noise is only generated during training.  During prediction, the output
      is simply equal to the first input (that is, the mean of the distribution used
      during training).
      if True, noise is only generated during training.  During
      prediction, the output is simply equal to the first input (that
      is, the mean of the distribution used during training).
    noise_epsilon: float
      The noise is scaled by this factor
    """
@@ -671,10 +769,10 @@ class Stack(tf.keras.layers.Layer):
class Variable(tf.keras.layers.Layer):
  """Output a trainable value.

  Due to a quirk of Keras, you must pass an input value when invoking this layer.
  It doesn't matter what value you pass.  Keras assumes every layer that is not
  an Input will have at least one parent, and violating this assumption causes
  errors during evaluation.
  Due to a quirk of Keras, you must pass an input value when invoking
  this layer.  It doesn't matter what value you pass.  Keras assumes
  every layer that is not an Input will have at least one parent, and
  violating this assumption causes errors during evaluation.
  """

  def __init__(self, initial_value, **kwargs):
@@ -830,8 +928,11 @@ class VinaFreeEnergy(tf.keras.layers.Layer):
class NeighborList(tf.keras.layers.Layer):
  """Computes a neighbor-list in Tensorflow.

  Neighbor-lists (also called Verlet Lists) are a tool for grouping atoms which
  are close to each other spatially
  Neighbor-lists (also called Verlet Lists) are a tool for grouping
  atoms which are close to each other spatially. This layer computes a
  Neighbor List from a provided tensor of atomic coordinates. You can
  think of this as a general "k-means" layer, but optimized for the
  case `k==3`.

  TODO(rbharath): Make this layer support batching.
  """
@@ -1121,9 +1222,12 @@ class NeighborList(tf.keras.layers.Layer):
class AtomicConvolution(tf.keras.layers.Layer):
  """Implements the atomic convolutional transform introduced in

  Gomes, Joseph, et al. "Atomic convolutional networks for predicting protein-ligand binding affinity." arXiv preprint arXiv:1703.10603 (2017).
  Gomes, Joseph, et al. "Atomic convolutional networks for predicting
  protein-ligand binding affinity." arXiv preprint arXiv:1703.10603
  (2017).

  At a high level, this transform performs a sort of graph convolution on the nearest neighbors graph in 3D space.
  At a high level, this transform performs a graph convolution
  on the nearest neighbors graph in 3D space.
  """

  def __init__(self,
@@ -1433,7 +1537,8 @@ class BetaShare(tf.keras.layers.Layer):
  Parameters
  ----------
  in_layers: list of Layers or tensors
    tensors in list must be the same size and list must include two or more tensors
    tensors in list must be the same size and list must include two or
    more tensors

  Returns
  -------
@@ -1656,15 +1761,15 @@ class GraphEmbedPoolLayer(tf.keras.layers.Layer):
  GraphCNNPool Layer from Robust Spatial Filtering with Graph Convolutional Neural Networks
  https://arxiv.org/abs/1703.00792

  This is a learnable pool operation
  It constructs a new adjacency matrix for a graph of specified number of nodes.
  This is a learnable pool operation It constructs a new adjacency
  matrix for a graph of specified number of nodes.

  This differs from our other pool opertions which set vertices to a function value
  without altering the adjacency matrix.
  This differs from our other pool operations which set vertices to a
  function value without altering the adjacency matrix.

  $V_{emb} = SpatialGraphCNN({V_{in}})$\\
  $V_{out} = \sigma(V_{emb})^{T} * V_{in}$
  $A_{out} = V_{emb}^{T} * A_{in} * V_{emb}$
  ..math:: V_{emb} = SpatialGraphCNN({V_{in}})
  ..math:: V_{out} = \sigma(V_{emb})^{T} * V_{in}
  ..math:: A_{out} = V_{emb}^{T} * A_{in} * V_{emb}
  """

  def __init__(self, num_vertices, **kwargs):
@@ -1693,7 +1798,6 @@ class GraphEmbedPoolLayer(tf.keras.layers.Layer):
    ----------
    num_filters: int
      Number of filters to have in the output

    in_layers: list of Layers or tensors
      [V, A, mask]
      V are the vertex features must be of shape (batch, vertex, channel)
@@ -1704,9 +1808,10 @@ class GraphEmbedPoolLayer(tf.keras.layers.Layer):
      mask is optional, to be used when not every graph has the
      same number of vertices

    Returns: tf.tensor
    Returns a tf.tensor with a graph convolution applied
    The shape will be (batch, vertex, self.num_filters)
    Returns
    -------
    Returns a `tf.tensor` with a graph convolution applied
    The shape will be `(batch, vertex, self.num_filters)`.
    """
    if len(inputs) == 3:
      V, A, mask = inputs
@@ -2761,7 +2866,9 @@ class GatedRecurrentUnit(tf.keras.layers.Layer):

class SetGather(tf.keras.layers.Layer):
  """set2set gather layer for graph-based model
  model using this layer must set pad_batches=True """

  Models using this layer must set `pad_batches=True`.
  """

  def __init__(self, M, batch_size, n_hidden=100, init='orthogonal', **kwargs):
    """
@@ -2799,7 +2906,9 @@ class SetGather(tf.keras.layers.Layer):

  def call(self, inputs):
    """Perform M steps of set2set gather,
        detailed descriptions in: https://arxiv.org/abs/1511.06391 """

    Detailed descriptions in: https://arxiv.org/abs/1511.06391
    """
    atom_features, atom_split = inputs
    c = tf.zeros((self.batch_size, self.n_hidden))
    h = tf.zeros((self.batch_size, self.n_hidden))