Commit 6d5adc89 authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

doc fix after first review

parent 6f257639
Loading
Loading
Loading
Loading
+5 −5
Original line number Diff line number Diff line
@@ -230,10 +230,10 @@ class PagtnMolGraphFeaturizer(MolecularFeaturizer):

  The featurization is based on `PAGTN model <https://arxiv.org/abs/1905.12712>`_. It is
  slightly more computationally intensive than default Graph Convolution Featuriser, but it
  builds a Molecular Graph connecting all tom pairs accounting for interactions of atom with
  builds a Molecular Graph connecting all atom pairs accounting for interactions of an atom with
  every other atom in the Molecule. According to the paper, interactions between two pairs
  of an atom are dependent on the relative distance between them and calculating the shortest
  path between them.
  of atom are dependent on the relative distance between them and and hence, the function needs
  to calculate the shortest path between them.

  The default node representation is constructed by concatenating the following values,
  and the feature length is 94.
@@ -247,9 +247,9 @@ class PagtnMolGraphFeaturizer(MolecularFeaturizer):
    include ``0 - 5``.
  - Aromaticity: Boolean representing if an atom is aromatic.

  The default edge representation are constructed by concatenating the following values,
  The default edge representation is constructed by concatenating the following values,
  and the feature length is 42. It builds a complete graph where each node is connected to
  every other node. The edge representations are calculated the shortest path between two nodes
  every other node. The edge representations are calculated based on the shortest path between two nodes
  (choose any one if multiple exist). Each bond encountered in the shortest path is used to
  calculate edge features.

+19 −11
Original line number Diff line number Diff line
@@ -17,8 +17,9 @@ class Pagtn(nn.Module):
    linear additive form of attention is applied. Attention Weights are derived
    by concatenating the node and edge features for each bond.
  * Update node representations with multiple rounds of message passing.
  * For each layer has, residual connections with its previous layer and finally compute its
    representation by combining the representations of all nodes in it.
  * For each layer has, residual connections with its previous layer.
  * The final molecular representation is computed by combining the representations
    of all nodes in the molecule.
  * Perform the final prediction using a linear layer

  Examples
@@ -58,7 +59,7 @@ class Pagtn(nn.Module):
               number_bond_features: int = 42,
               mode: str = 'regression',
               n_classes: int = 2,
               ouput_node_features: int = 256,
               output_node_features: int = 256,
               hidden_features: int = 32,
               num_layers: int = 5,
               num_heads: int = 1,
@@ -80,7 +81,7 @@ class Pagtn(nn.Module):
    n_classes: int
      The number of classes to predict per task
      (only used when ``mode`` is 'classification'). Default to 2.
    ouput_node_features : int
    output_node_features : int
      Size for the output node features in PAGTN layers. Default to 256.
    hidden_features : int
      Size for the hidden node features in PAGTN layers. Default to 32.
@@ -129,7 +130,7 @@ class Pagtn(nn.Module):

    self.model = DGLPAGTNPredictor(
        node_in_feats=number_atom_features,
        node_out_feats=ouput_node_features,
        node_out_feats=output_node_features,
        node_hid_feats=hidden_features,
        edge_feats=number_bond_features,
        depth=num_layers,
@@ -189,8 +190,9 @@ class PagtnModel(TorchModel):
    linear additive form of attention is applied. Attention Weights are derived
    by concatenating the node and edge features for each bond.
  * Update node representations with multiple rounds of message passing.
  * For each layer has, residual connections with its previous layer and finally compute its
    representation by combining the representations of all nodes in it.
  * For each layer has, residual connections with its previous layer.
  * The final molecular representation is computed by combining the representations
    of all nodes in the molecule.
  * Perform the final prediction using a linear layer

  Examples
@@ -224,6 +226,8 @@ class PagtnModel(TorchModel):
               number_bond_features: int = 42,
               mode: str = 'regression',
               n_classes: int = 2,
               output_node_features: int = 256,
               hidden_features: int = 32,
               num_layers: int = 5,
               num_heads: int = 1,
               dropout: float = 0.1,
@@ -243,6 +247,10 @@ class PagtnModel(TorchModel):
    n_classes: int
      The number of classes to predict per task
      (only used when ``mode`` is 'classification'). Default to 2.
    output_node_features : int
      Size for the output node features in PAGTN layers. Default to 256.
    hidden_features : int
      Size for the hidden node features in PAGTN layers. Default to 32.
    num_layers: int
      Number of graph neural network layers, i.e. number of rounds of message passing.
      Default to 2.
@@ -261,6 +269,8 @@ class PagtnModel(TorchModel):
        number_bond_features=number_bond_features,
        mode=mode,
        n_classes=n_classes,
        output_node_features=output_node_features,
        hidden_features=hidden_features,
        num_layers=num_layers,
        num_heads=num_heads,
        dropout=dropout,
@@ -275,7 +285,7 @@ class PagtnModel(TorchModel):
        model, loss=loss, output_types=output_types, **kwargs)

  def _prepare_batch(self, batch):
    """Create batch data for AttentiveFP.
    """Create batch data for Pagtn.

    Parameters
    ----------
@@ -297,9 +307,7 @@ class PagtnModel(TorchModel):
      raise ImportError('This class requires dgl.')

    inputs, labels, weights = batch
    dgl_graphs = [
        graph.to_dgl_graph(self_loop=self._self_loop) for graph in inputs[0]
    ]
    dgl_graphs = [graph.to_dgl_graph() for graph in inputs[0]]
    inputs = dgl.batch(dgl_graphs).to(self.device)
    _, labels, weights = super(PagtnModel, self)._prepare_batch(([], labels,
                                                                 weights))