Commit 95e69566 authored by peastman's avatar peastman
Browse files

GRU layer can be pickled

parent 17139e15
Loading
Loading
Loading
Loading
+14 −0
Original line number Diff line number Diff line
@@ -413,6 +413,20 @@ class GRU(Layer):
      self.rnn_zero_states.append(np.zeros(zero_state.get_shape(), np.float32))
    return out_tensor

  def none_tensors(self):
    saved_tensors = [
        self.out_tensor, self.rnn_initial_states, self.rnn_final_states,
        self.rnn_zero_states
    ]
    self.out_tensor = None
    self.rnn_initial_states = []
    self.rnn_final_states = []
    self.rnn_zero_states = []
    return saved_tensors

  def set_tensors(self, tensor):
    self.out_tensor, self.rnn_initial_states, self.rnn_final_states, self.rnn_zero_states = tensor


class TimeSeriesDense(Layer):

+9 −0
Original line number Diff line number Diff line
@@ -422,7 +422,13 @@ class TensorGraph(Model):
    # Remove out_tensor from the object to be pickled
    must_restore = False
    tensor_objects = self.tensor_objects
    rnn_initial_states = self.rnn_initial_states
    rnn_final_states = self.rnn_final_states
    rnn_zero_states = self.rnn_zero_states
    self.tensor_objects = {}
    self.rnn_initial_states = []
    self.rnn_final_states = []
    self.rnn_zero_states = []
    out_tensors = []
    if self.built:
      must_restore = True
@@ -450,6 +456,9 @@ class TensorGraph(Model):
      self._training_placeholder = training_placeholder
      self.built = True
    self.tensor_objects = tensor_objects
    self.rnn_initial_states = rnn_initial_states
    self.rnn_final_states = rnn_final_states
    self.rnn_zero_states = rnn_zero_states

  def evaluate_generator(self,
                         feed_dict_generator,