Commit 33789fc5 authored by Vignesh's avatar Vignesh
Browse files

Restore to toggle in pretrained model loading

parent 17540813
Loading
Loading
Loading
Loading
+43 −5
Original line number Diff line number Diff line
@@ -1010,9 +1010,34 @@ class KerasModel(Model):

    return assignment_map

  def _create_value_map(self, source_model, **kwargs):
    """
    Creates a value map between variables in the source model and their
    current values. This is used only when a custom value map is missing.

    Parameters
    ----------
    source_model: dc.models.KerasModel
        Source model to create value map from
    include_top: bool, default True
        if true, copies the last dense layer
    """
    value_map = {}
    source_vars = source_model.model.trainable_variables

    if tf.executing_eagerly():
      for source_var in source_vars:
        value_map[source_var] = source_var.numpy()
    else:
      for source_var in source_vars:
        value_map[source_var] = source_var.eval(session=self.session)

    return value_map

  def load_from_pretrained(self,
                           source_model,
                           assignment_map=None,
                           value_map=None,
                           checkpoint=None,
                           model_dir=None,
                           include_top=True,
@@ -1036,6 +1061,9 @@ class KerasModel(Model):
      model.
    assignment_map: Dict, default None
      Dictionary containing layer mapping between source and current model layers
    value_map: Dict, default None
      Dictionary containing source model trainable variables mapped to numpy
      arrays
    checkpoint: str, default None
      the path to the checkpoint file to load.  If this is None, the most recent
      checkpoint will be chosen automatically.  Call get_checkpoints() to get a
@@ -1046,25 +1074,35 @@ class KerasModel(Model):
        if True, copies the weights and bias associated with the final dense layer.
        Used only when assignment map is None.
    """

    self._ensure_built()
    if value_map is None:
      logger.info(
          "No value map provided. Creating default value map from restored model."
      )
      if tf.executing_eagerly():
        source_model.restore(model_dir=model_dir, checkpoint=checkpoint)
      else:
        source_model.restore(
            model_dir=model_dir, checkpoint=checkpoint, session=self.session)
      value_map = self._create_value_map(source_model=source_model)

    if assignment_map is None:
      logger.info("No assignment map provided. Creating custom assignment map.")
      assignment_map = self._create_assignment_map(
          source_model=source_model, include_top=include_top)

    if tf.executing_eagerly():
      source_model.restore(model_dir=model_dir, checkpoint=checkpoint)
      for source_var, dest_var in assignment_map.items():
        assert source_var.shape == dest_var.shape
        dest_var.assign(source_var)
        dest_var.assign(value_map[source_var])

    else:
      source_model.restore(
          model_dir=model_dir, checkpoint=checkpoint, session=self.session)
      self._assign_ops = []
      with self.session.as_default():
        for source_var, dest_var in assignment_map.items():
          assert source_var.shape == dest_var.shape
          assign_op = dest_var.assign(source_var)
          assign_op = dest_var.assign(value_map[source_var])
          self._assign_ops.append(assign_op)
          self.session.run(assign_op)

+49 −11
Original line number Diff line number Diff line
@@ -51,6 +51,7 @@ class TestPretrained(unittest.TestCase):
        feature_dim=self.feature_dim,
        model_dir=model_dir,
        batch_size=10)

    model.fit(self.dataset, nb_epoch=1000)
    predictions = np.squeeze(model.predict_on_batch(self.dataset.X))
    np.testing.assert_array_almost_equal(self.dataset.y, np.round(predictions))
@@ -58,9 +59,12 @@ class TestPretrained(unittest.TestCase):
  def test_load_from_pretrained_graph_mode(self):
    """Tests loading pretrained model in graph mode."""
    source_model = MLP(
        model_dir="./MLP/",
        hidden_layer_size=self.hidden_layer_size,
        feature_dim=self.feature_dim,
        hidden_layer_size=self.hidden_layer_size)
        model_dir=None,
        batch_size=10)

    source_model.fit(self.dataset, nb_epoch=1000, checkpoint_interval=0)

    dest_model = MLP(
        feature_dim=self.feature_dim,
@@ -68,29 +72,34 @@ class TestPretrained(unittest.TestCase):
        n_tasks=10)

    assignment_map = dict()
    value_map = dict()
    dest_vars = dest_model.model.trainable_variables[:-2]

    for idx, dest_var in enumerate(dest_vars):
      source_var = source_model.model.trainable_variables[idx]
      assignment_map[source_var] = dest_var
      value_map[source_var] = source_var.eval(session=source_model.session)

    dest_model.load_from_pretrained(
        source_model=source_model,
        assignment_map=assignment_map,
        include_top=False)
        value_map=value_map)

    for source_var, dest_var in assignment_map.items():
      np.testing.assert_array_almost_equal(
          source_var.eval(session=dest_model.session),
          source_var.eval(session=source_model.session),
          dest_var.eval(session=dest_model.session))

  def test_load_from_pretrained_eager(self):
  def test_load_from_pretrained_eager_mode(self):
    """Tests loading pretrained model in eager execution mode."""
    with context.eager_mode():
      source_model = MLP(
          model_dir="./MLP/",
          hidden_layer_size=self.hidden_layer_size,
          feature_dim=self.feature_dim,
          hidden_layer_size=self.hidden_layer_size)
          model_dir=None,
          batch_size=10)

      source_model.fit(self.dataset, nb_epoch=1000, checkpoint_interval=0)

      dest_model = MLP(
          feature_dim=self.feature_dim,
@@ -98,20 +107,25 @@ class TestPretrained(unittest.TestCase):
          n_tasks=10)

      assignment_map = dict()
      value_map = dict()
      dest_vars = dest_model.model.trainable_variables[:-2]

      for idx, dest_var in enumerate(dest_vars):
        source_var = source_model.model.trainable_variables[idx]
        assignment_map[source_var] = dest_var
        value_map[source_var] = source_var.numpy()

      dest_model.load_from_pretrained(
          source_model=source_model, assignment_map=assignment_map)
          source_model=source_model,
          value_map=value_map,
          assignment_map=assignment_map)

      for source_var, dest_var in assignment_map.items():
        np.testing.assert_array_almost_equal(source_var.numpy(),
                                             dest_var.numpy())

  def test_restore_equivalency(self):
  def test_restore_equivalency_graph_mode(self):
    """Test for restore based pretrained model loading in graph mode."""
    source_model = MLP(
        model_dir="./MLP/",
        feature_dim=self.feature_dim,
@@ -121,8 +135,32 @@ class TestPretrained(unittest.TestCase):
        feature_dim=self.feature_dim, hidden_layer_size=self.hidden_layer_size)

    dest_model.load_from_pretrained(
        source_model=source_model, assignment_map=None, include_top=True)
        source_model=source_model,
        assignment_map=None,
        value_map=None,
        include_top=True)

    predictions = np.squeeze(dest_model.predict_on_batch(self.dataset.X))

    np.testing.assert_array_almost_equal(self.dataset.y, np.round(predictions))

  def test_restore_equivalency_eager_mode(self):
    """Test for restore based pretrained model loading in eager mode."""
    with context.eager_mode():
      source_model = MLP(
          model_dir="./MLP/",
          feature_dim=self.feature_dim,
          hidden_layer_size=self.hidden_layer_size)

      dest_model = MLP(
          feature_dim=self.feature_dim,
          hidden_layer_size=self.hidden_layer_size)

      dest_model.load_from_pretrained(
          source_model=source_model,
          assignment_map=None,
          value_map=None,
          include_top=True)

      predictions = np.squeeze(dest_model.predict_on_batch(self.dataset.X))
      np.testing.assert_array_almost_equal(self.dataset.y,
                                           np.round(predictions))