Commit d87cbb1c authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

scalenorm forward annotation

parent fe97bf6c
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -61,7 +61,11 @@ class ScaleNorm(nn.Module):
    self.scale = nn.Parameter(torch.tensor(math.sqrt(scale)))
    self.eps = eps

<<<<<<< HEAD
  def forward(self, x: torch.Tensor):
=======
  def forward(self, x : torch.Tensor):
>>>>>>> scalenorm forward annotation
    norm = self.scale / torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
    return x * norm