Unverified Commit 1f02f283 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #965 from peastman/checkpoint

Can restore from any checkpoint, not just latest
parents 9af0485d ac21b5ec
Loading
Loading
Loading
Loading
+19 −5
Original line number Diff line number Diff line
@@ -766,16 +766,30 @@ class TensorGraph(Model):
    saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
    saver.save(self.session, self.save_file, global_step=self.global_step)

  def restore(self):
    """Reload the values of all variables from the most recent checkpoint file."""
  def get_checkpoints(self):
    """Get a list of all available checkpoint files."""
    return tf.train.get_checkpoint_state(
        self.model_dir).all_model_checkpoint_paths

  def restore(self, checkpoint=None):
    """Reload the values of all variables from a checkpoint file.

    Parameters
    ----------
    checkpoint: str
      the path to the checkpoint file to load.  If this is None, the most recent
      checkpoint will be chosen automatically.  Call get_checkpoints() to get a
      list of all available checkpoints.
    """
    if not self.built:
      self.build()
    last_checkpoint = tf.train.latest_checkpoint(self.model_dir)
    if last_checkpoint is None:
    if checkpoint is None:
      checkpoint = tf.train.latest_checkpoint(self.model_dir)
    if checkpoint is None:
      raise ValueError('No checkpoint found')
    with self._get_tf("Graph").as_default():
      saver = tf.train.Saver()
      saver.restore(self.session, last_checkpoint)
      saver.restore(self.session, checkpoint)

  def get_num_tasks(self):
    return len(self.outputs)