Commit fec23179 authored by peastman's avatar peastman
Browse files

Replaced Saver with Checkpoint

parent 26b924c5
Loading
Loading
Loading
Loading
+20 −13
Original line number Diff line number Diff line
@@ -159,7 +159,12 @@ class MAML(object):

      # Create variables for accumulating the gradients.

      gradients = tf.gradients(self._meta_loss, learner.variables)
      variables = list(learner.variables)
      gradients = tf.gradients(self._meta_loss, variables)
      for i in reversed(range(len(variables))):
        if gradients[i] is None:
          del variables[i]
          del gradients[i]
      zero_gradients = [tf.zeros(g.shape, g.dtype) for g in gradients]
      summed_gradients = [
          tf.Variable(z, trainable=False) for z in zero_gradients
@@ -172,7 +177,7 @@ class MAML(object):
      # Create the optimizers for meta-optimization and task optimization.

      self._global_step = tf.placeholder(tf.int32, [])
      grads_and_vars = list(zip(summed_gradients, learner.variables))
      grads_and_vars = list(zip(summed_gradients, variables))
      self._meta_train_op = optimizer._create_optimizer(
          self._global_step).apply_gradients(grads_and_vars)
      task_optimizer = GradientDescent(learning_rate=self._learning_rate)
@@ -180,6 +185,11 @@ class MAML(object):
          self._global_step).minimize(self._loss)
      self._session = tf.Session()

      # Create a Checkpoint for saving.

      self._checkpoint = tf.train.Checkpoint()
      self._checkpoint.listed = learner.variables

  def __del__(self):
    if '_model_dir_is_temp' in dir(self) and self._model_dir_is_temp:
      shutil.rmtree(self.model_dir)
@@ -208,9 +218,8 @@ class MAML(object):
      self._session.run(tf.global_variables_initializer())
      if restore:
        self.restore()
      saver = tf.train.Saver(
          self.learner.variables, max_to_keep=max_checkpoints_to_keep)
      checkpoint_index = 0
      manager = tf.train.CheckpointManager(self._checkpoint, self.model_dir,
                                           max_checkpoints_to_keep)
      checkpoint_time = time.time()

      # Main optimization loop.
@@ -230,9 +239,8 @@ class MAML(object):

        if i == steps - 1 or time.time(
        ) >= checkpoint_time + checkpoint_interval:
          saver.save(
              self._session, self.save_file, global_step=checkpoint_index)
          checkpoint_index += 1
          with self._session.as_default():
            manager.save()
          checkpoint_time = time.time()

  def restore(self):
