Commit ed78f8e7 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Adding test for frozen weights

parent 40275168
Loading
Loading
Loading
Loading
+19 −12
Original line number Diff line number Diff line
@@ -270,7 +270,8 @@ class ProgressiveMultitaskRegressor(TensorflowMultiTaskRegressor):

    return task_costs

  def fit(self, dataset, max_checkpoints_to_keep=5, **kwargs):
  def fit(self, dataset, tasks=None, close_session=True,
          max_checkpoints_to_keep=5, **kwargs):
    """Fit the model.

    Progressive networks are fit by training one task at a time. Iteratively
@@ -286,22 +287,28 @@ class ProgressiveMultitaskRegressor(TensorflowMultiTaskRegressor):
    AssertionError
      If model is not in training mode.
    """
    if tasks is None:
      tasks = range(self.n_tasks)
    with self.train_graph.graph.as_default():
      task_train_ops = {}
      for task in range(self.n_tasks):
        task_train_ops[task] = self.get_training_op(
            self.train_graph.graph, self.train_graph.loss, task)
      with self._get_shared_session(train=True) as sess:

      sess = self._get_shared_session(train=True)
      #with self._get_shared_session(train=True) as sess:
      sess.run(tf.initialize_all_variables())
      # Save an initial checkpoint.
      saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
      saver.save(sess, self._save_path, global_step=0)
        for task in range(self.n_tasks):
      for task in tasks:
        print("Fitting on task %d" % task)
        self.fit_task(sess, dataset, task, task_train_ops[task], **kwargs)
        saver.save(sess, self._save_path, global_step=task)
      # Always save a final checkpoint when complete.
      saver.save(sess, self._save_path, global_step=self.n_tasks)
      if close_session:
        sess.close()

  def get_training_op(self, graph, losses, task):
    """Get training op for applying gradients to variables.
+36 −0
Original line number Diff line number Diff line
@@ -105,3 +105,39 @@ class TestProgressive(test_util.TensorFlowTestCase):
        batch_size=2, verbosity="high")

    prog_model.fit(dataset)

  def test_frozen_weights(self):
    """Test that fitting one task doesn't change predictions of another.
    
    Tests that weights are frozen when training different tasks.
    """
    n_tasks = 2
    n_samples = 10
    n_features = 100
    np.random.seed(123)
    ids = np.arange(n_samples)
    X = np.random.rand(n_samples, n_features)
    y = np.zeros((n_samples, n_tasks))
    w = np.ones((n_samples, n_tasks))
    dataset = dc.data.NumpyDataset(X, y, w, ids)

    n_layers = 3
    prog_model = dc.models.ProgressiveMultitaskRegressor(
        n_tasks=n_tasks, n_features=n_features,
        alpha_init_stddevs=[.08]*n_layers, layer_sizes=[100]*n_layers,
        weight_init_stddevs=[.02]*n_layers, bias_init_consts=[1.]*n_layers,
        dropouts=[0.]*n_layers, learning_rate=0.003,
        batch_size=2, verbosity="high")

    # Fit just on task zero 
    # Notice that we keep the session open
    prog_model.fit(dataset, tasks=[0], close_session=False)
    y_pred_task_zero = prog_model.predict(dataset)[:, 0]

    # Fit on task one
    prog_model.fit(dataset, tasks=[1])
    y_pred_task_zero_after = prog_model.predict(dataset)[:, 0]

    # The predictions for task zero should not change after training
    # on task one. 
    np.testing.assert_allclose(y_pred_task_zero, y_pred_task_zero_after)