Commit 2745d99d authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #860 from lilleswing/858-badweave

Weave has list of out_tensor
parents 64d3ba49 5f1e2eb3
Loading
Loading
Loading
Loading
+22 −49
Original line number Diff line number Diff line
@@ -16,46 +16,14 @@ from deepchem.nn import activations
from deepchem.nn import initializations
from deepchem.nn import model_ops

from deepchem.models.tensorgraph.layers import Layer
from deepchem.models.tensorgraph.layers import Layer, LayerSplitter
from deepchem.models.tensorgraph.layers import convert_to_layers


class Combine_AP(Layer):

  def __init__(self, **kwargs):
    super(Combine_AP, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
    A = in_layers[0].out_tensor
    P = in_layers[1].out_tensor
    out_tensor = [A, P]
    if set_tensors:
      self.out_tensor = out_tensor
    return out_tensor


class Separate_AP(Layer):

  def __init__(self, **kwargs):
    super(Separate_AP, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
    out_tensor = in_layers[0].out_tensor[0]
    if set_tensors:
      self.out_tensor = out_tensor
    return out_tensor


class WeaveLayer(Layer):
  """ TensorGraph style implementation
    The same as deepchem.nn.WeaveLayer

    Note: Use WeaveLayerFactory to construct this layer
    """

  def __init__(self,
@@ -167,11 +135,11 @@ class WeaveLayer(Layer):

    self.build()

    atom_features = in_layers[0].out_tensor[0]
    pair_features = in_layers[0].out_tensor[1]
    atom_features = in_layers[0].out_tensor
    pair_features = in_layers[1].out_tensor

    pair_split = in_layers[1].out_tensor
    atom_to_pair = in_layers[2].out_tensor
    pair_split = in_layers[2].out_tensor
    atom_to_pair = in_layers[3].out_tensor

    AA = tf.matmul(atom_features, self.W_AA) + self.b_AA
    AA = self.activation(AA)
@@ -201,11 +169,11 @@ class WeaveLayer(Layer):
    else:
      P = pair_features

    out_tensor = [A, P]
    self.out_tensors = [A, P]
    if set_tensors:
      self.variables = self.trainable_weights
      self.out_tensor = out_tensor
    return out_tensor
      self.out_tensor = A
    return self.out_tensors

  def none_tensors(self):
    W_AP, b_AP, W_PP, W_PP, W_P, b_P = self.W_AP, self.b_AP, self.W_PP, self.W_PP, self.W_P, self.b_P
@@ -214,17 +182,22 @@ class WeaveLayer(Layer):
    W_AA, b_AA, W_PA, b_PA, W_A, b_A = self.W_AA, self.b_AA, self.W_PA, self.b_PA, self.W_A, self.b_A
    self.W_AA, self.b_AA, self.W_PA, self.b_PA, self.W_A, self.b_A = None, None, None, None, None, None

    out_tensor, trainable_weights, variables = self.out_tensor, self.trainable_weights, self.variables
    self.out_tensor, self.trainable_weights, self.variables, self.activation, self.init = None, [], [], None, None
    out_tensor, out_tensors, trainable_weights, variables = self.out_tensor, self.out_tensors, self.trainable_weights, self.variables
    self.out_tensor, self.out_tensors, self.trainable_weights, self.variables, self.activation, self.init = None, [], [], [], None, None

    return W_AP, b_AP, W_PP, W_PP, W_P, b_P, \
           W_AA, b_AA, W_PA, b_PA, W_A, b_A, \
           out_tensor, trainable_weights, variables
           out_tensor, out_tensors, trainable_weights, variables

  def set_tensors(self, tensor):
    self.W_AP, self.b_AP, self.W_PP, self.W_PP, self.W_P, self.b_P, \
    self.W_AA, self.b_AA, self.W_PA, self.b_PA, self.W_A, self.b_A, \
    self.out_tensor, self.trainable_weights, self.variables = tensor
    self.out_tensor, self.out_tensors, self.trainable_weights, self.variables = tensor


def WeaveLayerFactory(**kwargs):
  weaveLayer = WeaveLayer(**kwargs)
  return [LayerSplitter(i, in_layers=weaveLayer) for i in range(2)]


class WeaveGather(Layer):
+20 −49
Original line number Diff line number Diff line
@@ -3031,10 +3031,7 @@ def AlphaShare(in_layers=None, **kwargs):
  output_layers = []
  alpha_share = AlphaShareLayer(in_layers=in_layers, **kwargs)
  num_outputs = len(in_layers)
  for num_layer in range(0, num_outputs):
    ls = LayerSplitter(output_num=num_layer, in_layers=alpha_share)
    output_layers.append(ls)
  return output_layers
  return [LayerSplitter(x, in_layers=alpha_share) for x in range(num_outputs)]


class AlphaShareLayer(Layer):
@@ -3052,7 +3049,6 @@ class AlphaShareLayer(Layer):
  Returns
  -------
  out_tensor: a tensor with shape [len(in_layers), x, y] where x, y were the original layer dimensions
    out_tensor should be fed into LayerSplitter
  Distance matrix.
  """

@@ -3083,63 +3079,31 @@ class AlphaShareLayer(Layer):
    # concatenate subspaces, reshape to size of original input, then stack
    # such that out_tensor has shape (2,?,original_cols)
    count = 0
    out_tensor = []
    self.out_tensors = []
    tmp_tensor = []
    for row in range(n_alphas):
      tmp_tensor.append(tf.reshape(subspaces[row,], [-1, subspace_size]))
      count += 1
      if (count == 2):
        out_tensor.append(tf.concat(tmp_tensor, 1))
        self.out_tensors.append(tf.concat(tmp_tensor, 1))
        tmp_tensor = []
        count = 0

    out_tensor = tf.stack(out_tensor)

    self.alphas = alphas
    if set_tensors:
      self.out_tensor = out_tensor
    return out_tensor
      self.out_tensor = self.out_tensors[0]
    return self.out_tensors

  def none_tensors(self):
    num_outputs, out_tensor, alphas = self.num_outputs, self.out_tensor, self.alphas
    num_outputs, out_tensor, out_tensors, alphas = self.num_outputs, self.out_tensor, self.out_tensors, self.alphas
    self.num_outputs = None
    self.out_tensor = None
    self.out_tensors = None
    self.alphas = None
    return num_outputs, out_tensor, alphas

  def set_tensors(self, tensor):
    self.num_outputs, self.out_tensor, self.alphas = tensor


class LayerSplitter(Layer):
  """
  Returns the nth output of a layer
  Assumes out_tensor has shape [x, :] where x is the total number of intended output tensors
  """

  def __init__(self, output_num, **kwargs):
    """
    Parameters
    ----------
    output_num: int
        returns the out_tensor[output_num, :] of a layer
    """
    self.output_num = output_num
    super(LayerSplitter, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    inputs = self._get_input_tensors(in_layers)[0]
    self.out_tensor = inputs[self.output_num, :]
    out_tensor = self.out_tensor
    return self.out_tensor

  def none_tensors(self):
    out_tensor = self.out_tensor
    self.out_tensor = None
    return out_tensor
    return num_outputs, out_tensor, self.out_tensors, alphas

  def set_tensors(self, tensor):
    self.out_tensor = tensor
    self.num_outputs, self.out_tensor, self.out_tensors, self.alphas = tensor


class SluiceLoss(Layer):
@@ -3395,9 +3359,13 @@ class ANIFeat(Layer):
    return n_feat


class PassThroughLayer(Layer):
class LayerSplitter(Layer):
  """
  Layer which takes a tensor from in_tensor[0].out_tensors at an index
  Only layers which need to output multiple layers set and use the variable
  self.out_tensors.
  This is a utility for those special layers which set self.out_tensors
  to return a layer wrapping a specific tensor in in_layers[0].out_tensors
  """

  def __init__(self, output_num, **kwargs):
@@ -3409,10 +3377,13 @@ class PassThroughLayer(Layer):
    kwargs
    """
    self.output_num = output_num
    super(PassThroughLayer, self).__init__(**kwargs)
    super(LayerSplitter, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    self.out_tensor = self.in_layers[0].out_tensors[self.output_num]
    out_tensor = self.in_layers[0].out_tensors[self.output_num]
    if set_tensors:
      self.out_tensor = out_tensor
    return out_tensor


class GraphEmbedPoolLayer(Layer):
@@ -3520,7 +3491,7 @@ class GraphEmbedPoolLayer(Layer):

def GraphCNNPool(num_vertices, **kwargs):
  gcnnpool_layer = GraphEmbedPoolLayer(num_vertices, **kwargs)
  return [PassThroughLayer(x, in_layers=gcnnpool_layer) for x in range(2)]
  return [LayerSplitter(x, in_layers=gcnnpool_layer) for x in range(2)]


class GraphCNN(Layer):
+17 −10
Original line number Diff line number Diff line
import collections

import numpy as np
import six
import tensorflow as tf

from deepchem.data import NumpyDataset
from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.mol_graphs import ConvMol
from deepchem.metrics import to_one_hot
from deepchem.models.tensorgraph.graph_layers import WeaveLayer, WeaveGather, \
  Combine_AP, Separate_AP, DTNNEmbedding, DTNNStep, DTNNGather, DAGLayer, \
from deepchem.models.tensorgraph.graph_layers import WeaveGather, \
  DTNNEmbedding, DTNNStep, DTNNGather, DAGLayer, \
  DAGGather, DTNNExtract, MessagePassing, SetGather
from deepchem.models.tensorgraph.graph_layers import WeaveLayerFactory
from deepchem.models.tensorgraph.layers import Dense, Concat, SoftMax, \
  SoftMaxCrossEntropy, GraphConv, BatchNorm, \
  GraphPool, GraphGather, WeightedError, Dropout, BatchNormalization, Stack, Layer, Flatten, GraphCNN, GraphCNNPool
  GraphPool, GraphGather, WeightedError, Dropout, BatchNormalization, Stack, Flatten, GraphCNN, GraphCNNPool
from deepchem.models.tensorgraph.layers import L2Loss, Label, Weights, Feature
from deepchem.models.tensorgraph.tensor_graph import TensorGraph
from deepchem.trans import undo_transforms
@@ -57,28 +61,31 @@ class WeaveTensorGraph(TensorGraph):
        """
    self.atom_features = Feature(shape=(None, self.n_atom_feat))
    self.pair_features = Feature(shape=(None, self.n_pair_feat))
    combined = Combine_AP(in_layers=[self.atom_features, self.pair_features])
    self.pair_split = Feature(shape=(None,), dtype=tf.int32)
    self.atom_split = Feature(shape=(None,), dtype=tf.int32)
    self.atom_to_pair = Feature(shape=(None, 2), dtype=tf.int32)
    weave_layer1 = WeaveLayer(
    weave_layer1A, weave_layer1P = WeaveLayerFactory(
        n_atom_input_feat=self.n_atom_feat,
        n_pair_input_feat=self.n_pair_feat,
        n_atom_output_feat=self.n_hidden,
        n_pair_output_feat=self.n_hidden,
        in_layers=[combined, self.pair_split, self.atom_to_pair])
    weave_layer2 = WeaveLayer(
        in_layers=[
            self.atom_features, self.pair_features, self.pair_split,
            self.atom_to_pair
        ])
    weave_layer2A, weave_layer2P = WeaveLayerFactory(
        n_atom_input_feat=self.n_hidden,
        n_pair_input_feat=self.n_hidden,
        n_atom_output_feat=self.n_hidden,
        n_pair_output_feat=self.n_hidden,
        update_pair=False,
        in_layers=[weave_layer1, self.pair_split, self.atom_to_pair])
    separated = Separate_AP(in_layers=[weave_layer2])
        in_layers=[
            weave_layer1A, weave_layer1P, self.pair_split, self.atom_to_pair
        ])
    dense1 = Dense(
        out_channels=self.n_graph_feat,
        activation_fn=tf.nn.tanh,
        in_layers=[separated])
        in_layers=weave_layer2A)
    batch_norm1 = BatchNormalization(epsilon=1e-5, mode=1, in_layers=[dense1])
    weave_gather = WeaveGather(
        self.batch_size,
+0 −16
Original line number Diff line number Diff line
@@ -28,7 +28,6 @@ from deepchem.models.tensorgraph.layers import InteratomicL2Distances
from deepchem.models.tensorgraph.layers import IterRefLSTMEmbedding
from deepchem.models.tensorgraph.layers import L2Loss
from deepchem.models.tensorgraph.layers import LSTMStep
from deepchem.models.tensorgraph.layers import LayerSplitter
from deepchem.models.tensorgraph.layers import Log
from deepchem.models.tensorgraph.layers import Multiply
from deepchem.models.tensorgraph.layers import ReduceMean
@@ -709,21 +708,6 @@ class TestLayers(test_util.TensorFlowTestCase):
      assert test_1.shape == out_tensor.shape
      assert test_2.shape == out_tensor.shape

  def test_layer_splitter(self):
    """Test Layer Splitter"""
    input1 = np.arange(10).reshape(2, 5)
    input2 = np.arange(10, 20).reshape(2, 5)

    with self.test_session() as sess:
      input1 = tf.convert_to_tensor(input1, dtype=tf.float32)
      input2 = tf.convert_to_tensor(input2, dtype=tf.float32)
      input_tensor = tf.stack([input1, input2])
      output1 = LayerSplitter(0)(input_tensor)
      output2 = LayerSplitter(1)(input_tensor)
      sess.run(tf.global_variables_initializer())
      sess.run(tf.assert_equal(input1, output1.eval()))
      sess.run(tf.assert_equal(input2, output2.eval()))

  def test_sluice_loss(self):
    """Test the sluice loss function"""
    input1 = np.ones((3, 4))
+3 −16
Original line number Diff line number Diff line
@@ -2,8 +2,7 @@ import numpy as np
import tensorflow as tf

from deepchem.models import TensorGraph
from deepchem.models.tensorgraph.graph_layers import Combine_AP, Separate_AP, \
  WeaveLayer, WeaveGather, DTNNEmbedding, DTNNGather, DTNNStep, \
from deepchem.models.tensorgraph.graph_layers import WeaveLayer, WeaveGather, DTNNEmbedding, DTNNGather, DTNNStep, \
  DTNNExtract, DAGLayer, DAGGather, MessagePassing, SetGather
from deepchem.models.tensorgraph.layers import Feature, Conv1D, Dense, Flatten, Reshape, Squeeze, Transpose, \
  CombineMeanStd, Repeat, Gather, GRU, L2Loss, Concat, SoftMax, \
@@ -383,26 +382,14 @@ def test_WeightedError_pickle():
  tg.save()


def test_Combine_Separate_AP_pickle():
  tg = TensorGraph()
  atom_feature = Feature(shape=(None, 10))
  pair_feature = Feature(shape=(None, 5))
  C_AP = Combine_AP(in_layers=[atom_feature, pair_feature])
  S_AP = Separate_AP(in_layers=[C_AP])
  tg.add_output(S_AP)
  tg.set_loss(S_AP)
  tg.build()
  tg.save()


def test_Weave_pickle():
  tg = TensorGraph()
  atom_feature = Feature(shape=(None, 75))
  pair_feature = Feature(shape=(None, 14))
  pair_split = Feature(shape=(None,), dtype=tf.int32)
  atom_to_pair = Feature(shape=(None, 2), dtype=tf.int32)
  C_AP = Combine_AP(in_layers=[atom_feature, pair_feature])
  weave = WeaveLayer(in_layers=[C_AP, pair_split, atom_to_pair])
  weave = WeaveLayer(
      in_layers=[atom_feature, pair_feature, pair_split, atom_to_pair])
  tg.add_output(weave)
  tg.set_loss(weave)
  tg.build()