Commit e649d199 authored by peastman's avatar peastman
Browse files

Uncertainty prediction of saliency mapping for KerasModel

parent cb9e10f5
Loading
Loading
Loading
Loading
+252 −31
Original line number Diff line number Diff line
@@ -32,8 +32,8 @@ class KerasModel(Model):
  model layers, such as weight decay penalties, are also added.

  For more complicated cases, you can instead provide a function that directly
  computes the total loss.  It must be of the form f(inputs, labels, weights),
  taking the list of inputs to the model, the expected outputs, and any weight
  computes the total loss.  It must be of the form f(outputs, labels, weights),
  taking the list of outputs from the model, the expected values, and any weight
  matrices.  It should return a scalar equal to the value of the loss function
  for the batch.  No additional processing is done to the result; it is up to
  you to do any weighting, averaging, adding of penalty terms, etc.
@@ -129,9 +129,12 @@ class KerasModel(Model):
        else:
          raise ValueError('Unknown output type "%s"' % type)
    self._built = False
    self._inputs_built = False
    self._training_ops_built = False
    self._initialized_vars = set()

  def _ensure_built(self):
    """The first time this is called, create internal data structures."""
    if self._built:
      return
    self._built = True
@@ -141,12 +144,14 @@ class KerasModel(Model):
    self._tf_optimizer = self.optimizer._create_optimizer(self._global_step)
    self._checkpoint = tf.train.Checkpoint(
        optimizer=self._tf_optimizer, model=self.model)
    self._init_new_vars()

  def _create_training_ops(self, example_batch):
    if self._training_ops_built:
  def _create_inputs(self, example_inputs):
    """The first time this is called, create tensors representing the inputs and outputs."""
    if self._inputs_built:
      return
    self._ensure_built()
    self._training_ops_built = True
    self._inputs_built = True
    if tf.executing_eagerly():
      return
    if len(self.model.inputs) > 0:
@@ -154,7 +159,7 @@ class KerasModel(Model):
    else:
      # The model doesn't specify inputs, so guess the input shapes based on the
      # example batch.
      input_shapes = [(None,) + i.shape[1:] for i in example_batch[0]]
      input_shapes = [(None,) + i.shape[1:] for i in example_inputs]
      self._input_placeholders = [
          tf.placeholder(dtype=tf.float32, shape=s) for s in input_shapes
      ]
@@ -162,25 +167,39 @@ class KerasModel(Model):
        self.model.build(input_shapes[0])
      else:
        self.model.build(input_shapes)
    self._label_placeholders = [
        tf.placeholder(dtype=tf.float32, shape=t.shape)
        for t in example_batch[1]
    ]
    self._weights_placeholders = [
        tf.placeholder(dtype=tf.float32, shape=t.shape)
        for t in example_batch[2]
    ]
    if len(self._input_placeholders) == 1:
      self._output_tensors = self.model(
          self._input_placeholders[0], training=False)
      self._uncertainty_tensors = self.model(
          self._input_placeholders[0], training=True)
    else:
      self._output_tensors = self.model(
          self._input_placeholders, training=False)
      self._uncertainty_tensors = self.model(
          self._input_placeholders, training=True)
    if isinstance(self._output_tensors, tf.Tensor):
      self._output_tensors = [self._output_tensors]
    if self._prediction_outputs is None:
      self._prediction_outputs = list(range(len(self._output_tensors)))
      self._loss_outputs = list(range(len(self._output_tensors)))
    self._init_new_vars()

  def _create_training_ops(self, example_batch):
    """The first time this is called, create tensors used in optimization."""
    if self._training_ops_built:
      return
    self._create_inputs(example_batch[0])
    self._training_ops_built = True
    if tf.executing_eagerly():
      return
    self._label_placeholders = [
        tf.placeholder(dtype=tf.float32, shape=t.shape)
        for t in example_batch[1]
    ]
    self._weights_placeholders = [
        tf.placeholder(dtype=tf.float32, shape=t.shape)
        for t in example_batch[2]
    ]
    self._loss_tensor = self._loss_fn(
        [self._output_tensors[i] for i in self._loss_outputs],
        self._label_placeholders, self._weights_placeholders)
