Commit fac130d6 authored by leswing's avatar leswing
Browse files

Save things for queue creation

parent 0a03b138
Loading
Loading
Loading
Loading
+19 −2
Original line number Diff line number Diff line
@@ -5,7 +5,6 @@ import tensorflow as tf


class Layer(object):

  def __init__(self, **kwargs):
    if "name" not in kwargs:
      self.name = "%s%s" % (self.__class__.__name__, self._random_name())
@@ -181,7 +180,25 @@ class Input(Layer):
    super().__init__(**kwargs)

  def __call__(self, *parents):
    self.out_tensor = tf.placeholder(tf.float32, shape=self.t_shape)
    if not self.pre_queue:
      queue = parents[0]
      placeholder = queue.out_tensors[self.get_pre_q_name()]
      self.out_tensor = tf.placeholder_with_default(placeholder, self.shape)
      return self.out_tensor
    self.out_tensor = tf.placeholder(tf.float32, shape=self.shape)
    return self.out_tensor

  def create_pre_q(self, batch_size):
    if self.pre_queue:
      raise ValueError("Input is already pre_q")
    q_shape = (batch_size,) + self.shape[1:]
    return Input(shape=q_shape, name="%s_pre_q" % self.name, pre_queue=True)

  def get_pre_q_name(self):
    if self.pre_queue:
      raise ValueError("You are already pre_q")
    return "%s_pre_q" % self.name



class LossLayer(Layer):