Commit f27a3fdf authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #855 from peastman/submodel

Changes to support GANs
parents 2745d99d babd9621
Loading
Loading
Loading
Loading
+35 −0
Original line number Diff line number Diff line
@@ -270,6 +270,16 @@ class Layer(object):
  def __neg__(self):
    return Multiply([self, Constant(-1.0)])

  def __div__(self, other):
    if not isinstance(other, Layer):
      other = Constant(other)
    return Divide([self, other])

  def __truediv__(self, other):
    if not isinstance(other, Layer):
      other = Constant(other)
    return Divide([self, other])


def _convert_layer_to_tensor(value, dtype=None, name=None, as_ref=False):
  return tf.convert_to_tensor(value.out_tensor, dtype=dtype, name=name)
@@ -1102,6 +1112,31 @@ class Multiply(Layer):
    return out_tensor


class Divide(Layer):
  """Compute the ratio of the input layers."""

  def __init__(self, in_layers=None, **kwargs):
    super(Divide, self).__init__(in_layers, **kwargs)
    try:
      shape1 = list(self.in_layers[0].shape)
      shape2 = list(self.in_layers[1].shape)
      if len(shape1) < len(shape2):
        shape2, shape1 = shape1, shape2
      offset = len(shape1) - len(shape2)
      for i in range(len(shape2)):
        shape1[i + offset] = _max_dimension(shape1[i + offset], shape2[i])
      self._shape = tuple(shape1)
    except:
      pass

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    inputs = self._get_input_tensors(in_layers)
    out_tensor = inputs[0] / inputs[1]
    if set_tensors:
      self.out_tensor = out_tensor
    return out_tensor


class Log(Layer):
  """Compute the natural log of the input."""

+0 −1
Original line number Diff line number Diff line
@@ -149,7 +149,6 @@ class TestANIRegression(unittest.TestCase):
    assert self.model.layer_structures == restored_model.layer_structures
    assert self.model.atom_number_cases == restored_model.atom_number_cases
    assert self.model.batch_size == restored_model.batch_size
    assert self.model.learning_rate == restored_model.learning_rate
    assert self.model.use_queue == restored_model.use_queue

    assert expected == predicted
+162 −41
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ from tensorflow.python.framework.errors_impl import OutOfRangeError
from deepchem.data import NumpyDataset
from deepchem.metrics import to_one_hot, from_one_hot
from deepchem.models.models import Model
from deepchem.models.tensorgraph.layers import InputFifoQueue, Label, Feature, Weights
from deepchem.models.tensorgraph.layers import InputFifoQueue, Label, Feature, Weights, Constant
from deepchem.models.tensorgraph.optimizers import Adam
from deepchem.trans import undo_transforms
from deepchem.utils.evaluate import GeneratorEvaluator
@@ -60,11 +60,12 @@ class TensorGraph(Model):
    self.labels = list()
    self.outputs = list()
    self.task_weights = list()
    self.loss = None
    self.submodels = list()
    self.loss = Constant(0)
    self.built = False
    self.queue_installed = False
    self.optimizer = None
    self.learning_rate = learning_rate
    self.optimizer = Adam(
        learning_rate=learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-7)

    # Singular place to hold Tensor objects which don't serialize
    # These have to be reconstructed on restoring from pickle
@@ -116,7 +117,8 @@ class TensorGraph(Model):
          max_checkpoints_to_keep=5,
          checkpoint_interval=1000,
          deterministic=False,
          restore=False):
          restore=False,
          submodel=None):
    """Train this model on a dataset.

    Parameters
@@ -129,23 +131,28 @@ class TensorGraph(Model):
      the maximum number of checkpoints to keep.  Older checkpoints are discarded.
    checkpoint_interval: int
      the frequency at which to write checkpoints, measured in training steps.
      Set this to 0 to disable automatic checkpointing.
    deterministic: bool
      if True, the samples are processed in order.  If False, a different random
      order is used for each epoch.
    restore: bool
      if True, restore the model from the most recent checkpoint and continue training
      from there.  If False, retrain the model from scratch.
    submodel: Submodel
      an alternate training objective to use.  This should have been created by
      calling create_submodel().
    """
    return self.fit_generator(
        self.default_generator(
            dataset, epochs=nb_epoch, deterministic=deterministic),
        max_checkpoints_to_keep, checkpoint_interval, restore)
        max_checkpoints_to_keep, checkpoint_interval, restore, submodel)

  def fit_generator(self,
                    feed_dict_generator,
                    max_checkpoints_to_keep=5,
                    checkpoint_interval=1000,
                    restore=False):
                    restore=False,
                    submodel=None):
    """Train this model on data from a generator.

    Parameters
@@ -157,9 +164,17 @@ class TensorGraph(Model):
      the maximum number of checkpoints to keep.  Older checkpoints are discarded.
    checkpoint_interval: int
      the frequency at which to write checkpoints, measured in training steps.
      Set this to 0 to disable automatic checkpointing.
    restore: bool
      if True, restore the model from the most recent checkpoint and continue training
      from there.  If False, retrain the model from scratch.
    submodel: Submodel
      an alternate training objective to use.  This should have been created by
      calling create_submodel().

    Returns
    -------
    the average loss over the most recent checkpoint interval
    """

    def create_feed_dict():
