Unverified Commit 63af0c9b authored by Nathan Frey's avatar Nathan Frey Committed by GitHub
Browse files

Merge pull request #2258 from ncfrey/nf-saving-reloading

Reload tests for normalizing flow models
parents 12b35a1d 1b99686f
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@ from deepchem.models.models import Model
from deepchem.models.keras_model import KerasModel
from deepchem.models.optimizers import Optimizer, Adam
from deepchem.utils.typing import OneOrMany
from deepchem.utils.data_utils import load_from_disk, save_to_disk

logger = logging.getLogger(__name__)

@@ -183,6 +184,14 @@ class NormalizingFlowModel(KerasModel):

    return -tf.reduce_mean(self.flow.log_prob(input, training=True))

  def save(self):
    """Saves model to disk using joblib."""
    save_to_disk(self.model, self.get_model_filename(self.model_dir))

  def reload(self):
    """Loads model from joblib file on disk."""
    self.model = load_from_disk(self.get_model_filename(self.model_dir))

  def _create_gradient_fn(self,
                          variables: Optional[List[tf.Variable]]) -> Callable:
    """Create a function that computes gradients and applies them to the model.
+44 −0
Original line number Diff line number Diff line
@@ -282,6 +282,50 @@ def test_robust_multitask_classification_reload():
  assert scores[classification_metric.name] > .9


def test_normalizing_flow_model_reload():
  """Test that RobustMultitaskRegressor can be reloaded correctly."""
  from deepchem.models.normalizing_flows import NormalizingFlow, NormalizingFlowModel
  import tensorflow_probability as tfp
  tfd = tfp.distributions
  tfb = tfp.bijectors
  tfk = tf.keras
  tfk.backend.set_floatx('float64')

  model_dir = tempfile.mkdtemp()

  Made = tfb.AutoregressiveNetwork(
      params=2, hidden_units=[512, 512], activation='relu')

  flow_layers = [tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=Made)]
  # 3D Multivariate Gaussian base distribution
  nf = NormalizingFlow(
      base_distribution=tfd.MultivariateNormalDiag(
          loc=np.zeros(2), scale_diag=np.ones(2)),
      flow_layers=flow_layers)

  nfm = NormalizingFlowModel(nf, model_dir=model_dir)

  target_distribution = tfd.MultivariateNormalDiag(loc=np.array([1., 0.]))
  dataset = dc.data.NumpyDataset(X=target_distribution.sample(96))
  final = nfm.fit(dataset, nb_epoch=1)

  x = np.zeros(2)
  lp1 = nfm.flow.log_prob(x).numpy()

  assert nfm.flow.sample().numpy().shape == (2,)

  reloaded_model = NormalizingFlowModel(nf, model_dir=model_dir)
  reloaded_model.restore()

  # Check that reloaded model can sample from the distribution
  assert reloaded_model.flow.sample().numpy().shape == (2,)

  lp2 = reloaded_model.flow.log_prob(x).numpy()

  # Check that density estimation is same for reloaded model
  assert np.all(lp1 == lp2)


def test_robust_multitask_regressor_reload():
  """Test that RobustMultitaskRegressor can be reloaded correctly."""
  n_tasks = 10