Unverified Commit ff21bdac authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1620 from peastman/progressive

Converted progressive multitask models to KerasModel
parents 792dbac1 ef3c40a4
Loading
Loading
Loading
Loading
+60 −11
Original line number Diff line number Diff line
@@ -240,6 +240,7 @@ class KerasModel(Model):
    except ValueError:
      # The loss doesn't depend on any variables.
      self._train_op = 0
    self._train_op_for_vars = {}
    if self.tensorboard:
      self._summary_ops = tf.summary.scalar('loss', self._loss_tensor)
      self._summary_writer = tf.summary.FileWriter(self.model_dir)
@@ -259,7 +260,9 @@ class KerasModel(Model):
          max_checkpoints_to_keep=5,
          checkpoint_interval=1000,
          deterministic=False,
          restore=False):
          restore=False,
          variables=None,
          loss=None):
    """Train this model on a dataset.

    Parameters
@@ -279,17 +282,26 @@ class KerasModel(Model):
    restore: bool
      if True, restore the model from the most recent checkpoint and continue training
      from there.  If False, retrain the model from scratch.
    variables: list of tf.Variable
      the variables to train.  If None (the default), all trainable variables in
      the model are used.
    loss: function
      a function of the form f(outputs, labels, weights) that computes the loss
      for each batch.  If None (the default), the model's standard loss function
      is used.
   """
    return self.fit_generator(
        self.default_generator(
            dataset, epochs=nb_epoch, deterministic=deterministic),
        max_checkpoints_to_keep, checkpoint_interval, restore)
        max_checkpoints_to_keep, checkpoint_interval, restore, variables, loss)

  def fit_generator(self,
                    generator,
                    max_checkpoints_to_keep=5,
                    checkpoint_interval=1000,
                    restore=False):
                    restore=False,
                    variables=None,
                    loss=None):
    """Train this model on data from a generator.

    Parameters
@@ -305,6 +317,13 @@ class KerasModel(Model):
    restore: bool
      if True, restore the model from the most recent checkpoint and continue training
      from there.  If False, retrain the model from scratch.
    variables: list of tf.Variable
      the variables to train.  If None (the default), all trainable variables in
      the model are used.
    loss: function
      a function of the form f(outputs, labels, weights) that computes the loss
      for each batch.  If None (the default), the model's standard loss function
      is used.

    Returns
    -------
@@ -316,6 +335,7 @@ class KerasModel(Model):
                                           max_checkpoints_to_keep)
    avg_loss = 0.0
    averaged_batches = 0
    train_op = None
    time1 = time.time()

    # Main training loop.
@@ -334,6 +354,8 @@ class KerasModel(Model):

        # In eager mode we execute the loss function, accumulating the gradients.

        if loss is None:
          loss = self._loss_fn
        with tf.GradientTape() as tape:
          if len(inputs) == 1:
            inputs = inputs[0]
@@ -342,18 +364,38 @@ class KerasModel(Model):
            outputs = [outputs]
          if self._loss_outputs is not None:
            outputs = [outputs[i] for i in self._loss_outputs]
          loss = self._loss_fn(outputs, labels, weights)
        avg_loss += loss
        grads = tape.gradient(loss, self.model.trainable_variables)
        self._tf_optimizer.apply_gradients(
            zip(grads, self.model.trainable_variables))
          batch_loss = loss(outputs, labels, weights)
        avg_loss += batch_loss
        if variables is None:
          vars = self.model.trainable_variables
        else:
          vars = variables
        grads = tape.gradient(batch_loss, vars)
        self._tf_optimizer.apply_gradients(zip(grads, vars))
        tf.assign_add(self._global_step, 1)
        current_step = self._global_step.numpy()
      else:

        # In graph mode we execute the training op.

        fetches = [self._train_op, self._loss_tensor, self._global_step]
        if train_op is None:
          if loss is not None:
            loss_tensor = loss(
                [self._output_tensors[i] for i in self._loss_outputs],
                self._label_placeholders, self._weights_placeholders)
            train_op = self._tf_optimizer.minimize(
                loss_tensor, global_step=self._global_step, var_list=variables)
          elif variables is None:
            train_op = self._train_op
          else:
            var_key = tuple(variables)
            if var_key not in self._train_op_for_vars:
              self._train_op_for_vars[var_key] = self._tf_optimizer.minimize(
                  self._loss_tensor,
                  global_step=self._global_step,
                  var_list=variables)
            train_op = self._train_op_for_vars[var_key]
        fetches = [train_op, self._loss_tensor, self._global_step]
        if should_log:
          fetches.append(self._summary_ops)
        feed_dict = dict(zip(self._input_placeholders, inputs))
@@ -391,7 +433,7 @@ class KerasModel(Model):
    print("TIMING: model fitting took %0.3f s" % (time2 - time1))
    return avg_loss

  def fit_on_batch(self, X, y, w):
  def fit_on_batch(self, X, y, w, variables=None, loss=None):
    """Perform a single step of training.

    Parameters
