Commit 1aa112e2 authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

Review changes

parent 3dc4465f
Loading
Loading
Loading
Loading
+12 −9
Original line number Diff line number Diff line
@@ -42,8 +42,12 @@ class ScaleNorm(nn.Module):
  >>> output_tensor = layer(input_tensor)
  """

<<<<<<< HEAD
  def __init__(self, scale: int, eps: float = 1e-5):
>>>>>>> Type annotations
=======
  def __init__(self, scale: float, eps: float = 1e-5):
>>>>>>> Review changes
    """Initialize a ScaleNorm layer.

    Parameters
@@ -457,10 +461,9 @@ class PositionwiseFeedForward(nn.Module):
  """

  def __init__(self,
               *,
               d_input: int,
               d_hidden: int = None,
               d_output: int = None,
               d_hidden: int,
               d_output: int,
               activation: str,
               n_layers: int,
               dropout_p: float):
@@ -470,9 +473,9 @@ class PositionwiseFeedForward(nn.Module):
    ----------
    d_input: int
      Size of input layer.
    d_hidden: int
    d_hidden: int (same as d_input if d_output = 0)
      Size of hidden layer.
    d_output: int
    d_output: int (same as d_input if d_output = 0)
      Size of output layer.
    activation: str
      Activation function to be used. Can choose between 'relu' for ReLU, 'leakyrelu' for LeakyReLU, 'prelu' for PReLU,
@@ -506,8 +509,8 @@ class PositionwiseFeedForward(nn.Module):
      self.activation = lambda x: x

    self.n_layers = n_layers
    d_output = d_output if d_output is not None else d_input
    d_hidden = d_hidden if d_hidden is not None else d_input
    d_output = d_output if d_output is not 0 else d_input
    d_hidden = d_hidden if d_hidden is not 0 else d_input

    if n_layers == 1:
      self.linears = [nn.Linear(d_input, d_output)]
@@ -530,10 +533,10 @@ class PositionwiseFeedForward(nn.Module):
    x: torch.Tensor
      Input tensor.
    """
    if self.n_layers == 0:
    if not self.n_layers:
      return x

    elif self.n_layers == 1:
    if self.n_layers == 1:
      return self.dropout_p[0](self.act_func(self.linears[0](x)))

    else: