Commit cb9e10f5 authored by peastman's avatar peastman
Browse files

Improvements to defining loss for KerasModel

parent 8c678197
Loading
Loading
Loading
Loading
+110 −15
Original line number Diff line number Diff line
@@ -3,12 +3,13 @@ import tensorflow as tf
import time

from deepchem.data import NumpyDataset
from deepchem.models.losses import Loss
from deepchem.models.models import Model
from deepchem.models.tensorgraph.optimizers import Adam


class KerasModel(Model):
  """This is a DeepChem model implement by a Keras model.
  """This is a DeepChem model implemented by a Keras model.

  This class provides several advantages over using the Keras model's fitting
  and prediction methods directly.
@@ -22,11 +23,54 @@ class KerasModel(Model):

  3. It provides various additional features not found in the Keras Model class,
     such as uncertainty prediction and saliency mapping.

  The loss function for a model can be defined in two different ways.  For
  models that have only a single output and use a standard loss function, you
  can simply provide a dc.models.losses.Loss object.  This defines the loss for
  each sample or sample/task pair.  The result is automatically multiplied by
  the weights and averaged over the batch.  Any additional losses computed by
  model layers, such as weight decay penalties, are also added.

  For more complicated cases, you can instead provide a function that directly
  computes the total loss.  It must be of the form f(inputs, labels, weights),
  taking the list of inputs to the model, the expected outputs, and any weight
  matrices.  It should return a scalar equal to the value of the loss function
  for the batch.  No additional processing is done to the result; it is up to
  you to do any weighting, averaging, adding of penalty terms, etc.

  You can optionally provide an output_types argument, which describes how to
  interpret the model's outputs.  This should be a list of strings, one for each
  output.  Each entry must have one of the following values:

  - 'prediction': This is a normal output, and will be returned by predict().
    If output types are not specified, all outputs are assumed to be of this
    type.

  - 'loss': This output will be used in place of the normal outputs for
    computing the loss function.  For example, models that output probability
    distributions usually do it by computing unbounded numbers (the logits),
    then passing them through a softmax function to turn them into
    probabilities.  When computing the cross entropy, it is more numerically
    stable to use the logits directly rather than the probabilities.  You can
    do this by having the model produce both probabilities and logits as
    outputs, then specifying output_types=['prediction', 'loss'].  When
    predict() is called, only the first output (the probabilities) will be
    returned.  But during training, it is the second output (the logits) that
    will be passed to the loss function.

  - 'variance': This output is used for estimating the uncertainty in another
    output.  To create a model that can estimate uncertainty, there must be the
    same number of 'prediction' and 'variance' outputs.  Each variance output
    must have the same shape as the corresponding prediction output, and each
    element is an estimate of the variance in the corresponding prediction.
    Also be aware that if a model supports uncertainty, it MUST use dropout on
    every layer.  Otherwise, the uncertainties it computes will be inaccurate.
  """

  def __init__(self,
               model,
               loss_fn,
               loss,
               output_types=None,
               batch_size=100,
               model_dir=None,
               learning_rate=0.001,
@@ -38,13 +82,11 @@ class KerasModel(Model):
    ----------
    model: tf.keras.Model
      the Keras model implementing the calculation
    loss_fn: function
      a function defining the training loss for a batch.  It must be of the form
      f(inputs, labels, weights), taking the list of inputs to the model, the
      expected outputs, and weight matrices.  It should return a scalar equal to
      the value of the loss function for the batch.  (This is different from the
      loss function used by tf.keras.Model.compile(), which corresponds only to
      a single sample and does not include regularization terms.)
    loss: dc.models.losses.Loss or function
      a Loss or function defining how to compute the training loss for each
      batch, as described above
    output_types: list of strings
      the type of each output from the model, as described above
    batch_size: int
      default batch size for training and evaluating
    model_dir: str
@@ -60,12 +102,32 @@ class KerasModel(Model):
    super(KerasModel, self).__init__(
        model_instance=model, model_dir=model_dir, **kwargs)
    self.model = model
    self.loss_fn = loss_fn
    if isinstance(loss, Loss):
      self._loss_fn = _StandardLoss(model, loss)
    else:
      self._loss_fn = loss
    self.batch_size = batch_size
    if optimizer is None:
      self.optimizer = Adam(learning_rate=learning_rate)
    else:
      self.optimizer = optimizer
    if output_types is None:
      self._prediction_outputs = None
      self._loss_outputs = None
      self._variance_outputs = None
    else:
      self._prediction_outputs = []
      self._loss_outputs = []
      self._variance_outputs = []
      for i, type in enumerate(output_types):
        if type == 'prediction':
          self._prediction_outputs.append(i)
        elif type == 'loss':
          self._loss_outputs.append(i)
        elif type == 'variance':
          self._variance_outputs.append(i)
        else:
          raise ValueError('Unknown output type "%s"' % type)
    self._built = False
    self._training_ops_built = False

@@ -108,10 +170,20 @@ class KerasModel(Model):
        tf.placeholder(dtype=tf.float32, shape=t.shape)
        for t in example_batch[2]
    ]
    self._output_tensors = self.model(self._input_placeholders, training=False)
    self._loss_tensor = self.loss_fn(self._input_placeholders,
                                     self._label_placeholders,
                                     self._weights_placeholders)
    if len(self._input_placeholders) == 1:
      self._output_tensors = self.model(
          self._input_placeholders[0], training=False)
    else:
      self._output_tensors = self.model(
          self._input_placeholders, training=False)
    if isinstance(self._output_tensors, tf.Tensor):
      self._output_tensors = [self._output_tensors]
    if self._prediction_outputs is None:
      self._prediction_outputs = list(range(len(self._output_tensors)))
      self._loss_outputs = list(range(len(self._output_tensors)))
    self._loss_tensor = self._loss_fn(
        [self._output_tensors[i] for i in self._loss_outputs],
        self._label_placeholders, self._weights_placeholders)
    try:
      self._train_op = self._tf_optimizer.minimize(
          self._loss_tensor, global_step=self._global_step)
@@ -197,7 +269,12 @@ class KerasModel(Model):
        # In eager mode we execute the loss function, accumulating the gradients.

        with tf.GradientTape() as tape:
          loss = self.loss_fn(inputs, labels, weights)
          outputs = self.model(inputs[0])
          if isinstance(outputs, tf.Tensor):
            outputs = [outputs]
          if self._loss_outputs is not None:
            outputs = [outputs[i] for i in self._loss_outputs]
          loss = self._loss_fn(outputs, labels, weights)
          avg_loss += loss
          grads = tape.gradient(loss, self.model.trainable_variables)
          self._tf_optimizer.apply_gradients(
@@ -282,6 +359,8 @@ class KerasModel(Model):

      # Apply tranformers and record results.

      if self._prediction_outputs is not None:
        outputs = [outputs[i] for i in self._prediction_outputs]
      if len(transformers) > 0:
        if len(outputs) > 1:
          raise ValueError(
@@ -419,3 +498,19 @@ class KerasModel(Model):
    else:
      with self.session.as_default():
        self.model.load_weights(checkpoint)


class _StandardLoss(object):
  """The implements the loss function for models that use a dc.models.losses.Loss."""

  def __init__(self, model, loss):
    self.model = model
    self.loss = loss

  def __call__(self, outputs, labels, weights):
    if len(outputs) != 1 or len(labels) != 1 or len(weights) != 1:
      raise ValueError(
          "Loss functions expects exactly one each of outputs, labels, and weights"
      )
    loss = self.loss(outputs[0], labels[0]) * weights[0]
    return tf.reduce_mean(loss) + sum(self.model.losses)
+133 −0
Original line number Diff line number Diff line
import tensorflow as tf


class Loss:
  """A loss function for use in training models."""

  def __call__(self, output, labels):
    """Compute the loss function.

    The inputs are tensors containing the model's outputs and the labels for a
    batch.  The return value should be a tensor of shape (batch_size) or
    (batch_size, tasks) containing the value of the loss function on each
    sample or sample/task.

    Parameters
    ----------
    output: tensor
      the output of the model
    labels: tensor
      the expected output

    Returns
    -------
    The value of the loss function on each sample or sample/task pair
    """
    raise NotImplementedError("Subclasses must implement this")


class L1Loss(Loss):
  """The absolute difference between the true and predicted values."""

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    return tf.abs(output - labels)


class L2Loss(Loss):
  """The squared difference between the true and predicted values."""

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    return tf.square(output - labels)


class HingeLoss(Loss):
  """The hinge loss function.

  The 'output' argument should contain logits, and all elements of 'labels'
  should equal 0 or 1.
  """

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    return tf.losses.hinge_loss(
        labels, output, reduction=tf.losses.Reduction.NONE)


class BinaryCrossEntropy(Loss):
  """The cross entropy between pairs of probabilities.

  The arguments should each have shape (batch_size) or (batch_size, tasks) and
  contain probabilities.
  """

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    return tf.keras.losses.binary_crossentropy(labels, output)


class CategoricalCrossEntropy(Loss):
  """The cross entropy between two probability distributions.

  The arguments should each have shape (batch_size, classes) or
  (batch_size, tasks, classes), and represent a probability distribution over
  classes.
  """

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    return tf.keras.losses.categorical_crossentropy(labels, output)


class SigmoidCrossEntropy(Loss):
  """The cross entropy between pairs of probabilities.

  The arguments should each have shape (batch_size) or (batch_size, tasks).  The
  labels should be probabilities, while the outputs should be logits that are
  converted to probabilities using a sigmoid function.
  """

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    return tf.losses.sigmoid_cross_entropy(
        labels, output, reduction=tf.losses.Reduction.NONE)


class SoftmaxCrossEntropy(Loss):
  """The cross entropy between two probability distributions.

  The arguments should each have shape (batch_size, classes) or
  (batch_size, tasks, classes).  The labels should be probabilities, while the
  outputs should be logits that are converted to probabilities using a softmax
  function.
  """

  def __call__(self, output, labels):
    output, labels = _make_shapes_consistent(output, labels)
    return tf.losses.softmax_cross_entropy(
        labels, output, reduction=tf.losses.Reduction.NONE)


def _make_shapes_consistent(output, labels):
  """Try to make inputs have the same shape by adding dimensions of size 1."""
  shape1 = output.shape
  shape2 = labels.shape
  len1 = len(shape1)
  len2 = len(shape2)
  if len1 == len2:
    return (output, labels)
  if isinstance(shape1, tf.TensorShape):
    shape1 = tuple(shape1.as_list())
  if isinstance(shape2, tf.TensorShape):
    shape2 = tuple(shape2.as_list())
  if len1 > len2 and all(i == 1 for i in shape1[len2:]):
    for i in range(len1 - len2):
      labels = tf.expand_dims(labels, -1)
    return (output, labels)
  if len2 > len1 and all(i == 1 for i in shape2[len1:]):
    for i in range(len2 - len1):
      output = tf.expand_dims(output, -1)
    return (output, labels)
  raise ValueError("Incompatible shapes for outputs and labels: %s versus %s" %
                   (str(shape1), str(shape2)))
+12 −18
Original line number Diff line number Diff line
@@ -12,19 +12,18 @@ class TestKerasModel(unittest.TestCase):
    n_data_points = 10
    n_features = 2
    X = np.random.rand(n_data_points, n_features).astype(np.float32)
    y = np.expand_dims(X[:, 0] > X[:, 1], 1).astype(np.float32)
    y = (X[:, 0] > X[:, 1]).astype(np.float32)
    dataset = dc.data.NumpyDataset(X, y)
    inputs = tf.keras.Input(shape=(n_features,))
    hidden = tf.keras.layers.Dense(10, activation='relu')(inputs)
    outputs = tf.keras.layers.Dense(1, activation='sigmoid')(hidden)
    keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)

    def loss_fn(inputs, labels, weights):
      return tf.reduce_mean(
          tf.keras.metrics.binary_crossentropy(labels[0],
                                               keras_model(inputs[0])))

    model = dc.models.KerasModel(keras_model, loss_fn, learning_rate=0.005)
    logits = tf.keras.layers.Dense(1)(hidden)
    outputs = tf.keras.layers.Activation('sigmoid')(logits)
    keras_model = tf.keras.Model(inputs=inputs, outputs=[outputs, logits])
    model = dc.models.KerasModel(
        keras_model,
        dc.models.losses.SigmoidCrossEntropy(),
        output_types=['prediction', 'loss'],
        learning_rate=0.005)
    model.fit(dataset, nb_epoch=1000)
    prediction = np.squeeze(model.predict_on_batch(X))
    assert np.all(np.isclose(prediction, y.flatten(), atol=0.4))
@@ -42,19 +41,14 @@ class TestKerasModel(unittest.TestCase):
    n_data_points = 10
    n_features = 2
    X = np.random.rand(n_data_points, n_features).astype(np.float32)
    y = np.expand_dims(X[:, 0] > X[:, 1], 1).astype(np.float32)
    y = (X[:, 0] > X[:, 1]).astype(np.float32)
    dataset = dc.data.NumpyDataset(X, y)
    keras_model = tf.keras.Sequential([
        tf.keras.layers.Dense(10, activation='relu'),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])

    def loss_fn(inputs, labels, weights):
      return tf.reduce_mean(
          tf.keras.metrics.binary_crossentropy(labels[0],
                                               keras_model(inputs[0])))

    model = dc.models.KerasModel(keras_model, loss_fn, learning_rate=0.005)
    model = dc.models.KerasModel(
        keras_model, dc.models.losses.BinaryCrossEntropy(), learning_rate=0.005)
    model.fit(dataset, nb_epoch=1000)
    prediction = np.squeeze(model.predict_on_batch(X))
    assert np.all(np.isclose(prediction, y.flatten(), atol=0.4))