Unverified Commit 8aa48aed authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2623 from atreyamaj/encoderPR

[WIP] MAT Layers: Encoder
parents 500467de 939161ff
Loading
Loading
Loading
Loading
+64 −2
Original line number Diff line number Diff line
@@ -636,10 +636,72 @@ def test_multi_headed_mat_attention():
      dropout_p=0.0)
  input_tensor = torch.tensor([[1., 2.], [5., 6.]])
  mask = torch.tensor([[1., 1.], [1., 1.]])
  result = layer(input_tensor, input_tensor, input_tensor, mask, 0.0,
                 adj_matrix, distance_matrix)
  result = layer(input_tensor, input_tensor, input_tensor, mask, adj_matrix,
                 distance_matrix, 0.0)
  output_ar = torch.tensor([[[0.0492, -0.0792], [-0.9971, -0.3172],
                             [0.0492, -0.0792], [-0.9971, -0.3172]],
                            [[0.8671, 0.1069], [-3.4075, -0.8656],
                             [0.8671, 0.1069], [-3.4075, -0.8656]]])
  assert torch.allclose(result, output_ar, rtol=1e-3)


@pytest.mark.torch
def test_position_wise_feed_forward():
  """Test invoking PositionwiseFeedForward."""
  torch.manual_seed(0)
  input_ar = torch.tensor([[1., 2.], [5., 6.]])
  layer = torch_layers.PositionwiseFeedForward(
      d_input=2,
      d_hidden=2,
      d_output=2,
      activation='relu',
      n_layers=1,
      dropout_p=0.0)
  result = layer(input_ar)
  output_ar = torch.tensor([[0.4810, 0.0000], [1.9771, 0.0000]])
  assert torch.allclose(result, output_ar, rtol=1e-4)


@pytest.mark.torch
def test_sub_layer_connection():
  """Test invoking SublayerConnection."""
  torch.manual_seed(0)
  input_ar = torch.tensor([[1., 2.], [5., 6.]])
  layer = torch_layers.SublayerConnection(2, 0.0)
  result = layer(input_ar, input_ar)
  output_ar = torch.tensor([[2.0027e-05, 3.0000e+00], [4.0000e+00, 7.0000e+00]])
  assert torch.allclose(result, output_ar)


@pytest.mark.torch
def test_mat_encoder_layer():
  """Test invoking MATEncoderLayer."""
  torch.manual_seed(0)
  from rdkit import Chem
  input_ar = torch.Tensor([[1., 2.], [5., 6.]])
  mask = torch.Tensor([[1., 1.], [1., 1.]])
  mol = Chem.MolFromSmiles("CC")
  adj_matrix = Chem.GetAdjacencyMatrix(mol)
  distance_matrix = Chem.GetDistanceMatrix(mol)
  layer = torch_layers.MATEncoderLayer(
      dist_kernel='softmax',
      lambda_attention=0.33,
      lambda_distance=0.33,
      h=2,
      sa_hsize=2,
      sa_dropout_p=0.0,
      output_bias=True,
      d_input=2,
      d_hidden=2,
      d_output=2,
      activation='relu',
      n_layers=2,
      ff_dropout_p=0.0,
      encoder_hsize=2,
      encoder_dropout_p=0.0)
  result = layer(input_ar, mask, adj_matrix, distance_matrix, 0.0)
  output_ar = torch.tensor([[[0.9988, 2.0012], [-0.9999, 3.9999],
                             [0.9988, 2.0012], [-0.9999, 3.9999]],
                            [[5.0000, 6.0000], [3.0000, 8.0000],
                             [5.0000, 6.0000], [3.0000, 8.0000]]])
  assert torch.allclose(result, output_ar, rtol=1e-4)
+318 −22
Original line number Diff line number Diff line
import math
import numpy as np
from typing import Any, Tuple
try:
  import torch
  import torch.nn as nn
@@ -24,9 +25,9 @@ class ScaleNorm(nn.Module):

  Examples
  --------
  >>> import deepchem as dc
  >>> from deepchem.models.torch_models.layers import ScaleNorm
  >>> scale = 0.35
  >>> layer = dc.models.torch_models.layers.ScaleNorm(scale)
  >>> layer = ScaleNorm(scale)
  >>> input_tensor = torch.tensor([[1.269, 39.36], [0.00918, -9.12]])
  >>> output_tensor = layer(input_tensor)
  """
@@ -41,12 +42,11 @@ class ScaleNorm(nn.Module):
    eps: float
      Epsilon value. Default = 1e-5.
    """

    super(ScaleNorm, self).__init__()
    self.scale = nn.Parameter(torch.tensor(math.sqrt(scale)))
    self.eps = eps

  def forward(self, x: torch.Tensor):
  def forward(self, x: torch.Tensor) -> torch.Tensor:
    norm = self.scale / torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
    return x * norm

