Unverified Commit 358a54eb authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1566 from peastman/checkpoint

Replaced Saver with Checkpoint
parents 26b924c5 5d455729
Loading
Loading
Loading
Loading
+20 −13
Original line number Diff line number Diff line
@@ -159,7 +159,12 @@ class MAML(object):

      # Create variables for accumulating the gradients.

      gradients = tf.gradients(self._meta_loss, learner.variables)
      variables = list(learner.variables)
      gradients = tf.gradients(self._meta_loss, variables)
      for i in reversed(range(len(variables))):
        if gradients[i] is None:
          del variables[i]
          del gradients[i]
      zero_gradients = [tf.zeros(g.shape, g.dtype) for g in gradients]
      summed_gradients = [
          tf.Variable(z, trainable=False) for z in zero_gradients
@@ -172,7 +177,7 @@ class MAML(object):
      # Create the optimizers for meta-optimization and task optimization.

      self._global_step = tf.placeholder(tf.int32, [])
      grads_and_vars = list(zip(summed_gradients, learner.variables))
      grads_and_vars = list(zip(summed_gradients, variables))
      self._meta_train_op = optimizer._create_optimizer(
          self._global_step).apply_gradients(grads_and_vars)
      task_optimizer = GradientDescent(learning_rate=self._learning_rate)
@@ -180,6 +185,11 @@ class MAML(object):
          self._global_step).minimize(self._loss)
      self._session = tf.Session()

      # Create a Checkpoint for saving.

      self._checkpoint = tf.train.Checkpoint()
      self._checkpoint.listed = learner.variables

  def __del__(self):
    if '_model_dir_is_temp' in dir(self) and self._model_dir_is_temp:
      shutil.rmtree(self.model_dir)
@@ -208,9 +218,8 @@ class MAML(object):
      self._session.run(tf.global_variables_initializer())
      if restore:
        self.restore()
      saver = tf.train.Saver(
          self.learner.variables, max_to_keep=max_checkpoints_to_keep)
      checkpoint_index = 0
      manager = tf.train.CheckpointManager(self._checkpoint, self.model_dir,
                                           max_checkpoints_to_keep)
      checkpoint_time = time.time()

      # Main optimization loop.
@@ -230,9 +239,8 @@ class MAML(object):

        if i == steps - 1 or time.time(
        ) >= checkpoint_time + checkpoint_interval:
          saver.save(
              self._session, self.save_file, global_step=checkpoint_index)
          checkpoint_index += 1
          with self._session.as_default():
            manager.save()
          checkpoint_time = time.time()

  def restore(self):
