Unverified Commit 5984e9e1 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2060 from deepchem/keras_loss

Allow for reporting of loss curve from KerasModel.fit
parents b1e63162 3f45128c
Loading
Loading
Loading
Loading
+25 −6
Original line number Diff line number Diff line
@@ -272,7 +272,8 @@ class KerasModel(Model):
          restore: bool = False,
          variables: Optional[List[tf.Variable]] = None,
          loss: Optional[KerasLossFn] = None,
          callbacks: Union[Callable, List[Callable]] = []) -> float:
          callbacks: Union[Callable, List[Callable]] = [],
          all_losses: Optional[List[float]] = None) -> float:
    """Train this model on a dataset.

    Parameters
@@ -302,16 +303,20 @@ class KerasModel(Model):
    callbacks: function or list of functions
      one or more functions of the form f(model, step) that will be invoked after
      every step.  This can be used to perform validation, logging, etc.
    all_losses: Optional[List[float]], optional (default None)
      If specified, all logged losses are appended into this list. Note that
      you can call `fit()` repeatedly with the same list and losses will
      continue to be appended.

    Returns
    -------
    the average loss over the most recent checkpoint interval
    The average loss over the most recent checkpoint interval
   """
    return self.fit_generator(
        self.default_generator(
            dataset, epochs=nb_epoch,
            deterministic=deterministic), max_checkpoints_to_keep,
        checkpoint_interval, restore, variables, loss, callbacks)
        checkpoint_interval, restore, variables, loss, callbacks, all_losses)

  def fit_generator(self,
                    generator: Iterable[Tuple[Any, Any, Any]],
@@ -320,7 +325,8 @@ class KerasModel(Model):
                    restore: bool = False,
                    variables: Optional[List[tf.Variable]] = None,
                    loss: Optional[KerasLossFn] = None,
                    callbacks: Union[Callable, List[Callable]] = []) -> float:
                    callbacks: Union[Callable, List[Callable]] = [],
                    all_losses: Optional[List[float]] = None) -> float:
    """Train this model on data from a generator.

    Parameters
@@ -346,10 +352,14 @@ class KerasModel(Model):
    callbacks: function or list of functions
      one or more functions of the form f(model, step) that will be invoked after
      every step.  This can be used to perform validation, logging, etc.
    all_losses: Optional[List[float]], optional (default None)
      If specified, all logged losses are appended into this list. Note that
      you can call `fit()` repeatedly with the same list and losses will
      continue to be appended.

    Returns
    -------
    the average loss over the most recent checkpoint interval
    The average loss over the most recent checkpoint interval
    """
    if not isinstance(callbacks, SequenceCollection):
      callbacks = [callbacks]
@@ -358,6 +368,7 @@ class KerasModel(Model):
      manager = tf.train.CheckpointManager(self._checkpoint, self.model_dir,
                                           max_checkpoints_to_keep)
    avg_loss = 0.0
    last_avg_loss = 0.0
    averaged_batches = 0
    train_op = None
    if loss is None:
@@ -403,6 +414,11 @@ class KerasModel(Model):
        avg_loss = float(avg_loss) / averaged_batches
        logger.info(
            'Ending global_step %d: Average loss %g' % (current_step, avg_loss))
        if all_losses is not None:
          all_losses.append(avg_loss)
        # Capture the last avg_loss in case of return since we're resetting to
        # 0 now
        last_avg_loss = avg_loss
        avg_loss = 0.0
        averaged_batches = 0

@@ -421,13 +437,16 @@ class KerasModel(Model):
      avg_loss = float(avg_loss) / averaged_batches
      logger.info(
          'Ending global_step %d: Average loss %g' % (current_step, avg_loss))
      if all_losses is not None:
        all_losses.append(avg_loss)
      last_avg_loss = avg_loss

    if checkpoint_interval > 0:
      manager.save()

    time2 = time.time()
    logger.info("TIMING: model fitting took %0.3f s" % (time2 - time1))
    return avg_loss
    return last_avg_loss

  def _create_gradient_fn(self,
                          variables: Optional[List[tf.Variable]]) -> Callable:
+1 −1
Original line number Diff line number Diff line
@@ -140,7 +140,7 @@ class Model(BaseEstimator):

    Returns
    -------
    the average loss over the most recent epoch
    The average loss over the most recent checkpoint interval. 
    """
    for epoch in range(nb_epoch):
      logger.info("Starting epoch %s" % str(epoch + 1))
+23 −0
Original line number Diff line number Diff line
@@ -58,6 +58,29 @@ def test_overfit_sequential_model():
  assert scores[metric.name] > 0.9


def test_fit_use_all_losses():
  """Test fitting a KerasModel and getting a loss curve back."""
  n_data_points = 10
  n_features = 2
  X = np.random.rand(n_data_points, n_features)
  y = (X[:, 0] > X[:, 1]).astype(np.float32)
  dataset = dc.data.NumpyDataset(X, y)
  keras_model = tf.keras.Sequential([
      tf.keras.layers.Dense(10, activation='relu'),
      tf.keras.layers.Dense(1, activation='sigmoid')
  ])
  model = dc.models.KerasModel(
      keras_model,
      dc.models.losses.BinaryCrossEntropy(),
      learning_rate=0.005,
      log_frequency=10)
  losses = []
  model.fit(dataset, nb_epoch=1000, all_losses=losses)
  # Each epoch is a single step for this model
  assert len(losses) == 100
  assert np.count_nonzero(np.array(losses)) == 100


def test_fit_on_batch():
  """Test fitting a KerasModel to individual batches."""
  n_data_points = 10