Commit 4a51afbf authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Improving types and simplifying return

parent 8d7f70ab
Loading
Loading
Loading
Loading
+14 −13
Original line number Diff line number Diff line
@@ -273,7 +273,7 @@ class KerasModel(Model):
          variables: Optional[List[tf.Variable]] = None,
          loss: Optional[KerasLossFn] = None,
          callbacks: Union[Callable, List[Callable]] = [],
          all_losses: Optional[list] = None) -> float:
          all_losses: Optional[List[float]] = None) -> float:
    """Train this model on a dataset.

    Parameters
@@ -303,7 +303,7 @@ class KerasModel(Model):
    callbacks: function or list of functions
      one or more functions of the form f(model, step) that will be invoked after
      every step.  This can be used to perform validation, logging, etc.
    all_losses: list, optional (default False)
    all_losses: Optional[List[float]], optional (default False)
      If specified, all logged losses are appended into this list. Note that
      you can call `fit()` repeatedly with the same list and losses will
      continue to be appended.
@@ -326,7 +326,7 @@ class KerasModel(Model):
                    variables: Optional[List[tf.Variable]] = None,
                    loss: Optional[KerasLossFn] = None,
                    callbacks: Union[Callable, List[Callable]] = [],
                    all_losses: Optional[list] = None) -> float:
                    all_losses: Optional[List[float]] = None) -> float:
    """Train this model on data from a generator.

    Parameters
@@ -352,7 +352,7 @@ class KerasModel(Model):
    callbacks: function or list of functions
      one or more functions of the form f(model, step) that will be invoked after
      every step.  This can be used to perform validation, logging, etc.
    all_losses: list, optional (default False)
    all_losses: Optional[List[float]], optional (default False)
      If specified, all logged losses are appended into this list. Note that
      you can call `fit()` repeatedly with the same list and losses will
      continue to be appended.
@@ -367,8 +367,8 @@ class KerasModel(Model):
    if checkpoint_interval > 0:
      manager = tf.train.CheckpointManager(self._checkpoint, self.model_dir,
                                           max_checkpoints_to_keep)
    avg_losses = []
    avg_loss = 0.0
    last_avg_loss = 0.0
    averaged_batches = 0
    train_op = None
    if loss is None:
@@ -414,7 +414,11 @@ class KerasModel(Model):
        avg_loss = float(avg_loss) / averaged_batches
        logger.info(
            'Ending global_step %d: Average loss %g' % (current_step, avg_loss))
        avg_losses.append(avg_loss)
        if all_losses is not None:
          all_losses.append(avg_loss)
        # Capture the last avg_loss in case of return since we're resetting to
        # 0 now
        last_avg_loss = avg_loss
        avg_loss = 0.0
        averaged_batches = 0

@@ -433,19 +437,16 @@ class KerasModel(Model):
      avg_loss = float(avg_loss) / averaged_batches
      logger.info(
          'Ending global_step %d: Average loss %g' % (current_step, avg_loss))
      avg_losses.append(avg_loss)
      if all_losses is not None:
        all_losses.append(avg_loss)
      last_avg_loss = avg_loss

    if checkpoint_interval > 0:
      manager.save()

    time2 = time.time()
    logger.info("TIMING: model fitting took %0.3f s" % (time2 - time1))
    if all_losses is not None:
      all_losses.extend(avg_losses)
    if len(avg_losses) > 0:
      return avg_losses[-1]
    else:
      return 0.0
    return last_avg_loss

  def _create_gradient_fn(self,
                          variables: Optional[List[tf.Variable]]) -> Callable:
+1 −0
Original line number Diff line number Diff line
@@ -78,6 +78,7 @@ def test_fit_use_all_losses():
  model.fit(dataset, nb_epoch=1000, all_losses=losses)
  # Each epoch is a single step for this model
  assert len(losses) == 100
  assert np.count_nonzero(np.array(losses)) == 100


def test_fit_on_batch():