Commit ec53bdd0 authored by peastman's avatar peastman
Browse files

Fixed test cases

parent 111c1e0b
Loading
Loading
Loading
Loading
+38 −4
Original line number Diff line number Diff line
@@ -13,7 +13,41 @@ from deepchem.models.tensorgraph.tensor_graph import TensorGraph, TFWrapper
from deepchem.models.tensorgraph.layers import Layer, Feature, Label, Weights, \
    WeightedError, Dense, Dropout, WeightDecay, Reshape, SparseSoftMaxCrossEntropy, \
    L2Loss, ReduceSum, Concat, Stack, TensorWrapper, ReLU, Squeeze, SoftMax, Cast
from deepchem.models.tensorgraph.IRV import Slice
from deepchem.models.tensorgraph.layers import convert_to_layers


class Slice(Layer):
  """ Choose a slice of input on the last axis given order,
  Suppose input x has two dimensions,
  output f(x) = x[:, slice_num:slice_num+1]
  """

  def __init__(self, slice_num, axis=1, **kwargs):
    """
    Parameters
    ----------
    slice_num: int
      index of slice number
    axis: int
      axis id
    """
    self.slice_num = slice_num
    self.axis = axis
    super(Slice, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)

    slice_num = self.slice_num
    axis = self.axis
    inputs = in_layers[0].out_tensor
    out_tensor = tf.slice(inputs, [0] * axis + [slice_num], [-1] * axis + [1])

    if set_tensors:
      self.out_tensor = out_tensor
    return out_tensor


class ProgressiveMultitaskRegressor(TensorGraph):