Commit ecbdb24b authored by nd-02110114's avatar nd-02110114
Browse files

add reload tests for cgcnn and gat

parent f272f7cb
Loading
Loading
Loading
Loading
+75 −6
Original line number Diff line number Diff line
import unittest
import tempfile
from os import path, remove

import numpy as np

from deepchem.feat import CGCNNFeaturizer
from deepchem.molnet import load_perovskite, load_mp_metallicity
from deepchem.metrics import Metric, mae_score, roc_auc_score
@@ -15,8 +18,7 @@ except:


@unittest.skipIf(not has_pytorch_and_dgl, 'PyTorch and DGL are not installed')
def test_cgcnn():
  # regression test
def test_cgcnn_regression():
  # load datasets
  current_dir = path.dirname(path.abspath(__file__))
  config = {
@@ -47,17 +49,74 @@ def test_cgcnn():
  scores = model.evaluate(train, [regression_metric], transformers)
  assert scores[regression_metric.name] < 0.6

  # classification test
  if path.exists(path.join(current_dir, 'perovskite.json')):
    remove(path.join(current_dir, 'perovskite.json'))


@unittest.skipIf(not has_pytorch_and_dgl, 'PyTorch and DGL are not installed')
def test_cgcnn_classification():
  # load datasets
  current_dir = path.dirname(path.abspath(__file__))
  config = {
      "reload": False,
      "featurizer": CGCNNFeaturizer,
      # disable transformer
      "transformers": [],
      "data_dir": current_dir
  }
  tasks, datasets, transformers = load_mp_metallicity(**config)
  train, valid, test = datasets

  n_tasks = len(tasks)
  n_classes = 2
  model = CGCNNModel(
      n_tasks=n_tasks,
      n_classes=n_classes,
      mode='classification',
      batch_size=4,
      learning_rate=0.001)

  # check train
  model.fit(train, nb_epoch=20)

  # check predict shape
  valid_preds = model.predict_on_batch(valid.X)
  assert valid_preds.shape == (2, n_classes)
  test_preds = model.predict(test)
  assert test_preds.shape == (3, n_classes)

  # check overfit
  classification_metric = Metric(roc_auc_score, n_tasks=n_tasks)
  scores = model.evaluate(
      train, [classification_metric], transformers, n_classes=n_classes)
  assert scores[classification_metric.name] > 0.8

  if path.exists(path.join(current_dir, 'mp_is_metal.json')):
    remove(path.join(current_dir, 'mp_is_metal.json'))


@unittest.skipIf(not has_pytorch_and_dgl, 'PyTorch and DGL are not installed')
def test_cgcnn_reload():
  # load datasets
  current_dir = path.dirname(path.abspath(__file__))
  config = {
      "reload": False,
      "featurizer": CGCNNFeaturizer,
      # disable transformer
      "transformers": [],
      "data_dir": current_dir
  }
  tasks, datasets, transformers = load_mp_metallicity(**config)
  train, valid, test = datasets

  n_tasks = len(tasks)
  n_classes = 2
  model_dir = tempfile.mkdtemp()
  model = CGCNNModel(
      n_tasks=n_tasks,
      n_classes=n_classes,
      mode='classification',
      model_dir=model_dir,
      batch_size=4,
      learning_rate=0.001)

@@ -76,9 +135,19 @@ def test_cgcnn():
      train, [classification_metric], transformers, n_classes=n_classes)
  assert scores[classification_metric.name] > 0.8

  # TODO: Multi task classification test
  # reload
  reloaded_model = CGCNNModel(
      n_tasks=n_tasks,
      n_classes=n_classes,
      mode='classification',
      model_dir=model_dir,
      batch_size=4,
      learning_rate=0.001)
  reloaded_model.restore()

  original_pred = model.predict(test)
  reload_pred = reloaded_model.predict(test)
  assert np.all(original_pred == reload_pred)

  if path.exists(path.join(current_dir, 'perovskite.json')):
    remove(path.join(current_dir, 'perovskite.json'))
  if path.exists(path.join(current_dir, 'mp_is_metal.json')):
    remove(path.join(current_dir, 'mp_is_metal.json'))
