Commit 47e7c512 authored by Peter Eastman's avatar Peter Eastman
Browse files

Implemented dropout and parameter initialization

parent 385b857e
Loading
Loading
Loading
Loading
+22 −8
Original line number Diff line number Diff line
@@ -20,13 +20,16 @@ from deepchem.metrics import to_one_hot


from deepchem.models.tensorgraph.tensor_graph import TensorGraph
from deepchem.models.tensorgraph.layers import Feature, Label, Weights, WeightedError, Dense, Reshape, SoftMaxCrossEntropy, L2LossLayer
from deepchem.models.tensorgraph.layers import Feature, Label, Weights, WeightedError, Dense, Dropout, Reshape, SoftMaxCrossEntropy, L2Loss, Initializer

class TensorflowMultiTaskClassifier2(TensorGraph):
class TensorGraphMultiTaskClassifier(TensorGraph):
  def __init__(self,
               n_tasks,
               n_features,
               layer_sizes=[1000],
               weight_init_stddevs=[0.02],
               bias_init_consts=[1.0],
               dropouts=[0.5],
               n_classes=2,
               **kwargs):
    super().__init__(mode='classification', **kwargs)
@@ -41,8 +44,12 @@ class TensorflowMultiTaskClassifier2(TensorGraph):

    # Add the dense layers

    for size in layer_sizes:
      layer = Dense(in_layers=[prev_layer], out_channels=size, activation_fn=tf.nn.relu)
    for size, weight_stddev, bias_const, dropout in zip(layer_sizes, weight_init_stddevs, bias_init_consts, dropouts):
      layer = Dense(in_layers=[prev_layer], out_channels=size, activation_fn=tf.nn.relu,
                    weights_initializer=Initializer(tf.truncated_normal_initializer, stddev=weight_stddev),
                    biases_initializer=Initializer(tf.constant_initializer, value=bias_const))
      if dropout > 0.0:
        layer = Dropout(dropout, in_layers=[layer])
      prev_layer = layer

    # Compute the loss function for each label.
@@ -78,11 +85,14 @@ class TensorflowMultiTaskClassifier2(TensorGraph):



class TensorflowMultiTaskRegressor2(TensorGraph):
class TensorGraphMultiTaskRegressor(TensorGraph):
  def __init__(self,
               n_tasks,
               n_features,
               layer_sizes=[1000],
               weight_init_stddevs=[0.02],
               bias_init_consts=[1.0],
               dropouts=[0.5],
               **kwargs):
    super().__init__(mode='regression', **kwargs)
    self.n_tasks = n_tasks
@@ -95,8 +105,12 @@ class TensorflowMultiTaskRegressor2(TensorGraph):

    # Add the dense layers

    for size in layer_sizes:
      layer = Dense(in_layers=[prev_layer], out_channels=size, activation_fn=tf.nn.relu)
    for size, weight_stddev, bias_const, dropout in zip(layer_sizes, weight_init_stddevs, bias_init_consts, dropouts):
      layer = Dense(in_layers=[prev_layer], out_channels=size, activation_fn=tf.nn.relu,
                    weights_initializer=Initializer(tf.truncated_normal_initializer, stddev=weight_stddev),
                    biases_initializer=Initializer(tf.constant_initializer, value=bias_const))
      if dropout > 0.0:
        layer = Dropout(dropout, in_layers=[layer])
      prev_layer = layer

    # Compute the loss function for each label.
@@ -105,7 +119,7 @@ class TensorflowMultiTaskRegressor2(TensorGraph):
    self.add_output(output)
    labels = Label(shape=(None, n_tasks, 1))
    weights = Weights(shape=(None, n_tasks))
    loss = Reshape(shape=(-1, n_tasks), in_layers=[L2LossLayer(in_layers=[labels, output])])
    loss = Reshape(shape=(-1, n_tasks), in_layers=[L2Loss(in_layers=[labels, output])])
    weighted_loss = WeightedError(in_layers=[loss, weights])
    self.set_loss(weighted_loss)

+68 −33
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ class Layer(object):
  def set_tensors(self, tensor):
    self.out_tensor = tensor

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    raise NotImplementedError("Subclasses must implement for themselves")

  def __key(self):
