Commit 13003b4d authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #636 from lilleswing/saving-tg

Saving/Loading Weave Layers
parents 88798e83 ed720894
Loading
Loading
Loading
Loading
+174 −143
Original line number Diff line number Diff line
@@ -110,6 +110,7 @@ class WeaveLayer(Layer):
    self.n_pair_input_feat = n_pair_input_feat
    self.n_atom_output_feat = n_atom_output_feat
    self.n_pair_output_feat = n_pair_output_feat
    self.W_AP, self.b_AP, self.W_PP, self.b_PP, self.W_P, self.b_P = None, None, None, None, None, None
    super(WeaveLayer, self).__init__(**kwargs)

  def build(self):
@@ -205,6 +206,25 @@ class WeaveLayer(Layer):
      self.out_tensor = out_tensor
    return out_tensor

  def none_tensors(self):
    W_AP, b_AP, W_PP, W_PP, W_P, b_P = self.W_AP, self.b_AP, self.W_PP, self.W_PP, self.W_P, self.b_P
    self.W_AP, self.b_AP, self.W_PP, self.b_PP, self.W_P, self.b_P = None, None, None, None, None, None

    W_AA, b_AA, W_PA, b_PA, W_A, b_A = self.W_AA, self.b_AA, self.W_PA, self.b_PA, self.W_A, self.b_A
    self.W_AA, self.b_AA, self.W_PA, self.b_PA, self.W_A, self.b_A = None, None, None, None, None, None

    out_tensor, trainable_weights, variables = self.out_tensor, self.trainable_weights, self.variables
    self.out_tensor, self.trainable_weights, self.variables, self.activation, self.init = None, [], [], None, None

    return W_AP, b_AP, W_PP, W_PP, W_P, b_P, \
           W_AA, b_AA, W_PA, b_PA, W_A, b_A, \
           out_tensor, trainable_weights, variables

  def set_tensors(self, tensor):
    self.W_AP, self.b_AP, self.W_PP, self.W_PP, self.W_P, self.b_P, \
    self.W_AA, self.b_AA, self.W_PA, self.b_PA, self.W_A, self.b_A, \
    self.out_tensor, self.trainable_weights, self.variables = tensor


class WeaveGather(Layer):
  """ TensorGraph style implementation
@@ -242,6 +262,7 @@ class WeaveGather(Layer):
    self.activation = activations.get(activation)  # Get activations
    self.epsilon = epsilon
    self.momentum = momentum
    self.W, self.b = None, None
    super(WeaveGather, self).__init__(**kwargs)

  def build(self):
@@ -296,6 +317,17 @@ class WeaveGather(Layer):
    outputs = tf.reshape(outputs, [-1, self.n_input * 11])
    return outputs

  def none_tensors(self):
    W, b = self.W, self.b
    self.W, self.b = None, None

    out_tensor, trainable_weights, variables = self.out_tensor, self.trainable_weights, self.variables
    self.out_tensor, self.trainable_weights, self.variables = None, [], []
    return W, b, out_tensor, trainable_weights, variables

  def set_tensors(self, tensor):
    self.W, self.b, self.out_tensor, self.trainable_weights, self.variables = tensor


class DTNNEmbedding(Layer):
  """ TensorGraph style implementation
@@ -324,7 +356,6 @@ class DTNNEmbedding(Layer):
    super(DTNNEmbedding, self).__init__(**kwargs)

  def build(self):

    self.embedding_list = self.init(
        [self.periodic_table_length, self.n_embedding])
    self.trainable_weights = [self.embedding_list]
+1 −1
Original line number Diff line number Diff line
@@ -87,7 +87,7 @@ class Layer(object):
      else:
        raise ValueError('Unexpected input: ' + str(input))
    if reshape and len(tensors) > 1:
      shapes = [t.shape for t in tensors]
      shapes = [t.get_shape() for t in tensors]
      if any(s != shapes[0] for s in shapes[1:]):
        # Reshape everything to match the input with the most dimensions.

