Unverified Commit 89ab80eb authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1733 from peastman/tf2

[WIP] Convert to TensorFlow 2
parents e4325ff5 b43af8f1
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
@@ -128,8 +128,7 @@ def get_user_specified_features(df, featurizer, verbose=True):
  time1 = time.time()
  df[featurizer.feature_fields] = df[featurizer.feature_fields].apply(
      pd.to_numeric)
  df_fields = df[featurizer.feature_fields]
  X_shard = df_fields.values
  X_shard = df[featurizer.feature_fields].to_numpy()
  time2 = time.time()
  log("TIMING: user specified processing took %0.3f s" % (time2 - time1),
      verbose)
+10 −11
Original line number Diff line number Diff line
@@ -265,15 +265,15 @@ class Dataset(object):
    else:
      return None

  def make_iterator(self,
  def make_tf_dataset(self,
                      batch_size=100,
                      epochs=1,
                      deterministic=False,
                      pad_batches=False):
    """Create a tf.data.Iterator that iterates over the data in this Dataset.
    """Create a tf.data.Dataset that iterates over the data in this Dataset.

    The iterator's get_next() method returns a tuple of three tensors (X, y, w)
    which can be used to retrieve the features, labels, and weights respectively.
    Each value returned by the Dataset's iterator is a tuple of (X, y, w) for
    one batch.

    Parameters
    ----------
@@ -297,7 +297,7 @@ class Dataset(object):
              tf.TensorShape([None] + list(y.shape)),
              tf.TensorShape([None] + list(w.shape)))

    # Create a Tensorflow Dataset and have it create an Iterator.
    # Create a Tensorflow Dataset.

    def gen_data():
      for epoch in range(epochs):
@@ -305,8 +305,7 @@ class Dataset(object):
                                             pad_batches):
          yield (X, y, w)

    dataset = tf.data.Dataset.from_generator(gen_data, dtypes, shapes)
    return dataset.make_one_shot_iterator()
    return tf.data.Dataset.from_generator(gen_data, dtypes, shapes)


