Unverified Commit 998382d3 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2293 from mufeili/master

Model Wrapper for MPNNPredictor from DGL-LifeSci 
parents 7c392f6f fcca3bde
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -35,7 +35,7 @@ def test_gat_regression():
      learning_rate=0.001)

  # overfit test
  model.fit(dataset, nb_epoch=400)
  model.fit(dataset, nb_epoch=500)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.5

+1 −1
Original line number Diff line number Diff line
@@ -35,7 +35,7 @@ def test_gcn_regression():
      learning_rate=0.003)

  # overfit test
  model.fit(dataset, nb_epoch=200)
  model.fit(dataset, nb_epoch=300)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.5

+95 −0
Original line number Diff line number Diff line
import unittest
import tempfile

import numpy as np

import deepchem as dc
from deepchem.feat import MolGraphConvFeaturizer
from deepchem.models.torch_models import MPNNModel
from deepchem.models.tests.test_graph_models import get_dataset

try:
  import dgl
  import dgllife
  import torch
  has_torch_and_dgl = True
except:
  has_torch_and_dgl = False


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_mpnn_regression():
  # load datasets
  featurizer = MolGraphConvFeaturizer(use_edges=True)
  tasks, dataset, transformers, metric = get_dataset(
      'regression', featurizer=featurizer)

  # initialize models
  n_tasks = len(tasks)
  model = MPNNModel(mode='regression', n_tasks=n_tasks, batch_size=10)

  # overfit test
  model.fit(dataset, nb_epoch=400)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.5


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_mpnn_classification():
  # load datasets
  featurizer = MolGraphConvFeaturizer(use_edges=True)
  tasks, dataset, transformers, metric = get_dataset(
      'classification', featurizer=featurizer)

  # initialize models
  n_tasks = len(tasks)
  model = MPNNModel(
      mode='classification',
      n_tasks=n_tasks,
      batch_size=10,
      learning_rate=0.001)

  # overfit test
  model.fit(dataset, nb_epoch=200)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_mpnn_reload():
  # load datasets
  featurizer = MolGraphConvFeaturizer(use_edges=True)
  tasks, dataset, transformers, metric = get_dataset(
      'classification', featurizer=featurizer)

  # initialize models
  n_tasks = len(tasks)
  model_dir = tempfile.mkdtemp()
  model = MPNNModel(
      mode='classification',
      n_tasks=n_tasks,
      model_dir=model_dir,
      batch_size=10,
      learning_rate=0.001)

  model.fit(dataset, nb_epoch=200)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85

  reloaded_model = MPNNModel(
      mode='classification',
      n_tasks=n_tasks,
      model_dir=model_dir,
      batch_size=10,
      learning_rate=0.001)
  reloaded_model.restore()

  pred_mols = ["CCCC", "CCCCCO", "CCCCC"]
  X_pred = featurizer(pred_mols)
  random_dataset = dc.data.NumpyDataset(X_pred)
  original_pred = model.predict(random_dataset)
  reload_pred = reloaded_model.predict(random_dataset)
  assert np.all(original_pred == reload_pred)
+1 −0
Original line number Diff line number Diff line
@@ -4,3 +4,4 @@ from deepchem.models.torch_models.attentivefp import AttentiveFP, AttentiveFPMod
from deepchem.models.torch_models.cgcnn import CGCNN, CGCNNModel
from deepchem.models.torch_models.gat import GAT, GATModel
from deepchem.models.torch_models.gcn import GCN, GCNModel
from deepchem.models.torch_models.mpnn import MPNN, MPNNModel
+5 −19
Original line number Diff line number Diff line
@@ -197,7 +197,7 @@ class AttentiveFPModel(TorchModel):
  >> tasks, datasets, transformers = dc.molnet.load_tox21(
  ..     reload=False, featurizer=featurizer, transformers=[])
  >> train, valid, test = datasets
  >> model = dc.models.AttentiveFPModel(mode='classification', n_tasks=len(tasks),
  >> model = AttentiveFPModel(mode='classification', n_tasks=len(tasks),
  ..                          batch_size=32, learning_rate=0.001)
  >> model.fit(train, nb_epoch=50)

@@ -224,8 +224,6 @@ class AttentiveFPModel(TorchModel):
               number_atom_features: int = 30,
               number_bond_features: int = 11,
               n_classes: int = 2,
               nfeat_name: str = 'x',
               efeat_name: str = 'edge_attr',
               self_loop: bool = True,
               **kwargs):
    """
@@ -251,17 +249,10 @@ class AttentiveFPModel(TorchModel):
    n_classes: int
      The number of classes to predict per task
      (only used when ``mode`` is 'classification'). Default to 2.
    nfeat_name: str
      For an input graph ``g``, the model assumes that it stores node features in
      ``g.ndata[nfeat_name]`` and will retrieve input node features from that.
      Default to 'x'.
    efeat_name: str
      For an input graph ``g``, the model assumes that it stores edge features in
      ``g.edata[efeat_name]`` and will retrieve input edge features from that.
      Default to 'edge_attr'.
    self_loop: bool
      Whether to add self loops for the nodes, i.e. edges from nodes to themselves.
      Default to True.
      When input graphs have isolated nodes, self loops allow preserving the original feature
      of them in message passing. Default to True.
    kwargs
      This can include any keyword argument of TorchModel.
    """
@@ -274,9 +265,7 @@ class AttentiveFPModel(TorchModel):
        mode=mode,
        number_atom_features=number_atom_features,
        number_bond_features=number_bond_features,
        n_classes=n_classes,
        nfeat_name=nfeat_name,
        efeat_name=efeat_name)
        n_classes=n_classes)
    if mode == 'regression':
      loss: Loss = L2Loss()
      output_types = ['prediction']
@@ -295,9 +284,6 @@ class AttentiveFPModel(TorchModel):
    ----------
    batch: tuple
      The tuple is ``(inputs, labels, weights)``.
    self_loop: bool
      Whether to add self loops for the nodes, i.e. edges from nodes
      to themselves. Default to False.

    Returns
    -------
Loading