Commit 7dcf2ea4 authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

scalenorm forward annotation

parent fb5127db
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -46,7 +46,7 @@ class ScaleNorm(nn.Module):
    self.scale = nn.Parameter(torch.tensor(math.sqrt(scale)))
    self.eps = eps

  def forward(self, x):
  def forward(self, x : torch.Tensor):
    norm = self.scale / torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
    return x * norm