Commit d3dc29fa authored by leswing's avatar leswing
Browse files

Documentation on Layer Splitter

parent bb1a590f
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -16,7 +16,7 @@ from deepchem.nn import activations
from deepchem.nn import initializations
from deepchem.nn import model_ops

from deepchem.models.tensorgraph.layers import Layer, PassThroughLayer
from deepchem.models.tensorgraph.layers import Layer, LayerSplitter
from deepchem.models.tensorgraph.layers import convert_to_layers


@@ -197,7 +197,7 @@ class WeaveLayer(Layer):

def WeaveLayerFactory(**kwargs):
  weaveLayer = WeaveLayer(**kwargs)
  return [PassThroughLayer(i, in_layers=weaveLayer) for i in range(2)]
  return [LayerSplitter(i, in_layers=weaveLayer) for i in range(2)]


class WeaveGather(Layer):
+7 −6
Original line number Diff line number Diff line
@@ -3031,9 +3031,7 @@ def AlphaShare(in_layers=None, **kwargs):
  output_layers = []
  alpha_share = AlphaShareLayer(in_layers=in_layers, **kwargs)
  num_outputs = len(in_layers)
  return [
      PassThroughLayer(x, in_layers=alpha_share) for x in range(num_outputs)
  ]
  return [LayerSplitter(x, in_layers=alpha_share) for x in range(num_outputs)]


class AlphaShareLayer(Layer):
@@ -3182,9 +3180,12 @@ class BetaShare(Layer):
    self.out_tensor, self.betas = tensor


class PassThroughLayer(Layer):
class LayerSplitter(Layer):
  """
  Layer which takes a tensor from in_tensor[0].out_tensors at an index
  This is a special layer which only makes sense in the context of TensorGraph.
  It takes a single input layer which sets the class variable self.out_tensors
  to a list of tensors.
  """

  def __init__(self, output_num, **kwargs):
@@ -3196,7 +3197,7 @@ class PassThroughLayer(Layer):
    kwargs
    """
    self.output_num = output_num
    super(PassThroughLayer, self).__init__(**kwargs)
    super(LayerSplitter, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    out_tensor = self.in_layers[0].out_tensors[self.output_num]
@@ -3310,7 +3311,7 @@ class GraphEmbedPoolLayer(Layer):

def GraphCNNPool(num_vertices, **kwargs):
  gcnnpool_layer = GraphEmbedPoolLayer(num_vertices, **kwargs)
  return [PassThroughLayer(x, in_layers=gcnnpool_layer) for x in range(2)]
  return [LayerSplitter(x, in_layers=gcnnpool_layer) for x in range(2)]


class GraphCNN(Layer):