@@ -241,8 +249,7 @@ class MAML(object):
    if last_checkpoint is None:
      raise ValueError('No checkpoint found')
    with self._graph.as_default():
      saver = tf.train.Saver(self.learner.variables)
      saver.restore(self._session, last_checkpoint)
      self._checkpoint.restore(last_checkpoint).run_restore_ops(self._session)

  def train_on_current_task(self, optimization_steps=1, restore=True):
    """Perform a few steps of gradient descent to fine tune the model on the current task.
+6 −6
Original line number Diff line number Diff line
@@ -172,9 +172,8 @@ class GAN(TensorGraph):
      weight_products = layers.Reshape(
          (n_generators * n_discriminators,),
          in_layers=layers.Reshape(
              (n_discriminators,
               1), in_layers=discrim_weights) * layers.Reshape(
                   (1, n_generators), in_layers=gen_weights))
              (n_discriminators, 1), in_layers=discrim_weights) *
          layers.Reshape((1, n_generators), in_layers=gen_weights))
      total_gen_loss = layers.WeightedError((layers.Stack(gen_losses, axis=0),
                                             weight_products))
      total_discrim_loss = layers.WeightedError((layers.Stack(
@@ -376,7 +375,8 @@ class GAN(TensorGraph):
    time1 = time.time()
    with self._get_tf("Graph").as_default():
      if checkpoint_interval > 0:
        saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
        manager = tf.train.CheckpointManager(
            self._get_tf('Checkpoint'), self.model_dir, max_checkpoints_to_keep)
      for feed_dict in batches:
        # Every call to fit_generator() will increment global_step, but we only
        # want it to get incremented once for the entire batch, so record the
@@ -413,7 +413,7 @@ class GAN(TensorGraph):
        # Write checkpoints and report progress.

        if discrim_average_steps == checkpoint_interval:
          saver.save(self.session, self.save_file, global_step=self.global_step)
          self._exec_with_session(lambda: manager.save())
          discrim_loss = discrim_error / max(1, discrim_average_steps)
          gen_loss = gen_error / max(1, gen_average_steps)
          print(
@@ -433,7 +433,7 @@ class GAN(TensorGraph):
          print(
              'Ending global_step %d: generator average loss %g, discriminator average loss %g'
              % (self.global_step, gen_loss, discrim_loss))
        saver.save(self.session, self.save_file, global_step=self.global_step)
        self._exec_with_session(lambda: manager.save())
        time2 = time.time()
        print("TIMING: model fitting took %0.3f s" % (time2 - time1))

+21 −20
Original line number Diff line number Diff line
@@ -213,10 +213,8 @@ class TensorGraph(Model):
        else:
          train_op = submodel.get_train_op()
      if checkpoint_interval > 0:
        saver = tf.train.Saver(
            self.get_variables(),
            max_to_keep=max_checkpoints_to_keep,
            save_relative_paths=True)
        manager = tf.train.CheckpointManager(
            self._get_tf('Checkpoint'), self.model_dir, max_checkpoints_to_keep)
      if restore:
        self.restore()
      avg_loss, n_averaged_batches = 0.0, 0.0
@@ -262,7 +260,7 @@ class TensorGraph(Model):
        n_averaged_batches += 1
        self.global_step += 1
        if checkpoint_interval > 0 and self.global_step % checkpoint_interval == checkpoint_interval - 1:
          saver.save(self.session, self.save_file, global_step=self.global_step)
          self._exec_with_session(lambda: manager.save())
          avg_loss = float(avg_loss) / n_averaged_batches
          logger.info('Ending global_step %d: Average loss %g' %
                      (self.global_step, avg_loss))
@@ -273,7 +271,7 @@ class TensorGraph(Model):
        if n_averaged_batches > 0:
          logger.info('Ending global_step %d: Average loss %g' %
                      (self.global_step, avg_loss))
        saver.save(self.session, self.save_file, global_step=self.global_step)
        self._exec_with_session(lambda: manager.save())
        time2 = time.time()
        logger.info("TIMING: model fitting took %0.3f s" % (time2 - time1))
    return avg_loss
@@ -733,6 +731,7 @@ class TensorGraph(Model):
      self._get_tf('train_op')
      for submodel in self.submodels:
        train_op = submodel.get_train_op()
      self._get_tf('Checkpoint').save_counter

      # Initialize variables.

@@ -1053,6 +1052,10 @@ class TensorGraph(Model):
    elif obj == 'GlobalStep':
      with self._get_tf("Graph").as_default():
        self.tensor_objects['GlobalStep'] = tf.Variable(0, trainable=False)
    elif obj == 'Checkpoint':
      checkpoint = tf.train.Checkpoint()
      checkpoint.listed = self.get_variables()
      self.tensor_objects['Checkpoint'] = checkpoint
    return self._get_tf(obj)

  def save_checkpoint(self, max_checkpoints_to_keep=5):
@@ -1067,9 +1070,16 @@ class TensorGraph(Model):
    max_checkpoints_to_keep: int
      the maximum number of checkpoints to keep.  Older checkpoints are discarded.
    """
    saver = tf.train.Saver(
        self.get_variables(), max_to_keep=max_checkpoints_to_keep)
    saver.save(self.session, self.save_file, global_step=self.global_step)
    manager = tf.train.CheckpointManager(
        self._get_tf('Checkpoint'), self.model_dir, max_checkpoints_to_keep)
    self._exec_with_session(lambda: manager.save())

  def _exec_with_session(self, f):
    if tf.executing_eagerly():
      f()
    else:
      with self.session.as_default():
        f()

  def get_checkpoints(self):
    """Get a list of all available checkpoint files."""
