Unverified Commit 204fbd9e authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2624 from atreyamaj/GeneratorPR

[WIP] MAT Layers: Embedding + Generator
parents 92947e03 6bc0d70b
Loading
Loading
Loading
Loading
+23 −0
Original line number Diff line number Diff line
@@ -705,3 +705,26 @@ def test_mat_encoder_layer():
                            [[5.0000, 6.0000], [3.0000, 8.0000],
                             [5.0000, 6.0000], [3.0000, 8.0000]]])
  assert torch.allclose(result, output_ar, rtol=1e-4)


@pytest.mark.torch
def test_mat_embedding():
  """Test invoking MATEmbedding."""
  torch.manual_seed(0)
  input_ar = torch.tensor([1., 2., 3.])
  layer = torch_layers.MATEmbedding(3, 1, 0.0)
  result = layer(input_ar).detach()
  output_ar = torch.tensor([-1.2353])
  assert torch.allclose(result, output_ar, rtol=1e-4)


@pytest.mark.torch
def test_mat_generator():
  """Test invoking MATGenerator."""
  torch.manual_seed(0)
  input_ar = torch.tensor([1., 2., 3.])
  layer = torch_layers.MATGenerator(3, 'mean', 1, 1, 0.0)
  mask = torch.tensor([1., 1., 1.])
  result = layer(input_ar, mask)
  output_ar = torch.tensor([-1.4436])
  assert torch.allclose(result, output_ar, rtol=1e-4)
+148 −0
Original line number Diff line number Diff line
@@ -503,3 +503,151 @@ class PositionwiseFeedForward(nn.Module):
      for i in range(self.n_layers - 1):
        x = self.dropout_p[i](self.activation(self.linears[i](x)))
      return self.linears[-1](x)


class MATEmbedding(nn.Module):
  """Embedding layer to create embedding for inputs.

  In an embedding layer, input is taken and converted to a vector representation for each input.
  In the MATEmbedding layer, an input tensor is processed through a dropout-adjusted linear layer and the resultant vector is returned.

  References
  ----------
  .. [1] Lukasz Maziarka et al. "Molecule Attention Transformer" Graph Representation Learning workshop and Machine Learning and the Physical Sciences workshop at NeurIPS 2019. 2020. https://arxiv.org/abs/2002.08264

  Examples
  --------
  >>> from deepchem.models.torch_models.layers import MATEmbedding
  >>> layer = MATEmbedding(d_input = 3, d_output = 3, dropout_p = 0.2)
  >>> input_tensor = torch.tensor([1., 2., 3.])
  >>> output = layer(input_tensor)
  """

  def __init__(self,
               d_input: int = 36,
               d_output: int = 1024,
               dropout_p: float = 0.0):
    """Initialize a MATEmbedding layer.

    Parameters
    ----------
    d_input: int
      Size of input layer.
    d_output: int
      Size of output layer.
    dropout_p: float
      Dropout probability for layer.
    """
    super(MATEmbedding, self).__init__()
    self.linear_unit = nn.Linear(d_input, d_output)
    self.dropout = nn.Dropout(dropout_p)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Computation for the MATEmbedding layer.

    Parameters
    ----------
    x: torch.Tensor
      Input tensor to be converted into a vector.
    """
    return self.dropout(self.linear_unit(x))


class MATGenerator(nn.Module):
  """MATGenerator defines the linear and softmax generator step for the Molecular Attention Transformer [1]_.

  In the MATGenerator, a Generator is defined which performs the Linear + Softmax generation step.
  Depending on the type of aggregation selected, the attention output layer performs different operations.

  References
  ----------
  .. [1] Lukasz Maziarka et al. "Molecule Attention Transformer" Graph Representation Learning workshop and Machine Learning and the Physical Sciences workshop at NeurIPS 2019. 2020. https://arxiv.org/abs/2002.08264

  Examples
  --------
  >>> from deepchem.models.torch_models.layers import MATGenerator
  >>> layer = MATGenerator(hsize = 3, aggregation_type = 'mean', d_output = 1, n_layers = 1, dropout_p = 0.3, attn_hidden = 128, attn_out = 4)
  >>> input_tensor = torch.tensor([1., 2., 3.])
  >>> mask = torch.tensor([1., 1., 1.])
  >>> output = layer(input_tensor, mask)
  """

  def __init__(self,
               hsize: int = 1024,
               aggregation_type: str = 'mean',
               d_output: int = 1,
               n_layers: int = 1,
               dropout_p: float = 0.0,
               attn_hidden: int = 128,
               attn_out: int = 4):
    """Initialize a MATGenerator.

    Parameters
    ----------
    hsize: int
      Size of input layer.
    aggregation_type: str
      Type of aggregation to be used. Can be 'grover', 'mean' or 'contextual'.
    d_output: int
      Size of output layer.
    dropout_p: float
      Dropout probability for layer.
    attn_hidden: int
      Size of hidden attention layer.
    attn_out: int
      Size of output attention layer.
    """
    super(MATGenerator, self).__init__()

    if aggregation_type == 'grover':
      self.att_net = nn.Sequential(
          nn.Linear(hsize, attn_hidden, bias=False),
          nn.Tanh(),
          nn.Linear(attn_hidden, attn_out, bias=False),
      )
      hsize *= attn_out

    if n_layers == 1:
      self.proj: Any = nn.Linear(hsize, d_output)

    else:
      self.proj = []

      for i in range(n_layers - 1):
        self.proj.append(nn.Linear(hsize, attn_hidden))
        self.proj.append(nn.LeakyReLU(negative_slope=0.1))
        self.proj.append(nn.LayerNorm(attn_hidden))
        self.proj.append(nn.Dropout(dropout_p))
      self.proj.append(nn.Linear(attn_hidden, d_output))
      self.proj = torch.nn.Sequential(*self.proj)
    self.aggregation_type = aggregation_type

  def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    """Computation for the MATGenerator layer.

    Parameters
    ----------
    x: torch.Tensor
      Input tensor.
    mask: torch.Tensor
      Mask for padding so that padded values do not get included in attention score calculation.
    """
    mask = mask.unsqueeze(-1).float()
    out_masked = x * mask
    if self.aggregation_type == 'mean':
      out_sum = out_masked.sum(dim=1)
      mask_sum = mask.sum(dim=(1))
      out_avg_pooling = out_sum / mask_sum

    elif self.aggregation_type == 'grover':
      out_attn = self.att_net(out_masked)
      out_attn = out_attn.masked_fill(mask == 0, -1e9)
      out_attn = F.softmax(out_attn, dim=1)
      out_avg_pooling = torch.matmul(
          torch.transpose(out_attn, -1, -2), out_masked)
      out_avg_pooling = out_avg_pooling.view(out_avg_pooling.size(0), -1)

    elif self.aggregation_type == 'contextual':
      out_avg_pooling = x
    projected = self.proj(out_avg_pooling)
    return projected
+6 −1
Original line number Diff line number Diff line
@@ -137,6 +137,12 @@ Torch Layers
.. autoclass:: deepchem.models.torch_models.layers.PositionwiseFeedForward
  :members:

.. autoclass:: deepchem.models.torch_models.layers.MATEmbedding
  :members:

.. autoclass:: deepchem.models.torch_models.layers.MATGenerator
  :members:

.. autofunction:: deepchem.models.layers.cosine_dist

Jax Layers
@@ -144,4 +150,3 @@ Jax Layers

.. autoclass:: deepchem.models.jax_models.layers.Linear
  :members: