Commit ab01f844 authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

Removed kwargs

parent 06a45509
Loading
Loading
Loading
Loading
+2 −4
Original line number Diff line number Diff line
@@ -178,8 +178,7 @@ class MultiHeadedMATAttention(nn.Module):
              adj_matrix: np.ndarray,
              distance_matrix: np.ndarray,
              eps: float = 1e-6,
              inf: float = 1e12,
              **kwargs):
              inf: float = 1e12):
    """Output computation for the MultiHeadedAttention layer.
    Parameters
    ----------
@@ -203,8 +202,7 @@ class MultiHeadedMATAttention(nn.Module):
    ]

    x, _ = self._single_attention(query, key, value, mask, dropout_p,
                                  adj_matrix, distance_matrix, eps, inf,
                                  **kwargs)
                                  adj_matrix, distance_matrix, eps, inf)
    x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)

    return self.output_linear(x)