Commit 9a459f34 authored by leswing's avatar leswing
Browse files

GOGO

parent b1700271
Loading
Loading
Loading
Loading
+4 −33
Original line number Diff line number Diff line
@@ -3051,7 +3051,6 @@ class AlphaShareLayer(Layer):
  Returns
  -------
  out_tensor: a tensor with shape [len(in_layers), x, y] where x, y were the original layer dimensions
    out_tensor should be fed into LayerSplitter
  Distance matrix.
  """

@@ -3109,37 +3108,6 @@ class AlphaShareLayer(Layer):
    self.num_outputs, self.out_tensor, self.out_tensors, self.alphas = tensor


class LayerSplitter(Layer):
  """
  Returns the nth output of a layer
  Assumes out_tensor has shape [x, :] where x is the total number of intended output tensors
  """

  def __init__(self, output_num, **kwargs):
    """
    Parameters
    ----------
    output_num: int
        returns the out_tensor[output_num, :] of a layer
    """
    self.output_num = output_num
    super(LayerSplitter, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    inputs = self._get_input_tensors(in_layers)[0]
    self.out_tensor = inputs[self.output_num, :]
    out_tensor = self.out_tensor
    return self.out_tensor

  def none_tensors(self):
    out_tensor = self.out_tensor
    self.out_tensor = None
    return out_tensor

  def set_tensors(self, tensor):
    self.out_tensor = tensor


class SluiceLoss(Layer):
  """
  Calculates the loss in a Sluice Network
@@ -3231,7 +3199,10 @@ class PassThroughLayer(Layer):
    super(PassThroughLayer, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    self.out_tensor = self.in_layers[0].out_tensors[self.output_num]
    out_tensor = self.in_layers[0].out_tensors[self.output_num]
    if set_tensors:
      self.out_tensor = out_tensor
    return out_tensor


class GraphEmbedPoolLayer(Layer):