Commit 83b9ba1e authored by Nathan Frey's avatar Nathan Frey
Browse files

Add tfp dependency

parent 40f5ebee
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -6,6 +6,8 @@ import numpy as np
import logging
from typing import List

import tensorflow_probability as tfp

from deepchem.models.models import Model

logger = logging.getLogger(__name__)
@@ -50,7 +52,7 @@ class NormalizingFlowLayer(object):
    Parameters
    ----------
    model : object
      Model object from TensorFlowProbability, Pytorch, etc. The model
      Model object from `tensorflow_probability.bijectors`. The model
      should be a bijective transformation with forward, inverse, and 
      LDJ methods.
    kwargs : dict
@@ -245,7 +247,7 @@ class NormalizingFlowModel(Model):
  """

  def __init__(self,
               base_distribution,
               base_distribution: tfp.distributions.Distribution,
               normalizing_flow: NormalizingFlow,
               event_shape=None):
    """Creates a new NormalizingFlowModel.
+23 −8
Original line number Diff line number Diff line
@@ -21,23 +21,20 @@ class TestNormalizingFlow(unittest.TestCase):
  def setUp(self):

    self.ef = ExpFlow()
    self.nfm = TransformedNormal()

  def test_simple_flow(self):
    """Tests a simple flow of Exp layers."""

    dist = tfp.distributions.Normal(0, 1)  # univariate Gaussian
    X = dist.sample([10])
    g = self.ef
    flows = [g, g]
    nf = NormalizingFlow(flows)
    nfm = NormalizingFlowModel(dist, nf)
    X = self.nfm.sample([10])

    ys, ldjs = nfm(X)
    xs, ildjs = nf._inverse(ys[-1])
    ys, ldjs = self.nfm(X)
    xs, ildjs = self.nfm.normalizing_flow._inverse(ys[-1])

    assert len(xs) == 3
    assert len(ys) == 3
    assert xs[0].shape == 10
    assert np.isclose(self.nfm.log_prob(1), -1.4, atol=0.5)


class ExpFlow(NormalizingFlowLayer):
@@ -55,3 +52,21 @@ class ExpFlow(NormalizingFlowLayer):

  def _forward_log_det_jacobian(self, x):
    return self.model.forward_log_det_jacobian(x, 1)


class TransformedNormal(NormalizingFlowModel):
  """Univariate Gaussian base distribution."""

  def __init__(self, 
    base_distribution=tfp.distributions.Normal(0, 1),
    normalizing_flow=NormalizingFlow([ExpFlow(), ExpFlow()])
    ):

    super(TransformedNormal, self).__init__(base_distribution, normalizing_flow)

  def sample(self, shape, seed=None):
    return self.base_distribution.sample(sample_shape=shape, seed=seed)

  def log_prob(self, value):
    return self.base_distribution.log_prob(value=value)