Commit b7b22477 authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

Scalenorm scale to float

parent 7fbd8dbb
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -31,12 +31,12 @@ class ScaleNorm(nn.Module):
  >>> output_tensor = layer(input_tensor)
  """

  def __init__(self, scale: int, eps: float = 1e-5):
  def __init__(self, scale: float, eps: float = 1e-5):
    """Initialize a ScaleNorm layer.

    Parameters
    ----------
    scale: Real number or single element tensor
    scale: float
      Scale magnitude.
    eps: float
      Epsilon value. Default = 1e-5.