Commit d90214a4 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changes

parent d3bf2ff4
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -788,7 +788,7 @@ class KerasModel(Model):
    if it produces multiple outputs
    """
    generator = self.default_generator(
        dataset, mode='predict', pad_batches=False)
        dataset, mode='predict', deterministic=True, pad_batches=False)
    return self.predict_on_generator(
        generator,
        transformers=transformers,
+67 −15
Original line number Diff line number Diff line
@@ -550,10 +550,9 @@ def test_weave_classification_reload():

  classification_metric = dc.metrics.Metric(dc.metrics.roc_auc_score)

  batch_size = 10
  batch_size = 3

  #model_dir = tempfile.mkdtemp()
  model_dir = "/tmp/foobarbaz7"
  model_dir = tempfile.mkdtemp()
  model = dc.models.WeaveModel(
      n_tasks,
      batch_size=batch_size,
@@ -563,7 +562,7 @@ def test_weave_classification_reload():
      model_dir=model_dir)

  # Fit trained model
  model.fit(dataset, nb_epoch=30)
  model.fit(dataset, nb_epoch=3)

  # Eval model on train
  scores = model.evaluate(dataset, [classification_metric])
@@ -572,13 +571,12 @@ def test_weave_classification_reload():
  # Check predictions match on random sample
  predmols = ["CCCC", "CCCCCO", "CCCCC"]
  Xpred = featurizer(predmols)

  predset = dc.data.NumpyDataset(Xpred)
  origpred = model.predict(predset)
  print("origpred")
  print(origpred)
  origpred2 = model.predict(predset)
  assert np.all(origpred == origpred2)

  del model.model
  del model
  reloaded_model = dc.models.WeaveModel(
      n_tasks,
      batch_size=batch_size,
@@ -587,13 +585,37 @@ def test_weave_classification_reload():
      dropouts=0.0,
      model_dir=model_dir)
  reloaded_model.restore()

  Xproc = reloaded_model.compute_features_on_batch(Xpred)
  reloadout = reloaded_model.model(Xproc)
  print("reloadout")
  print(reloadout)

  reloadpred = reloaded_model.predict(predset)
  assert np.all(origpred == reloadpred)
  print("reloadpred")
  print(reloadpred)

  print("origpred")
  print(origpred)

  ## Try re-restore
  #reloaded_model.restore()
  #reloadpred = reloaded_model.predict(predset)

  #assert np.all(origpred == reloadpred)
  print("np.amax(origpred - reloadpred)")
  print(np.amax(origpred - reloadpred))
  print("np.allclose(origpred, reloadpred)")
  print(np.allclose(origpred, reloadpred))

  # Eval model on train
  scores = reloaded_model.evaluate(dataset, [classification_metric])
  print("scores")
  print(scores)
  assert scores[classification_metric.name] > .9

  assert np.all(origpred == reloadpred)


# TODO: THIS IS FAILING!
def test_MPNN_regression_reload():
@@ -637,6 +659,13 @@ def test_MPNN_regression_reload():
  scores = model.evaluate(dataset, [regression_metric])
  assert scores[regression_metric.name] > .8

  # Custom save
  save_dir = tempfile.mkdtemp()
  model.model.save(save_dir)

  from tensorflow import keras
  reloaded = keras.models.load_model(save_dir)

  # Reload trained model
  reloaded_model = dc.models.MPNNModel(
      n_tasks,
@@ -649,7 +678,12 @@ def test_MPNN_regression_reload():
      use_queue=False,
      mode="regression",
      model_dir=model_dir)
  reloaded_model.restore()
  #reloaded_model.restore()
  reloaded_model.model = reloaded

  # Eval model on train
  scores = reloaded_model.evaluate(dataset, [regression_metric])
  assert scores[regression_metric.name] > .8

  # Check predictions match on random sample
  predmols = ["CCCC", "CCCCCO", "CCCCC"]
@@ -657,12 +691,10 @@ def test_MPNN_regression_reload():
  predset = dc.data.NumpyDataset(Xpred)
  origpred = model.predict(predset)
  reloadpred = reloaded_model.predict(predset)
  print("np.amax(origpred - reloadpred)")
  print(np.amax(origpred - reloadpred))
  assert np.all(origpred == reloadpred)

  # Eval model on train
  scores = reloaded_model.evaluate(dataset, [regression_metric])
  assert scores[regression_metric.name] > .8


# TODO: THIS IS FAILING!
def test_textCNN_classification_reload():
@@ -682,7 +714,7 @@ def test_textCNN_classification_reload():
  classification_metric = dc.metrics.Metric(dc.metrics.roc_auc_score)

  char_dict, length = dc.models.TextCNNModel.build_char_dict(dataset)
  batch_size = 10
  batch_size = 3

  model_dir = tempfile.mkdtemp()
  model = dc.models.TextCNNModel(
@@ -714,12 +746,25 @@ def test_textCNN_classification_reload():
      model_dir=model_dir)
  reloaded_model.restore()

  assert len(reloaded_model.model.get_weights()) == len(
      model.model.get_weights())
  for (reloaded, orig) in zip(reloaded_model.model.get_weights(),
                              model.model.get_weights()):
    assert np.all(reloaded == orig)

  # Check predictions match on random sample
  predmols = ["CCCC", "CCCCCO", "CCCCC"]
  Xpred = featurizer(predmols)
  predset = dc.data.NumpyDataset(Xpred, ids=predmols)
  origpred = model.predict(predset)
  reloadpred = reloaded_model.predict(predset)

  Xproc = reloaded_model.smiles_to_seq_batch(np.array(predmols))
  reloadout = reloaded_model.model(Xproc)
  origout = model.model(Xproc)

  assert len(model.model.layers) == len(reloaded_model.model.layers)

  assert np.all(origpred == reloadpred)

  # Eval model on train
@@ -809,6 +854,13 @@ def test_graphconvmodel_reload():
  scores = model.evaluate(dataset, [classification_metric])
  assert scores[classification_metric.name] >= 0.9

  # Custom save
  save_dir = tempfile.mkdtemp()
  model.model.save(save_dir)

  from tensorflow import keras
  reloaded = keras.models.load_model(save_dir)

  # Reload trained Model
  reloaded_model = dc.models.GraphConvModel(
      len(tasks),