@@ -402,11 +444,18 @@ class KerasModel(Model):
      the labels for the batch
    w: ndarray
      the weights for the batch
    variables: list of tf.Variable
      the variables to train.  If None (the default), all trainable variables in
      the model are used.
    loss: function
      a function of the form f(outputs, labels, weights) that computes the loss
      for each batch.  If None (the default), the model's standard loss function
      is used.
   """
    if not self.built:
      self.build()
    dataset = NumpyDataset(X, y, w)
    return self.fit(dataset, nb_epoch=1)
    return self.fit(dataset, nb_epoch=1, variables=variables, loss=loss)

  def _predict(self, generator, transformers, outputs, uncertainty):
    """
+22 −0
Original line number Diff line number Diff line
@@ -595,6 +595,28 @@ class Stack(tf.keras.layers.Layer):
    return tf.stack(inputs, axis=self.axis)


class Variable(tf.keras.layers.Layer):
  """Output a trainable value."""

  def __init__(self, initial_value, **kwargs):
    """Construct a variable layer.

    Parameters
    ----------
    initial_value: array or Tensor
      the initial value the layer should output
    """
    super(Variable, self).__init__(**kwargs)
    self.initial_value = initial_value

  def build(self, input_shape):
    self.var = tf.Variable(self.initial_value, dtype=self.dtype)
    self.built = True

  def call(self, inputs):
    return self.var


class VinaFreeEnergy(tf.keras.layers.Layer):
  """Computes free-energy as defined by Autodock Vina.

+15 −0
Original line number Diff line number Diff line
@@ -109,6 +109,21 @@ class SoftmaxCrossEntropy(Loss):
        labels, output, reduction=tf.losses.Reduction.NONE)


class SparseSoftmaxCrossEntropy(Loss):
  """The cross entropy between two probability distributions.

  The labels should have shape (batch_size) or (batch_size, tasks), and be
  integer class labels.  The outputs have shape (batch_size, classes) or
  (batch_size, tasks, classes) and be logits that are converted to probabilities
  using a softmax function.
  """

  def __call__(self, output, labels):
    labels = tf.cast(labels, tf.int32)
    return tf.losses.sparse_softmax_cross_entropy(
        labels, output, reduction=tf.losses.Reduction.NONE)


def _make_shapes_consistent(output, labels):
  """Try to make inputs have the same shape by adding dimensions of size 1."""
  shape1 = output.shape
+86 −122
Original line number Diff line number Diff line
@@ -9,48 +9,13 @@ import collections
from deepchem.utils.save import log
from deepchem.metrics import to_one_hot
from deepchem.metrics import from_one_hot
from deepchem.models.tensorgraph.tensor_graph import TensorGraph, TFWrapper
from deepchem.models.tensorgraph.layers import Layer, Feature, Label, Weights, \
    WeightedError, Dense, Dropout, WeightDecay, Reshape, SparseSoftMaxCrossEntropy, \
    L2Loss, ReduceSum, Concat, Stack, TensorWrapper, ReLU, Squeeze, SoftMax, Cast
from deepchem.models.tensorgraph.layers import convert_to_layers
from deepchem.models import KerasModel, layers
from deepchem.models.losses import L2Loss, SparseSoftmaxCrossEntropy
from deepchem.models.keras_model import _StandardLoss
from tensorflow.keras.layers import Input, Dense, Dropout, ReLU, Concatenate, Add, Multiply, Softmax


class Slice(Layer):
  """ Choose a slice of input on the last axis given order,
  Suppose input x has two dimensions,
  output f(x) = x[:, slice_num:slice_num+1]
  """

  def __init__(self, slice_num, axis=1, **kwargs):
    """
    Parameters
    ----------
    slice_num: int
      index of slice number
    axis: int
      axis id
    """
    self.slice_num = slice_num
    self.axis = axis
    super(Slice, self).__init__(**kwargs)

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)

    slice_num = self.slice_num
    axis = self.axis
    inputs = in_layers[0].out_tensor
    out_tensor = tf.slice(inputs, [0] * axis + [slice_num], [-1] * axis + [1])

    if set_tensors:
      self.out_tensor = out_tensor
    return out_tensor


class ProgressiveMultitaskRegressor(TensorGraph):
class ProgressiveMultitaskRegressor(KerasModel):
  """Implements a progressive multitask neural network for regression.

  Progressive Networks: https://arxiv.org/pdf/1606.04671v3.pdf