@@ -62,25 +62,24 @@ class MultiHeadedMATAttention(nn.Module):
  .. [1] Lukasz Maziarka et al. "Molecule Attention Transformer" Graph Representation Learning workshop and Machine Learning and the Physical Sciences workshop at NeurIPS 2019. 2020. https://arxiv.org/abs/2002.08264
  Examples
  --------
  >>> import deepchem as dc
  >>> from deepchem.models.torch_models.layers import MultiHeadedMATAttention
  >>> from rdkit import Chem
  >>> mol = Chem.MolFromSmiles("CC")
  >>> adj_matrix = Chem.GetAdjacencyMatrix(mol)
  >>> distance_matrix = Chem.GetDistanceMatrix(mol)
  >>> import deepchem.models.torch_models.layers
  >>> layer = dc.models.torch_models.layers.MultiHeadedMATAttention(dist_kernel='softmax', lambda_attention=0.33, lambda_distance=0.33, h=2, hsize=2, dropout_p=0.0)
  >>> layer = MultiHeadedMATAttention(dist_kernel='softmax', lambda_attention=0.33, lambda_distance=0.33, h=2, hsize=2, dropout_p=0.0)
  >>> input_tensor = torch.tensor([[1., 2.], [5., 6.]])
  >>> mask = torch.tensor([[1., 1.], [1., 1.]])
  >>> result = layer(input_tensor, input_tensor, input_tensor, mask, 0.0, adj_matrix, distance_matrix)
  >>> result = layer(input_tensor, input_tensor, input_tensor, mask, adj_matrix, distance_matrix, 0.0)
  """

  def __init__(self,
               dist_kernel: str,
               lambda_attention: float,
               lambda_distance: float,
               h: int,
               hsize: int,
               dropout_p: float,
               dist_kernel: str = 'softmax',
               lambda_attention: float = 0.33,
               lambda_distance: float = 0.33,
               h: int = 16,
               hsize: int = 1024,
               dropout_p: float = 0.0,
               output_bias: bool = True):
    """Initialize a multi-headed attention layer.
    Parameters
@@ -120,11 +119,11 @@ class MultiHeadedMATAttention(nn.Module):
                        key: torch.Tensor,
                        value: torch.Tensor,
                        mask: torch.Tensor,
                        dropout_p: float,
                        adj_matrix: np.ndarray,
                        distance_matrix: np.ndarray,
                        dropout_p: float = 0.0,
                        eps: float = 1e-6,
                        inf: float = 1e12):
                        inf: float = 1e12) -> Tuple[torch.Tensor, torch.Tensor]:
    """Defining and computing output for a single MAT attention layer.
    Parameters
    ----------
@@ -136,12 +135,12 @@ class MultiHeadedMATAttention(nn.Module):
      Standard value parameter for attention.
    mask: torch.Tensor
      Masks out padding values so that they are not taken into account when computing the attention score.
    dropout_p: float
      Dropout probability.
    adj_matrix: np.ndarray
      Adjacency matrix of the input molecule, returned from dc.feat.MATFeaturizer()
    dist_matrix: np.ndarray
      Distance matrix of the input molecule, returned from dc.feat.MATFeaturizer()
    dropout_p: float
      Dropout probability.
    eps: float
      Epsilon value
    inf: float
@@ -175,11 +174,11 @@ class MultiHeadedMATAttention(nn.Module):
              key: torch.Tensor,
              value: torch.Tensor,
              mask: torch.Tensor,
              dropout_p: float,
              adj_matrix: np.ndarray,
              distance_matrix: np.ndarray,
              dropout_p: float = 0.0,
              eps: float = 1e-6,
              inf: float = 1e12):
              inf: float = 1e12) -> torch.Tensor:
    """Output computation for the MultiHeadedAttention layer.
    Parameters
    ----------
