Commit 1f71d5de authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

Mypy fixes

parent eaa5d81a
Loading
Loading
Loading
Loading
+7 −3
Original line number Diff line number Diff line
@@ -733,7 +733,7 @@ class PositionwiseFeedForward(nn.Module):
               d_input: int = 1024,
               d_hidden: int = 1024,
               d_output: int = 1024,
               activation: Any = 'leakyrelu',
               activation: str = 'leakyrelu',
               n_layers: int = 1,
               dropout_p: float = 0.0):
    """Initialize a PositionwiseFeedForward layer.
@@ -757,7 +757,7 @@ class PositionwiseFeedForward(nn.Module):
    super(PositionwiseFeedForward, self).__init__()

    if activation == 'relu':
      self.activation = nn.ReLU()
      self.activation: Any = nn.ReLU()

    elif activation == 'leakyrelu':
      self.activation = nn.LeakyReLU(0.1)
@@ -778,17 +778,21 @@ class PositionwiseFeedForward(nn.Module):
    elif activation == "linear":
      self.activation = lambda x: x

<<<<<<< HEAD
    self.n_layers = n_layers
<<<<<<< HEAD
    d_output = d_output if d_output is not None else d_input
    d_hidden = d_hidden if d_hidden is not None else d_input
=======
=======
    self.n_layers: int = n_layers
>>>>>>> Mypy fixes
    d_output = d_output if d_output != 0 else d_input
    d_hidden = d_hidden if d_hidden != 0 else d_input
>>>>>>> Rebase+Update

    if n_layers == 1:
      self.linears = [nn.Linear(d_input, d_output)]
      self.linears: Any = [nn.Linear(d_input, d_output)]

    else:
      self.linears = [nn.Linear(d_input, d_hidden)] + \