@@ -109,7 +74,8 @@ class ProgressiveMultitaskRegressor(TensorGraph):
      same value is used for every layer.
    """

    super(ProgressiveMultitaskRegressor, self).__init__(**kwargs)
    if weight_decay_penalty != 0.0:
      raise ValueError('Weight decay is not currently supported')
    self.n_tasks = n_tasks
    self.n_features = n_features
    self.layer_sizes = layer_sizes
@@ -133,81 +99,65 @@ class ProgressiveMultitaskRegressor(TensorGraph):
      self.activation_fns = [activation_fns] * n_layers

    # Add the input features.
    self.mol_features = Feature(shape=(None, n_features))
    self._task_labels = Label(shape=(None, n_tasks))
    self._task_weights = Weights(shape=(None, n_tasks))
    mol_features = Input(shape=(n_features,))

    all_layers = {}
    outputs = []
    self._task_layers = []
    for task in range(self.n_tasks):
      task_layers = []
      for i in range(n_layers):
        if i == 0:
          prev_layer = self.mol_features
          prev_layer = mol_features
        else:
          prev_layer = all_layers[(i - 1, task)]
          if task > 0:
            lateral_contrib, trainables = self.add_adapter(all_layers, task, i)
            task_layers.extend(trainables)

        layer = Dense(
            in_layers=[prev_layer],
            out_channels=layer_sizes[i],
            activation_fn=None,
            weights_initializer=TFWrapper(
                tf.truncated_normal_initializer,
        dense = Dense(
            layer_sizes[i],
            kernel_initializer=tf.truncated_normal_initializer(
                stddev=self.weight_init_stddevs[i]),
            biases_initializer=TFWrapper(
                tf.constant_initializer, value=self.bias_init_consts[i]))
        task_layers.append(layer)
            bias_initializer=tf.constant_initializer(
                value=self.bias_init_consts[i]))
        layer = dense(prev_layer)
        task_layers.append(dense)

        if i > 0 and task > 0:
          layer = layer + lateral_contrib
          layer = Add()([layer, lateral_contrib])
        assert self.activation_fns[i] is tf.nn.relu, "Only ReLU is supported"
        layer = ReLU(in_layers=[layer])
        layer = ReLU()(layer)
        if self.dropouts[i] > 0.0:
          layer = Dropout(self.dropouts[i], in_layers=[layer])
          layer = Dropout(self.dropouts[i])(layer)
        all_layers[(i, task)] = layer

      prev_layer = all_layers[(n_layers - 1, task)]
      layer = Dense(
          in_layers=[prev_layer],
          out_channels=n_outputs,
          weights_initializer=TFWrapper(
              tf.truncated_normal_initializer,
      dense = Dense(
          n_outputs,
          kernel_initializer=tf.truncated_normal_initializer(
              stddev=self.weight_init_stddevs[-1]),
          biases_initializer=TFWrapper(
              tf.constant_initializer, value=self.bias_init_consts[-1]))
      task_layers.append(layer)
          bias_initializer=tf.constant_initializer(
              value=self.bias_init_consts[-1]))
      layer = dense(prev_layer)
      task_layers.append(dense)

      if task > 0:
        lateral_contrib, trainables = self.add_adapter(all_layers, task,
                                                       n_layers)
        task_layers.extend(trainables)
        layer = layer + lateral_contrib
        layer = Add()([layer, lateral_contrib])
      output_layer = self.create_output(layer)
      outputs.append(output_layer)
      self._task_layers.append(task_layers)

      label = Slice(task, axis=1, in_layers=[self._task_labels])
      weight = Slice(task, axis=1, in_layers=[self._task_weights])
      task_loss = self.create_loss(layer, label, weight)
      self.create_submodel(layers=task_layers, loss=task_loss, optimizer=None)
    outputs = layers.Stack(axis=1)(outputs)
    model = tf.keras.Model(inputs=mol_features, outputs=outputs)
    super(ProgressiveMultitaskRegressor,
          self).__init__(model, self.create_loss(), **kwargs)

    outputs = Stack(axis=1, in_layers=outputs)
    self.add_output(outputs)

    # Weight decay not activated
    """
    if weight_decay_penalty != 0.0:
      weighted_loss = WeightDecay(
          weight_decay_penalty,
          weight_decay_penalty_type,
          in_layers=[weighted_loss])
    """

  def create_loss(self, layer, label, weight):
    weighted_loss = ReduceSum(L2Loss(in_layers=[label, layer, weight]))
    return weighted_loss
  def create_loss(self):
    return L2Loss()

  def create_output(self, layer):
    return layer
@@ -235,35 +185,31 @@ class ProgressiveMultitaskRegressor(TensorGraph):
      prev_layers.append(all_layers[(i - 1, prev_task)])
    # prev_layers is a list with elements of size
    # (batch_size, layer_sizes[i-1])
    prev_layer = Concat(axis=1, in_layers=prev_layers)
    with self._get_tf("Graph").as_default():
      alpha = TensorWrapper(
          tf.Variable(
              tf.truncated_normal((1,), stddev=alpha_init_stddev),
              name="alpha_layer_%d_task%d" % (i, task)))
    if len(prev_layers) == 1:
      prev_layer = prev_layers[0]
    else:
      prev_layer = Concatenate(axis=1)(prev_layers)
    alpha = layers.Variable(tf.truncated_normal((1,), stddev=alpha_init_stddev))
    trainable_layers.append(alpha)

    prev_layer = prev_layer * alpha
    prev_layer = Multiply()([prev_layer, alpha([])])
    dense1 = Dense(
        in_layers=[prev_layer],
        out_channels=layer_sizes[i - 1],
        activation_fn=None,
        weights_initializer=TFWrapper(
            tf.truncated_normal_initializer, stddev=weight_init_stddev),
        biases_initializer=TFWrapper(
            tf.constant_initializer, value=bias_init_const))
        layer_sizes[i - 1],
        kernel_initializer=tf.truncated_normal_initializer(
            stddev=weight_init_stddev),
        bias_initializer=tf.constant_initializer(value=bias_init_const))
    prev_layer = dense1(prev_layer)
    trainable_layers.append(dense1)

    dense2 = Dense(
        in_layers=[dense1],
        out_channels=layer_sizes[i],
        activation_fn=None,
        weights_initializer=TFWrapper(
            tf.truncated_normal_initializer, stddev=weight_init_stddev),
        biases_initializer=None)
        layer_sizes[i],
        kernel_initializer=tf.truncated_normal_initializer(
            stddev=weight_init_stddev),
        use_bias=False)
    prev_layer = dense2(prev_layer)
    trainable_layers.append(dense2)

    return dense2, trainable_layers
    return prev_layer, trainable_layers

  def fit(self,
          dataset,
@@ -276,28 +222,40 @@ class ProgressiveMultitaskRegressor(TensorGraph):
    for task in range(self.n_tasks):
      self.fit_task(
          dataset,
          task,
          nb_epoch=nb_epoch,
          max_checkpoints_to_keep=max_checkpoints_to_keep,
          checkpoint_interval=checkpoint_interval,
          deterministic=deterministic,
          restore=restore,
          submodel=task,
          **kwargs)

  def fit_task(self,
               dataset,
               task,
               nb_epoch=10,
               max_checkpoints_to_keep=5,
               checkpoint_interval=1000,
               deterministic=False,
               restore=False,
               submodel=None,
               **kwargs):
    """Fit one task."""
    shape = dataset.get_shape()
    batch = [[np.zeros((self.batch_size,) + s[1:])] for s in shape]
    self._create_training_ops(batch)
    generator = self.default_generator(
        dataset, epochs=nb_epoch, deterministic=deterministic)
    self.fit_generator(generator, max_checkpoints_to_keep, checkpoint_interval,
                       restore, self.submodels[submodel])
    variables = []
    for layer in self._task_layers[task]:
      variables.append(layer.trainable_variables)
    loss = TaskLoss(self.model, self.create_loss(), task)
    self.fit_generator(
        generator,
        max_checkpoints_to_keep,
        checkpoint_interval,
        restore,
        variables=variables,
        loss=loss)


class ProgressiveMultitaskClassifier(ProgressiveMultitaskRegressor):
@@ -338,15 +296,21 @@ class ProgressiveMultitaskClassifier(ProgressiveMultitaskRegressor):
        n_outputs=n_outputs,
        **kwargs)

  def create_loss(self, layer, label, weight):
    task_label = Squeeze(squeeze_dims=1, in_layers=[label])
    task_label = Cast(dtype=tf.int32, in_layers=[task_label])
    task_weight = Squeeze(squeeze_dims=1, in_layers=[weight])

    loss = SparseSoftMaxCrossEntropy(in_layers=[task_label, layer])
    weighted_loss = WeightedError(in_layers=[loss, task_weight])
    return weighted_loss
  def create_loss(self):
    return SparseSoftmaxCrossEntropy()

  def create_output(self, layer):
    output = SoftMax(in_layers=[layer])
    return output
    return Softmax()(layer)


class TaskLoss(_StandardLoss):

  def __init__(self, model, loss, task):
    super(TaskLoss, self).__init__(model, loss)
    self.task = task

  def __call__(self, outputs, labels, weights):
    outputs = [t[:, self.task] for t in outputs]
    labels = [t[:, self.task] for t in labels]
    weights = [t[:, self.task] for t in weights]
    return super(TaskLoss, self).__call__(outputs, labels, weights)
+78 −0
Original line number Diff line number Diff line
@@ -278,3 +278,81 @@ class TestKerasModel(unittest.TestCase):
    event_file = os.path.join(model.model_dir, event_file[0])
    file_size = os.stat(event_file).st_size
    assert file_size > 0

  def test_fit_variables(self):
    """Test training a subset of the variables in a model."""

    class VarModel(tf.keras.Model):

      def __init__(self, **kwargs):
        super(VarModel, self).__init__(**kwargs)
        self.var1 = tf.Variable([0.5])
        self.var2 = tf.Variable([0.5])

      def call(self, inputs, training=False):
        return [self.var1, self.var2]

    def loss(outputs, labels, weights):
      return (outputs[0] * outputs[1] - labels[0])**2

    keras_model = VarModel()
    model = dc.models.KerasModel(keras_model, loss, learning_rate=0.01)
    x = np.ones((1, 1))
    vars = model.predict_on_batch(x)
    assert np.allclose(vars[0], 0.5)
    assert np.allclose(vars[1], 0.5)
    model.fit_generator([(x, x, x)] * 300)
    vars = model.predict_on_batch(x)
    assert np.allclose(vars[0], 1.0)
    assert np.allclose(vars[1], 1.0)
    model.fit_generator([(x, 2 * x, x)] * 300, variables=[keras_model.var1])
    vars = model.predict_on_batch(x)
    assert np.allclose(vars[0], 2.0)
    assert np.allclose(vars[1], 1.0)
    model.fit_generator([(x, x, x)] * 300, variables=[keras_model.var2])
    vars = model.predict_on_batch(x)
    assert np.allclose(vars[0], 2.0)
    assert np.allclose(vars[1], 0.5)

  def test_fit_variables_eager(self):
    """Test training a subset of the variables in a model, in eager mode."""
    with context.eager_mode():
      self.test_fit_variables()

  def test_fit_loss(self):
    """Test specifying a different loss function when calling fit()."""

    class VarModel(tf.keras.Model):

      def __init__(self, **kwargs):
        super(VarModel, self).__init__(**kwargs)
        self.var1 = tf.Variable([0.5])
        self.var2 = tf.Variable([0.5])

      def call(self, inputs, training=False):
        return [self.var1, self.var2]

    def loss1(outputs, labels, weights):
      return (outputs[0] * outputs[1] - labels[0])**2

    def loss2(outputs, labels, weights):
      return (outputs[0] + outputs[1] - labels[0])**2

    keras_model = VarModel()
    model = dc.models.KerasModel(keras_model, loss1, learning_rate=0.01)
    x = np.ones((1, 1))
    vars = model.predict_on_batch(x)
    assert np.allclose(vars[0], 0.5)
    assert np.allclose(vars[1], 0.5)
    model.fit_generator([(x, x, x)] * 300)
    vars = model.predict_on_batch(x)
    assert np.allclose(vars[0], 1.0)
    assert np.allclose(vars[1], 1.0)
    model.fit_generator([(x, 3 * x, x)] * 300, loss=loss2)
    vars = model.predict_on_batch(x)
    assert np.allclose(vars[0] + vars[1], 3.0)

  def test_fit_loss_eager(self):
    """Test specifying a different loss function when calling fit(), in eager mode."""
    with context.eager_mode():
      self.test_fit_loss()
Loading