Commit de7f5863 authored by peastman's avatar peastman
Browse files

Updated MAML example for new API

parent 1330ea31
Loading
Loading
Loading
Loading
+20 −2
Original line number Diff line number Diff line
@@ -140,8 +140,8 @@ class MAML(object):
        for s, t in zip(self._input_shapes, self._input_dtypes)
    ]
    variables = learner.variables
    self._loss, _ = learner.compute_model(self._input_placeholders, variables,
                                          False)
    self._loss, self._outputs = learner.compute_model(self._input_placeholders,
                                                      variables, False)
    loss, _ = learner.compute_model(self._input_placeholders, variables, True)

    # Build the meta-learning model.
@@ -272,3 +272,21 @@ class MAML(object):
      feed_dict[p] = v
    for i in range(optimization_steps):
      self._session.run(self._task_train_op, feed_dict=feed_dict)

  def predict_on_batch(self, inputs):
    """Compute the model's outputs for a batch of inputs.

    Parameters
    ----------
    inputs: list of arrays
      the inputs to the model

    Returns
    -------
    (loss, outputs) where loss is the value of the model's loss function, and
    outputs is a list of the model's outputs
    """
    feed_dict = {}
    for p, v in zip(self._input_placeholders, inputs):
      feed_dict[p] = v
    return self._session.run([self._loss, self._outputs], feed_dict=feed_dict)
+4 −14
Original line number Diff line number Diff line
@@ -93,22 +93,12 @@ class TestMAML(unittest.TestCase):

    new_maml = dc.metalearning.MAML(learner, model_dir=maml.model_dir)
    new_maml.restore()
    feed_dict = {}
    for j in range(len(batch)):
      feed_dict[new_maml._input_placeholders[j]] = batch[j]
      feed_dict[new_maml._meta_placeholders[j]] = batch[j]
    new_loss = np.average(
        np.sqrt(new_maml._session.run(new_maml._loss, feed_dict=feed_dict)))
    assert new_loss == loss1[-1]
    loss, outputs = new_maml.predict_on_batch(batch)
    assert np.sqrt(loss) == loss1[-1]

    # Do the same thing, only using the "restore" argument to fit().

    new_maml = dc.metalearning.MAML(learner, model_dir=maml.model_dir)
    new_maml.fit(0, restore=True)
    feed_dict = {}
    for j in range(len(batch)):
      feed_dict[new_maml._input_placeholders[j]] = batch[j]
      feed_dict[new_maml._meta_placeholders[j]] = batch[j]
    new_loss = np.average(
        np.sqrt(new_maml._session.run(new_maml._loss, feed_dict=feed_dict)))
    assert new_loss == loss1[-1]
    loss, outputs = new_maml.predict_on_batch(batch)
    assert np.sqrt(loss) == loss1[-1]
+32 −29
Original line number Diff line number Diff line
@@ -19,7 +19,7 @@ n_tasks = y.shape[1]
# Toxcast has data on 6874 molecules and 617 tasks.  However, the data is very
# sparse: most tasks do not include data for most molecules.  It also is very
# unbalanced: there are many more negatives than positives.  For each task,
# create a list of alternating postives and negatives so each batch will have
# create a list of alternating positives and negatives so each batch will have
# equal numbers of both.

task_molecules = []
@@ -28,16 +28,9 @@ for i in range(n_tasks):
  negatives = [j for j in range(n_molecules) if w[j, i] > 0 and y[j, i] == 0]
  np.random.shuffle(positives)
  np.random.shuffle(negatives)
  mols = sum((list(x) for x in zip(positives, negatives)), [])
  mols = sum((list(m) for m in zip(positives, negatives)), [])
  task_molecules.append(mols)

# Create the model to train.  We use a simple fully connected network with
# one hidden layer.

model = dc.models.MultitaskClassifier(
    1, n_features, layer_sizes=[1000], dropouts=[0.0])
model.build()

# Define a MetaLearner describing the learning problem.


@@ -48,10 +41,26 @@ class ToxcastLearner(dc.metalearning.MetaLearner):
    self.batch_size = 10
    self.batch_start = [0] * n_tasks
    self.set_task_index(0)
    self.w1 = tf.Variable(
        np.random.normal(size=[n_features, 1000], scale=0.02), dtype=tf.float32)
    self.w2 = tf.Variable(
        np.random.normal(size=[1000, 1], scale=0.02), dtype=tf.float32)
    self.b1 = tf.Variable(np.ones(1000), dtype=tf.float32)
    self.b2 = tf.Variable(np.zeros(1), dtype=tf.float32)

  def compute_model(self, inputs, variables, training):
    x, y = [tf.cast(i, tf.float32) for i in inputs]
    w1, w2, b1, b2 = variables
    dense1 = tf.nn.relu(tf.matmul(x, w1) + b1)
    logits = tf.matmul(dense1, w2) + b2
    output = tf.sigmoid(logits)
    loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=y))
    return loss, [output]

  @property
  def loss(self):
    return model.loss
  def variables(self):
    return [self.w1, self.w2, self.b1, self.b2]

  def set_task_index(self, index):
    self.task = index
@@ -63,18 +72,13 @@ class ToxcastLearner(dc.metalearning.MetaLearner):
    task = self.task
    start = self.batch_start[task]
    mols = task_molecules[task][start:start + self.batch_size]
    labels = np.zeros((self.batch_size, 1, 2))
    labels[np.arange(self.batch_size), 0, y[mols, task].astype(np.int64)] = 1
    weights = np.ones((self.batch_size, 1))
    feed_dict = {}
    feed_dict[model.features[0].out_tensor] = x[mols, :]
    feed_dict[model.labels[0].out_tensor] = labels
    feed_dict[model.task_weights[0].out_tensor] = weights
    labels = np.zeros((self.batch_size, 1))
    labels[np.arange(self.batch_size), 0] = y[mols, task]
    if start + 2 * self.batch_size > len(task_molecules[task]):
      self.batch_start[task] = 0
    else:
      self.batch_start[task] += self.batch_size
    return feed_dict
    return [x[mols, :], labels]


# Run meta-learning on 80% of the tasks.
@@ -93,16 +97,15 @@ def compute_scores(optimize):
  y_true = []
  y_pred = []
  losses = []
  with model._get_tf("Graph").as_default():
    prediction = tf.nn.softmax(model.outputs[0].out_tensor)
  for task in range(learner.n_training_tasks, n_tasks):
    learner.set_task_index(task)
    if optimize:
        maml.train_on_current_task()
      feed_dict = learner.get_batch()
      y_true.append(feed_dict[model.labels[0].out_tensor][:, 0, 0])
      y_pred.append(maml._session.run(prediction, feed_dict=feed_dict)[:, 0, 0])
      losses.append(maml._session.run(model.loss, feed_dict=feed_dict))
      maml.train_on_current_task(restore=True)
    inputs = learner.get_batch()
    loss, prediction = maml.predict_on_batch(inputs)
    y_true.append(inputs[1])
    y_pred.append(prediction[0][:, 0])
    losses.append(loss)
  y_true = np.concatenate(y_true)
  y_pred = np.concatenate(y_pred)
  print()