Commit 5c80fb77 authored by peastman's avatar peastman
Browse files

Can disable automatic checkpointing

parent d768e653
Loading
Loading
Loading
Loading
+32 −8
Original line number Diff line number Diff line
@@ -131,6 +131,7 @@ class TensorGraph(Model):
      the maximum number of checkpoints to keep.  Older checkpoints are discarded.
    checkpoint_interval: int
      the frequency at which to write checkpoints, measured in training steps.
      Set this to 0 to disable automatic checkpointing.
    deterministic: bool
      if True, the samples are processed in order.  If False, a different random
      order is used for each epoch.
@@ -163,12 +164,17 @@ class TensorGraph(Model):
      the maximum number of checkpoints to keep.  Older checkpoints are discarded.
    checkpoint_interval: int
      the frequency at which to write checkpoints, measured in training steps.
      Set this to 0 to disable automatic checkpointing.
    restore: bool
      if True, restore the model from the most recent checkpoint and continue training
      from there.  If False, retrain the model from scratch.
    submodel: Submodel
      an alternate training objective to use.  This should have been created by
      calling create_submodel().

    Returns
    -------
    the average loss over the most recent checkpoint interval
    """

    def create_feed_dict():
@@ -188,11 +194,11 @@ class TensorGraph(Model):
        train_op = self._get_tf('train_op')
      else:
        train_op = submodel.get_train_op()
      if checkpoint_interval > 0:
        saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
      if restore:
        self.restore()
      avg_loss, n_averaged_batches = 0.0, 0.0
      coord = tf.train.Coordinator()
      n_samples = 0
      n_enqueued = [0]
      final_sample = [None]
@@ -226,7 +232,7 @@ class TensorGraph(Model):
        avg_loss += loss
        n_averaged_batches += 1
        self.global_step += 1
        if self.global_step % checkpoint_interval == checkpoint_interval - 1:
        if checkpoint_interval > 0 and self.global_step % checkpoint_interval == checkpoint_interval - 1:
          saver.save(self.session, self.save_file, global_step=self.global_step)
          avg_loss = float(avg_loss) / n_averaged_batches
          print('Ending global_step %d: Average loss %g' % (self.global_step,
@@ -234,11 +240,14 @@ class TensorGraph(Model):
          avg_loss, n_averaged_batches = 0.0, 0.0
      if n_averaged_batches > 0:
        avg_loss = float(avg_loss) / n_averaged_batches
      if checkpoint_interval > 0:
        if n_averaged_batches > 0:
          print('Ending global_step %d: Average loss %g' % (self.global_step,
                                                            avg_loss))
        saver.save(self.session, self.save_file, global_step=self.global_step)
        time2 = time.time()
        print("TIMING: model fitting took %0.3f s" % (time2 - time1))
    return avg_loss

  def _log_tensorboard(self, summary):
    """
@@ -729,6 +738,21 @@ class TensorGraph(Model):
        self.tensor_objects['GlobalStep'] = tf.Variable(0, trainable=False)
    return self._get_tf(obj)

  def save_checkpoint(self, max_checkpoints_to_keep=5):
    """Save a checkpoint to disk.

    Usually you do not need to call this method, since fit() saves checkpoints
    automatically.  If you have disabled automatic checkpointing during fitting,
    this can be called to manually write checkpoints.

    Parameters
    ----------
    max_checkpoints_to_keep: int
      the maximum number of checkpoints to keep.  Older checkpoints are discarded.
    """
    saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
    saver.save(self.session, self.save_file, global_step=self.global_step)

  def restore(self):
    """Reload the values of all variables from the most recent checkpoint file."""
    if not self.built: