Commit f9611f47 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

First changes

parent c2accdf3
Loading
Loading
Loading
Loading
+19 −1
Original line number Diff line number Diff line
@@ -1192,7 +1192,8 @@ def test_fittransform_regression_reload():
      batch_size=n_samples,
      fit_transformers=fit_transformers,
      n_evals=1,
      optimizer=Adam(learning_rate=0.003, beta1=0.9, beta2=0.999),
      optimizer=dc.models.optimizers.Adam(
          learning_rate=0.003, beta1=0.9, beta2=0.999),
      model_dir=model_dir)

  # Fit trained model
@@ -1201,3 +1202,20 @@ def test_fittransform_regression_reload():
  # Eval model on train
  scores = model.evaluate(dataset, [regression_metric])
  assert scores[regression_metric.name] < .1

  reloaded_model = dc.models.MultitaskFitTransformRegressor(
      n_tasks, [n_features, n_features],
      dropouts=[0.01],
      weight_init_stddevs=[np.sqrt(6) / np.sqrt(1000)],
      batch_size=n_samples,
      fit_transformers=fit_transformers,
      n_evals=1,
      optimizer=dc.models.optimizers.Adam(
          learning_rate=0.003, beta1=0.9, beta2=0.999),
      model_dir=model_dir)
  reloaded_model.restore()

  # Check predictions match on random sample
  origpred = model.predict(dataset)
  reloadpred = reloaded_model.predict(dataset)
  assert np.all(origpred == reloadpred)