Unverified Commit 710f61a0 authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #1191 from peastman/eagertg

Support for eager mode in TensorGraph
parents 30b03541 0d041310
Loading
Loading
Loading
Loading
+155 −204
Original line number Diff line number Diff line
@@ -38,8 +38,13 @@ class Layer(object):
    if tfe.in_eager_mode():
      self.variables = []
      self._built = False
      self._non_pickle_fields = ['variables', '_built']
    else:
      self.variable_scope = ''
      self._non_pickle_fields = [
          'out_tensor', 'rnn_initial_states', 'rnn_final_states',
          'rnn_zero_states'
      ]

  def _get_layer_number(self):
    class_name = self.__class__.__name__
@@ -49,12 +54,19 @@ class Layer(object):
    return "%s" % Layer.layer_number_dict[class_name]

  def none_tensors(self):
    out_tensor = self.out_tensor
    self.out_tensor = None
    return out_tensor
    saved_tensors = []
    for field in self._non_pickle_fields:
      value = self.__getattribute__(field)
      saved_tensors.append(value)
      if isinstance(value, list):
        self.__setattr__(field, [])
      else:
        self.__setattr__(field, None)
    return saved_tensors

  def set_tensors(self, tensor):
    self.out_tensor = tensor
  def set_tensors(self, tensors):
    for field, t in zip(self._non_pickle_fields, tensors):
      self.__setattr__(field, t)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    raise NotImplementedError("Subclasses must implement for themselves")
