Commit b92b368c authored by Nathan Frey's avatar Nathan Frey
Browse files

Formatting

parent f18739ee
Loading
Loading
Loading
Loading
+10 −9
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ Normalizing flows for transforming probability distributions.

import numpy as np
import logging
from typing import List, Iterable, Optional, Tuple, Sequence
from typing import List, Iterable, Optional, Tuple, Sequence, Any

import tensorflow as tf

@@ -45,7 +45,7 @@ class NormalizingFlow(tf.keras.models.Model):

  @tf.function
  def fit_on_batch(self, x: np.ndarray,
                   optimizer: dc.models.optimizers.Optimizer,
                   optimizer: tf.keras.optimizers.Optimizer,
                   loss: dc.models.losses.Loss) -> float:
    """Fit on batch of samples.
    
@@ -66,7 +66,8 @@ class NormalizingFlow(tf.keras.models.Model):
    """

    with tf.GradientTape() as tape:
      batch_loss = loss(x)
      dummy_labels = np.ones(len(x))
      batch_loss = loss(x, dummy_labels)
      grads = tape.gradient(batch_loss, self.trainable_variables)
      optimizer.apply_gradients(zip(grads, self.trainable_variables))
    return batch_loss
@@ -104,8 +105,8 @@ class NormalizingFlowModel(NormalizingFlow):
  def __init__(self,
               base_distribution,
               flow_layers: Sequence,
               optimizer: Optional[dc.models.optimizers.Optimizer] = None,
               loss: Optional[dc.models.losses.Loss] = None,
               optimizer: Optional[tf.keras.optimizers.Optimizer] = None,
               loss: Optional[Any] = None,
               **kwargs):
    """Creates a new NormalizingFlowModel.

@@ -116,10 +117,10 @@ class NormalizingFlowModel(NormalizingFlow):
      Typically an N dimensional multivariate Gaussian.
    flow_layers : Sequence[tfb.Bijector]
      An iterable of bijectors that comprise the flow.
    optimizer: dc.models.optimizers.Optimizer
    optimizer: Optional[tf.keras.optimizers.Optimizer]
      An instance of Optimizer.
    loss: dc.models.losses.Loss
      An instance of Loss.
    loss: Optional[Any]
      Loss function, e.g. an instance of dc.models.losses.Loss.

    Examples
    --------
@@ -222,7 +223,7 @@ class NormalizingFlowModel(NormalizingFlow):
    final_loss = batch_loss
    return (final_loss, avg_loss)

  def nll(self, X):
  def nll(self, X, labels):
    """Negative log loss."""

    return -tf.reduce_mean(self.flow.log_prob(X, training=True))