Commit 2bb35800 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by Bharath Ramsundar
Browse files

First cut of chemnet reload tests

parent 6ec9ebeb
Loading
Loading
Loading
Loading
+5 −3
Original line number Diff line number Diff line
@@ -561,7 +561,9 @@ class KerasModel(Model):
      returns the values of the uncertainty outputs.
    other_output_types: list, optional
      Provides a list of other output_types (strings) to predict from model.
    Returns:

    Returns
    -------
    a NumPy array of the model produces a single output, or a list of arrays
    if it produces multiple outputs
    """
+55 −0
Original line number Diff line number Diff line
@@ -525,6 +525,10 @@ def test_weave_classification_reload():
  predset = dc.data.NumpyDataset(Xpred)
  origpred = model.predict(predset)
  reloadpred = reloaded_model.predict(predset)

  # Try re-restore
  reloaded_model.restore()
  reloadpred = reloaded_model.predict(predset)
  assert np.all(origpred == reloadpred)

  # Eval model on train
@@ -755,6 +759,57 @@ def test_graphconvmodel_reload():
      model_dir=model_dir)
  reloaded_model.restore()

  # Check predictions match on random sample
  predmols = ["CCCC", "CCCCCO", "CCCCC"]
  Xpred = featurizer(predmols)
  predset = dc.data.NumpyDataset(Xpred)
  origpred = model.predict(predset)
  reloadpred = reloaded_model.predict(predset)
  #assert np.all(origpred == reloadpred)

  # Try re-restore
  reloaded_model.restore()
  reloadpred = reloaded_model.predict(predset)
  assert np.all(origpred == reloadpred)

  # Eval model on train
  scores = reloaded_model.evaluate(dataset, [classification_metric])
  assert scores[classification_metric.name] > .9


def test_chemception_reload():
  """Test that chemception models can be saved and reloaded."""
  img_size = 80
  img_spec = "engd"
  res = 0.5
  n_tasks = 1
  featurizer = dc.feat.SmilesToImage(
      img_size=img_size, img_spec=img_spec, res=res)
  mols = ["C", "CC", "CCC"]
  X = featurizer(mols)
  y = np.array([0, 1, 0])
  dataset = dc.data.NumpyDataset(X, y, ids=mols)
  classsification_metric = dc.metrics.Metric(
      dc.metrics.roc_auc_score, np.mean, mode="classification")

  model_dir = tempfile.mkdtemp()
  model = dc.models.ChemCeption(
      n_tasks=n_tasks,
      img_spec="engd",
      model_dir=model_dir,
      mode="classification")
  model.fit(dataset, nb_epoch=300)
  scores = model.evaluate(dataset, [metric], [])
  assert scores[classification_metric.name] >= 0.9

  # Reload Trained Model
  reloaded_model = dc.models.ChemCeption(
      n_tasks=n_tasks,
      img_spec="engd",
      model_dir=model_dir,
      mode="classification")
  reloaded_model.restore()

  # Check predictions match on random sample
  predmols = ["CCCC", "CCCCCO", "CCCCC"]
  Xpred = featurizer(predmols)