@@ -175,20 +190,18 @@ class TensorGraph(Model):
      self.build()
    with self._get_tf("Graph").as_default():
      time1 = time.time()
      loss = self.loss
      if submodel is None:
        train_op = self._get_tf('train_op')
      else:
        train_op = submodel.get_train_op()
        if submodel.loss is not None:
          loss = submodel.loss
      if checkpoint_interval > 0:
        saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
      self.session.run(tf.global_variables_initializer())
      if restore:
        self.restore()
      else:
        # Initialize variables that have pre-trained values.
        for layer in self.layers.values():
          if layer.variable_values is not None:
            variables = self.get_layer_variables(layer)
            for var, val in zip(variables, layer.variable_values):
              self.session.run(var.assign(val))
      avg_loss, n_averaged_batches = 0.0, 0.0
      coord = tf.train.Coordinator()
      n_samples = 0
      n_enqueued = [0]
      final_sample = [None]
@@ -212,17 +225,16 @@ class TensorGraph(Model):
        n_samples += 1
        should_log = (self.tensorboard and
                      n_samples % self.tensorboard_log_frequency == 0)
        fetches = [train_op, self.loss.out_tensor]
        fetches = [train_op, loss.out_tensor]
        if should_log:
          fetches.append(self._get_tf("summary_op"))
        fetched_values = self.session.run(fetches, feed_dict=feed_dict)
        if should_log:
          self._log_tensorboard(fetches[2])
        loss = fetched_values[1]
        avg_loss += loss
        avg_loss += fetched_values[1]
        n_averaged_batches += 1
        self.global_step += 1
        if self.global_step % checkpoint_interval == checkpoint_interval - 1:
        if checkpoint_interval > 0 and self.global_step % checkpoint_interval == checkpoint_interval - 1:
          saver.save(self.session, self.save_file, global_step=self.global_step)
          avg_loss = float(avg_loss) / n_averaged_batches
          print('Ending global_step %d: Average loss %g' % (self.global_step,
@@ -230,11 +242,14 @@ class TensorGraph(Model):
          avg_loss, n_averaged_batches = 0.0, 0.0
      if n_averaged_batches > 0:
        avg_loss = float(avg_loss) / n_averaged_batches
      if checkpoint_interval > 0:
        if n_averaged_batches > 0:
          print('Ending global_step %d: Average loss %g' % (self.global_step,
                                                            avg_loss))
        saver.save(self.session, self.save_file, global_step=self.global_step)
        time2 = time.time()
        print("TIMING: model fitting took %0.3f s" % (time2 - time1))
    return avg_loss

  def _log_tensorboard(self, summary):
    """
@@ -250,11 +265,11 @@ class TensorGraph(Model):
    writer.add_summary(summary, global_step=global_step)
    writer.close()

  def fit_on_batch(self, X, y, w):
  def fit_on_batch(self, X, y, w, submodel=None):
    if not self.built:
      self.build()
    dataset = NumpyDataset(X, y)
    return self.fit(dataset, nb_epoch=1)
    return self.fit(dataset, nb_epoch=1, submodel=submodel)

  def default_generator(self,
                        dataset,
@@ -435,6 +450,9 @@ class TensorGraph(Model):
    for l in self.features + self.labels + self.task_weights + self.outputs:
      add_layers_to_list(l, sorted_layers)
    add_layers_to_list(self.loss, sorted_layers)
    for submodel in self.submodels:
      if submodel.loss is not None:
        add_layers_to_list(submodel.loss, sorted_layers)
    return sorted_layers

  def build(self):
@@ -453,9 +471,23 @@ class TensorGraph(Model):
          self.rnn_zero_states += layer.rnn_zero_states
          layer.add_summary_to_tg()
      self.session = tf.Session()

      self.built = True

      # Ensure all training operators have been created.

      self._get_tf('train_op')
      for submodel in self.submodels:
        train_op = submodel.get_train_op()

      # Initialize variables.

      self.session.run(tf.global_variables_initializer())
      for layer in self.layers.values():
        if layer.variable_values is not None:
          variables = self.get_layer_variables(layer)
          for var, val in zip(variables, layer.variable_values):
            self.session.run(var.assign(val))

    for layer in self.layers.values():
      if layer.tensorboard:
        self.tensorboard = True
@@ -515,6 +547,46 @@ class TensorGraph(Model):
    """Set the optimizer to use for fitting."""
    self.optimizer = optimizer

  def create_submodel(self, layers=None, loss=None, optimizer=None):
    """Create an alternate objective for training one piece of a TensorGraph.

    A TensorGraph consists of a set of layers, and specifies a loss function and
    optimizer to use for training those layers.  Usually this is sufficient, but
    there are cases where you want to train different parts of a model separately.
    For example, a GAN consists of a generator and a discriminator.  They are
    trained separately, and they use different loss functions.

    A submodel defines an alternate objective to use in cases like this.  It may
    optionally specify any of the following: a subset of layers in the model to
    train; a different loss function; and a different optimizer to use.  This
    method creates a submodel, which you can then pass to fit() to use it for
    training.

    Parameters
    ----------
    layers: list
      the list of layers to train.  If None, all layers in the model will be
      trained.
    loss: Layer
      the loss function to optimize.  If None, the model's main loss function
      will be used.
    optimizer: Optimizer
      the optimizer to use for training.  If None, the model's main optimizer
      will be used.

    Returns
    -------
    the newly created submodel, which can be passed to any of the fitting
    methods.
    """
    if self.built:
      raise ValueError('Submodels must be created before build() is called.')
    submodel = Submodel(self, layers, loss, optimizer)
    self.submodels.append(submodel)
    if loss is not None:
      self._add_layer(loss)
    return submodel

  def get_pickling_errors(self, obj, seen=None):
    if seen == None:
      seen = []
@@ -557,8 +629,6 @@ class TensorGraph(Model):
      must_restore = True
      for layer in self.topsort():
        out_tensors.append(layer.none_tensors())
      optimizer = self.optimizer
      self.optimizer = None
      training_placeholder = self._training_placeholder
      self._training_placeholder = None
      self.built = False
@@ -578,7 +648,6 @@ class TensorGraph(Model):
      for index, layer in enumerate(self.topsort()):
        layer.set_tensors(out_tensors[index])
      self._training_placeholder = training_placeholder
      self.optimizer = optimizer
      self.built = True
    self.tensor_objects = tensor_objects
    self.rnn_initial_states = rnn_initial_states
@@ -652,17 +721,17 @@ class TensorGraph(Model):
    elif obj == "FileWriter":
      self.tensor_objects['FileWriter'] = tf.summary.FileWriter(self.model_dir)
    elif obj == 'Optimizer':
      if self.optimizer is None:
        self.optimizer = Adam(
            learning_rate=self.learning_rate,
            beta1=0.9,
            beta2=0.999,
            epsilon=1e-7)
      self.tensor_objects['Optimizer'] = self.optimizer._create_optimizer(
          self._get_tf('GlobalStep'))
    elif obj == 'train_op':
      self.tensor_objects['train_op'] = self._get_tf('Optimizer').minimize(
          self.loss.out_tensor, global_step=self._get_tf('GlobalStep'))
      opt = self._get_tf('Optimizer')
      global_step = self._get_tf('GlobalStep')
      try:
        self.tensor_objects['train_op'] = opt.minimize(
            self.loss.out_tensor, global_step=global_step)
      except ValueError:
        # The loss doesn't depend on any variables.
        self.tensor_objects['train_op'] = 0
    elif obj == 'summary_op':
      self.tensor_objects['summary_op'] = tf.summary.merge_all(
          key=tf.GraphKeys.SUMMARIES)
@@ -671,6 +740,21 @@ class TensorGraph(Model):
        self.tensor_objects['GlobalStep'] = tf.Variable(0, trainable=False)
    return self._get_tf(obj)

  def save_checkpoint(self, max_checkpoints_to_keep=5):
    """Save a checkpoint to disk.

    Usually you do not need to call this method, since fit() saves checkpoints
    automatically.  If you have disabled automatic checkpointing during fitting,
    this can be called to manually write checkpoints.

    Parameters
    ----------
    max_checkpoints_to_keep: int
      the maximum number of checkpoints to keep.  Older checkpoints are discarded.
    """
    saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
    saver.save(self.session, self.save_file, global_step=self.global_step)

  def restore(self):
    """Reload the values of all variables from the most recent checkpoint file."""
    if not self.built:
@@ -756,3 +840,40 @@ class TFWrapper(object):

  def __call__(self):
    return self.tf_class(**self.kwargs)


class Submodel(object):
  """An alternate objective for training one piece of a TensorGraph."""

  def __init__(self, graph, layers, loss, optimizer):
    """Create a submodel.

    In normal use, you should call create_submodel() on the TensorGraph instead
    of using this constructor directly."""
    self.graph = graph
    self.layers = layers
    self.loss = loss
    self.optimizer = optimizer
    self._train_op = None

  def get_train_op(self):
    """Get the Tensorflow operator to use for training."""
    if self._train_op is None:
      if self.layers is None:
        variables = None
      else:
        variables = []
        for layer in self.layers:
          variables += self.graph.get_layer_variables(layer)
      if self.loss is None:
        loss = self.graph.loss
      else:
        loss = self.loss
      if self.optimizer is None:
        optimizer = self.graph.optimizer
      else:
        optimizer = self.optimizer
      global_step = self.graph._get_tf('GlobalStep')
      tf_opt = optimizer._create_optimizer(global_step)
      self._train_op = tf_opt.minimize(loss.out_tensor, global_step, variables)
    return self._train_op
+53 −2
Original line number Diff line number Diff line
@@ -311,21 +311,31 @@ class TestTensorGraph(unittest.TestCase):
    expected.append(2 * v2)
    tg.add_output(-c1)
    expected.append(-v1)
    tg.add_output(c1 / c2)
    expected.append(v1 / v2)
    tg.add_output(c1 / 2)
    expected.append(v1 / 2)
    for o, e in zip(tg.outputs, expected):
      value = tg.predict_on_batch(np.array([0]), outputs=o)
      assert np.array_equal(e, value)

  def test_initialize_variable(self):
    """Test methods for initializing a variable."""
    # Set by variable constructor.

    tg = dc.models.TensorGraph(use_queue=False)
    features = Feature(shape=(None, 1))
    tg.set_loss(Dense(1, in_layers=features))
    var = Variable([10.0])
    tg.add_output(var)
    tg.fit_generator([])
    assert tg.predict_on_batch(np.zeros((1, 1))) == [10.0]

    # Set by set_variable_initial_values().

    tg = dc.models.TensorGraph(use_queue=False)
    tg.set_loss(Dense(1, in_layers=features))
    var.set_variable_initial_values([[15.0]])
    tg.fit_generator([])
    tg.add_output(var)
    assert tg.predict_on_batch(np.zeros((1, 1))) == [15.0]

  def test_copy_layers(self):
@@ -350,3 +360,44 @@ class TestTensorGraph(unittest.TestCase):
      values = tg.session.run(variables)
    for v1, v2 in zip(values, copy.in_layers[0].variable_values):
      assert np.array_equal(v1, v2)

  def test_submodels(self):
    """Test optimizing submodels."""
    tg = dc.models.TensorGraph(learning_rate=0.1, batch_size=1)
    features = Feature(shape=(None, 1))
    var1 = Variable([2.0])
    var2 = Variable([2.0])
    tg.add_output(var1)
    tg.add_output(var2)
    loss = (var1 - 1) * (var1 - 1) + (var2 - 1) * (var2 - 1) + features
    tg.set_loss(loss)
    subloss1 = var1 * var1 + features
    subloss2 = var1 * var1 + var2 * var2 + features
    submodel1 = tg.create_submodel(loss=subloss1)
    submodel2 = tg.create_submodel(layers=[var2], loss=subloss2)
    data = np.zeros((1, 1))
    generator = [{features: data}] * 500

    # Optimize submodel 1.  This should send var1 to 0 while leaving var2 unchanged.

    tg.fit_generator(generator, submodel=submodel1)
    self.assertAlmostEqual(
        0.0, tg.predict_on_batch(data, outputs=var1)[0], places=4)
    self.assertAlmostEqual(
        2.0, tg.predict_on_batch(data, outputs=var2)[0], places=4)

    # Optimize the main loss.  This should send both variables toward 1.

    tg.fit_generator(generator)
    self.assertAlmostEqual(
        1.0, tg.predict_on_batch(data, outputs=var1)[0], places=4)
    self.assertAlmostEqual(
        1.0, tg.predict_on_batch(data, outputs=var2)[0], places=4)

    # Optimize submodel 2.  This should send var2 to 0 while leaving var1 unchanged.

    tg.fit_generator(generator, submodel=submodel2)
    self.assertAlmostEqual(
        1.0, tg.predict_on_batch(data, outputs=var1)[0], places=4)
    self.assertAlmostEqual(
        0.0, tg.predict_on_batch(data, outputs=var2)[0], places=4)
+308 −0

File added.

Preview size limit exceeded, changes collapsed.