+1 −1
Original line number Diff line number Diff line
@@ -102,7 +102,7 @@ class WeaveTensorGraph(TensorGraph):
        cost = L2Loss(in_layers=[label, regression])
        costs.append(cost)

    all_cost = Concat(in_layers=costs)
    all_cost = Concat(in_layers=costs, axis=0)
    self.weights = Weights(shape=(None, self.n_tasks))
    loss = WeightedError(in_layers=[all_cost, self.weights])
    self.set_loss(loss)
+29 −1
Original line number Diff line number Diff line
@@ -384,6 +384,30 @@ class TensorGraph(Model):
    """
    self.optimizer = optimizer

  def get_pickling_errors(self, obj, seen=None):
    if seen == None:
      seen = []
    try:
      state = obj.__getstate__()
    except AttributeError:
      return
    if state == None:
      return
    if isinstance(state, tuple):
      if not isinstance(state[0], dict):
        state = state[1]
      else:
        state = state[0].update(state[1])
    result = {}
    for i in state:
      try:
        pickle.dumps(state[i], protocol=2)
      except pickle.PicklingError:
        if not state[i] in seen:
          seen.append(state[i])
          result[i] = self.get_pickling_errors(state[i], seen)
    return result

  def save(self):
    # Remove out_tensor from the object to be pickled
    must_restore = False
@@ -402,7 +426,11 @@ class TensorGraph(Model):
    # Pickle itself
    pickle_name = os.path.join(self.model_dir, "model.pickle")
    with open(pickle_name, 'wb') as fout:
      try:
        pickle.dump(self, fout)
      except Exception as e:
        print(self.get_pickling_errors(self))
        raise e

    # add out_tensor back to everyone
    if must_restore:
+315 −0
Original line number Diff line number Diff line
import numpy as np
import tensorflow as tf
from deepchem.models import TensorGraph
from deepchem.models.tensorgraph.layers import Feature, Conv1D, Dense, Flatten, Reshape, Squeeze, Transpose, \
    CombineMeanStd, Repeat, GRU, L2Loss, Concat, SoftMax, Constant, Variable, Add, Multiply, InteratomicL2Distances, \
    SoftMaxCrossEntropy, ReduceMean, ToFloat, ReduceSquareDifference, Conv2D, MaxPool, ReduceSum, GraphConv, GraphPool, \
    GraphGather, BatchNorm, WeightedError


def test_Conv1D_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 1, 1))
  conv = Conv1D(2, 1, in_layers=feature)
  tg.add_output(conv)
  tg.set_loss(conv)
  tg.build()
  tg.save()


def test_Dense_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 1))
  dense = Dense(out_channels=1, in_layers=feature)
  tg.add_output(dense)
  tg.set_loss(dense)
  tg.build()
  tg.save()


def test_Flatten_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 1))
  layer = Flatten(in_layers=feature)
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_Reshape_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 1))
  layer = Reshape(shape=(-1, 2), in_layers=feature)
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_Squeeze_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 1))
  layer = Squeeze(squeeze_dims=-1, in_layers=feature)
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_Transpose_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 1))
  layer = Transpose(perm=(1, 0), in_layers=feature)
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_CombineMeanStd_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 1))
  layer = CombineMeanStd(in_layers=[feature, feature])
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_Repeat_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 1))
  layer = Repeat(n_times=10, in_layers=feature)
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_GRU_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 10, 10))
  layer = GRU(n_hidden=10, batch_size=tg.batch_size, in_layers=feature)
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_L2loss_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 1))
  layer = L2Loss(in_layers=[feature, feature])
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_Softmax_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 1))
  layer = SoftMax(in_layers=feature)
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_Concat_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 1))
  layer = Concat(in_layers=[feature, feature])
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_Constant_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 1))
  layer = Constant(np.expand_dims([17] * tg.batch_size, -1))
  output = Add(in_layers=[feature, layer])
  tg.add_output(output)
  tg.set_loss(output)
  tg.build()
  tg.save()


def test_Variable_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 1))
  layer = Variable(np.expand_dims([17] * tg.batch_size, -1))
  output = Multiply(in_layers=[feature, layer])
  tg.add_output(output)
  tg.set_loss(output)
  tg.build()
  tg.save()


def testInteratomicL2Distances():
  """
    TODO(LESWING) what is ndim here?
    :return:
    """
  tg = TensorGraph()
  n_atoms = tg.batch_size
  M_nbrs = 4
  n_dim = 3
  feature = Feature(shape=(tg.batch_size, 3))
  neighbors = Feature(shape=(tg.batch_size, M_nbrs), dtype=tf.int32)
  layer = InteratomicL2Distances(
      N_atoms=n_atoms,
      M_nbrs=M_nbrs,
      ndim=n_dim,
      in_layers=[feature, neighbors])
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_SoftmaxCrossEntropy_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 1))
  layer = SoftMaxCrossEntropy(in_layers=[feature, feature])
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_ReduceMean_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 1))
  layer = ReduceMean(in_layers=[feature])
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_ToFloat_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 1))
  layer = ToFloat(in_layers=[feature])
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_ReduceSum_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 1))
  layer = ReduceSum(in_layers=[feature])
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_ReduceSquareDifference_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 1))
  layer = ReduceSquareDifference(in_layers=[feature, feature])
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_Conv2D_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 10, 10))
  layer = Conv2D(num_outputs=3, in_layers=feature)
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_MaxPool_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 10, 10, 10))
  layer = MaxPool(in_layers=feature)
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_GraphConv_pickle():
  tg = TensorGraph()
  atom_features = Feature(shape=(None, 75))
  degree_slice = Feature(shape=(None, 2), dtype=tf.int32)
  membership = Feature(shape=(None,), dtype=tf.int32)

  deg_adjs = []
  for i in range(0, 10 + 1):
    deg_adj = Feature(shape=(None, i + 1), dtype=tf.int32)
    deg_adjs.append(deg_adj)
  layer = GraphConv(
      64,
      activation_fn=tf.nn.relu,
      in_layers=[atom_features, degree_slice, membership] + deg_adjs)
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_GraphPool_Pickle():
  tg = TensorGraph()
  atom_features = Feature(shape=(None, 75))
  degree_slice = Feature(shape=(None, 2), dtype=tf.int32)
  membership = Feature(shape=(None,), dtype=tf.int32)
  deg_adjs = []
  for i in range(0, 10 + 1):
    deg_adj = Feature(shape=(None, i + 1), dtype=tf.int32)
    deg_adjs.append(deg_adj)
  layer = GraphPool(
      in_layers=[atom_features, degree_slice, membership] + deg_adjs)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_GraphGather_Pickle():
  tg = TensorGraph()
  atom_features = Feature(shape=(None, 75))
  degree_slice = Feature(shape=(None, 2), dtype=tf.int32)
  membership = Feature(shape=(None,), dtype=tf.int32)
  deg_adjs = []
  for i in range(0, 10 + 1):
    deg_adj = Feature(shape=(None, i + 1), dtype=tf.int32)
    deg_adjs.append(deg_adj)
  layer = GraphGather(
      batch_size=tg.batch_size,
      activation_fn=tf.nn.tanh,
      in_layers=[atom_features, degree_slice, membership] + deg_adjs)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_BatchNorm_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 10))
  layer = BatchNorm(in_layers=feature)
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()


def test_WeightedError_pickle():
  tg = TensorGraph()
  feature = Feature(shape=(tg.batch_size, 10))
  layer = WeightedError(in_layers=[feature, feature])
  tg.add_output(layer)
  tg.set_loss(layer)
  tg.build()
  tg.save()