Commit 1b99686f authored by Nathan Frey's avatar Nathan Frey
Browse files

Check reloaded density estimation

parent 81ee9064
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -309,6 +309,9 @@ def test_normalizing_flow_model_reload():
  dataset = dc.data.NumpyDataset(X=target_distribution.sample(96))
  final = nfm.fit(dataset, nb_epoch=1)

  x = np.zeros(2)
  lp1 = nfm.flow.log_prob(x).numpy()

  assert nfm.flow.sample().numpy().shape == (2,)

  reloaded_model = NormalizingFlowModel(nf, model_dir=model_dir)
@@ -317,6 +320,11 @@ def test_normalizing_flow_model_reload():
  # Check that reloaded model can sample from the distribution
  assert reloaded_model.flow.sample().numpy().shape == (2,)

  lp2 = reloaded_model.flow.log_prob(x).numpy()

  # Check that density estimation is same for reloaded model
  assert np.all(lp1 == lp2)


def test_robust_multitask_regressor_reload():
  """Test that RobustMultitaskRegressor can be reloaded correctly."""