@@ -80,7 +80,7 @@ class TensorWrapper(Layer):
    self.out_tensor = out_tensor
    super(TensorWrapper, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    """Take no actions."""
    pass

@@ -98,6 +98,26 @@ def convert_to_layers(in_layers):
  return layers


class Initializer(object):
  """This class exists as a workaround for Tensorflow initializers not being picklable."""

  def __init__(self, initializer_class, **kwargs):
    """Create an Initializer for constructing Tensorflow initializers.

    Parameters
    ----------
    initializer_class: class
      the type of initializer to create
    kwargs:
      any other arguments will be passed on to the Tensorflow initializer's constructor
    """
    self.initializer_class = initializer_class
    self.kwargs = kwargs

  def __call__(self):
    return self.initializer_class(**self.kwargs)


class Conv1D(Layer):

  def __init__(self, width, out_channels, **kwargs):
@@ -106,7 +126,7 @@ class Conv1D(Layer):
    self.out_tensor = None
    super(Conv1D, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -150,7 +170,7 @@ class Dense(Layer):
      scope_name = self.name
    self.scope_name = scope_name

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -197,7 +217,7 @@ class Flatten(Layer):
  def __init__(self, **kwargs):
    super(Flatten, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -219,7 +239,7 @@ class Reshape(Layer):
    self.shape = shape
    super(Reshape, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -234,7 +254,7 @@ class Transpose(Layer):
    super(Transpose, self).__init__(**kwargs)
    self.perm = perm

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -249,7 +269,7 @@ class CombineMeanStd(Layer):
  def __init__(self, **kwargs):
    super(CombineMeanStd, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -269,7 +289,7 @@ class Repeat(Layer):
    self.n_times = n_times
    super(Repeat, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -290,7 +310,7 @@ class GRU(Layer):
    self.batch_size = batch_size
    super(GRU, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -315,7 +335,7 @@ class TimeSeriesDense(Layer):
    self.out_channels = out_channels
    super(TimeSeriesDense, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -337,7 +357,7 @@ class Input(Layer):
    super(Input, self).__init__(**kwargs)
    self.op_type = "cpu"

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -380,7 +400,7 @@ class L2Loss(Layer):
  def __init__(self, **kwargs):
    super(L2Loss, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -395,7 +415,7 @@ class SoftMax(Layer):
  def __init__(self, **kwargs):
    super(SoftMax, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -412,7 +432,7 @@ class Concat(Layer):
    self.axis = axis
    super(Concat, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -434,7 +454,7 @@ class InteratomicL2Distances(Layer):
    self.ndim = ndim
    super(InteratomicL2Distances, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -458,7 +478,7 @@ class SoftMaxCrossEntropy(Layer):
  def __init__(self, **kwargs):
    super(SoftMaxCrossEntropy, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -477,7 +497,7 @@ class ReduceMean(Layer):
    self.axis = axis
    super(ReduceMean, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -493,7 +513,7 @@ class ReduceMean(Layer):

class ToFloat(Layer):

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -509,7 +529,7 @@ class ReduceSum(Layer):
    self.axis = axis
    super(ReduceSum, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -529,7 +549,7 @@ class ReduceSquareDifference(Layer):
    self.axis = axis
    super(ReduceSquareDifference, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -547,7 +567,7 @@ class Conv2D(Layer):
    self.kernel_size = kernel_size
    super(Conv2D, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -575,7 +595,7 @@ class MaxPool(Layer):
    self.padding = padding
    super(MaxPool, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -599,7 +619,7 @@ class InputFifoQueue(Layer):
    super(InputFifoQueue, self).__init__(**kwargs)
    self.op_type = "cpu"

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -639,7 +659,7 @@ class GraphConv(Layer):
    self.activation_fn = activation_fn
    super(GraphConv, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -745,7 +765,7 @@ class GraphPool(Layer):
    self.max_degree = max_degree
    super(GraphPool, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -795,7 +815,7 @@ class GraphGather(Layer):
    self.activation_fn = activation_fn
    super(GraphGather, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -837,7 +857,7 @@ class GraphGather(Layer):

class BatchNorm(Layer):

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -849,7 +869,7 @@ class BatchNorm(Layer):

class WeightedError(Layer):

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -927,7 +947,7 @@ class VinaFreeEnergy(Layer):
    out_tensor = tf.exp(-((d - 3) / 2)**2)
    return out_tensor

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    """
    Parameters
    ----------
@@ -983,7 +1003,7 @@ class WeightedLinearCombo(Layer):
    self.std = std
    super(WeightedLinearCombo, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
@@ -1035,7 +1055,7 @@ class NeighborList(Layer):
    self.stop = stop
    super(NeighborList, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    """Creates tensors associated with neighbor-listing."""
    if in_layers is None:
      in_layers = self.in_layers
@@ -1287,6 +1307,21 @@ class NeighborList(Layer):
            tf.transpose(tf.stack(tf.meshgrid(*mesh_args))), (self.n_cells,
                                                              self.ndim)))

class Dropout(Layer):

  def __init__(self, dropout_prob, **kwargs):
    self.dropout_prob = dropout_prob
    super(Dropout, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)
    parent_tensor = in_layers[0].out_tensor
    keep_prob = 1.0-self.dropout_prob*kwargs['training']
    self.out_tensor = tf.nn.dropout(parent_tensor, keep_prob)
    return self.out_tensor


class AtomicConvolution(Layer):

@@ -1316,7 +1351,7 @@ class AtomicConvolution(Layer):
    self.atom_types = atom_types
    super(AtomicConvolution, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None):
  def create_tensor(self, in_layers=None, **kwargs):
    """
    Parameters
    ----------
+12 −3
Original line number Diff line number Diff line
@@ -120,9 +120,11 @@ class TensorGraph(Model):
    def create_feed_dict():
      if self.use_queue:
        while True:
          yield {}
          yield {self._training_placeholder: 1.0}
      for d in feed_dict_generator:
        yield {k.out_tensor: v for k, v in six.iteritems(d)}
        feed_dict = {k.out_tensor: v for k, v in six.iteritems(d)}
        feed_dict[self._training_placeholder] = 1.0
        yield feed_dict

    if not self.built:
      self.build()
@@ -252,6 +254,7 @@ class TensorGraph(Model):
              self.layers[k.name].out_tensor: v
              for k, v in six.iteritems(feed_dict)
          }
          feed_dict[self._training_placeholder] = 0.0
          result = np.array(sess.run(out_tensors, feed_dict=feed_dict))
          if len(result.shape) == 3:
            result = np.transpose(result, axes=[1, 0, 2])
@@ -308,6 +311,7 @@ class TensorGraph(Model):
    if self.built:
      return
    with self._get_tf("Graph").as_default():
      self._training_placeholder = tf.placeholder(dtype=tf.float32, shape=())
      if self.random_seed is not None:
        tf.set_random_seed(self.random_seed)
      self._install_queue()
@@ -316,7 +320,7 @@ class TensorGraph(Model):
      for node in order:
        with tf.name_scope(node):
          node_layer = self.layers[node]
          node_layer.create_tensor()
          node_layer.create_tensor(training=self._training_placeholder)
      self.built = True

    for layer in self.layers.values():
@@ -342,6 +346,7 @@ class TensorGraph(Model):
    shapes = []
    pre_q_inputs = []
    q = InputFifoQueue(shapes, names, in_layers=pre_q_inputs)

    for layer in self.features + self.labels + self.task_weights:
      pre_q_input = layer.create_pre_q(self.batch_size)
      shapes.append(pre_q_input.shape)
@@ -374,6 +379,8 @@ class TensorGraph(Model):
      for node in self.topsort():
        node_layer = self.layers[node]
        out_tensors.append(node_layer.none_tensors())
      training_placeholder = self._training_placeholder
      self._training_placeholder = None
      self.built = False

    # Pickle itself
@@ -386,6 +393,7 @@ class TensorGraph(Model):
      for index, node in enumerate(self.topsort()):
        node_layer = self.layers[node]
        node_layer.set_tensors(out_tensors[index])
      self._training_placeholder = training_placeholder
      self.built = True
    self.tensor_objects = tensor_objects

@@ -498,6 +506,7 @@ def _enqueue_batch(tg, generator, graph, sess, coord):
    num_samples = 0
    for feed_dict in generator:
      enq = {}
      enq[tg._training_placeholder] = 1.0
      for layer in tg.features + tg.labels + tg.task_weights:
        enq[tg.get_pre_q_input(layer).out_tensor] = feed_dict[layer]
      sess.run(tg.input_queue.out_tensor, feed_dict=enq)