Commit c2accdf3 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Fittransform first cut

parent 4d48b577
Loading
Loading
Loading
Loading
+36 −0
Original line number Diff line number Diff line
@@ -1165,3 +1165,39 @@ def test_seq2seq_reload():
  assert len(pred4e) == len(reloaded_pred4e)
  for (p4e, r4e) in zip(pred4e, reloaded_pred4e):
    assert p4e == r4e


def test_fittransform_regression_reload():
  """Test that MultitaskFitTransformRegressor can reload simple regression datasets."""
  n_samples = 10
  n_features = 3
  n_tasks = 1

  # Generate dummy dataset
  np.random.seed(123)
  tf.random.set_seed(123)
  ids = np.arange(n_samples)
  X = np.random.rand(n_samples, n_features, n_features)
  y = np.zeros((n_samples, n_tasks))
  w = np.ones((n_samples, n_tasks))
  dataset = dc.data.NumpyDataset(X, y, w, ids)

  fit_transformers = [dc.trans.CoulombFitTransformer(dataset)]
  regression_metric = dc.metrics.Metric(dc.metrics.mean_squared_error)
  model_dir = tempfile.mkdtemp()
  model = dc.models.MultitaskFitTransformRegressor(
      n_tasks, [n_features, n_features],
      dropouts=[0.01],
      weight_init_stddevs=[np.sqrt(6) / np.sqrt(1000)],
      batch_size=n_samples,
      fit_transformers=fit_transformers,
      n_evals=1,
      optimizer=Adam(learning_rate=0.003, beta1=0.9, beta2=0.999),
      model_dir=model_dir)

  # Fit trained model
  model.fit(dataset, nb_epoch=100)

  # Eval model on train
  scores = model.evaluate(dataset, [regression_metric])
  assert scores[regression_metric.name] < .1