+43 −1
Original line number Diff line number Diff line
import unittest
import tempfile

import numpy as np

import deepchem as dc
from deepchem.feat import MolGraphConvFeaturizer
from deepchem.models import GATModel
from deepchem.models.tests.test_graph_models import get_dataset
@@ -51,4 +55,42 @@ def test_gat_classification():
  # GAT's convergence is a little slow
  model.fit(dataset, nb_epoch=150)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.9
  assert scores['mean-roc_auc_score'] >= 0.85


@unittest.skipIf(not has_pytorch_and_pyg,
                 'PyTorch and PyTorch Geometric are not installed')
def test_gat_reload():
  # load datasets
  featurizer = MolGraphConvFeaturizer()
  tasks, dataset, transformers, metric = get_dataset(
      'classification', featurizer=featurizer)

  # initialize models
  n_tasks = len(tasks)
  model_dir = tempfile.mkdtemp()
  model = GATModel(
      mode='classification',
      n_tasks=n_tasks,
      model_dir=model_dir,
      batch_size=10,
      learning_rate=0.001)

  model.fit(dataset, nb_epoch=150)
  scores = model.evaluate(dataset, [metric], transformers)
  assert scores['mean-roc_auc_score'] >= 0.85

  reloaded_model = GATModel(
      mode='classification',
      n_tasks=n_tasks,
      model_dir=model_dir,
      batch_size=10,
      learning_rate=0.001)
  reloaded_model.restore()

  pred_mols = ["CCCC", "CCCCCO", "CCCCC"]
  X_pred = featurizer(pred_mols)
  random_dataset = dc.data.NumpyDataset(X_pred)
  original_pred = model.predict(random_dataset)
  reload_pred = reloaded_model.predict(random_dataset)
  assert np.all(original_pred == reload_pred)
+5 −5
Original line number Diff line number Diff line
@@ -42,10 +42,6 @@ DeepChem has a number of "soft" requirements.
|                                |               |                                                   |
|                                |               |                                                   |
+--------------------------------+---------------+---------------------------------------------------+
| `OpenAI Gym`_                  | Not Testing   | :code:`dc.rl`                                     |
|                                |               |                                                   |
|                                |               |                                                   |
+--------------------------------+---------------+---------------------------------------------------+
| `matminer`_                    | latest        | :code:`dc.feat.materials_featurizers`             |
|                                |               |                                                   |
|                                |               |                                                   |
@@ -66,6 +62,10 @@ DeepChem has a number of "soft" requirements.
|                                |               |                                                   |
|                                |               |                                                   |
+--------------------------------+---------------+---------------------------------------------------+
| `OpenAI Gym`_                  | Not Testing   | :code:`dc.rl`                                     |
|                                |               |                                                   |
|                                |               |                                                   |
+--------------------------------+---------------+---------------------------------------------------+
| `OpenMM`_                      | latest        | :code:`dc.utils.rdkit_utils`                      |
|                                |               |                                                   |
|                                |               |                                                   |
@@ -125,12 +125,12 @@ DeepChem has a number of "soft" requirements.
.. _`Deep Graph Library`: https://www.dgl.ai/
.. _`HuggingFace Transformers`: https://huggingface.co/transformers/
.. _`LightGBM`: https://lightgbm.readthedocs.io/en/latest/index.html
.. _`OpenAI Gym`: https://gym.openai.com/
.. _`matminer`: https://hackingmaterials.lbl.gov/matminer/
.. _`MDTraj`: http://mdtraj.org/
.. _`Mol2vec`: https://github.com/samoturk/mol2vec
.. _`Mordred`: http://mordred-descriptor.github.io/documentation/master/
.. _`NetworkX`: https://networkx.github.io/documentation/stable/index.html
.. _`OpenAI Gym`: https://gym.openai.com/
.. _`OpenMM`: http://openmm.org/
.. _`PDBFixer`: https://github.com/pandegroup/pdbfixer
.. _`Pillow`: https://pypi.org/project/Pillow/