Unverified Commit 968e521d authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2280 from mufeili/master

Model Wrapper for GATPredictor and AttentiveFPPredictor from DGL-LifeSci
parents ab07c0b6 9822f076
Loading
Loading
Loading
Loading
+5 −3
Original line number Diff line number Diff line
@@ -146,9 +146,6 @@ class GraphData:

    src = self.edge_index[0]
    dst = self.edge_index[1]
    if self_loop:
      src = np.concatenate([src, np.arange(self.num_nodes)])
      dst = np.concatenate([dst, np.arange(self.num_nodes)])

    g = dgl.graph(
        (torch.from_numpy(src).long(), torch.from_numpy(dst).long()),
@@ -161,6 +158,11 @@ class GraphData:
    if self.edge_features is not None:
      g.edata['edge_attr'] = torch.from_numpy(self.edge_features).float()

    if self_loop:
      # This assumes that the edge features for self loops are full-zero tensors
      # In the future we may want to support featurization for self loops
      g.add_edges(np.arange(self.num_nodes), np.arange(self.num_nodes))

    return g


+1 −0
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ from deepchem.models.gbdt_models import GBDTModel
# PyTorch models
try:
  from deepchem.models.torch_models import TorchModel
  from deepchem.models.torch_models import AttentiveFP, AttentiveFPModel
  from deepchem.models.torch_models import CGCNN, CGCNNModel
  from deepchem.models.torch_models import GAT, GATModel
  from deepchem.models.torch_models import GCN, GCNModel
+95 −0
Original line number Diff line number Diff line
import unittest
import tempfile

import numpy as np

import deepchem as dc
from deepchem.feat import MolGraphConvFeaturizer
from deepchem.models import AttentiveFPModel
from deepchem.models.tests.test_graph_models import get_dataset

try:
  import dgl
  import dgllife
  import torch
  has_torch_and_dgl = True
except:
  has_torch_and_dgl = False


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_attentivefp_regression():
  # load datasets
  featurizer = MolGraphConvFeaturizer(use_edges=True)
  tasks, dataset, transformers, metric = get_dataset(
      'regression', featurizer=featurizer)

  # initialize models
  n_tasks = len(tasks)
  model = AttentiveFPModel(mode='regression', n_tasks=n_tasks, batch_size=10)

  # overfit test
  model.fit(dataset, nb_epoch=100)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.5


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_attentivefp_classification():
  # load datasets
  featurizer = MolGraphConvFeaturizer(use_edges=True)
  tasks, dataset, transformers, metric = get_dataset(
      'classification', featurizer=featurizer)

  # initialize models
  n_tasks = len(tasks)
  model = AttentiveFPModel(
      mode='classification',
      n_tasks=n_tasks,
      batch_size=10,
      learning_rate=0.001)

  # overfit test
  model.fit(dataset, nb_epoch=100)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85


@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_attentivefp_reload():
  # load datasets
  featurizer = MolGraphConvFeaturizer(use_edges=True)
  tasks, dataset, transformers, metric = get_dataset(
      'classification', featurizer=featurizer)

  # initialize models
  n_tasks = len(tasks)
  model_dir = tempfile.mkdtemp()
  model = AttentiveFPModel(
      mode='classification',
      n_tasks=n_tasks,
      model_dir=model_dir,
      batch_size=10,
      learning_rate=0.001)

  model.fit(dataset, nb_epoch=100)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85

  reloaded_model = AttentiveFPModel(
      mode='classification',
      n_tasks=n_tasks,
      model_dir=model_dir,
      batch_size=10,
      learning_rate=0.001)
  reloaded_model.restore()

  pred_mols = ["CCCC", "CCCCCO", "CCCCC"]
  X_pred = featurizer(pred_mols)
  random_dataset = dc.data.NumpyDataset(X_pred)
  original_pred = model.predict(random_dataset)
  reload_pred = reloaded_model.predict(random_dataset)
  assert np.all(original_pred == reload_pred)
+26 −19
Original line number Diff line number Diff line
@@ -9,15 +9,16 @@ from deepchem.models import GATModel
from deepchem.models.tests.test_graph_models import get_dataset

try:
  import torch  # noqa
  import torch_geometric  # noqa
  has_pytorch_and_pyg = True
  import dgl
  import dgllife
  import torch
  has_torch_and_dgl = True
except:
  has_pytorch_and_pyg = False
  has_torch_and_dgl = False


@unittest.skipIf(not has_pytorch_and_pyg,
                 'PyTorch and PyTorch Geometric are not installed')
@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_gat_regression():
  # load datasets
  featurizer = MolGraphConvFeaturizer()
@@ -26,17 +27,21 @@ def test_gat_regression():

  # initialize models
  n_tasks = len(tasks)
  model = GATModel(mode='regression', n_tasks=n_tasks, batch_size=10)
  model = GATModel(
      mode='regression',
      n_tasks=n_tasks,
      number_atom_features=30,
      batch_size=10,
      learning_rate=0.001)

  # overfit test
  # GAT's convergence is a little slow
  model.fit(dataset, nb_epoch=300)
  model.fit(dataset, nb_epoch=400)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.75
  assert scores['mean_absolute_error'] < 0.5


@unittest.skipIf(not has_pytorch_and_pyg,
                 'PyTorch and PyTorch Geometric are not installed')
@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_gat_classification():
  # load datasets
  featurizer = MolGraphConvFeaturizer()
@@ -48,18 +53,18 @@ def test_gat_classification():
  model = GATModel(
      mode='classification',
      n_tasks=n_tasks,
      number_atom_features=30,
      batch_size=10,
      learning_rate=0.001)

  # overfit test
  # GAT's convergence is a little slow
  model.fit(dataset, nb_epoch=150)
  model.fit(dataset, nb_epoch=100)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.70
  assert scores['mean-roc_auc_score'] >= 0.85


@unittest.skipIf(not has_pytorch_and_pyg,
                 'PyTorch and PyTorch Geometric are not installed')
@unittest.skipIf(not has_torch_and_dgl,
                 'PyTorch, DGL, or DGL-LifeSci are not installed')
def test_gat_reload():
  # load datasets
  featurizer = MolGraphConvFeaturizer()
@@ -72,17 +77,19 @@ def test_gat_reload():
  model = GATModel(
      mode='classification',
      n_tasks=n_tasks,
      number_atom_features=30,
      model_dir=model_dir,
      batch_size=10,
      learning_rate=0.001)

  model.fit(dataset, nb_epoch=150)
  model.fit(dataset, nb_epoch=100)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.70
  assert scores['mean-roc_auc_score'] >= 0.85

  reloaded_model = GATModel(
      mode='classification',
      n_tasks=n_tasks,
      number_atom_features=30,
      model_dir=model_dir,
      batch_size=10,
      learning_rate=0.001)
+8 −7
Original line number Diff line number Diff line
@@ -31,10 +31,11 @@ def test_gcn_regression():
      mode='regression',
      n_tasks=n_tasks,
      number_atom_features=30,
      batch_size=10)
      batch_size=10,
      learning_rate=0.003)

  # overfit test
  model.fit(dataset, nb_epoch=100)
  model.fit(dataset, nb_epoch=200)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean_absolute_error'] < 0.5

@@ -54,10 +55,10 @@ def test_gcn_classification():
      n_tasks=n_tasks,
      number_atom_features=30,
      batch_size=10,
      learning_rate=0.001)
      learning_rate=0.0003)

  # overfit test
  model.fit(dataset, nb_epoch=50)
  model.fit(dataset, nb_epoch=70)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85

@@ -79,9 +80,9 @@ def test_gcn_reload():
      number_atom_features=30,
      model_dir=model_dir,
      batch_size=10,
      learning_rate=0.001)
      learning_rate=0.0003)

  model.fit(dataset, nb_epoch=50)
  model.fit(dataset, nb_epoch=70)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85

@@ -91,7 +92,7 @@ def test_gcn_reload():
      number_atom_features=30,
      model_dir=model_dir,
      batch_size=10,
      learning_rate=0.001)
      learning_rate=0.0003)
  reloaded_model.restore()

  pred_mols = ["CCCC", "CCCCCO", "CCCCC"]
Loading