Commit a3b4a3aa authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

changes

parent 40604945
Loading
Loading
Loading
Loading
+28 −21
Original line number Diff line number Diff line
@@ -33,9 +33,7 @@ class WeaveModel(KerasModel):
  """Implements Google-style Weave Graph Convolutions

  This model implements the Weave style graph convolutions
  from the following paper.

  Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond fingerprints." Journal of computer-aided molecular design 30.8 (2016): 595-608.
  from [1]_.

  The biggest difference between WeaveModel style convolutions
  and GraphConvModel style convolutions is that Weave
@@ -44,17 +42,24 @@ class WeaveModel(KerasModel):
  explicitly to model bond interactions. This may cause
  scaling issues, but may possibly allow for better modeling
  of subtle bond effects.

  References
  ----------
  .. [1] Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond
  fingerprints." Journal of computer-aided molecular design 30.8 (2016):
  595-608.

  """

  def __init__(self,
               n_tasks,
               n_atom_feat=75,
               n_pair_feat=14,
               n_hidden=50,
               n_graph_feat=128,
               mode="classification",
               n_classes=2,
               batch_size=100,
               n_tasks: int,
               n_atom_feat: int = 75,
               n_pair_feat: int = 14,
               n_hidden: int = 50,
               n_graph_feat: int = 128,
               mode: str = "classification",
               n_classes: int = 2,
               batch_size: int = 100,
               **kwargs):
    """
    Parameters
@@ -660,6 +665,8 @@ class GraphConvModel(KerasModel):
  following paper [1]_. These graph convolutions start with a per-atom set of
  descriptors for each atom in a molecule, then combine and recombine these
  descriptors over convolutional layers.
  following [1]_.


  References
  ----------
@@ -669,16 +676,16 @@ class GraphConvModel(KerasModel):
  """

  def __init__(self,
               n_tasks,
               graph_conv_layers=[64, 64],
               dense_layer_size=128,
               dropout=0.0,
               mode="classification",
               number_atom_features=75,
               n_classes=2,
               batch_size=100,
               batch_normalize=True,
               uncertainty=False,
               n_tasks: int,
               graph_conv_layers: List[int] = [64, 64],
               dense_layer_size: int = 128,
               dropout: float = 0.0,
               mode: str = "classification",
               number_atom_features: int = 75,
               n_classes: int = 2,
               batch_size: int = 100,
               batch_normalize: bool = True,
               uncertainty: bool = False,
               **kwargs):
    """The wrapper class for graph convolutions.

+69 −35
Original line number Diff line number Diff line
@@ -9,7 +9,18 @@ from tensorflow.keras.layers import Dropout
class InteratomicL2Distances(tf.keras.layers.Layer):
  """Compute (squared) L2 Distances between atoms given neighbors."""

  def __init__(self, N_atoms, M_nbrs, ndim, **kwargs):
  def __init__(self, N_atoms: int, M_nbrs: int, ndim: int, **kwargs):
    """Constructor for this layer.

    Parameters
    ----------
    N_atoms: int
      Number of atoms in the system total.
    M_nbrs: int
      Number of neighbors to consider when computing distances.
    n_dim:  int
      Number of descriptors for each atom.
    """
    super(InteratomicL2Distances, self).__init__(**kwargs)
    self.N_atoms = N_atoms
    self.M_nbrs = M_nbrs
@@ -48,18 +59,21 @@ class GraphConv(tf.keras.layers.Layer):
  
  This layer implements the graph convolution introduced in 

  Duvenaud, David K., et al. "Convolutional networks on graphs for learning molecular fingerprints." Advances in neural information processing systems. 2015. https://arxiv.org/abs/1509.09292
  
  The graph convolution combines per-node feature vectures in a
  nonlinear fashion with the feature vectors for neighboring nodes.
  This "blends" information in local neighborhoods of a graph.

  References
  ----------
  .. [1] Duvenaud, David K., et al. "Convolutional networks on graphs for learning molecular fingerprints." Advances in neural information processing systems. 2015. https://arxiv.org/abs/1509.09292
  
  """

  def __init__(self,
               out_channel,
               min_deg=0,
               max_deg=10,
               activation_fn=None,
               out_channel: int,
               min_deg: int = 0,
               max_deg: int = 10,
               activation_fn: Callable = None,
               **kwargs):
    """Initialize a graph convolutional layer.

@@ -2027,28 +2041,33 @@ class Highway(tf.keras.layers.Layer):

class WeaveLayer(tf.keras.layers.Layer):
  """This class implements the core Weave convolution from the
  Google graph convolution paper.

  Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond fingerprints." Journal of computer-aided molecular design 30.8 (2016): 595-608.
  Google graph convolution paper [1]_

  This model contains atom features and bond features
  separately.Here, bond features are also called pair features.
  There are 2 types of transformation, atom->atom, atom->pair,
  pair->atom, pair->pair that this model implements.

  References
  ----------
  .. [1] Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond
  fingerprints." Journal of computer-aided molecular design 30.8 (2016):
  595-608.

  """

  def __init__(self,
               n_atom_input_feat=75,
               n_pair_input_feat=14,
               n_atom_output_feat=50,
               n_pair_output_feat=50,
               n_hidden_AA=50,
               n_hidden_PA=50,
               n_hidden_AP=50,
               n_hidden_PP=50,
               update_pair=True,
               init='glorot_uniform',
               activation='relu',
               n_atom_input_feat: int = 75,
               n_pair_input_feat: int = 14,
               n_atom_output_feat: int = 50,
               n_pair_output_feat: int = 50,
               n_hidden_AA: int = 50,
               n_hidden_PA: int = 50,
               n_hidden_AP: int = 50,
               n_hidden_PP: int = 50,
               update_pair: bool = True,
               init: str = 'glorot_uniform',
               activation: str = 'relu',
               **kwargs):
    """
    Parameters