@@ -86,12 +98,35 @@ class Layer(object):
    -------
    Layer
    """
    if tfe.in_eager_mode():
      raise ValueError('shared() is not supported in eager mode')
    if self.variable_scope == '':
      return self.clone(in_layers)
    raise ValueError('%s does not implement shared()' % self.__class__.__name__)

  def __call__(self, *in_layers, **kwargs):
    return self.create_tensor(in_layers=in_layers, set_tensors=False, **kwargs)
  def __call__(self, *inputs, **kwargs):
    """Execute the layer in eager mode to compute its output as a function of inputs.

    If the layer defines any variables, they are created the first time it is invoked.

    Arbitrary keyword arguments may be specified after the list of inputs.  Most
    layers do not expect or use any additional arguments, but there are a few
    significant cases.

    - Recurrent layers usually accept an argument `initial_state` which can be
      used to specify the initial state for the recurrent cell.  When this
      argument is omitted, they use a default initial state, usually all zeros.
    - A few layers behave differently during training than during inference,
      such as Dropout and CombineMeanStd.  You can specify a boolean value with
      the `training` argument to tell it which mode it is being called in.

    Parameters
    ----------
    inputs: tensors
      the inputs to pass to the layer.  The values may be tensors, numpy arrays,
      or anything else that can be converted to tensors of the correct shape.
    """
    return self.create_tensor(in_layers=inputs, set_tensors=False, **kwargs)

  @property
  def shape(self):
@@ -250,8 +285,6 @@ class Layer(object):
      This means the newly created layers will share variables with the original
      ones.
    """
    if tfe.in_eager_mode():
      raise ValueError('copy() is not supported in eager mode')
    if self in replacements:
      return replacements[self]
    copied_inputs = [
@@ -268,6 +301,9 @@ class Layer(object):
      variables = variables_graph.get_layer_variables(self)
      if len(variables) > 0:
        with variables_graph._get_tf("Graph").as_default():
          if tfe.in_eager_mode():
            values = [v.numpy() for v in variables]
          else:
            values = variables_graph.session.run(variables)
          copy.set_variable_initial_values(values)
    return copy
@@ -369,6 +405,8 @@ class SharedVariableScope(Layer):
    self._shared_with = None

  def shared(self, in_layers):
    if tfe.in_eager_mode():
      raise ValueError('shared() is not supported in eager mode')
    copy = self.clone(in_layers)
    self._reuse = True
    copy._reuse = True
@@ -509,6 +547,7 @@ class Conv1D(Layer):
    if tfe.in_eager_mode():
      if not self._built:
        self._layer = self._build_layer()
        self._non_pickle_fields.append('_layer')
      layer = self._layer
    else:
      layer = self._build_layer()
@@ -589,6 +628,7 @@ class Dense(SharedVariableScope):
      if tfe.in_eager_mode():
        if not self._built:
          self._layer = self._build_layer(False)
          self._non_pickle_fields.append('_layer')
        layer = self._layer
      else:
        layer = self._build_layer(reuse)
@@ -675,6 +715,7 @@ class Highway(Layer):
    if tfe.in_eager_mode():
      if not self._built:
        self._layers = self._build_layers(out_channels)
        self._non_pickle_fields.append('_layers')
      layers = self._layers
    else:
      layers = self._build_layers(out_channels)
@@ -995,6 +1036,9 @@ class GRU(Layer):
    if tfe.in_eager_mode():
      self._cell = tf.contrib.rnn.GRUCell(n_hidden)
      self._zero_state = self._cell.zero_state(batch_size, tf.float32)
      self._non_pickle_fields += ['_cell', '_zero_state']
    else:
      self._non_pickle_fields.append('out_tensors')
    try:
      parent_shape = self.in_layers[0].shape
      self._shape = (batch_size, parent_shape[1], n_hidden)
@@ -1037,21 +1081,6 @@ class GRU(Layer):
    else:
      return out_tensor

  def none_tensors(self):
    saved_tensors = [
        self.out_tensor, self.rnn_initial_states, self.rnn_final_states,
        self.rnn_zero_states, self.out_tensors
    ]
    self.out_tensor = None
    self.rnn_initial_states = []
    self.rnn_final_states = []
    self.rnn_zero_states = []
    self.out_tensors = []
    return saved_tensors

  def set_tensors(self, tensor):
    self.out_tensor, self.rnn_initial_states, self.rnn_final_states, self.rnn_zero_states, self.out_tensors = tensor


class LSTM(Layer):
  """A Long Short Term Memory.
@@ -1086,6 +1115,7 @@ class LSTM(Layer):
    if tfe.in_eager_mode():
      self._cell = tf.contrib.rnn.LSTMCell(n_hidden)
      self._zero_state = self._cell.zero_state(batch_size, tf.float32)
      self._non_pickle_fields += ['_cell', '_zero_state']
    try:
      parent_shape = self.in_layers[0].shape
      self._shape = (batch_size, parent_shape[1], n_hidden)
@@ -1132,20 +1162,6 @@ class LSTM(Layer):
    else:
      return out_tensor

  def none_tensors(self):
    saved_tensors = [
        self.out_tensor, self.rnn_initial_states, self.rnn_final_states,
        self.rnn_zero_states
    ]
    self.out_tensor = None
    self.rnn_initial_states = []
    self.rnn_final_states = []
    self.rnn_zero_states = []
    return saved_tensors

  def set_tensors(self, tensor):
    self.out_tensor, self.rnn_initial_states, self.rnn_final_states, self.rnn_zero_states = tensor


class TimeSeriesDense(Layer):

@@ -1166,6 +1182,7 @@ class TimeSeriesDense(Layer):
    if tfe.in_eager_mode():
      if not self._built:
        self._layer = self._build_layer()
        self._non_pickle_fields.append('_layer')
      layer = self._layer
    else:
      layer = self._build_layer()
@@ -1762,11 +1779,11 @@ class ReduceMean(Layer):
  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    inputs = self._get_input_tensors(in_layers)
    if len(inputs) > 1:
      self.out_tensor = tf.stack(inputs)
      out_tensor = tf.stack(inputs)
    else:
      self.out_tensor = inputs[0]
      out_tensor = inputs[0]

    out_tensor = tf.reduce_mean(self.out_tensor, axis=self.axis)
    out_tensor = tf.reduce_mean(out_tensor, axis=self.axis)
    if set_tensors:
      self.out_tensor = out_tensor
    return out_tensor
@@ -1793,11 +1810,11 @@ class ReduceMax(Layer):
  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    inputs = self._get_input_tensors(in_layers)
    if len(inputs) > 1:
      self.out_tensor = tf.stack(inputs)
      out_tensor = tf.stack(inputs)
    else:
      self.out_tensor = inputs[0]
      out_tensor = inputs[0]

    out_tensor = tf.reduce_max(self.out_tensor, axis=self.axis)
    out_tensor = tf.reduce_max(out_tensor, axis=self.axis)
    if set_tensors:
      self.out_tensor = out_tensor
    return out_tensor
@@ -1843,11 +1860,11 @@ class ReduceSum(Layer):
  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    inputs = self._get_input_tensors(in_layers)
    if len(inputs) > 1:
      self.out_tensor = tf.stack(inputs)
      out_tensor = tf.stack(inputs)
    else:
      self.out_tensor = inputs[0]
      out_tensor = inputs[0]

    out_tensor = tf.reduce_sum(self.out_tensor, axis=self.axis)
    out_tensor = tf.reduce_sum(out_tensor, axis=self.axis)
    if set_tensors:
      self.out_tensor = out_tensor
    return out_tensor
@@ -1974,6 +1991,7 @@ class Conv2D(SharedVariableScope):
        if tfe.in_eager_mode():
          if not self._built:
            self._layer = self._build_layer(False)
            self._non_pickle_fields.append('_layer')
          layer = self._layer
        else:
          layer = self._build_layer(reuse)
@@ -2091,6 +2109,7 @@ class Conv3D(SharedVariableScope):
        if tfe.in_eager_mode():
          if not self._built:
            self._layer = self._build_layer(False)
            self._non_pickle_fields.append('_layer')
          layer = self._layer
        else:
          layer = self._build_layer(reuse)
@@ -2208,6 +2227,7 @@ class Conv2DTranspose(SharedVariableScope):
        if tfe.in_eager_mode():
          if not self._built:
            self._layer = self._build_layer(False)
            self._non_pickle_fields.append('_layer')
          layer = self._layer
        else:
          layer = self._build_layer(reuse)
@@ -2325,6 +2345,7 @@ class Conv3DTranspose(SharedVariableScope):
        if tfe.in_eager_mode():
          if not self._built:
            self._layer = self._build_layer(False)
            self._non_pickle_fields.append('_layer')
          layer = self._layer
        else:
          layer = self._build_layer(reuse)
@@ -2491,14 +2512,7 @@ class InputFifoQueue(Layer):
    self.out_tensor = self.queue.enqueue(feed_dict)
    self.close_op = self.queue.close()
    self.out_tensors = self.queue.dequeue()

  def none_tensors(self):
    queue, out_tensors, out_tensor, close_op = self.queue, self.out_tensor, self.out_tensor, self.close_op
    self.queue, self.out_tensor, self.out_tensors, self.close_op = None, None, None, None
    return queue, out_tensors, out_tensor, close_op

  def set_tensors(self, tensors):
    self.queue, self.out_tensor, self.out_tensors, self.close_op = tensors
    self._non_pickle_fields += ['queue', 'out_tensors', 'close_op']


class GraphConv(Layer):
@@ -2518,22 +2532,31 @@ class GraphConv(Layer):

  def _create_variables(self, in_channels):
    # Generate the nb_affine weights and biases
    self.W_list = [
    W_list = [
        initializations.glorot_uniform([in_channels, self.out_channel])
        for k in range(self.num_deg)
    ]
    self.b_list = [
    b_list = [
        model_ops.zeros(shape=[
            self.out_channel,
        ]) for k in range(self.num_deg)
    ]
    return (W_list, b_list)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    inputs = self._get_input_tensors(in_layers)
    # in_layers = [atom_features, deg_slice, membership, deg_adj_list placeholders...]
    in_channels = inputs[0].get_shape()[-1].value
    if not tfe.in_eager_mode() or not self._built:
      self._create_variables(in_channels)
    if tfe.in_eager_mode():
      if not self._built:
        W_list, b_list = self._create_variables(in_channels)
        self.variables = W_list + b_list
        self._built = True
      else:
        W_list = self.variables[:self.num_deg]
        b_list = self.variables[self.num_deg:]
    else:
      W_list, b_list = self._create_variables(in_channels)

    # Extract atom_features
    atom_features = inputs[0]
@@ -2544,11 +2567,11 @@ class GraphConv(Layer):

    # Perform the mol conv
    # atom_features = graph_conv(atom_features, deg_adj_lists, deg_slice,
    #                            self.max_deg, self.min_deg, self.W_list,
    #                            self.b_list)
    #                            self.max_deg, self.min_deg, W_list,
    #                            b_list)

    W = iter(self.W_list)
    b = iter(self.b_list)
    W = iter(W_list)
    b = iter(b_list)

    # Sum all neighbors using adjacency matrix
    deg_summed = self.sum_neigh(atom_features, deg_adj_lists)
@@ -2595,9 +2618,6 @@ class GraphConv(Layer):
    if set_tensors:
      self._record_variable_scope(self.name)
      self.out_tensor = out_tensor
    if tfe.in_eager_mode() and not self._built:
      self._built = True
      self.variables = self.W_list + self.b_list
    return out_tensor

  def sum_neigh(self, atoms, deg_adj_lists):
@@ -2613,14 +2633,6 @@ class GraphConv(Layer):

    return deg_summed

  def none_tensors(self):
    out_tensor, W_list, b_list = self.out_tensor, self.W_list, self.b_list
    self.out_tensor, self.W_list, self.b_list = None, None, None
    return out_tensor, W_list, b_list

  def set_tensors(self, tensors):
    self.out_tensor, self.W_list, self.b_list = tensors


class GraphPool(Layer):

@@ -2770,28 +2782,14 @@ class LSTMStep(Layer):
    """Constructs learnable weights for this layer."""
    init = self.init
    inner_init = self.inner_init
    self.W = init((self.input_dim, 4 * self.output_dim))
    self.U = inner_init((self.output_dim, 4 * self.output_dim))
    W = init((self.input_dim, 4 * self.output_dim))
    U = inner_init((self.output_dim, 4 * self.output_dim))

    self.b = create_variable(
    b = create_variable(
        np.hstack((np.zeros(self.output_dim), np.ones(self.output_dim),
                   np.zeros(self.output_dim), np.zeros(self.output_dim))),
        dtype=tf.float32)

  def none_tensors(self):
    """Zeros out stored tensors for pickling."""
    W, U, b, out_tensor = self.W, self.U, self.b, self.out_tensor
    h, c = self.h, self.c
    trainable_weights = self.trainable_weights
    self.W, self.U, self.b, self.out_tensor = None, None, None, None
    self.h, self.c = None, None
    self.trainable_weights = []
    return W, U, b, h, c, out_tensor, trainable_weights

  def set_tensors(self, tensor):
    """Sets all stored tensors."""
    (self.W, self.U, self.b, self.h, self.c, self.out_tensor,
     self.trainable_weights) = tensor
    return [W, U, b]

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    """Execute this layer on input tensors.
@@ -2809,18 +2807,18 @@ class LSTMStep(Layer):
    activation = self.activation
    inner_activation = self.inner_activation

    if tfe.in_eager_mode() and not self._built:
      self._create_variables()
      self.variables = [self.W, self.U, self.b]
    if tfe.in_eager_mode():
      if not self._built:
        self.variables = self._create_variables()
        self._built = True
    if not tfe.in_eager_mode():
      self._create_variables()
      self.trainable_weights = [self.W, self.U, self.b]
      W, U, b = self.variables
    else:
      W, U, b = self._create_variables()
    inputs = self._get_input_tensors(in_layers)
    x, h_tm1, c_tm1 = inputs

    # Taken from Keras code [citation needed]
    z = model_ops.dot(x, self.W) + model_ops.dot(h_tm1, self.U) + self.b
    z = model_ops.dot(x, W) + model_ops.dot(h_tm1, U) + b

    z0 = z[:, :self.output_dim]
    z1 = z[:, self.output_dim:2 * self.output_dim]
@@ -2835,8 +2833,6 @@ class LSTMStep(Layer):
    h = o * activation(c)

    if set_tensors:
      self.h = h
      self.c = c
      self.out_tensor = h
    return h, [h, c]

@@ -2901,10 +2897,10 @@ class AttnLSTMEmbedding(Layer):
  def _create_variables(self):
    n_feat = self.n_feat
    lstm = LSTMStep(n_feat, 2 * n_feat)
    self.q_init = model_ops.zeros([self.n_test, n_feat])
    self.r_init = model_ops.zeros([self.n_test, n_feat])
    self.states_init = lstm.get_initial_states([self.n_test, n_feat])
    return lstm
    q_init = model_ops.zeros([self.n_test, n_feat])
    r_init = model_ops.zeros([self.n_test, n_feat])
    states_init = lstm.get_initial_states([self.n_test, n_feat])
    return (lstm, q_init, r_init, states_init)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    """Execute this layer on input tensors.
@@ -2931,17 +2927,21 @@ class AttnLSTMEmbedding(Layer):

    if tfe.in_eager_mode():
      if not self._built:
        self._lstm = self._create_variables()
        self._lstm, self.q_init, self.r_init, self.states_init = self._create_variables(
        )
        self._non_pickle_fields += ['_lstm', 'q_init', 'r_init', 'states_init']
      lstm = self._lstm
      q_init = self.q_init
      r_init = self.r_init
      states_init = self.states_init
    else:
      lstm = self._create_variables()
      self.trainable_weights = [self.q_init, self.r_init]
      lstm, q_init, r_init, states_init = self._create_variables()

    ### Performs computations

    # Get initializations
    q = self.q_init
    states = self.states_init
    q = q_init
    states = states_init

    for d in range(self.max_depth):
      # Process using attention
@@ -2956,28 +2956,11 @@ class AttnLSTMEmbedding(Layer):

    if set_tensors:
      self.out_tensor = xp
      self.xq = x + q
      self.xp = xp
    if tfe.in_eager_mode() and not self._built:
      self._built = True
      self.variables = self._lstm.variables + [self.q_init, self.r_init]
      self.variables = lstm.variables + [q_init, r_init] + states_init
    return [x + q, xp]

  def none_tensors(self):
    q_init, r_init, states_init = self.q_init, self.r_init, self.states_init
    xq, xp = self.xq, self.xp
    out_tensor = self.out_tensor
    trainable_weights = self.trainable_weights
    self.q_init, self.r_init, self.states_init = None, None, None
    self.xq, self.xp = None, None
    self.out_tensor = None
    self.trainable_weights = []
    return q_init, r_init, states_init, xq, xp, out_tensor, trainable_weights

  def set_tensors(self, tensor):
    (self.q_init, self.r_init, self.states_init, self.xq, self.xp,
     self.out_tensor, self.trainable_weights) = tensor


class IterRefLSTMEmbedding(Layer):
  """Implements the Iterative Refinement LSTM.
@@ -3022,15 +3005,16 @@ class IterRefLSTMEmbedding(Layer):

    # Support set lstm
    support_lstm = LSTMStep(n_feat, 2 * n_feat)
    self.q_init = model_ops.zeros([self.n_support, n_feat])
    self.support_states_init = support_lstm.get_initial_states(
    q_init = model_ops.zeros([self.n_support, n_feat])
    support_states_init = support_lstm.get_initial_states(
        [self.n_support, n_feat])

    # Test lstm
    test_lstm = LSTMStep(n_feat, 2 * n_feat)
    self.p_init = model_ops.zeros([self.n_test, n_feat])
    self.test_states_init = test_lstm.get_initial_states([self.n_test, n_feat])
    return (support_lstm, test_lstm)
    p_init = model_ops.zeros([self.n_test, n_feat])
    test_states_init = test_lstm.get_initial_states([self.n_test, n_feat])
    return (support_lstm, q_init, support_states_init, test_lstm, p_init,
            test_states_init)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    """Execute this layer on input tensors.
@@ -3051,12 +3035,21 @@ class IterRefLSTMEmbedding(Layer):
    """
    if tfe.in_eager_mode():
      if not self._built:
        self._support_lstm, self._test_lstm = self._create_variables()
        self._support_lstm, self.q_init, self.support_states_init, self._test_lstm, self.p_init, self.test_states_init = self._create_variables(
        )
        self._non_pickle_fields += [
            '_support_lstm', 'q_init', 'support_states_init', '_test_lstm',
            'p_init', 'test_states_init'
        ]
      support_lstm = self._support_lstm
      q_init = self.q_init
      support_states_init = self.support_states_init
      test_lstm = self._test_lstm
      p_init = self.p_init
      test_states_init = self.test_states_init
    else:
      support_lstm, test_lstm = self._create_variables()
      self.trainable_weights = []
      support_lstm, q_init, support_states_init, test_lstm, p_init, test_states_init = self._create_variables(
      )

    # self.build()
    inputs = self._get_input_tensors(in_layers)
@@ -3066,12 +3059,12 @@ class IterRefLSTMEmbedding(Layer):
    x, xp = inputs

    # Get initializations
    p = self.p_init
    q = self.q_init
    p = p_init
    q = q_init
    # Rename support
    z = xp
    states = self.support_states_init
    x_states = self.test_states_init
    states = support_states_init
    x_states = test_states_init

    for d in range(self.max_depth):
      # Process support xp using attention
@@ -3097,36 +3090,15 @@ class IterRefLSTMEmbedding(Layer):
      z = r

    if set_tensors:
      self.xp = x + p
      self.xpq = xp + q
      self.out_tensor = self.xp
      self.out_tensor = xp
    if tfe.in_eager_mode() and not self._built:
      self.variables = self._support_lstm.variables + self._test_lstm.variables + [
          self.q_init, self.p_init
      ]
      self.variables = support_lstm.variables + test_lstm.variables + [
          q_init, p_init
      ] + support_states_init + test_states_init
      self._built = True

    return [x + p, xp + q]

  def none_tensors(self):
    p_init, q_init = self.p_init, self.q_init,
    support_states_init, test_states_init = (self.support_states_init,
                                             self.test_states_init)
    xp, xpq = self.xp, self.xpq
    out_tensor = self.out_tensor
    trainable_weights = self.trainable_weights
    (self.p_init, self.q_init, self.support_states_init,
     self.test_states_init) = (None, None, None, None)
    self.xp, self.xpq = None, None
    self.out_tensor = None
    self.trainable_weights = []
    return (p_init, q_init, support_states_init, test_states_init, xp, xpq,
            out_tensor, trainable_weights)

  def set_tensors(self, tensor):
    (self.p_init, self.q_init, self.support_states_init, self.test_states_init,
     self.xp, self.xpq, self.out_tensor, self.trainable_weights) = tensor


class BatchNorm(Layer):

@@ -3156,6 +3128,7 @@ class BatchNorm(Layer):
    if tfe.in_eager_mode():
      if not self._built:
        self._layer = self._build_layer()
        self._non_pickle_fields.append('_layer')
      layer = self._layer
    else:
      layer = self._build_layer()
@@ -3198,6 +3171,7 @@ class BatchNormalization(Layer):
        shape, initializer=self.gamma_init, name='{}_gamma'.format(self.name))
    self.beta = self.add_weight(
        shape, initializer=self.beta_init, name='{}_beta'.format(self.name))
    self._non_pickle_fields += ['gamma', 'beta']

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    inputs = self._get_input_tensors(in_layers)
@@ -3214,14 +3188,6 @@ class BatchNormalization(Layer):
      self.out_tensor = out_tensor
    return out_tensor

  def none_tensors(self):
    gamma, beta, out_tensor = self.gamma, self.beta, self.out_tensor
    self.gamma, self.beta, self.out_tensor = None, None, None
    return gamma, beta, out_tensor

  def set_tensors(self, tensor):
    self.gamma, self.beta, self.out_tensor = tensor


class WeightedError(Layer):

@@ -3330,6 +3296,7 @@ class VinaFreeEnergy(Layer):
    if tfe.in_eager_mode():
      if not self._built:
        self._weighted_combo, self._w = self._build_layers()
        self._non_pickle_fields += ['_weighted_combo', '_w']
      weighted_combo = self._weighted_combo
      w = self._w
    else:
@@ -4099,31 +4066,22 @@ class AlphaShareLayer(Layer):
    # concatenate subspaces, reshape to size of original input, then stack
    # such that out_tensor has shape (2,?,original_cols)
    count = 0
    self.out_tensors = []
    out_tensors = []
    tmp_tensor = []
    for row in range(n_alphas):
      tmp_tensor.append(tf.reshape(subspaces[row,], [-1, subspace_size]))
      count += 1
      if (count == 2):
        self.out_tensors.append(tf.concat(tmp_tensor, 1))
        out_tensors.append(tf.concat(tmp_tensor, 1))
        tmp_tensor = []
        count = 0

    self.alphas = alphas
    if set_tensors:
      self.out_tensor = self.out_tensors[0]
    return self.out_tensors

  def none_tensors(self):
    num_outputs, out_tensor, out_tensors, alphas = self.num_outputs, self.out_tensor, self.out_tensors, self.alphas
    self.num_outputs = None
    self.out_tensor = None
    self.out_tensors = None
    self.alphas = None
    return num_outputs, out_tensor, self.out_tensors, alphas

  def set_tensors(self, tensor):
    self.num_outputs, self.out_tensor, self.out_tensors, self.alphas = tensor
      self.out_tensor = out_tensors[0]
      self.out_tensors = out_tensors
      self.alphas = alphas
      self._non_pickle_fields += ['out_tensors', 'alphas']
    return out_tensors


class SluiceLoss(Layer):
@@ -4194,18 +4152,10 @@ class BetaShare(Layer):
    else:
      betas = create_variable(tf.random_normal([1, n_betas]), name='betas')
    out_tensor = tf.matmul(betas, subspaces)
    self.betas = betas
    self.out_tensor = tf.reshape(out_tensor, [-1, original_cols])
    return self.out_tensor

  def none_tensors(self):
    out_tensor, betas = self.out_tensor, self.betas
    self.out_tensor = None
    self.betas = None
    return out_tensor, betas

  def set_tensors(self, tensor):
    self.out_tensor, self.betas = tensor
    out_tensor = tf.reshape(out_tensor, [-1, original_cols])
    if set_tensors:
      self.out_tensor = out_tensor
    return out_tensor


class ANIFeat(Layer):
@@ -4482,6 +4432,7 @@ class GraphEmbedPoolLayer(Layer):
    if set_tensors:
      self.out_tensor = result[0]
      self.out_tensors = [result, result_A]
      self._non_pickle_fields.append('out_tensors')
    return result, result_A

  def _create_variables(self, no_features, no_filters, name):
+3 −3
Original line number Diff line number Diff line
@@ -23,12 +23,12 @@ def _to_tensor(x, dtype):
  return x


def create_variable(value, dtype=None, name=None):
def create_variable(value, dtype=None, name=None, trainable=True):
  """Create a tf.Variable or tfe.Variable, depending on the current mode."""
  if tfe.in_eager_mode():
    return tfe.Variable(value, dtype=dtype, name=name)
    return tfe.Variable(value, dtype=dtype, name=name, trainable=trainable)
  else:
    return tf.Variable(value, dtype=dtype, name=name)
    return tf.Variable(value, dtype=dtype, name=name, trainable=trainable)


def ones(shape, dtype=None, name=None):
+217 −37

File changed.

Preview size limit exceeded, changes collapsed.

+2 −2
Original line number Diff line number Diff line
@@ -713,7 +713,7 @@ class TestLayersEager(test_util.TensorFlowTestCase):
        test_out, support_out = layer(test, support)
        assert test_out.shape == (n_test, n_feat)
        assert support_out.shape == (n_support, n_feat)
        assert len(layer.variables) == 5
        assert len(layer.variables) == 7

  def test_iter_ref_lstm_embedding(self):
    """Test invoking AttnLSTMEmbedding in eager mode."""
@@ -730,7 +730,7 @@ class TestLayersEager(test_util.TensorFlowTestCase):
        test_out, support_out = layer(test, support)
        assert test_out.shape == (n_test, n_feat)
        assert support_out.shape == (n_support, n_feat)
        assert len(layer.variables) == 8
        assert len(layer.variables) == 12

  def test_batch_norm(self):
    """Test invoking BatchNorm in eager mode."""
+114 −2

File changed.

Preview size limit exceeded, changes collapsed.