Commit 82a0a753 authored by leswing's avatar leswing
Browse files

Saving/Loading Weave Layers

parent 88798e83
Loading
Loading
Loading
Loading
+174 −143
Original line number Diff line number Diff line
@@ -110,6 +110,7 @@ class WeaveLayer(Layer):
    self.n_pair_input_feat = n_pair_input_feat
    self.n_atom_output_feat = n_atom_output_feat
    self.n_pair_output_feat = n_pair_output_feat
    self.W_AP, self.b_AP, self.W_PP, self.b_PP, self.W_P, self.b_P = None, None, None, None, None, None
    super(WeaveLayer, self).__init__(**kwargs)

  def build(self):
@@ -205,6 +206,25 @@ class WeaveLayer(Layer):
      self.out_tensor = out_tensor
    return out_tensor

  def none_tensors(self):
    W_AP, b_AP, W_PP, W_PP, W_P, b_P = self.W_AP, self.b_AP, self.W_PP, self.W_PP, self.W_P, self.b_P
    self.W_AP, self.b_AP, self.W_PP, self.b_PP, self.W_P, self.b_P = None, None, None, None, None, None

    W_AA, b_AA, W_PA, b_PA, W_A, b_A = self.W_AA, self.b_AA, self.W_PA, self.b_PA, self.W_A, self.b_A
    self.W_AA, self.b_AA, self.W_PA, self.b_PA, self.W_A, self.b_A = None, None, None, None, None, None

    out_tensor, trainable_weights, variables = self.out_tensor, self.trainable_weights, self.variables
    self.out_tensor, self.trainable_weights, self.variables, self.activation, self.init = None, [], [], None, None

    return W_AP, b_AP, W_PP, W_PP, W_P, b_P, \
           W_AA, b_AA, W_PA, b_PA, W_A, b_A, \
           out_tensor, trainable_weights, variables

  def set_tensors(self, tensor):
    self.W_AP, self.b_AP, self.W_PP, self.W_PP, self.W_P, self.b_P, \
    self.W_AA, self.b_AA, self.W_PA, self.b_PA, self.W_A, self.b_A, \
    self.out_tensor, self.trainable_weights, self.variables = tensor


class WeaveGather(Layer):
  """ TensorGraph style implementation
@@ -242,6 +262,7 @@ class WeaveGather(Layer):
    self.activation = activations.get(activation)  # Get activations
    self.epsilon = epsilon
    self.momentum = momentum
    self.W, self.b = None, None
    super(WeaveGather, self).__init__(**kwargs)

  def build(self):
@@ -296,6 +317,17 @@ class WeaveGather(Layer):
    outputs = tf.reshape(outputs, [-1, self.n_input * 11])
    return outputs

  def none_tensors(self):
    W, b = self.W, self.b
    self.W, self.b = None, None

    out_tensor, trainable_weights, variables = self.out_tensor, self.trainable_weights, self.variables
    self.out_tensor, self.trainable_weights, self.variables = None, [], []
    return W, b, out_tensor, trainable_weights, variables

  def set_tensors(self, tensor):
    self.W, self.b, self.out_tensor, self.trainable_weights, self.variables = tensor


class DTNNEmbedding(Layer):
  """ TensorGraph style implementation
@@ -324,7 +356,6 @@ class DTNNEmbedding(Layer):
    super(DTNNEmbedding, self).__init__(**kwargs)

  def build(self):

    self.embedding_list = self.init(
        [self.periodic_table_length, self.n_embedding])
    self.trainable_weights = [self.embedding_list]
+30 −1
Original line number Diff line number Diff line
@@ -384,6 +384,30 @@ class TensorGraph(Model):
    """
    self.optimizer = optimizer

  def get_pickling_errors(self, obj, seen=None):
    if seen == None:
      seen = []
    try:
      state = obj.__getstate__()
    except AttributeError:
      return
    if state == None:
      return
    if isinstance(state, tuple):
      if not isinstance(state[0], dict):
        state = state[1]
      else:
        state = state[0].update(state[1])
    result = {}
    for i in state:
      try:
        pickle.dumps(state[i], protocol=2)
      except pickle.PicklingError:
        if not state[i] in seen:
          seen.append(state[i])
          result[i] = self.get_pickling_errors(state[i], seen)
    return result

  def save(self):
    # Remove out_tensor from the object to be pickled
    must_restore = False
@@ -401,8 +425,13 @@ class TensorGraph(Model):

    # Pickle itself
    pickle_name = os.path.join(self.model_dir, "model.pickle")
    self.get_pickling_errors(self)
    with open(pickle_name, 'wb') as fout:
      try:
        pickle.dump(self, fout)
      except Exception as e:
        print(self.get_pickling_errors(self))
        raise e

    # add out_tensor back to everyone
    if must_restore: