Commit d092e7ef authored by Peter Eastman's avatar Peter Eastman
Browse files

Fixed bug when logging to tensorboard

parent 1b7866b1
Loading
Loading
Loading
Loading
+15 −13
Original line number Diff line number Diff line
@@ -180,7 +180,7 @@ class TensorGraph(Model):
      self.session.run(tf.global_variables_initializer())
      if restore:
        self.restore()
      avg_loss, n_batches = 0.0, 0.0
      avg_loss, n_averaged_batches = 0.0, 0.0
      n_samples = 0
      n_enqueued = [0]
      final_sample = [None]
@@ -190,7 +190,6 @@ class TensorGraph(Model):
            args=(self, feed_dict_generator, self._get_tf("Graph"),
                  self.session, n_enqueued, final_sample))
        enqueue_thread.start()
      fetches = [train_op, self.loss.out_tensor]
      for feed_dict in create_feed_dict():
        if self.use_queue:
          # Don't let this thread get ahead of the enqueue thread, since if
@@ -202,24 +201,27 @@ class TensorGraph(Model):
            time.sleep(0)
          if n_samples == final_sample[0]:
            break
        n_samples += 1
        should_log = (self.tensorboard and
                      n_samples % self.tensorboard_log_frequency == 0)
        fetches = [train_op, self.loss.out_tensor]
        if should_log:
          fetches.append(self._get_tf("summary_op"))
        fetched_values = self.session.run(fetches, feed_dict=feed_dict)
        loss = fetched_values[-1]
        if should_log:
          self._log_tensorboard(fetches[2])
        loss = fetched_values[1]
        avg_loss += loss
        n_batches += 1
        n_averaged_batches += 1
        self.global_step += 1
        n_samples += 1
        if self.tensorboard and n_samples % self.tensorboard_log_frequency == 0:
          summary = self.session.run(
              self._get_tf("summary_op"), feed_dict=feed_dict)
          self._log_tensorboard(summary)
        if self.global_step % checkpoint_interval == checkpoint_interval - 1:
          saver.save(self.session, self.save_file, global_step=self.global_step)
          avg_loss = float(avg_loss) / n_batches
          avg_loss = float(avg_loss) / n_averaged_batches
          print('Ending global_step %d: Average loss %g' % (self.global_step,
                                                            avg_loss))
          avg_loss, n_batches = 0.0, 0.0
      if n_batches > 0:
        avg_loss = float(avg_loss) / n_batches
          avg_loss, n_averaged_batches = 0.0, 0.0
      if n_averaged_batches > 0:
        avg_loss = float(avg_loss) / n_averaged_batches
        print('Ending global_step %d: Average loss %g' % (self.global_step,
                                                          avg_loss))
      saver.save(self.session, self.save_file, global_step=self.global_step)