Unverified Commit 04d90883 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1085 from lilleswing/weave-save

Weave save and load
parents 6a442b74 133ee7e7
Loading
Loading
Loading
Loading
+18 −16
Original line number Diff line number Diff line
@@ -65,8 +65,8 @@ class WeaveLayer(Layer):

    """
    super(WeaveLayer, self).__init__(**kwargs)
    self.init = initializations.get(init)  # Set weight initialization
    self.activation = activations.get(activation)  # Get activations
    self.init = init  # Set weight initialization
    self.activation = activation  # Get activations
    self.update_pair = update_pair  # last weave layer does not need to update
    self.n_hidden_AA = n_hidden_AA
    self.n_hidden_PA = n_hidden_PA
@@ -87,18 +87,19 @@ class WeaveLayer(Layer):
        TODO(rbharath): Need to make this not set instance variables to
        follow style in other layers.
        """
    init = initializations.get(self.init)  # Set weight initialization

    self.W_AA = self.init([self.n_atom_input_feat, self.n_hidden_AA])
    self.W_AA = init([self.n_atom_input_feat, self.n_hidden_AA])
    self.b_AA = model_ops.zeros(shape=[
        self.n_hidden_AA,
    ])

    self.W_PA = self.init([self.n_pair_input_feat, self.n_hidden_PA])
    self.W_PA = init([self.n_pair_input_feat, self.n_hidden_PA])
    self.b_PA = model_ops.zeros(shape=[
        self.n_hidden_PA,
    ])

    self.W_A = self.init([self.n_hidden_A, self.n_atom_output_feat])
    self.W_A = init([self.n_hidden_A, self.n_atom_output_feat])
    self.b_A = model_ops.zeros(shape=[
        self.n_atom_output_feat,
    ])
@@ -107,17 +108,17 @@ class WeaveLayer(Layer):
        self.W_AA, self.b_AA, self.W_PA, self.b_PA, self.W_A, self.b_A
    ]
    if self.update_pair:
      self.W_AP = self.init([self.n_atom_input_feat * 2, self.n_hidden_AP])
      self.W_AP = init([self.n_atom_input_feat * 2, self.n_hidden_AP])
      self.b_AP = model_ops.zeros(shape=[
          self.n_hidden_AP,
      ])

      self.W_PP = self.init([self.n_pair_input_feat, self.n_hidden_PP])
      self.W_PP = init([self.n_pair_input_feat, self.n_hidden_PP])
      self.b_PP = model_ops.zeros(shape=[
          self.n_hidden_PP,
      ])

      self.W_P = self.init([self.n_hidden_P, self.n_pair_output_feat])
      self.W_P = init([self.n_hidden_P, self.n_pair_output_feat])
      self.b_P = model_ops.zeros(shape=[
          self.n_pair_output_feat,
      ])
@@ -129,6 +130,7 @@ class WeaveLayer(Layer):
    """ description and explanation refer to deepchem.nn.WeaveLayer
        parent layers: [atom_features, pair_features], pair_split, atom_to_pair
        """
    activation = activations.get(self.activation)  # Get activations
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -142,30 +144,30 @@ class WeaveLayer(Layer):
    atom_to_pair = in_layers[3].out_tensor

    AA = tf.matmul(atom_features, self.W_AA) + self.b_AA
    AA = self.activation(AA)
    AA = activation(AA)
    PA = tf.matmul(pair_features, self.W_PA) + self.b_PA
    PA = self.activation(PA)
    PA = activation(PA)
    PA = tf.segment_sum(PA, pair_split)

    A = tf.matmul(tf.concat([AA, PA], 1), self.W_A) + self.b_A
    A = self.activation(A)
    A = activation(A)

    if self.update_pair:
      AP_ij = tf.matmul(
          tf.reshape(
              tf.gather(atom_features, atom_to_pair),
              [-1, 2 * self.n_atom_input_feat]), self.W_AP) + self.b_AP
      AP_ij = self.activation(AP_ij)
      AP_ij = activation(AP_ij)
      AP_ji = tf.matmul(
          tf.reshape(
              tf.gather(atom_features, tf.reverse(atom_to_pair, [1])),
              [-1, 2 * self.n_atom_input_feat]), self.W_AP) + self.b_AP
      AP_ji = self.activation(AP_ji)
      AP_ji = activation(AP_ji)

      PP = tf.matmul(pair_features, self.W_PP) + self.b_PP
      PP = self.activation(PP)
      PP = activation(PP)
      P = tf.matmul(tf.concat([AP_ij + AP_ji, PP], 1), self.W_P) + self.b_P
      P = self.activation(P)
      P = activation(P)
    else:
      P = pair_features

@@ -183,7 +185,7 @@ class WeaveLayer(Layer):
    self.W_AA, self.b_AA, self.W_PA, self.b_PA, self.W_A, self.b_A = None, None, None, None, None, None

    out_tensor, out_tensors, trainable_weights, variables = self.out_tensor, self.out_tensors, self.trainable_weights, self.variables
    self.out_tensor, self.out_tensors, self.trainable_weights, self.variables, self.activation, self.init = None, [], [], [], None, None
    self.out_tensor, self.out_tensors, self.trainable_weights, self.variables = None, [], [], []

    return W_AP, b_AP, W_PP, W_PP, W_P, b_P, \
           W_AA, b_AA, W_PA, b_PA, W_A, b_A, \
+8 −0
Original line number Diff line number Diff line
@@ -2929,6 +2929,14 @@ class BatchNormalization(Layer):
      self.out_tensor = out_tensor
    return out_tensor

  def none_tensors(self):
    gamma, beta, out_tensor = self.gamma, self.beta, self.out_tensor
    self.gamma, self.beta, self.out_tensor = None, None, None
    return gamma, beta, out_tensor

  def set_tensors(self, tensor):
    self.gamma, self.beta, self.out_tensor = tensor


class WeightedError(Layer):

+20 −0
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ from deepchem.models import GraphConvTensorGraph
from deepchem.models import TensorGraph
from deepchem.molnet.load_function.delaney_datasets import load_delaney
from deepchem.models.tensorgraph.layers import ReduceSum, L2Loss
from deepchem.models import WeaveTensorGraph


class TestGraphModels(unittest.TestCase):
@@ -99,3 +100,22 @@ class TestGraphModels(unittest.TestCase):
    module = model2.create_submodel(loss=loss)
    model2.restore()
    model2.fit(dataset, nb_epoch=1, submodel=module)

  def test_change_loss_function_weave(self):
    tasks, dataset, transformers, metric = self.get_dataset(
        'regression', 'Weave', num_tasks=1)

    batch_size = 50
    model = WeaveTensorGraph(
        len(tasks), batch_size=batch_size, mode='regression', use_queue=False)

    model.fit(dataset, nb_epoch=1)
    model.save()

    model2 = TensorGraph.load_from_dir(model.model_dir, restore=False)
    dummy_label = model2.labels[-1]
    dummy_ouput = model2.outputs[-1]
    loss = ReduceSum(L2Loss(in_layers=[dummy_label, dummy_ouput]))
    module = model2.create_submodel(loss=loss)
    model2.restore()
    model2.fit(dataset, nb_epoch=1, submodel=module)