Commit c3688a96 authored by Vignesh's avatar Vignesh
Browse files

Switched back to variable based restore

parent 9e8bd59d
Loading
Loading
Loading
Loading
+43 −26
Original line number Diff line number Diff line
@@ -957,10 +957,14 @@ class KerasModel(Model):
      return int(self._global_step)
    return self._global_step.eval(session=self.session)

  def default_assignment_map(self, source_model, include_top=True, **kwargs):
  def _create_assignment_map(self, source_model, include_top=True, **kwargs):
    """
    Creates a default assignment map between layers of source and current model.
    This is used only when a custom assignment map is missing.
    Creates a default assignment map between variables of source and current model.
    This is used only when a custom assignment map is missing. This assumes the
    model is made of different layers followed by a dense layer for mapping to
    output tasks. include_top is used to control whether or not the final dense
    layer is used. The default assignment map is useful in cases where the type
    of task is different (classification vs regression) and/or number of tasks.

    Parameters
    ----------
@@ -970,11 +974,15 @@ class KerasModel(Model):
        if true, copies the last dense layer
    """
    assignment_map = {}
    source_vars = source_model.model.trainable_variables
    dest_vars = self.model.trainable_variables

    for idx, layer in enumerate(source_model.model.layers[:-1]):
      assignment_map[layer] = self.model.layers[idx]
    if include_top:
      assignment_map[source_model.model.layers[-1]] = self.model.layers[-1]
    if not include_top:
      source_vars = source_vars[:-2]
      dest_vars = dest_vars[:-2]

    for source_var, dest_var in zip(source_vars, dest_vars):
      assignment_map[source_var] = dest_var

    return assignment_map

@@ -985,19 +993,23 @@ class KerasModel(Model):
                           model_dir=None,
                           include_top=True,
                           **kwargs):
    """Loads a set of layer weights from a pretrained model. The method takes in
    a source model and an assignment map, and several other optional arguments.
    The assignment map is a dictionary between the layers of the source model
    and the layers in the current model, whose weights we want to copy over. If
    an assignment map is not supplied, a default assignment map is generated,
    which assumes the architecture has a Dense layer as its last layer and can be
    included by setting include_top to True. The default assignment map allows
    applying an existing trained model to a different multi-task setting involving
    the same model or to a different task like classification from regression or
    vice-versa.
    """Copies variable values between pretrained model and current model. This
    method assumes that the variable values of the pretrained model are saved to
    disk. source_model is a dc.KerasModel with the same architecture as the
    pretrained model. The variable values are then restored to the source_model.
    assignment_map is a dictionary mapping the variables from the source model
    to those in the current model. If no assignment_map is provided, one is made
    from scratch and assumes the model is composed of several different layers,
    with the final one being a dense layer. include_top is used to control whether
    or not the final dense layer is used. The default assignment map is useful in
    cases where the type of task is different (classification vs regression) and/or
    number of tasks in the setting.

    Parameters
    ----------
    source_model: dc.KerasModel, required
      Model which has the same topology and architecture as the pretrained
      model.
    assignment_map: Dict, default None
      Dictionary containing layer mapping between source and current model layers
    checkpoint: str, default None
@@ -1007,27 +1019,32 @@ class KerasModel(Model):
    model_dir: str, default None
      Restore model from custom model directory if needed
    include_top: bool, default True
        if true, copies the last dense layer. Used only when assignment map is None
        if True, copies the weights and bias associated with the final dense layer.
        Used only when assignment map is None.
    """
    self._ensure_built()
    if assignment_map is None:
      assignment_map = self.default_assignment_map(
      assignment_map = self._create_assignment_map(
          source_model=source_model, include_top=include_top)

    self._assign_ops = []
    if tf.executing_eagerly():
      source_model.restore(model_dir=model_dir, checkpoint=checkpoint)
      for source_layer, dest_layer in assignment_map.items():
        dest_vars = dest_layer.trainable_variables
        dest_layer.set_weights(source_layer.get_weights())
      for source_var, dest_var in assignment_map.items():
        assert source_var.shape == dest_var.shape
        dest_var.assign(source_var)

    else:
      source_model.restore(
          model_dir=model_dir, checkpoint=checkpoint, session=self.session)
      self._assign_ops = []
      with self.session.as_default():
        for source_layer, dest_layer in assignment_map.items():
          dest_vars = dest_layer.trainable_variables
          dest_layer.set_weights(source_layer.get_weights())
        for source_var, dest_var in assignment_map.items():
          assert source_var.shape == dest_var.shape
          assign_op = dest_var.assign(source_var)
          self._assign_ops.append(assign_op)
          self.session.run(assign_op)

    dest_vars = list(assignment_map.values())
    self._initialized_vars.update(set(dest_vars))


+15 −21
Original line number Diff line number Diff line
@@ -68,25 +68,21 @@ class TestPretrained(unittest.TestCase):
        n_tasks=10)

    assignment_map = dict()
    dest_layers = dest_model.model.layers[:-1]
    dest_vars = dest_model.model.trainable_variables[:-2]

    for idx, dest_layer in enumerate(dest_layers):
      source_layer = source_model.model.layers[idx]
      assignment_map[source_layer] = dest_layer
    for idx, dest_var in enumerate(dest_vars):
      source_var = source_model.model.trainable_variables[idx]
      assignment_map[source_var] = dest_var

    dest_model.load_from_pretrained(
        source_model=source_model,
        assignment_map=assignment_map,
        include_top=False)

    for source_layer, dest_layer in assignment_map.items():
      for var_old, var_new in zip(source_layer.trainable_variables,
                                  dest_layer.trainable_variables):
        # Need to fix this by running session ops everytime

    for source_var, dest_var in assignment_map.items():
      np.testing.assert_array_almost_equal(
            var_old.eval(session=dest_model.session),
            var_new.eval(session=dest_model.session))
          source_var.eval(session=dest_model.session),
          dest_var.eval(session=dest_model.session))

  def test_load_from_pretrained_eager(self):
    """Tests loading pretrained model in eager execution mode."""
@@ -102,19 +98,18 @@ class TestPretrained(unittest.TestCase):
          n_tasks=10)

      assignment_map = dict()
      dest_layers = dest_model.model.layers[:-1]
      dest_vars = dest_model.model.trainable_variables[:-2]

      for idx, dest_layer in enumerate(dest_layers):
        source_layer = source_model.model.layers[idx]
        assignment_map[source_layer] = dest_layer
      for idx, dest_var in enumerate(dest_vars):
        source_var = source_model.model.trainable_variables[idx]
        assignment_map[source_var] = dest_var

      dest_model.load_from_pretrained(
          source_model=source_model, assignment_map=assignment_map)

      for source_layer, dest_layer in assignment_map.items():
        for var_old, var_new in zip(source_layer.trainable_variables,
                                    dest_layer.trainable_variables):
          np.testing.assert_array_almost_equal(var_old.numpy(), var_new.numpy())
      for source_var, dest_var in assignment_map.items():
        np.testing.assert_array_almost_equal(source_var.numpy(),
                                             dest_var.numpy())

  def test_restore_equivalency(self):
    source_model = MLP(
@@ -128,7 +123,6 @@ class TestPretrained(unittest.TestCase):
    dest_model.load_from_pretrained(
        source_model=source_model, assignment_map=None, include_top=True)

    dest_model.fit(self.dataset, nb_epoch=1)
    predictions = np.squeeze(dest_model.predict_on_batch(self.dataset.X))

    np.testing.assert_array_almost_equal(self.dataset.y, np.round(predictions))