@@ -190,7 +209,15 @@ class KerasModel(Model):
    except ValueError:
      # The loss doesn't depend on any variables.
      self._train_op = 0
    self.session.run(tf.global_variables_initializer())
    self._init_new_vars()

  def _init_new_vars(self):
    """Initialize any new variables created since the last call to this method."""
    if not tf.executing_eagerly():
      vars = set(tf.global_variables())
      new_vars = vars.difference(self._initialized_vars)
      self.session.run(tf.variables_initializer(new_vars))
      self._initialized_vars = vars

  def fit(self,
          dataset,
@@ -316,12 +343,23 @@ class KerasModel(Model):
    return avg_loss

  def fit_on_batch(self, X, y, w):
    """Perform a single step of training.

    Parameters
    ----------
    X: ndarray
      the inputs for the batch
    y: ndarray
      the labels for the batch
    w: ndarray
      the weights for the batch
   """
    if not self.built:
      self.build()
    dataset = NumpyDataset(X, y, w)
    return self.fit(dataset, nb_epoch=1)

  def _predict(self, generator, transformers, outputs):
  def _predict(self, generator, transformers, uncertainty):
    """
    Predict outputs for data provided by a generator.

@@ -335,30 +373,53 @@ class KerasModel(Model):
      (inputs, labels, weights).
    transformers: list
      List of dc.trans.Transformers.
    uncertainty: bool
      specifies whether this is being called as part of estimating uncertainty.
      If True, it sets the training flag so that dropout will be enabled, and
      returns the values of the uncertainty outputs.
    Returns:
      a NumPy array of the model produces a single output, or a list of arrays
      if it produces multiple outputs
    """
    results = None
    variances = None
    if uncertainty:
      if self._variance_outputs is None or len(self._variance_outputs) == 0:
        raise ValueError('This model cannot compute uncertainties')
      if len(self._variance_outputs) != len(self._prediction_outputs):
        raise ValueError(
            'The number of variances must exactly match the number of outputs')
    for batch in generator:
      self._create_training_ops(batch)
      inputs, labels, weights = batch
      self._create_inputs(inputs)
      if tf.executing_eagerly():

        # In eager mode we invoke the model directly.

        outputs = self.model(inputs)
        if len(inputs) == 1:
          inputs = inputs[0]
        outputs = self.model(inputs, training=uncertainty)
        outputs = [t.numpy() for t in outputs]
      else:

        # In graph mode we execute the output tensors.

        fetches = [self._train_op, self._loss_tensor, self._global_step]
        if uncertainty:
          fetches = self._uncertainty_tensors
        else:
          fetches = self._output_tensors
        feed_dict = dict(zip(self._input_placeholders, inputs))
        outputs = self.session.run(self._output_tensors, feed_dict=feed_dict)
        outputs = self.session.run(fetches, feed_dict=feed_dict)

      # Apply tranformers and record results.

      if uncertainty:
        var = [outputs[i] for i in self._variance_outputs]
        if variances is None:
          variances = var
        else:
          for i, t in enumerate(var):
            variances[i].append(t)
      if self._prediction_outputs is not None:
        outputs = [outputs[i] for i in self._prediction_outputs]
      if len(transformers) > 0:
@@ -377,15 +438,20 @@ class KerasModel(Model):
    # Concatenate arrays to create the final results.

    final_results = []
    for result_list in results:
      final_results.append(np.concatenate(result_list, axis=0))
    final_variances = []
    for r in results:
      final_results.append(np.concatenate(r, axis=0))
    if uncertainty:
      for v in variances:
        final_variances.append(np.concatenate(v, axis=0))
      return zip(final_results, final_variances)
    # If only one output, just return array
    if len(final_results) == 1:
      return final_results[0]
    else:
      return final_results

  def predict_on_generator(self, generator, transformers=[], outputs=None):
  def predict_on_generator(self, generator, transformers=[]):
    """
    Parameters
    ----------
@@ -398,9 +464,9 @@ class KerasModel(Model):
      a NumPy array of the model produces a single output, or a list of arrays
      if it produces multiple outputs
    """
    return self._predict(generator, transformers, outputs)
    return self._predict(generator, transformers, False)

  def predict_on_batch(self, X, transformers=[], outputs=None):
  def predict_on_batch(self, X, transformers=[]):
    """Generates predictions for input samples, processing samples in a batch.

    Parameters
@@ -416,9 +482,36 @@ class KerasModel(Model):
    if it produces multiple outputs
    """
    dataset = NumpyDataset(X=X, y=None)
    return self.predict(dataset, transformers, outputs)
    return self.predict(dataset, transformers)

  def predict(self, dataset, transformers=[], outputs=None):
  def predict_uncertainty_on_batch(self, X, masks=50):
    """
    Predict the model's outputs, along with the uncertainty in each one.

    The uncertainty is computed as described in https://arxiv.org/abs/1703.04977.
    It involves repeating the prediction many times with different dropout masks.
    The prediction is computed as the average over all the predictions.  The
    uncertainty includes both the variation among the predicted values (epistemic
    uncertainty) and the model's own estimates for how well it fits the data
    (aleatoric uncertainty).  Not all models support uncertainty prediction.

    Parameters
    ----------
    X: ndarray
      the input data, as a Numpy array.
    masks: int
      the number of dropout masks to average over

    Returns
    -------
    for each output, a tuple (y_pred, y_std) where y_pred is the predicted
    value of the output, and each element of y_std estimates the standard
    deviation of the corresponding element of y_pred
    """
    dataset = NumpyDataset(X=X, y=None)
    return self.predict_uncertainty(dataset, masks)

  def predict(self, dataset, transformers=[]):
    """
    Uses self to make predictions on provided Dataset object.

@@ -435,7 +528,136 @@ class KerasModel(Model):
    if it produces multiple outputs
    """
    generator = self.default_generator(dataset, predict=True, pad_batches=False)
    return self.predict_on_generator(generator, transformers, outputs)
    return self.predict_on_generator(generator, transformers)

  def predict_uncertainty(self, dataset, masks=50):
    """
    Predict the model's outputs, along with the uncertainty in each one.

    The uncertainty is computed as described in https://arxiv.org/abs/1703.04977.
    It involves repeating the prediction many times with different dropout masks.
    The prediction is computed as the average over all the predictions.  The
    uncertainty includes both the variation among the predicted values (epistemic
    uncertainty) and the model's own estimates for how well it fits the data
    (aleatoric uncertainty).  Not all models support uncertainty prediction.

    Parameters
    ----------
    dataset: dc.data.Dataset
      Dataset to make prediction on
    masks: int
      the number of dropout masks to average over

    Returns
    -------
    for each output, a tuple (y_pred, y_std) where y_pred is the predicted
    value of the output, and each element of y_std estimates the standard
    deviation of the corresponding element of y_pred
    """
    sum_pred = []
    sum_sq_pred = []
    sum_var = []
    for i in range(masks):
      generator = self.default_generator(
          dataset, predict=True, pad_batches=False)
      results = self._predict(generator, [], True)
      if len(sum_pred) == 0:
        for p, v in results:
          sum_pred.append(p)
          sum_sq_pred.append(p * p)
          sum_var.append(v)
      else:
        for j, (p, v) in enumerate(results):
          sum_pred[j] += p
          sum_sq_pred[j] += p * p
          sum_var[j] += v
    output = []
    std = []
    for i in range(len(sum_pred)):
      p = sum_pred[i] / masks
      output.append(p)
      std.append(np.sqrt(sum_sq_pred[i] / masks - p * p + sum_var[i] / masks))
    if len(output) == 1:
      return (output[0], std[0])
    else:
      return zip(output, std)

  def compute_saliency(self, X):
    """Compute the saliency map for an input sample.

    This computes the Jacobian matrix with the derivative of each output element
    with respect to each input element.  More precisely,

    - If this model has a single output, it returns a matrix of shape
      (output_shape, input_shape) with the derivatives.
    - If this model has multiple outputs, it returns a list of matrices, one
      for each output.

    This method cannot be used on models that take multiple inputs.

    Parameters
    ----------
    X: ndarray
      the input data for a single sample

    Returns
    -------
    the Jacobian matrix, or a list of matrices
    """
    input_shape = X.shape
    X = np.reshape(X, [1] + list(X.shape))
    self._create_inputs([X])
    if tf.executing_eagerly():
      # In eager mode we use a GradientTape to compute gradients.

      X = tf.constant(X)
      with tf.GradientTape(
          persistent=True, watch_accessed_variables=False) as tape:
        tape.watch(X)
        outputs = self.model(X)
        if isinstance(outputs, tf.Tensor):
          outputs = [outputs]
        final_result = []
        for output in outputs:
          output_shape = tuple(output.shape.as_list()[1:])
          output = tf.reshape(output, [-1])
          result = []
          for i in range(output.shape[0]):
            result.append(tape.gradient(output[i], X))
          final_result.append(
              tf.reshape(tf.stack(result), output_shape + input_shape).numpy())
    else:
      # In graph mode we use tf.gradients().

      def jacobian(y, x):
        # Adapted from https://github.com/tensorflow/tensorflow/issues/675#issuecomment-319891923.
        y = tf.reshape(tf.convert_to_tensor(y)[0], [-1])
        n = y.shape[0]
        loop_vars = [
            tf.constant(0, tf.int32),
            tf.TensorArray(tf.float32, size=n)
        ]
        _, jacobian = tf.while_loop(
            lambda j, _: j < n,
            lambda j, result: (j + 1, result.write(j, tf.gradients(y[j], x))),
            loop_vars)
        return jacobian.stack()

      grads = [
          jacobian(self._output_tensors[i], self._input_placeholders[0])
          for i in self._prediction_outputs
      ]
      feed_dict = {self._input_placeholders[0]: X}
      result = self.session.run(grads, feed_dict=feed_dict)
      output_shapes = [
          tuple(o.shape.as_list()[1:]) for o in self._output_tensors
      ]
      final_result = [
          x.reshape(s + input_shape) for x, s in zip(result, output_shapes)
      ]
    if len(final_result) == 1:
      return final_result[0]
    return final_result

  def default_generator(self,
                        dataset,
@@ -494,10 +716,9 @@ class KerasModel(Model):
    if checkpoint is None:
      raise ValueError('No checkpoint found')
    if tf.executing_eagerly():
      self.model.load_weights(checkpoint)
      self._checkpoint.restore(checkpoint)
    else:
      with self.session.as_default():
        self.model.load_weights(checkpoint)
      self._checkpoint.restore(checkpoint).run_restore_ops(self.session)


class _StandardLoss(object):
+129 −0
Original line number Diff line number Diff line
@@ -60,3 +60,132 @@ class TestKerasModel(unittest.TestCase):
    """Test fitting a KerasModel defined as a sequential model, in eager mode."""
    with context.eager_mode():
      self.test_overfit_sequential_model()

  def test_checkpointing(self):
    """Test loading and saving checkpoints with KerasModel."""
    # Create two models using the same model directory.

    keras_model1 = tf.keras.Sequential([tf.keras.layers.Dense(10)])
    keras_model2 = tf.keras.Sequential([tf.keras.layers.Dense(10)])
    model1 = dc.models.KerasModel(keras_model1, dc.models.losses.L2Loss())
    model2 = dc.models.KerasModel(
        keras_model2, dc.models.losses.L2Loss(), model_dir=model1.model_dir)

    # Check that they produce different results.

    X = np.random.rand(5, 5)
    y1 = model1.predict_on_batch(X)
    y2 = model2.predict_on_batch(X)
    assert not np.array_equal(y1, y2)

    # Save a checkpoint from the first model and load it into the second one,
    # and make sure they now match.

    model1.save_checkpoint()
    model2.restore()
    y3 = model1.predict_on_batch(X)
    y4 = model2.predict_on_batch(X)
    assert np.array_equal(y1, y3)
    assert np.array_equal(y1, y4)

  def test_checkpointing_eager(self):
    """Test loading and saving checkpoints with KerasModel, in eager mode."""
    with context.eager_mode():
      self.test_checkpointing()

  def test_uncertainty(self):
    """Test estimating uncertainty a KerasModel."""
    n_samples = 30
    n_features = 1
    noise = 0.1
    X = np.random.rand(n_samples, n_features).astype(np.float32)
    y = (10 * X + np.random.normal(
        scale=noise, size=(n_samples, n_features))).astype(np.float32)
    dataset = dc.data.NumpyDataset(X, y)

    # Build a model that predicts uncertainty.

    inputs = tf.keras.Input(shape=(n_features,))
    hidden = tf.keras.layers.Dense(200, activation='relu')(inputs)
    dropout = tf.keras.layers.Dropout(rate=0.1)(hidden)
    output = tf.keras.layers.Dense(n_features)(dropout)
    log_var = tf.keras.layers.Dense(n_features)(dropout)
    var = tf.keras.layers.Activation(tf.exp)(log_var)
    keras_model = tf.keras.Model(
        inputs=inputs, outputs=[output, var, output, log_var])

    def loss(outputs, labels, weights):
      diff = labels[0] - outputs[0]
      log_var = outputs[1]
      var = tf.exp(log_var)
      return tf.reduce_mean(diff * diff / var + log_var)

    model = dc.models.KerasModel(
        keras_model,
        loss,
        output_types=['prediction', 'variance', 'loss', 'loss'],
        learning_rate=0.003)

    # Fit the model and see if its predictions are correct.

    model.fit(dataset, nb_epoch=2500)
    pred, std = model.predict_uncertainty(dataset)
    assert np.mean(np.abs(y - pred)) < 1.0
    assert noise < np.mean(std) < 1.0

  def test_uncertainty_eager(self):
    """Test estimating uncertainty a KerasModel, in eager mode."""
    with context.eager_mode():
      self.test_uncertainty()

  def test_saliency_mapping(self):
    """Test computing a saliency map."""
    n_tasks = 3
    n_features = 5
    keras_model = tf.keras.Sequential([
        tf.keras.layers.Dense(20, activation='tanh'),
        tf.keras.layers.Dense(n_tasks)
    ])
    model = dc.models.KerasModel(keras_model, dc.models.losses.L2Loss())
    x = np.random.random(n_features).astype(np.float32)
    s = model.compute_saliency(x)
    assert s.shape[0] == n_tasks
    assert s.shape[1] == n_features

    # Take a tiny step in the direction of s and see if the output changes by
    # the expected amount.

    delta = 0.01
    for task in range(n_tasks):
      norm = np.sqrt(np.sum(s[task]**2))
      step = 0.5 * delta / norm
      pred1 = model.predict_on_batch((x + s[task] * step).reshape(
          (1, n_features))).flatten()
      pred2 = model.predict_on_batch((x - s[task] * step).reshape(
          (1, n_features))).flatten()
      self.assertAlmostEqual(
          pred1[task], (pred2 + norm * delta)[task], places=4)

  def test_saliency_mapping_eager(self):
    """Test computing a saliency map, in eager mode."""
    with context.eager_mode():
      self.test_saliency_mapping()

  def test_saliency_shapes(self):
    """Test computing saliency maps for multiple outputs with multiple dimensions."""
    inputs = tf.keras.Input(shape=(2, 3))
    flatten = tf.keras.layers.Flatten()(inputs)
    output1 = tf.keras.layers.Reshape((4, 1))(tf.keras.layers.Dense(4)(flatten))
    output2 = tf.keras.layers.Reshape((1, 5))(tf.keras.layers.Dense(5)(flatten))
    keras_model = tf.keras.Model(inputs=inputs, outputs=[output1, output2])
    model = dc.models.KerasModel(keras_model, dc.models.losses.L2Loss())
    x = np.random.random((2, 3)).astype(np.float32)
    s = model.compute_saliency(x)
    assert len(s) == 2
    assert s[0].shape == (4, 1, 2, 3)
    assert s[1].shape == (1, 5, 2, 3)

  def test_saliency_shapes_eager(self):
    """Test computing saliency maps for multiple outputs with multiple dimensions, in eager mode."""
    with context.eager_mode():
      self.test_saliency_shapes()