Commit c451665c authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #853 from peastman/queue

Different method for synchronizing queue
parents 8d449974 8f133c30
Loading
Loading
Loading
Loading
+0 −3
Original line number Diff line number Diff line
@@ -1588,9 +1588,6 @@ class InputFifoQueue(Layer):
  def set_tensors(self, tensors):
    self.queue, self.out_tensor, self.out_tensors, self.close_op = tensors

  def close(self):
    self.queue.close()


class GraphConv(Layer):

+34 −30
Original line number Diff line number Diff line
@@ -187,38 +187,49 @@ class TensorGraph(Model):
            variables = self.get_layer_variables(layer)
            for var, val in zip(variables, layer.variable_values):
              self.session.run(var.assign(val))
      avg_loss, n_batches = 0.0, 0.0
      avg_loss, n_averaged_batches = 0.0, 0.0
      coord = tf.train.Coordinator()
      n_samples = 0
      n_enqueued = [0]
      final_sample = [None]
      if self.use_queue:
        enqueue_thread = threading.Thread(
            target=_enqueue_batch,
            args=(self, feed_dict_generator, self._get_tf("Graph"),
                  self.session, coord))
                  self.session, n_enqueued, final_sample))
        enqueue_thread.start()
      fetches = [train_op, self.loss.out_tensor]
      for feed_dict in create_feed_dict():
        try:
        if self.use_queue:
          # Don't let this thread get ahead of the enqueue thread, since if
          # we try to read more batches than the total number that get queued,
          # this thread will hang indefinitely.
          while n_enqueued[0] <= n_samples:
            if n_samples == final_sample[0]:
              break
            time.sleep(0)
          if n_samples == final_sample[0]:
            break
        n_samples += 1
        should_log = (self.tensorboard and
                      n_samples % self.tensorboard_log_frequency == 0)
        fetches = [train_op, self.loss.out_tensor]
        if should_log:
          fetches.append(self._get_tf("summary_op"))
        fetched_values = self.session.run(fetches, feed_dict=feed_dict)
          loss = fetched_values[-1]
        if should_log:
          self._log_tensorboard(fetches[2])
        loss = fetched_values[1]
        avg_loss += loss
          n_batches += 1
        n_averaged_batches += 1
        self.global_step += 1
          n_samples += 1
          if self.tensorboard and n_samples % self.tensorboard_log_frequency == 0:
            summary = self.session.run(
                self._get_tf("summary_op"), feed_dict=feed_dict)
            self._log_tensorboard(summary)
        except OutOfRangeError:
          break
        if self.global_step % checkpoint_interval == checkpoint_interval - 1:
          saver.save(self.session, self.save_file, global_step=self.global_step)
          avg_loss = float(avg_loss) / n_batches
          avg_loss = float(avg_loss) / n_averaged_batches
          print('Ending global_step %d: Average loss %g' % (self.global_step,
                                                            avg_loss))
          avg_loss, n_batches = 0.0, 0.0
      if n_batches > 0:
        avg_loss = float(avg_loss) / n_batches
          avg_loss, n_averaged_batches = 0.0, 0.0
      if n_averaged_batches > 0:
        avg_loss = float(avg_loss) / n_averaged_batches
        print('Ending global_step %d: Average loss %g' % (self.global_step,
                                                          avg_loss))
      saver.save(self.session, self.save_file, global_step=self.global_step)
@@ -695,7 +706,7 @@ class TensorGraph(Model):
    pass


def _enqueue_batch(tg, generator, graph, sess, coord):
def _enqueue_batch(tg, generator, graph, sess, n_enqueued, final_sample):
  """
  Function to load data into
  Parameters
@@ -704,7 +715,6 @@ def _enqueue_batch(tg, generator, graph, sess, coord):
  dataset
  graph
  sess
  coord

  Returns
  -------
@@ -718,14 +728,8 @@ def _enqueue_batch(tg, generator, graph, sess, coord):
      for layer in tg.features + tg.labels + tg.task_weights:
        enq[tg.get_pre_q_input(layer).out_tensor] = feed_dict[layer]
      sess.run(tg.input_queue.out_tensor, feed_dict=enq)
      num_samples += 1
      if tg.tensorboard and num_samples % tg.tensorboard_log_frequency == 0:
        enq = {k.out_tensor: v for k, v in six.iteritems(feed_dict)}
        summary = sess.run(tg._get_tf("summary_op"), feed_dict=enq)
        tg._log_tensorboard(summary)
    sess.run(tg.input_queue.close_op)
    coord.num_samples = num_samples
    coord.request_stop()
      n_enqueued[0] += 1
    final_sample[0] = n_enqueued[0]


class TFWrapper(object):