Commit ae9a19a0 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Fix to loss returns

parent 9436f338
Loading
Loading
Loading
Loading
+19 −11
Original line number Diff line number Diff line
@@ -272,7 +272,7 @@ class KerasModel(Model):
          restore: bool = False,
          variables: Optional[List[tf.Variable]] = None,
          loss: Optional[KerasLossFn] = None,
          callbacks: Union[Callable, List[Callable]] = []) -> float:
          callbacks: Union[Callable, List[Callable]] = []) -> List[float]:
    """Train this model on a dataset.

    Parameters
@@ -313,14 +313,15 @@ class KerasModel(Model):
            deterministic=deterministic), max_checkpoints_to_keep,
        checkpoint_interval, restore, variables, loss, callbacks)

  def fit_generator(self,
  def fit_generator(
      self,
      generator: Iterable[Tuple[Any, Any, Any]],
      max_checkpoints_to_keep: int = 5,
      checkpoint_interval: int = 1000,
      restore: bool = False,
      variables: Optional[List[tf.Variable]] = None,
      loss: Optional[KerasLossFn] = None,
                    callbacks: Union[Callable, List[Callable]] = []) -> float:
      callbacks: Union[Callable, List[Callable]] = []) -> List[float]:
    """Train this model on data from a generator.

    Parameters
@@ -357,6 +358,7 @@ class KerasModel(Model):
    if checkpoint_interval > 0:
      manager = tf.train.CheckpointManager(self._checkpoint, self.model_dir,
                                           max_checkpoints_to_keep)
    avg_losses = []
    avg_loss = 0.0
    averaged_batches = 0
    train_op = None
@@ -403,6 +405,7 @@ class KerasModel(Model):
        avg_loss = float(avg_loss) / averaged_batches
        logger.info(
            'Ending global_step %d: Average loss %g' % (current_step, avg_loss))
        avg_losses.append(avg_loss)
        avg_loss = 0.0
        averaged_batches = 0

@@ -421,13 +424,14 @@ class KerasModel(Model):
      avg_loss = float(avg_loss) / averaged_batches
      logger.info(
          'Ending global_step %d: Average loss %g' % (current_step, avg_loss))
      avg_losses.append(avg_loss)

    if checkpoint_interval > 0:
      manager.save()

    time2 = time.time()
    logger.info("TIMING: model fitting took %0.3f s" % (time2 - time1))
    return avg_loss
    return avg_losses

  def _create_gradient_fn(self,
                          variables: Optional[List[tf.Variable]]) -> Callable:
@@ -496,7 +500,7 @@ class KerasModel(Model):
    """
    self._ensure_built()
    dataset = NumpyDataset(X, y, w)
    return self.fit(
    losses = self.fit(
        dataset,
        nb_epoch=1,
        max_checkpoints_to_keep=max_checkpoints_to_keep,
@@ -504,6 +508,10 @@ class KerasModel(Model):
        variables=variables,
        loss=loss,
        callbacks=callbacks)
    if len(losses) != 1:
      raise ValueError(
          "Each batch should take only one global step to fit. Unknown error.")
    return losses[0]

  def _predict(
      self, generator: Iterable[Tuple[Any, Any, Any]],
+2 −2
Original line number Diff line number Diff line
@@ -127,7 +127,7 @@ class Model(BaseEstimator):
    """
    raise NotImplementedError

  def fit(self, dataset: Dataset, nb_epoch: int = 10) -> float:
  def fit(self, dataset: Dataset, nb_epoch: int = 10) -> List[float]:
    """
    Fits a model on data in a Dataset object.

@@ -140,7 +140,7 @@ class Model(BaseEstimator):

    Returns
    -------
    the average loss over the most recent epoch
    The average losses over course of training. 
    """
    for epoch in range(nb_epoch):
      logger.info("Starting epoch %s" % str(epoch + 1))