Commit fb5127db authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

Type annotations

parent 3145032b
Loading
Loading
Loading
Loading
+30 −32
Original line number Diff line number Diff line
@@ -31,7 +31,7 @@ class ScaleNorm(nn.Module):
  >>> output_tensor = layer(input_tensor)
  """

  def __init__(self, scale, eps=1e-5):
  def __init__(self, scale: int, eps: float = 1e-5):
    """Initialize a ScaleNorm layer.

    Parameters
@@ -70,13 +70,13 @@ class MultiHeadedMATAttention(nn.Module):
  """

  def __init__(self,
               dist_kernel,
               lambda_attention,
               lambda_distance,
               h,
               hsize,
               dropout_p,
               output_bias=True):
               dist_kernel: str,
               lambda_attention: float,
               lambda_distance: float,
               h: int,
               hsize: int,
               dropout_p: float,
               output_bias: bool = True):
    """Initialize a multi-headed attention layer.

    Parameters
@@ -96,7 +96,6 @@ class MultiHeadedMATAttention(nn.Module):
    output_bias: bool
      If True, dense layers will use bias vectors.
    """

    super().__init__()
    if dist_kernel == "softmax":
      self.dist_kernel = lambda x: torch.softmax(-x, dim=-1)
@@ -112,16 +111,16 @@ class MultiHeadedMATAttention(nn.Module):
    self.dropout_p = nn.Dropout(dropout_p)
    self.output_linear = nn.Linear(hsize, hsize, output_bias)

  def _singleAttention(self,
                       query,
                       key,
                       value,
                       mask,
                       dropout_p,
                       adj_matrix,
                       distance_matrix,
                       eps=1e-6,
                       inf=1e12):
  def _single_attention(self,
                        query: torch.Tensor,
                        key: torch.Tensor,
                        value: torch.Tensor,
                        mask: torch.Tensor,
                        dropout_p: float,
                        adj_matrix: np.ndarray,
                        distance_matrix: np.ndarray,
                        eps: float = 1e-6,
                        inf: float = 1e12):
    """Defining and computing output for a single MAT attention layer.

    Parameters
@@ -145,7 +144,6 @@ class MultiHeadedMATAttention(nn.Module):
    inf: float
      Value of infinity to be used.
    """

    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

@@ -169,15 +167,15 @@ class MultiHeadedMATAttention(nn.Module):
    return torch.matmul(p_weighted, value), p_attn

  def forward(self,
              query,
              key,
              value,
              mask,
              dropout_p,
              adj_matrix,
              distance_matrix,
              eps=1e-6,
              inf=1e12,
              query: torch.Tensor,
              key: torch.Tensor,
              value: torch.Tensor,
              mask: torch.Tensor,
              dropout_p: float,
              adj_matrix: np.ndarray,
              distance_matrix: np.ndarray,
              eps: float = 1e-6,
              inf: float = 1e12,
              **kwargs):
    """Output computation for the MultiHeadedAttention layer.

@@ -192,7 +190,6 @@ class MultiHeadedMATAttention(nn.Module):
    mask: torch.Tensor
      Masks out padding values so that they are not taken into account when computing the attention score.
    """

    if mask is not None:
      mask = mask.unsqueeze(1)

@@ -203,8 +200,9 @@ class MultiHeadedMATAttention(nn.Module):
        for layer, x in zip(self.linear_layers, (query, key, value))
    ]

    x, _ = self._singleAttention(query, key, value, mask, dropout_p, adj_matrix,
                                 distance_matrix, eps, inf, **kwargs)
    x, _ = self._single_attention(query, key, value, mask, dropout_p,
                                  adj_matrix, distance_matrix, eps, inf,
                                  **kwargs)
    x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)

    return self.output_linear(x)