Commit 3f8afbe0 authored by leswing's avatar leswing
Browse files

Run simple save/load test

parent 1af65bfc
Loading
Loading
Loading
Loading
+53 −0
Original line number Diff line number Diff line
@@ -21,5 +21,58 @@ class TestDeepchemBuild(unittest.TestCase):
    import pandas as pd
    print(pd.__version__)

  def get_dataset(self,
                  mode='classification',
                  featurizer='GraphConv',
                  num_tasks=2):
    from deepchem.molnet import load_bace_classification, load_delaney
    import numpy as np
    import deepchem as dc
    from deepchem.data import NumpyDataset
    data_points = 10
    if mode == 'classification':
      tasks, all_dataset, transformers = load_bace_classification(featurizer)
    else:
      tasks, all_dataset, transformers = load_delaney(featurizer)

    train, valid, test = all_dataset
    for i in range(1, num_tasks):
      tasks.append("random_task")
    w = np.ones(shape=(data_points, len(tasks)))

    if mode == 'classification':
      y = np.random.randint(0, 2, size=(data_points, len(tasks)))
      metric = dc.metrics.Metric(
        dc.metrics.roc_auc_score, np.mean, mode="classification")
    else:
      y = np.random.normal(size=(data_points, len(tasks)))
      metric = dc.metrics.Metric(
        dc.metrics.mean_absolute_error, mode="regression")

    ds = NumpyDataset(train.X[:data_points], y, w, train.ids[:data_points])

    return tasks, ds, transformers, metric

  def test_graph_conv_model(self):
    from deepchem.models import GraphConvModel, TensorGraph
    import numpy as np
    tasks, dataset, transformers, metric = self.get_dataset(
      'classification', 'GraphConv')

    batch_size = 50
    model = GraphConvModel(
      len(tasks), batch_size=batch_size, mode='classification')

    model.fit(dataset, nb_epoch=10)
    scores = model.evaluate(dataset, [metric], transformers)
    assert scores['mean-roc_auc_score'] >= 0.9

    model.save()
    model = TensorGraph.load_from_dir(model.model_dir)
    scores2 = model.evaluate(dataset, [metric], transformers)
    assert np.allclose(scores['mean-roc_auc_score'],
                       scores2['mean-roc_auc_score'])


if __name__ == '__main__':
  unittest.main()