Commit 7b99647f authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

Code changes + Test + docs

parent b7b22477
Loading
Loading
Loading
Loading
+29 −2
Original line number Diff line number Diff line
@@ -612,7 +612,34 @@ def test_scale_norm():
  """Test invoking ScaleNorm."""
  input_ar = torch.tensor([[1., 99., 10000.], [0.003, 999.37, 23.]])
  layer = torch_layers.ScaleNorm(0.35)
  result1 = layer.forward(input_ar)
  result1 = layer(input_ar)
  output_ar = torch.tensor([[5.9157897e-05, 5.8566318e-03, 5.9157896e-01],
                            [1.7754727e-06, 5.9145141e-01, 1.3611957e-02]])
  assert torch.allclose(result1, output_ar)


@pytest.mark.torch
def test_multi_headed_mat_attention():
  """Test invoking MultiHeadedMATAttention."""
  import rdkit
  torch.manual_seed(0)
  input_smile = "CC"
  mol = rdkit.Chem.rdmolfiles.MolFromSmiles(input_smile)
  adj_matrix = rdkit.Chem.rdmolops.GetAdjacencyMatrix(mol)
  distance_matrix = rdkit.Chem.rdmolops.GetDistanceMatrix(mol)
  layer = torch_layers.MultiHeadedMATAttention(
      dist_kernel='softmax',
      lambda_attention=0.33,
      lambda_distance=0.33,
      h=2,
      hsize=2,
      dropout_p=0.0)
  input_tensor = torch.tensor([[1., 2.], [5., 6.]])
  mask = torch.tensor([[1., 1.], [1., 1.]])
  result = layer(input_tensor, input_tensor, input_tensor, mask, 0.0,
                 adj_matrix, distance_matrix)
  output_ar = torch.tensor([[[0.0492, -0.0792], [-0.9971, -0.3172],
                             [0.0492, -0.0792], [-0.9971, -0.3172]],
                            [[0.8671, 0.1069], [-3.4075, -0.8656],
                             [0.8671, 0.1069], [-3.4075, -0.8656]]])
  assert torch.allclose(result, output_ar, rtol=1e-3)
+16 −9
Original line number Diff line number Diff line
@@ -65,8 +65,14 @@ class MultiHeadedMATAttention(nn.Module):
  Examples
  --------
  >>> import deepchem as dc
  >>> attention = dc.models.torch_models.layers.MATAttention('softmax', 0.33, 0.33')
  >>> self_attn_layer = dc.models.torch_models.layers.MultiHeadedAttention(dist_kernel = 'softmax', lambda_attention = 0.33, lambda_adistance = 0.33, h = 8, hsize = 1024, dropout_p = 0.1)
  >>> import rdkit
  >>> mol = rdkit.Chem.rdmolfiles.MolFromSmiles("CC")
  >>> adj_matrix = rdkit.Chem.rdmolops.GetAdjacencyMatrix(mol)
  >>> distance_matrix = rdkit.Chem.rdmolops.GetDistanceMatrix(mol)
  >>> layer = dc.models.torch_models.layers.MultiHeadedMATAttention(dist_kernel='softmax', lambda_attention=0.33, lambda_distance=0.33, h=2, hsize=2, dropout_p=0.0)
  >>> input_tensor = torch.tensor([[1., 2.], [5., 6.]])
  >>> mask = torch.tensor([[1., 1.], [1., 1.]])
  >>> result = layer(input_tensor, input_tensor, input_tensor, mask, 0.0, adj_matrix, distance_matrix)
  """

  def __init__(self,
@@ -153,18 +159,19 @@ class MultiHeadedMATAttention(nn.Module):
          -inf)
    p_attn = F.softmax(scores, dim=-1)

    adj_matrix = adj_matrix / (adj_matrix.sum(dim=-1).unsqueeze(2) + eps)
    p_adj = adj_matrix.unsqueeze(1).repeat(1, query.shape[1], 1, 1)
    adj_matrix = adj_matrix / (
        torch.sum(torch.tensor(adj_matrix), dim=-1).unsqueeze(1) + eps)
    p_adj = adj_matrix.repeat(1, query.shape[1], 1, 1)

    distance_matrix = distance_matrix.masked_fill(
    distance_matrix = torch.tensor(distance_matrix).masked_fill(
        mask.repeat(1, mask.shape[-1], 1) == 0, np.inf)
    distance_matrix = self.dist_kernel(distance_matrix)
    p_dist = distance_matrix.unsqueeze(1).repeat(1, query.shape[1], 1, 1)

    p_weighted = self.lambda_attention * p_attn + self.lambda_distance * p_dist + self.lambda_adjacency * p_adj
    p_weighted = dropout_p(p_weighted)
    p_weighted = self.dropout_p(p_weighted)

    return torch.matmul(p_weighted, value), p_attn
    bd = value.broadcast_to(p_weighted.shape)
    return torch.matmul(p_weighted.float(), bd.float()), p_attn

  def forward(self,
              query: torch.Tensor,
+3 −0
Original line number Diff line number Diff line
@@ -115,4 +115,7 @@ another tensor. DeepChem maintains an extensive collection of layers which perfo
.. autoclass:: deepchem.models.torch_models.layers.ScaleNorm
  :members:

.. autoclass:: deepchem.models.torch_models.layers.MultiHeadedMATAttention
  :members:

.. autofunction:: deepchem.models.layers.cosine_dist