Commit 904c8fad authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

Update

parent a8a22c22
Loading
Loading
Loading
Loading
+3 −9
Original line number Diff line number Diff line
@@ -389,14 +389,8 @@ class MATEncoderLayer(nn.Module):
    self.sublayer = nn.ModuleList([layer for _ in range(2)])
    self.size = encoder_hsize

  def forward(
      self,
      x: torch.Tensor,
      mask: torch.Tensor,
      sa_dropout_p: float,
      adj_matrix: np.ndarray,
      distance_matrix: np.ndarray
  ):
  def forward(self, x: torch.Tensor, mask: torch.Tensor, sa_dropout_p: float,
              adj_matrix: np.ndarray, distance_matrix: np.ndarray):
    """Output computation for the MATEncoder layer.

    Parameters