@@ -2140,10 +2159,14 @@ class WeaveLayer(tf.keras.layers.Layer):
      ])
    self.built = True

  def call(self, inputs):
  def call(self, inputs: List):
    """Creates weave tensors.

    inputs: [atom_features, pair_features, pair_split, atom_to_pair]
    Parameters
    ----------
    inputs: List
      Should contain 4 tensors [atom_features, pair_features, pair_split,
      atom_to_pair]
    """
    atom_features = inputs[0]
    pair_features = inputs[1]
@@ -2187,24 +2210,27 @@ class WeaveLayer(tf.keras.layers.Layer):
class WeaveGather(tf.keras.layers.Layer):
  """Implements the weave-gathering section of weave convolutions.

  Implements the gathering layer from the following paper:

  Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond
  fingerprints." Journal of computer-aided molecular design 30.8 (2016): 595-608.
  Implements the gathering layer from [1]_.

  The weave gathering layer gathers per-atom features to create a
  molecule-level fingerprint in a weave convolutional network. This layer can
  also perform Gaussian histogram expansion as detailed in the original paper.

  References
  ----------
  .. [1] Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond
  fingerprints." Journal of computer-aided molecular design 30.8 (2016):
  595-608.
  """

  def __init__(self,
               batch_size,
               n_input=128,
               gaussian_expand=False,
               init='glorot_uniform',
               activation='tanh',
               epsilon=1e-3,
               momentum=0.99,
               batch_size: int,
               n_input: int = 128,
               gaussian_expand: bool = False,
               init: str = 'glorot_uniform',
               activation: str = 'tanh',
               epsilon: float = 1e-3,
               momentum: float = 0.99,
               **kwargs):
    """
    Parameters
@@ -2254,6 +2280,14 @@ class WeaveGather(tf.keras.layers.Layer):
    self.built = True

  def call(self, inputs):
    """Creates weave tensors.

    Parameters
    ----------
    inputs: List
      Should contain 4 tensors [atom_features, pair_features, pair_split,
      atom_to_pair]
    """
    outputs = inputs[0]
    atom_split = inputs[1]