Commit 1b7866b1 authored by Peter Eastman's avatar Peter Eastman
Browse files

Different method for synchronizing queue

parent 5019ad5d
Loading
Loading
Loading
Loading
+0 −3
Original line number Diff line number Diff line
@@ -1485,9 +1485,6 @@ class InputFifoQueue(Layer):
  def set_tensors(self, tensors):
    self.queue, self.out_tensor, self.out_tensors, self.close_op = tensors

  def close(self):
    self.queue.close()


class GraphConv(Layer):

+26 −25
Original line number Diff line number Diff line
@@ -181,17 +181,27 @@ class TensorGraph(Model):
      if restore:
        self.restore()
      avg_loss, n_batches = 0.0, 0.0
      coord = tf.train.Coordinator()
      n_samples = 0
      n_enqueued = [0]
      final_sample = [None]
      if self.use_queue:
        enqueue_thread = threading.Thread(
            target=_enqueue_batch,
            args=(self, feed_dict_generator, self._get_tf("Graph"),
                  self.session, coord))
                  self.session, n_enqueued, final_sample))
        enqueue_thread.start()
      fetches = [train_op, self.loss.out_tensor]
      for feed_dict in create_feed_dict():
        try:
        if self.use_queue:
          # Don't let this thread get ahead of the enqueue thread, since if
          # we try to read more batches than the total number that get queued,
          # this thread will hang indefinitely.
          while n_enqueued[0] <= n_samples:
            if n_samples == final_sample[0]:
              break
            time.sleep(0)
          if n_samples == final_sample[0]:
            break
        fetched_values = self.session.run(fetches, feed_dict=feed_dict)
        loss = fetched_values[-1]
        avg_loss += loss
@@ -202,8 +212,6 @@ class TensorGraph(Model):
          summary = self.session.run(
              self._get_tf("summary_op"), feed_dict=feed_dict)
          self._log_tensorboard(summary)
        except OutOfRangeError:
          break
        if self.global_step % checkpoint_interval == checkpoint_interval - 1:
          saver.save(self.session, self.save_file, global_step=self.global_step)
          avg_loss = float(avg_loss) / n_batches
@@ -688,7 +696,7 @@ class TensorGraph(Model):
    pass


def _enqueue_batch(tg, generator, graph, sess, coord):
def _enqueue_batch(tg, generator, graph, sess, n_enqueued, final_sample):
  """
  Function to load data into
  Parameters
@@ -697,7 +705,6 @@ def _enqueue_batch(tg, generator, graph, sess, coord):
  dataset
  graph
  sess
  coord

  Returns
  -------
@@ -711,14 +718,8 @@ def _enqueue_batch(tg, generator, graph, sess, coord):
      for layer in tg.features + tg.labels + tg.task_weights:
        enq[tg.get_pre_q_input(layer).out_tensor] = feed_dict[layer]
      sess.run(tg.input_queue.out_tensor, feed_dict=enq)
      num_samples += 1
      if tg.tensorboard and num_samples % tg.tensorboard_log_frequency == 0:
        enq = {k.out_tensor: v for k, v in six.iteritems(feed_dict)}
        summary = sess.run(tg._get_tf("summary_op"), feed_dict=enq)
        tg._log_tensorboard(summary)
    sess.run(tg.input_queue.close_op)
    coord.num_samples = num_samples
    coord.request_stop()
      n_enqueued[0] += 1
    final_sample[0] = n_enqueued[0]


class TFWrapper(object):