Commit 5e7f3cea authored by mufeili's avatar mufeili
Browse files

Update

parent c22d6fac
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -75,7 +75,7 @@ def test_attentivefp_reload():
      batch_size=10,
      learning_rate=0.001)

  model.fit(dataset, nb_epoch=60)
  model.fit(dataset, nb_epoch=70)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85

+1 −1
Original line number Diff line number Diff line
@@ -57,7 +57,7 @@ def test_gat_classification():
      learning_rate=0.001)

  # overfit test
  model.fit(dataset, nb_epoch=60)
  model.fit(dataset, nb_epoch=70)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85

+1 −1
Original line number Diff line number Diff line
@@ -34,7 +34,7 @@ def test_gcn_regression():
      batch_size=10)

  # overfit test
  model.fit(dataset, nb_epoch=100)
  model.fit(dataset, nb_epoch=110)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.5

+1 −0
Original line number Diff line number Diff line
@@ -4,3 +4,4 @@ from deepchem.models.torch_models.attentivefp import AttentiveFP, AttentiveFPMod
from deepchem.models.torch_models.cgcnn import CGCNN, CGCNNModel
from deepchem.models.torch_models.gat import GAT, GATModel
from deepchem.models.torch_models.gcn import GCN, GCNModel
from deepchem.models.torch_models.mpnn import MPNN
+162 −0
Original line number Diff line number Diff line
"""
DGL-based MPNN for graph property prediction.
"""
import torch.nn as nn
import torch.nn.functional as F

from deepchem.models.losses import Loss, L2Loss, SparseSoftmaxCrossEntropy
from deepchem.models.torch_models.torch_model import TorchModel


class MPNN(nn.Module):
  """Model for Graph Property Prediction.

  This model proceeds as follows:

  * Combine latest node representations and edge features in updating node representations,
    which involves multiple rounds of message passing
  * For each graph, compute its representation by combining the representations
    of all nodes in it, which involves a Set2Set layer.
  * Perform the final prediction using an MLP

  Examples
  --------

  >>> import deepchem as dc
  >>> import dgl
  TODO

  References
  ----------
  .. [1] Justin Gilmer, Samuel S. Schoenholz, Patrick F. Riley, Oriol Vinyals, George E. Dahl.
         "Neural Message Passing for Quantum Chemistry." ICML 2017.

  Notes
  -----
  This class requires DGL (https://github.com/dmlc/dgl) and DGL-LifeSci
  (https://github.com/awslabs/dgl-lifesci) to be installed.
  """

  def __init__(self,
               n_tasks: int,
               node_out_feats: int = 64,
               edge_hidden_feats: int = 128,
               num_step_message_passing: int = 3,
               num_step_set2set: int = 6,
               num_layer_set2set: int = 3,
               mode: str = 'regression',
               number_atom_features: int = 30,
               number_bond_features: int = 11,
               n_classes: int = 2,
               nfeat_name: str = 'x',
               efeat_name: str = 'edge_attr'):
    """
    Parameters
    ----------
    n_tasks: int
      Number of tasks.
    node_out_feats: int
      The length of the final node representation vectors. Default to 64.
    edge_hidden_feats: int
      The length of the hidden edge representation vectors. Default to 128.
    num_step_message_passing: int
      The number of rounds of message passing. Default to 3.
    num_step_set2set: int
      The number of set2set steps. Default to 6.
    num_layer_set2set: int
      The number of set2set layers. Default to 3.
    mode: str
      The model type, 'classification' or 'regression'. Default to 'regression'.
    number_atom_features: int
      The length of the initial atom feature vectors. Default to 30.
    number_bond_features: int
      The length of the initial bond feature vectors. Default to 11.
    n_classes: int
      The number of classes to predict per task
      (only used when ``mode`` is 'classification'). Default to 2.
    nfeat_name: str
      For an input graph ``g``, the model assumes that it stores node features in
      ``g.ndata[nfeat_name]`` and will retrieve input node features from that.
      Default to 'x'.
    efeat_name: str
      For an input graph ``g``, the model assumes that it stores edge features in
      ``g.edata[efeat_name]`` and will retrieve input edge features from that.
      Default to 'edge_attr'.
    """
    try:
      import dgl
    except:
      raise ImportError('This class requires dgl.')
    try:
      import dgllife
    except:
      raise ImportError('This class requires dgllife.')

    if mode not in ['classification', 'regression']:
      raise ValueError("mode must be either 'classification' or 'regression'")

    super(MPNN, self).__init__()

    self.n_tasks = n_tasks
    self.mode = mode
    self.n_classes = n_classes
    self.nfeat_name = nfeat_name
    self.efeat_name = efeat_name
    if mode == 'classification':
      out_size = n_tasks * n_classes
    else:
      out_size = n_tasks

    from dgllife.model import MPNNPredictor as DGLMPNNPredictor

    self.model = DGLMPNNPredictor(
        node_in_feats=number_atom_features,
        edge_in_feats=number_bond_features,
        node_out_feats=node_out_feats,
        edge_hidden_feats=edge_hidden_feats,
        n_tasks=out_size,
        num_step_message_passing=num_step_message_passing,
        num_step_set2set=num_step_set2set,
        num_layer_set2set=num_layer_set2set
    )

  def forward(self, g):
    """Predict graph labels

    Parameters
    ----------
    g: DGLGraph
      A DGLGraph for a batch of graphs. It stores the node features in
      ``dgl_graph.ndata[self.nfeat_name]`` and edge features in
      ``dgl_graph.edata[self.efeat_name]``.

    Returns
    -------
    torch.Tensor
      The model output.

      * When self.mode = 'regression',
        its shape will be ``(dgl_graph.batch_size, self.n_tasks)``.
      * When self.mode = 'classification', the output consists of probabilities
        for classes. Its shape will be
        ``(dgl_graph.batch_size, self.n_tasks, self.n_classes)`` if self.n_tasks > 1;
        its shape will be ``(dgl_graph.batch_size, self.n_classes)`` if self.n_tasks is 1.
    torch.Tensor, optional
      This is only returned when self.mode = 'classification', the output consists of the
      logits for classes before softmax.
    """
    node_feats = g.ndata[self.nfeat_name]
    edge_feats = g.edata[self.efeat_name]
    out = self.model(g, node_feats, edge_feats)

    if self.mode == 'classification':
      if self.n_tasks == 1:
        logits = out.view(-1, self.n_classes)
        softmax_dim = 1
      else:
        logits = out.view(-1, self.n_tasks, self.n_classes)
        softmax_dim = 2
      proba = F.softmax(logits, dim=softmax_dim)
      return proba, logits
    else:
      return out