@@ -191,6 +190,16 @@ class MultiHeadedMATAttention(nn.Module):
      Standard value parameter for attention.
    mask: torch.Tensor
      Masks out padding values so that they are not taken into account when computing the attention score.
    adj_matrix: np.ndarray
      Adjacency matrix of the input molecule, returned from dc.feat.MATFeaturizer()
    dist_matrix: np.ndarray
      Distance matrix of the input molecule, returned from dc.feat.MATFeaturizer()
    dropout_p: float
      Dropout probability.
    eps: float
      Epsilon value
    inf: float
      Value of infinity to be used.
    """
    if mask is not None:
      mask = mask.unsqueeze(1)
@@ -202,8 +211,295 @@ class MultiHeadedMATAttention(nn.Module):
        for layer, x in zip(self.linear_layers, (query, key, value))
    ]

    x, _ = self._single_attention(query, key, value, mask, dropout_p,
                                  adj_matrix, distance_matrix, eps, inf)
    x, _ = self._single_attention(query, key, value, mask, adj_matrix,
                                  distance_matrix, dropout_p, eps, inf)
    x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)

    return self.output_linear(x)


class MATEncoderLayer(nn.Module):
  """Encoder layer for use in the Molecular Attention Transformer [1]_.

  The MATEncoder layer primarily consists of a self-attention layer (MultiHeadedMATAttention) and a feed-forward layer (PositionwiseFeedForward).
  This layer can be stacked multiple times to form an encoder.

  References
  ----------
  .. [1] Lukasz Maziarka et al. "Molecule Attention Transformer" Graph Representation Learning workshop and Machine Learning and the Physical Sciences workshop at NeurIPS 2019. 2020. https://arxiv.org/abs/2002.08264

  Examples
  --------
  >>> from deepchem.models.torch_models.layers import MATEncoderLayer
  >>> from rdkit import Chem
  >>> mol = Chem.MolFromSmiles("CC")
  >>> adj_matrix = Chem.GetAdjacencyMatrix(mol)
  >>> distance_matrix = Chem.GetDistanceMatrix(mol)
  >>> layer = MATEncoderLayer(dist_kernel='softmax', lambda_attention=0.33, lambda_distance=0.33, h=2, sa_hsize=2, sa_dropout_p=0.0, output_bias=True, d_input=2, d_hidden=2, d_output=2, activation='relu', n_layers=2, ff_dropout_p=0.0, encoder_hsize=2, encoder_dropout_p=0.0)
  >>> x = torch.Tensor([[1., 2.], [5., 6.]])
  >>> mask = torch.Tensor([[1., 1.], [1., 1.]])
  >>> output = layer(x, mask, adj_matrix = adj_matrix, distance_matrix = distance_matrix, sa_dropout_p = 0.0)
  """

  def __init__(self,
               dist_kernel: str = 'softmax',
               lambda_attention: float = 0.33,
               lambda_distance: float = 0.33,
               h: int = 16,
               sa_hsize: int = 1024,
               sa_dropout_p: float = 0.0,
               output_bias: bool = True,
               d_input: int = 1024,
               d_hidden: int = 1024,
               d_output: int = 1024,
               activation: Any = nn.LeakyReLU(),
               n_layers: int = 1,
               ff_dropout_p: float = 0.0,
               encoder_hsize: int = 1024,
               encoder_dropout_p: float = 0.0):
    """Initialize a MATEncoder layer.

    Parameters
    ----------
    dist_kernel: str
      Kernel activation to be used. Can be either 'softmax' for softmax or 'exp' for exponential, for the self-attention layer.
    lambda_attention: float
      Constant to be multiplied with the attention matrix in the self-attention layer.
    lambda_distance: float
      Constant to be multiplied with the distance matrix in the self-attention layer.
    h: int
      Number of attention heads for the self-attention layer.
    sa_hsize: int
      Size of dense layer in the self-attention layer.
    sa_dropout_p: float
      Dropout probability for the self-attention layer.
    output_bias: bool
      If True, dense layers will use bias vectors in the self-attention layer.
    d_input: int
      Size of input layer in the feed-forward layer.
    d_hidden: int
      Size of hidden layer in the feed-forward layer.
    d_output: int
      Size of output layer in the feed-forward layer.
    activation: str
      Activation function to be used in the feed-forward layer.
      Can choose between 'relu' for ReLU, 'leakyrelu' for LeakyReLU, 'prelu' for PReLU,
      'tanh' for TanH, 'selu' for SELU, 'elu' for ELU and 'linear' for linear activation.
    n_layers: int
      Number of layers in the feed-forward layer.
    dropout_p: float
      Dropout probability in the feeed-forward layer.
    encoder_hsize: int
      Size of Dense layer for the encoder itself.
    encoder_dropout_p: float
      Dropout probability for connections in the encoder layer.
    """
    super(MATEncoderLayer, self).__init__()
    self.self_attn = MultiHeadedMATAttention(dist_kernel, lambda_attention,
                                             lambda_distance, h, sa_hsize,
                                             sa_dropout_p, output_bias)
    self.feed_forward = PositionwiseFeedForward(
        d_input, d_hidden, d_output, activation, n_layers, ff_dropout_p)
    layer = SublayerConnection(size=encoder_hsize, dropout_p=encoder_dropout_p)
    self.sublayer = nn.ModuleList([layer for _ in range(2)])
    self.size = encoder_hsize

  def forward(self,
              x: torch.Tensor,
              mask: torch.Tensor,
              adj_matrix: np.ndarray,
              distance_matrix: np.ndarray,
              sa_dropout_p: float = 0.0) -> torch.Tensor:
    """Output computation for the MATEncoder layer.

    In the MATEncoderLayer intialization, self.sublayer is defined as an nn.ModuleList of 2 layers. We will be passing our computation through these layers sequentially.
    nn.ModuleList is subscriptable and thus we can access it as self.sublayer[0], for example.

    Parameters
    ----------
    x: torch.Tensor
      Input tensor.
    mask: torch.Tensor
      Masks out padding values so that they are not taken into account when computing the attention score.
    adj_matrix: np.ndarray
      Adjacency matrix of a molecule.
    distance_matrix: np.ndarray
      Distance matrix of a molecule.
    sa_dropout_p: float
      Dropout probability for the self-attention layer (MultiHeadedMATAttention).
    """
    x = self.sublayer[0](x,
                         self.self_attn(
                             x,
                             x,
                             x,
                             mask=mask,
                             dropout_p=sa_dropout_p,
                             adj_matrix=adj_matrix,
                             distance_matrix=distance_matrix))
    return self.sublayer[1](x, self.feed_forward(x))


class SublayerConnection(nn.Module):
  """SublayerConnection layer which establishes a residual connection, as used in the Molecular Attention Transformer [1]_.

  The SublayerConnection layer is a residual layer which is then passed through Layer Normalization.
  The residual connection is established by computing the dropout-adjusted layer output of a normalized tensor and adding this to the original input tensor.

  References
  ----------
  .. [1] Lukasz Maziarka et al. "Molecule Attention Transformer" Graph Representation Learning workshop and Machine Learning and the Physical Sciences workshop at NeurIPS 2019. 2020. https://arxiv.org/abs/2002.08264

  Examples
  --------
  >>> from deepchem.models.torch_models.layers import SublayerConnection
  >>> scale = 0.35
  >>> layer = SublayerConnection(2, 0.)
  >>> input_ar = torch.tensor([[1., 2.], [5., 6.]])
  >>> output = layer(input_ar, input_ar)
  """

  def __init__(self, size: int, dropout_p: float = 0.0):
    """Initialize a SublayerConnection Layer.

    Parameters
    ----------
    size: int
      Size of layer.
    dropout_p: float
      Dropout probability.
    """
    super(SublayerConnection, self).__init__()
    self.norm = nn.LayerNorm(size)
    self.dropout_p = nn.Dropout(dropout_p)

  def forward(self, x: torch.Tensor, output: torch.Tensor) -> torch.Tensor:
    """Output computation for the SublayerConnection layer.

    Takes an input tensor x, then adds the dropout-adjusted sublayer output for normalized x to it.
    This is done to add a residual connection followed by LayerNorm.

    Parameters
    ----------
    x: torch.Tensor
      Input tensor.
    output: torch.Tensor
      Layer whose normalized output will be added to x.
    """
    if x is None:
      return self.dropout_p(self.norm(output))

    if len(x.shape) < len(output.shape):
      temp_ar = x
      op_ar = output
      adjusted = temp_ar.unsqueeze(1).repeat(1, op_ar.shape[1], 1)
    elif len(x.shape) > len(output.shape):
      temp_ar = output
      op_ar = x
      adjusted = temp_ar.unsqueeze(1).repeat(1, op_ar.shape[1], 1)
    else:
      return x + self.dropout_p(self.norm(output))

    return adjusted + self.dropout_p(self.norm(op_ar))


class PositionwiseFeedForward(nn.Module):
  """PositionwiseFeedForward is a layer used to define the position-wise feed-forward (FFN) algorithm for the Molecular Attention Transformer [1]_

  Each layer in the MAT encoder contains a fully connected feed-forward network which applies two linear transformations and the given activation function.
  This is done in addition to the SublayerConnection module.

  References
  ----------
  .. [1] Lukasz Maziarka et al. "Molecule Attention Transformer" Graph Representation Learning workshop and Machine Learning and the Physical Sciences workshop at NeurIPS 2019. 2020. https://arxiv.org/abs/2002.08264

  Examples
  --------
  >>> from deepchem.models.torch_models.layers import PositionwiseFeedForward
  >>> feed_fwd_layer = PositionwiseFeedForward(d_input = 2, d_hidden = 2, d_output = 2, activation = 'relu', n_layers = 1, dropout_p = 0.1)
  >>> input_tensor = torch.tensor([[1., 2.], [5., 6.]])
  >>> output_tensor = feed_fwd_layer(input_tensor)
  """

  def __init__(self,
               d_input: int = 1024,
               d_hidden: int = 1024,
               d_output: int = 1024,
               activation: str = 'leakyrelu',
               n_layers: int = 1,
               dropout_p: float = 0.0):
    """Initialize a PositionwiseFeedForward layer.

    Parameters
    ----------
    d_input: int
      Size of input layer.
    d_hidden: int (same as d_input if d_output = 0)
      Size of hidden layer.
    d_output: int (same as d_input if d_output = 0)
      Size of output layer.
    activation: str
      Activation function to be used. Can choose between 'relu' for ReLU, 'leakyrelu' for LeakyReLU, 'prelu' for PReLU,
      'tanh' for TanH, 'selu' for SELU, 'elu' for ELU and 'linear' for linear activation.
    n_layers: int
      Number of layers.
    dropout_p: float
      Dropout probability.
    """
    super(PositionwiseFeedForward, self).__init__()

    if activation == 'relu':
      self.activation: Any = nn.ReLU()

    elif activation == 'leakyrelu':
      self.activation = nn.LeakyReLU(0.1)

    elif activation == 'prelu':
      self.activation = nn.PReLU()

    elif activation == 'tanh':
      self.activation = nn.Tanh()

    elif activation == 'selu':
      self.activation = nn.SELU()

    elif activation == 'elu':
      self.activation = nn.ELU()

    elif activation == "linear":
      self.activation = lambda x: x

    self.n_layers: int = n_layers
    d_output = d_output if d_output != 0 else d_input
    d_hidden = d_hidden if d_hidden != 0 else d_input

    if n_layers == 1:
      self.linears: Any = [nn.Linear(d_input, d_output)]

    else:
      self.linears = [nn.Linear(d_input, d_hidden)] + \
                      [nn.Linear(d_hidden, d_hidden) for _ in range(n_layers - 2)] + \
                      [nn.Linear(d_hidden, d_output)]

    self.linears = nn.ModuleList(self.linears)
    dropout_layer = nn.Dropout(dropout_p)
    self.dropout_p = nn.ModuleList([dropout_layer for _ in range(n_layers)])

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Output Computation for the PositionwiseFeedForward layer.

    Parameters
    ----------
    x: torch.Tensor
      Input tensor.
    """
    if not self.n_layers:
      return x

    if self.n_layers == 1:
      return self.dropout_p[0](self.activation(self.linears[0](x)))

    else:
      for i in range(self.n_layers - 1):
        x = self.dropout_p[i](self.activation(self.linears[i](x)))
      return self.linears[-1](x)
+9 −0
Original line number Diff line number Diff line
@@ -125,9 +125,18 @@ Torch Layers
.. autoclass:: deepchem.models.torch_models.layers.ScaleNorm
  :members:

.. autoclass:: deepchem.models.torch_models.layers.MATEncoderLayer
  :members:

.. autoclass:: deepchem.models.torch_models.layers.MultiHeadedMATAttention
  :members:

.. autoclass:: deepchem.models.torch_models.layers.SublayerConnection
  :members:

.. autoclass:: deepchem.models.torch_models.layers.PositionwiseFeedForward
  :members:

.. autofunction:: deepchem.models.layers.cosine_dist

Jax Layers