Commit 55e3df94 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

First steps to reload test

parent 31c9b6bf
Loading
Loading
Loading
Loading
+48 −0
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ import tensorflow as tf
from flaky import flaky
from sklearn.ensemble import RandomForestClassifier
from deepchem.molnet.load_function.chembl25_datasets import chembl25_tasks
from deepchem.feat import create_char_to_idx


def test_sklearn_classifier_reload():
@@ -967,3 +968,50 @@ def test_chemception_reload():
  origpred = model.predict(predset)
  reloadpred = reloaded_model.predict(predset)
  assert np.all(origpred == reloadpred)


def test_smiles2vec_reload():
  """Test that smiles2vec models can be saved and reloaded."""
  max_len = 250
  pad_len = 10
  char_to_idx = create_char_to_idx(
      dataset_file, max_len=max_len, smiles_field="smiles")
  feat = dc.feat.SmilesToSeq(
      char_to_idx=char_to_idx, max_len=max_len, pad_len=pad_len)

  n_tasks = 1
  data_points = 10
  mols = ["CCCCCCCC"] * data_points
  X = featurizer(mols)

  y = np.random.randint(0, 2, size=(data_points, n_tasks))
  w = np.ones(shape=(data_points, n_tasks))
  dataset = dc.data.NumpyDataset(X, y, w, mols)
  classsification_metric = dc.metrics.Metric(
      dc.metrics.roc_auc_score, np.mean, mode="classification")

  model_dir = tempfile.mkdtemp()
  model = Smiles2Vec(
      char_to_idx=char_to_idx,
      max_seq_len=max_seq_len,
      use_conv=True,
      n_tasks=n_tasks,
      model_dir=model_dir,
      mode="classification")
  model.fit(dataset, nb_epoch=3)

  ## Reload Trained Model
  #reloaded_model = dc.models.ChemCeption(
  #    n_tasks=n_tasks,
  #    img_spec="engd",
  #    model_dir=model_dir,
  #    mode="classification")
  #reloaded_model.restore()

  ## Check predictions match on random sample
  #predmols = ["CCCC", "CCCCCO", "CCCCC"]
  #Xpred = featurizer(predmols)
  #predset = dc.data.NumpyDataset(Xpred)
  #origpred = model.predict(predset)
  #reloadpred = reloaded_model.predict(predset)
  #assert np.all(origpred == reloadpred)