@@ -1093,17 +1103,8 @@ class TensorGraph(Model):
    if checkpoint is None:
      raise ValueError('No checkpoint found')
    with self._get_tf("Graph").as_default():
      reader = NewCheckpointReader(checkpoint)
      var_names = set([x for x in reader.get_variable_to_shape_map()])
      var_list = []
      for var in self.get_variables():
        name = var.name
        if ':' in name:
          name = name[:name.rfind(':')]
        if name in var_names:
          var_list.append(var)
      saver = tf.train.Saver(var_list=var_list)
      saver.restore(self.session, checkpoint)
      self._get_tf('Checkpoint').restore(checkpoint).run_restore_ops(
          self.session)

  def get_num_tasks(self):
    return len(self.default_outputs)
+0 −1
Original line number Diff line number Diff line
@@ -336,7 +336,6 @@ class TestLayers(test_util.TensorFlowTestCase):
    value = np.random.uniform(size=(2, 3)).astype(np.float32)
    with self.session() as sess:
      result = Log()(value).eval()
      assert np.array_equal(np.log(value), result)
      assert np.all(np.isclose(np.log(value), result, atol=0.001))

  def test_exp(self):
+13 −11
Original line number Diff line number Diff line
@@ -167,6 +167,13 @@ class A3C(object):
    with self._graph._get_tf("Graph").as_default():
      self._session = tf.Session()
    self._rnn_states = self._graph.rnn_zero_states
    with self._graph._get_tf("Graph").as_default():
      with tf.variable_scope('global'):
        self._checkpoint = tf.train.Checkpoint()
        self._checkpoint.save_counter  # Ensure the variable has been created
      self._checkpoint.listed = tf.get_collection(
          tf.GraphKeys.GLOBAL_VARIABLES, scope='global')
      self._session.run(self._checkpoint.save_counter.initializer)

  def _build_graph(self, tf_graph, scope, model_dir):
    """Construct a TensorGraph containing the policy and loss calculations."""
@@ -260,15 +267,14 @@ class A3C(object):
        thread.start()
      variables = tf.get_collection(
          tf.GraphKeys.GLOBAL_VARIABLES, scope='global')
      saver = tf.train.Saver(variables, max_to_keep=max_checkpoints_to_keep)
      checkpoint_index = 0
      manager = tf.train.CheckpointManager(
          self._checkpoint, self._graph.model_dir, max_checkpoints_to_keep)
      while True:
        threads = [t for t in threads if t.isAlive()]
        if len(threads) > 0:
          threads[0].join(checkpoint_interval)
        checkpoint_index += 1
        saver.save(
            self._session, self._graph.save_file, global_step=checkpoint_index)
        with self._session.as_default():
          manager.save()
        if len(threads) == 0:
          break

@@ -349,10 +355,7 @@ class A3C(object):
    if last_checkpoint is None:
      raise ValueError('No checkpoint found')
    with self._graph._get_tf("Graph").as_default():
      variables = tf.get_collection(
          tf.GraphKeys.GLOBAL_VARIABLES, scope='global')
      saver = tf.train.Saver(variables)
      saver.restore(self._session, last_checkpoint)
      self._checkpoint.restore(last_checkpoint).run_restore_ops(self._session)

  def _create_feed_dict(self, state, use_saved_states):
    """Create a feed dict for use by predict() or select_action()."""
@@ -501,8 +504,7 @@ class _Worker(object):
                         1] += self.a3c.discount_factor * discounted_rewards[j]
      advantages[
          j -
          1] += self.a3c.discount_factor * self.a3c.advantage_lambda * advantages[
              j]
          1] += self.a3c.discount_factor * self.a3c.advantage_lambda * advantages[j]

    # Record the actions, converting to one-hot if necessary.

Loading