class NumpyDataset(Dataset):
+8 −16
Original line number Diff line number Diff line
@@ -685,27 +685,19 @@ class TestDatasets(test_util.TensorFlowTestCase):
    assert new_data.y.shape == (num_datapoints * num_datasets, num_tasks)
    assert len(new_data.tasks) == len(datasets[0].tasks)

  def test_make_iterator(self):
  def test_make_tf_dataset(self):
    """Test creating a Tensorflow Iterator from a Dataset."""
    X = np.random.random((100, 5))
    y = np.random.random((100, 1))
    dataset = dc.data.NumpyDataset(X, y)
    iterator = dataset.make_iterator(
    iterator = dataset.make_tf_dataset(
        batch_size=10, epochs=2, deterministic=True)
    next_element = iterator.get_next()
    with self.session() as sess:
      for i in range(20):
        batch_X, batch_y, batch_w = sess.run(next_element)
    for i, (batch_X, batch_y, batch_w) in enumerate(iterator):
      offset = (i % 10) * 10
      np.testing.assert_array_equal(X[offset:offset + 10, :], batch_X)
      np.testing.assert_array_equal(y[offset:offset + 10, :], batch_y)
      np.testing.assert_array_equal(np.ones((10, 1)), batch_w)
      finished = False
      try:
        sess.run(next_element)
      except tf.errors.OutOfRangeError:
        finished = True
    assert finished
    assert i == 19


if __name__ == "__main__":
+48 −83
Original line number Diff line number Diff line
@@ -111,7 +111,8 @@ class MAML(object):
    # Record inputs.

    self.learner = learner
    self._learning_rate = learning_rate
    self.learning_rate = learning_rate
    self.optimization_steps = optimization_steps
    self.meta_batch_size = meta_batch_size
    self.optimizer = optimizer

@@ -127,68 +128,13 @@ class MAML(object):
    self.model_dir = model_dir
    self.save_file = "%s/%s" % (self.model_dir, "model")

    learner.select_task()
    example_inputs = learner.get_batch()
    self._input_shapes = [(None,) + i.shape[1:] for i in example_inputs]
    self._input_dtypes = [x.dtype for x in example_inputs]
    self._input_placeholders = [
        tf.placeholder(dtype=tf.as_dtype(t), shape=s)
        for s, t in zip(self._input_shapes, self._input_dtypes)
    ]
    self._meta_placeholders = [
        tf.placeholder(dtype=tf.as_dtype(t), shape=s)
        for s, t in zip(self._input_shapes, self._input_dtypes)
    ]
    variables = learner.variables
    self._loss, self._outputs = learner.compute_model(self._input_placeholders,
                                                      variables, False)
    loss, _ = learner.compute_model(self._input_placeholders, variables, True)

    # Build the meta-learning model.

    updated_variables = variables
    for i in range(optimization_steps):
      gradients = tf.gradients(loss, updated_variables)
      updated_variables = [
          v if g is None else v - self._learning_rate * g
          for v, g in zip(updated_variables, gradients)
      ]
      if i == optimization_steps - 1:
        # In the final loss, use different placeholders for all inputs so the loss will be
        # computed from a different batch.

        inputs = self._meta_placeholders
      else:
        inputs = self._input_placeholders
      loss, outputs = learner.compute_model(inputs, updated_variables, True)
    self._meta_loss = loss

    # Create variables for accumulating the gradients.

    variables = list(learner.variables)
    gradients = tf.gradients(self._meta_loss, variables)
    for i in reversed(range(len(variables))):
      if gradients[i] is None:
        del variables[i]
        del gradients[i]
    zero_gradients = [tf.zeros(g.shape, g.dtype) for g in gradients]
    summed_gradients = [tf.Variable(z, trainable=False) for z in zero_gradients]
    self._clear_gradients = tf.group(
        *[s.assign(z) for s, z in zip(summed_gradients, zero_gradients)])
    self._add_gradients = tf.group(
        *[s.assign_add(g) for s, g in zip(summed_gradients, gradients)])

    # Create the optimizers for meta-optimization and task optimization.

    self._global_step = tf.placeholder(tf.int32, [])
    grads_and_vars = list(zip(summed_gradients, variables))
    self._meta_train_op = optimizer._create_optimizer(
        self._global_step).apply_gradients(grads_and_vars)
    task_optimizer = GradientDescent(learning_rate=self._learning_rate)
    self._task_train_op = task_optimizer._create_optimizer(
        self._global_step).minimize(self._loss)
    self._session = tf.Session()
    self._session.run(tf.global_variables_initializer())
    self._global_step = tf.Variable(0, trainable=False)
    self._tf_optimizer = optimizer._create_optimizer(self._global_step)
    task_optimizer = GradientDescent(learning_rate=self.learning_rate)
    self._tf_task_optimizer = task_optimizer._create_optimizer(
        self._global_step)

    # Create a Checkpoint for saving.

@@ -227,32 +173,53 @@ class MAML(object):

    # Main optimization loop.

    learner = self.learner
    variables = learner.variables
    for i in range(steps):
      self._session.run(self._clear_gradients)
      for j in range(self.meta_batch_size):
        self.learner.select_task()
        inputs = self.learner.get_batch()
        feed_dict = {}
        feed_dict[self._global_step] = i
        for k in range(len(inputs)):
          feed_dict[self._input_placeholders[k]] = inputs[k]
          feed_dict[self._meta_placeholders[k]] = inputs[k]
        self._session.run(self._add_gradients, feed_dict=feed_dict)
      self._session.run(self._meta_train_op)
        learner.select_task()
        meta_loss, meta_gradients = self._compute_meta_loss(
            learner.get_batch(), learner.get_batch(), variables)
        if j == 0:
          summed_gradients = meta_gradients
        else:
          summed_gradients = [
              s + g for s, g in zip(summed_gradients, meta_gradients)
          ]
      self._tf_optimizer.apply_gradients(zip(summed_gradients, variables))

      # Do checkpointing.

      if i == steps - 1 or time.time() >= checkpoint_time + checkpoint_interval:
        with self._session.as_default():
        manager.save()
        checkpoint_time = time.time()

  @tf.function
  def _compute_meta_loss(self, inputs, inputs2, variables):
    """This is called during fitting to compute the meta-loss (the loss after a
    few steps of optimization), and its gradient.
    """
    updated_variables = variables
    with tf.GradientTape() as meta_tape:
      for k in range(self.optimization_steps):
        with tf.GradientTape() as tape:
          loss, _ = self.learner.compute_model(inputs, updated_variables, True)
        gradients = tape.gradient(loss, updated_variables)
        updated_variables = [
            v if g is None else v - self.learning_rate * g
            for v, g in zip(updated_variables, gradients)
        ]
      meta_loss, _ = self.learner.compute_model(inputs2, updated_variables,
                                                True)
    meta_gradients = meta_tape.gradient(meta_loss, variables)
    return meta_loss, meta_gradients

  def restore(self):
    """Reload the model parameters from the most recent checkpoint file."""
    last_checkpoint = tf.train.latest_checkpoint(self.model_dir)
    if last_checkpoint is None:
      raise ValueError('No checkpoint found')
    self._checkpoint.restore(last_checkpoint).run_restore_ops(self._session)
    self._checkpoint.restore(last_checkpoint)

  def train_on_current_task(self, optimization_steps=1, restore=True):
    """Perform a few steps of gradient descent to fine tune the model on the current task.
@@ -266,12 +233,13 @@ class MAML(object):
    """
    if restore:
      self.restore()
    inputs = self.learner.get_batch()
    feed_dict = {}
    for p, v in zip(self._input_placeholders, inputs):
      feed_dict[p] = v
    variables = self.learner.variables
    for i in range(optimization_steps):
      self._session.run(self._task_train_op, feed_dict=feed_dict)
      inputs = self.learner.get_batch()
      with tf.GradientTape() as tape:
        loss, _ = self.learner.compute_model(inputs, variables, True)
      gradients = tape.gradient(loss, variables)
      self._tf_task_optimizer.apply_gradients(zip(gradients, variables))

  def predict_on_batch(self, inputs):
    """Compute the model's outputs for a batch of inputs.
@@ -286,7 +254,4 @@ class MAML(object):
    (loss, outputs) where loss is the value of the model's loss function, and
    outputs is a list of the model's outputs
    """
    feed_dict = {}
    for p, v in zip(self._input_placeholders, inputs):
      feed_dict[p] = v
    return self._session.run([self._loss, self._outputs], feed_dict=feed_dict)
    return self.learner.compute_model(inputs, self.learner.variables, False)
+8 −18
Original line number Diff line number Diff line
@@ -62,17 +62,13 @@ class TestMAML(unittest.TestCase):
    loss2 = []
    for i in range(50):
      learner.select_task()
      maml.restore()
      batch = learner.get_batch()
      feed_dict = {}
      for j in range(len(batch)):
        feed_dict[maml._input_placeholders[j]] = batch[j]
        feed_dict[maml._meta_placeholders[j]] = batch[j]
      loss1.append(
          np.average(
              np.sqrt(maml._session.run(maml._loss, feed_dict=feed_dict))))
      loss2.append(
          np.average(
              np.sqrt(maml._session.run(maml._meta_loss, feed_dict=feed_dict))))
      loss, outputs = maml.predict_on_batch(batch)
      loss1.append(np.sqrt(loss))
      maml.train_on_current_task()
      loss, outputs = maml.predict_on_batch(batch)
      loss2.append(np.sqrt(loss))

    # Initially the model should do a bad job of fitting the sine function.

@@ -82,23 +78,17 @@ class TestMAML(unittest.TestCase):

    assert np.average(loss2) < 1.0

    # If we train on the current task, the loss should go down.

    maml.train_on_current_task()
    assert np.average(
        np.sqrt(maml._session.run(maml._loss, feed_dict=feed_dict))) < loss1[-1]

    # Verify that we can create a new MAML object, reload the parameters from the first one, and
    # get the same result.

    new_maml = dc.metalearning.MAML(learner, model_dir=maml.model_dir)
    new_maml = dc.metalearning.MAML(SineLearner(), model_dir=maml.model_dir)
    new_maml.restore()
    loss, outputs = new_maml.predict_on_batch(batch)
    assert np.sqrt(loss) == loss1[-1]

    # Do the same thing, only using the "restore" argument to fit().

    new_maml = dc.metalearning.MAML(learner, model_dir=maml.model_dir)
    new_maml = dc.metalearning.MAML(SineLearner(), model_dir=maml.model_dir)
    new_maml.fit(0, restore=True)
    loss, outputs = new_maml.predict_on_batch(batch)
    assert np.sqrt(loss) == loss1[-1]
Loading