Commit 20859008 authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

Removing an unwanted function

parent 34436fc0
Loading
Loading
Loading
Loading
+3 −50
Original line number Diff line number Diff line
@@ -109,52 +109,6 @@ class LCNNBlock(nn.Module):
    return node_feats


class AtomWiseLinear(nn.Module):
  """
  Performs Matrix Multiplication
  It is used to transform each node wise feature into a scalar.

  """

  def __init__(self,
               input_feature: int,
               output_feature: int,
               dropout: float = 0.0,
               UseBN: bool = True):
    """
    Parameters
    ----------
    input_feature: int
        Size of input feature size
    output_feature: int
        Size of output feature size
    dropout: float
        p value for dropout between 0.0 to 1.0
    UseBN: bool
        Setting it to True will perform Batch Normalisation
    """

    super(AtomWiseLinear, self).__init__()
    self.conv_weights = nn.Linear(input_feature, output_feature)

  def forward(self, node_feats):
    """
    Update node representations.

    Parameters
    ----------
    node_feats: torch.Tensor
        The node features. The shape is `(N, Node_feature_size)`.

    Returns
    -------
    node_feats: torch.Tensor
        The updated node features. The shape is `(N, Node_feature_size)`.
    """
    node_feats = self.conv_weights(node_feats)
    return node_feats


class Atom_Wise_Convolution(nn.Module):
  """
  Performs self convolution to each node
@@ -401,8 +355,9 @@ class LCNN(nn.Module):

    self.LCNN_blocks = nn.Sequential(*modules)
    self.Atom_wise_Conv = Atom_Wise_Convolution(n_features, sitewise_n_feature)
    self.Atom_wise_Lin = AtomWiseLinear(sitewise_n_feature, sitewise_n_feature)
    self.Atom_wise_Lin = nn.Linear(sitewise_n_feature, sitewise_n_feature)
    self.fc = nn.Linear(sitewise_n_feature, n_task)
    self.activation = Shifted_softplus()

  def forward(self, G):
    """
@@ -423,13 +378,11 @@ class LCNN(nn.Module):
      raise ImportError("This class requires DGL to be installed.")
    G = G.local_var()
    node_feats = G.ndata.pop('x')

    for conv in self.LCNN_blocks:
      node_feats = conv(G, node_feats)
    node_feats = self.Atom_wise_Conv(node_feats)
    node_feats = self.Atom_wise_Lin(node_feats)
    G.ndata['new'] = node_feats

    G.ndata['new'] = self.activation(node_feats)
    y = dgl.mean_nodes(G, 'new')
    y = self.fc(y)
    return y