@@ -241,8 +249,7 @@ class MAML(object):
    if last_checkpoint is None:
      raise ValueError('No checkpoint found')
    with self._graph.as_default():
      saver = tf.train.Saver(self.learner.variables)
      saver.restore(self._session, last_checkpoint)
      self._checkpoint.restore(last_checkpoint).run_restore_ops(self._session)

  def train_on_current_task(self, optimization_steps=1, restore=True):
    """Perform a few steps of gradient descent to fine tune the model on the current task.
+6 −6
Original line number Diff line number Diff line
@@ -172,9 +172,8 @@ class GAN(TensorGraph):
      weight_products = layers.Reshape(
          (n_generators * n_discriminators,),
          in_layers=layers.Reshape(
              (n_discriminators,
               1), in_layers=discrim_weights) * layers.Reshape(
                   (1, n_generators), in_layers=gen_weights))
              (n_discriminators, 1), in_layers=discrim_weights) *
          layers.Reshape((1, n_generators), in_layers=gen_weights))
      total_gen_loss = layers.WeightedError((layers.Stack(gen_losses, axis=0),
                                             weight_products))
      total_discrim_loss = layers.WeightedError((layers.Stack(
@@ -376,7 +375,8 @@ class GAN(TensorGraph):
    time1 = time.time()
    with self._get_tf("Graph").as_default():
      if checkpoint_interval > 0:
        saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
        manager = tf.train.CheckpointManager(
            self._get_tf('Checkpoint'), self.model_dir, max_checkpoints_to_keep)
      for feed_dict in batches:
        # Every call to fit_generator() will increment global_step, but we only
        # want it to get incremented once for the entire batch, so record the
@@ -413,7 +413,7 @@ class GAN(TensorGraph):
        # Write checkpoints and report progress.

        if discrim_average_steps == checkpoint_interval:
          saver.save(self.session, self.save_file, global_step=self.global_step)
          self._exec_with_session(lambda: manager.save())
          discrim_loss = discrim_error / max(1, discrim_average_steps)
          gen_loss = gen_error / max(1, gen_average_steps)
          print(
@@ -433,7 +433,7 @@ class GAN(TensorGraph):
          print(
              'Ending global_step %d: generator average loss %g, discriminator average loss %g'
              % (self.global_step, gen_loss, discrim_loss))
        saver.save(self.session, self.save_file, global_step=self.global_step)
        self._exec_with_session(lambda: manager.save())
        time2 = time.time()
        print("TIMING: model fitting took %0.3f s" % (time2 - time1))

+2 −1
Original line number Diff line number Diff line
@@ -487,6 +487,7 @@ class ANIRegression(TensorGraph):

        val = npo[k]
        tensor = g.get_tensor_by_name(k)
        if tensor.dtype != tf.resource:  # workaround for save_counter incorrectly being marked as a trainable variable
          all_ops.append(tf.assign(tensor, val))

      obj.session.run(all_ops)
+21 −20
Original line number Diff line number Diff line
@@ -213,10 +213,8 @@ class TensorGraph(Model):
        else:
          train_op = submodel.get_train_op()
      if checkpoint_interval > 0:
        saver = tf.train.Saver(
            self.get_variables(),
            max_to_keep=max_checkpoints_to_keep,
            save_relative_paths=True)
        manager = tf.train.CheckpointManager(
            self._get_tf('Checkpoint'), self.model_dir, max_checkpoints_to_keep)
      if restore:
        self.restore()
      avg_loss, n_averaged_batches = 0.0, 0.0
@@ -262,7 +260,7 @@ class TensorGraph(Model):
        n_averaged_batches += 1
        self.global_step += 1
        if checkpoint_interval > 0 and self.global_step % checkpoint_interval == checkpoint_interval - 1:
          saver.save(self.session, self.save_file, global_step=self.global_step)
          self._exec_with_session(lambda: manager.save())
          avg_loss = float(avg_loss) / n_averaged_batches
          logger.info('Ending global_step %d: Average loss %g' %
                      (self.global_step, avg_loss))
@@ -273,7 +271,7 @@ class TensorGraph(Model):
        if n_averaged_batches > 0:
          logger.info('Ending global_step %d: Average loss %g' %
                      (self.global_step, avg_loss))
        saver.save(self.session, self.save_file, global_step=self.global_step)
        self._exec_with_session(lambda: manager.save())
        time2 = time.time()
        logger.info("TIMING: model fitting took %0.3f s" % (time2 - time1))
    return avg_loss
@@ -733,6 +731,7 @@ class TensorGraph(Model):
      self._get_tf('train_op')
      for submodel in self.submodels:
        train_op = submodel.get_train_op()
      self._get_tf('Checkpoint').save_counter

      # Initialize variables.

@@ -1053,6 +1052,10 @@ class TensorGraph(Model):
    elif obj == 'GlobalStep':
      with self._get_tf("Graph").as_default():
        self.tensor_objects['GlobalStep'] = tf.Variable(0, trainable=False)
    elif obj == 'Checkpoint':
      checkpoint = tf.train.Checkpoint()
      checkpoint.listed = self.get_variables()
      self.tensor_objects['Checkpoint'] = checkpoint
    return self._get_tf(obj)

  def save_checkpoint(self, max_checkpoints_to_keep=5):
@@ -1067,9 +1070,16 @@ class TensorGraph(Model):
    max_checkpoints_to_keep: int
      the maximum number of checkpoints to keep.  Older checkpoints are discarded.
    """
    saver = tf.train.Saver(
        self.get_variables(), max_to_keep=max_checkpoints_to_keep)
    saver.save(self.session, self.save_file, global_step=self.global_step)
    manager = tf.train.CheckpointManager(
        self._get_tf('Checkpoint'), self.model_dir, max_checkpoints_to_keep)
    self._exec_with_session(lambda: manager.save())

  def _exec_with_session(self, f):
    if tf.executing_eagerly():
      f()
    else:
      with self.session.as_default():
        f()

  def get_checkpoints(self):
    """Get a list of all available checkpoint files."""
@@ -1093,17 +1103,8 @@ class TensorGraph(Model):
    if checkpoint is None:
      raise ValueError('No checkpoint found')
    with self._get_tf("Graph").as_default():
      reader = NewCheckpointReader(checkpoint)
      var_names = set([x for x in reader.get_variable_to_shape_map()])
      var_list = []
      for var in self.get_variables():
        name = var.name
        if ':' in name:
          name = name[:name.rfind(':')]
        if name in var_names:
          var_list.append(var)
      saver = tf.train.Saver(var_list=var_list)
      saver.restore(self.session, checkpoint)
      self._get_tf('Checkpoint').restore(checkpoint).run_restore_ops(
          self.session)

  def get_num_tasks(self):
    return len(self.default_outputs)
+0 −1
Original line number Diff line number Diff line
@@ -336,7 +336,6 @@ class TestLayers(test_util.TensorFlowTestCase):
    value = np.random.uniform(size=(2, 3)).astype(np.float32)
    with self.session() as sess:
      result = Log()(value).eval()
      assert np.array_equal(np.log(value), result)
      assert np.all(np.isclose(np.log(value), result, atol=0.001))

  def test_exp(self):
Loading