Commit fe2909ea authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

Type annotations

parent a6b6d95c
Loading
Loading
Loading
Loading
+18 −19
Original line number Diff line number Diff line
@@ -28,11 +28,11 @@ class ScaleNorm(nn.Module):
  >>> from deepchem.models.torch_models.layers import ScaleNorm
  >>> scale = 0.35
  >>> layer = dc.models.torch_models.layers.ScaleNorm(scale)
  >>> input_tensor = torch.Tensor([[1.269, 39.36], [0.00918, -9.12]])
  >>> input_tensor = torch.tensor([[1.269, 39.36], [0.00918, -9.12]])
  >>> output_tensor = layer(input_tensor)
  """

  def __init__(self, scale: float, eps: float = 1e-5):
  def __init__(self, scale: int, eps: float = 1e-5):
    """Initialize a ScaleNorm layer.

    Parameters
@@ -46,7 +46,7 @@ class ScaleNorm(nn.Module):
    self.scale = nn.Parameter(torch.tensor(math.sqrt(scale)))
    self.eps = eps

  def forward(self, x: torch.Tensor) -> torch.Tensor:
  def forward(self, x: torch.Tensor):
    norm = self.scale / torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
    return x * norm

@@ -67,7 +67,7 @@ class MATEmbedding(nn.Module):
  >>> layer = MATEmbedding(d_input = 1024, d_output = 1024, dropout_p = 0.2)
  """

  def __init__(self, *, d_input: int, d_output: int, dropout_p: float):
  def __init__(self, d_input: int, d_output: int, dropout_p: float):
    """Initialize a MATEmbedding layer.

    Parameters
@@ -79,12 +79,11 @@ class MATEmbedding(nn.Module):
    dropout_p: float
      Dropout probability for layer.
    """

    super(MATEmbedding, self).__init__()
    self.lut = nn.Linear(d_input, d_output)
    self.linear_unit = nn.Linear(d_input, d_output)
    self.dropout = nn.Dropout(dropout_p)

  def forward(self, x):
  def forward(self, x: torch.Tensor):
    """Computation for the MATEmbedding layer.

    Parameters
@@ -92,7 +91,7 @@ class MATEmbedding(nn.Module):
    x: torch.Tensor
      Input tensor to be converted into a vector.
    """
    return self.dropout(self.lut(x))
    return self.dropout(self.linear_unit(x))


class MATGenerator(nn.Module):
@@ -112,14 +111,13 @@ class MATGenerator(nn.Module):
  """

  def __init__(self,
               *,
               hsize,
               aggregation_type,
               d_output,
               n_layers,
               dropout_p,
               attn_hidden=128,
               attn_out=4):
               hsize: int,
               aggregation_type: str,
               d_output: int,
               n_layers: int,
               dropout_p: float,
               attn_hidden: int = 128,
               attn_out: int = 4):
    """Initialize a MATGenerator.

    Parameters
@@ -137,7 +135,6 @@ class MATGenerator(nn.Module):
    attn_out: int
      Size of output attention layer.
    """

    super(MATGenerator, self).__init__()

    if aggregation_type == 'grover':
@@ -153,6 +150,7 @@ class MATGenerator(nn.Module):

    else:
      self.proj = []

      for i in range(n_layers - 1):
        self.proj.append(nn.Linear(hsize, attn_hidden))
        self.proj.append(nn.LeakyReLU(negative_slope=0.1))
@@ -162,7 +160,7 @@ class MATGenerator(nn.Module):
      self.proj = torch.nn.Sequential(*self.proj)
    self.aggregation_type = aggregation_type

  def forward(self, x, mask):
  def forward(self, x: torch.Tensor, mask: torch.Tensor):
    """Computation for the MATGenerator layer.

    Parameters
@@ -172,13 +170,13 @@ class MATGenerator(nn.Module):
    mask: torch.Tensor
      Mask for padding so that padded values do not get included in attention score calculation.
    """

    mask = mask.unsqueeze(-1).float()
    out_masked = x * mask
    if self.aggregation_type == 'mean':
      out_sum = out_masked.sum(dim=1)
      mask_sum = mask.sum(dim=(1))
      out_avg_pooling = out_sum / mask_sum

    elif self.aggregation_type == 'grover':
      out_attn = self.att_net(out_masked)
      out_attn = out_attn.masked_fill(mask == 0, -1e9)
@@ -186,6 +184,7 @@ class MATGenerator(nn.Module):
      out_avg_pooling = torch.matmul(
          torch.transpose(out_attn, -1, -2), out_masked)
      out_avg_pooling = out_avg_pooling.view(out_avg_pooling.size(0), -1)

    elif self.aggregation_type == 'contextual':
      out_avg_pooling = x
    projected = self.proj(out_avg_pooling)