Commit 9fa2f9ea authored by Vignesh's avatar Vignesh
Browse files

Moved to Layer based assignment map, removed restore use

parent 50c6a926
Loading
Loading
Loading
Loading
+71 −25
Original line number Diff line number Diff line
@@ -927,7 +927,7 @@ class KerasModel(Model):
    return tf.train.get_checkpoint_state(
        self.model_dir).all_model_checkpoint_paths

  def restore(self, checkpoint=None):
  def restore(self, checkpoint=None, model_dir=None, session=None):
    """Reload the values of all variables from a checkpoint file.

    Parameters
@@ -938,14 +938,18 @@ class KerasModel(Model):
      list of all available checkpoints.
    """
    self._ensure_built()
    if model_dir is None:
      model_dir = self.model_dir
    if checkpoint is None:
      checkpoint = tf.train.latest_checkpoint(self.model_dir)
      checkpoint = tf.train.latest_checkpoint(model_dir)
    if checkpoint is None:
      raise ValueError('No checkpoint found')
    if tf.executing_eagerly():
      self._checkpoint.restore(checkpoint)
    else:
      self._checkpoint.restore(checkpoint).run_restore_ops(self.session)
      if session is None:
        session = self.session
      self._checkpoint.restore(checkpoint).run_restore_ops(session)

  def get_global_step(self):
    """Get the number of steps of fitting that have been performed."""
@@ -953,36 +957,78 @@ class KerasModel(Model):
      return int(self._global_step)
    return self._global_step.eval(session=self.session)

  def load_pretrained(self,
  def default_assignment_map(self, source_model, include_top=True, **kwargs):
    """
    Creates a default assignment map between layers of source and current model.
    This is used only when a custom assignment map is missing.

    Parameters
    ----------
    source_model: dc.models.KerasModel
        Source model to copy variable values from.
    include_top: bool, default True
        if true, copies the last dense layer
    """
    assignment_map = {}

    for idx, layer in enumerate(source_model.model.layers[:-1]):
      assignment_map[layer] = self.model.layers[idx]
    if include_top:
      assignment_map[source_model.model.layers[-1]] = self.model.layers[-1]

    return assignment_map

  def load_from_pretrained(self,
                           source_model,
                           assignment_map=None,
                           checkpoint=None,
                      model_dir=None):
    """Load from a pretrained model.
                           model_dir=None,
                           include_top=True,
                           **kwargs):
    """Loads a set of layer weights from a pretrained model. The method takes in
    a source model and an assignment map, and several other optional arguments.
    The assignment map is a dictionary between the layers of the source model
    and the layers in the current model, whose weights we want to copy over. If
    an assignment map is not supplied, a default assignment map is generated,
    which assumes the architecture has a Dense layer as its last layer and can be
    included by setting include_top to True. The default assignment map allows
    applying an existing trained model to a different multi-task setting involving
    the same model or to a different task like classification from regression or
    vice-versa.

    Parameters
    ----------
    assignment_map: Dict, default None
      Dictionary containing variable mapping between source and current model
      variables
      Dictionary containing layer mapping between source and current model layers
    checkpoint: str, default None
      the path to the checkpoint file to load.  If this is None, the most recent
      checkpoint will be chosen automatically.  Call get_checkpoints() to get a
      list of all available checkpoints.
    model_dir: str, default None
      Restore model from custom model directory if needed
    include_top: bool, default True
        if true, copies the last dense layer. Used only when assignment map is None
    """
    self._ensure_built()
    if assignment_map is None:
      self.restore(checkpoint=checkpoint, model_dir=model_dir)
    else:
      assignment_map = self.default_assignment_map(
          source_model=source_model, include_top=include_top)

    self._assign_ops = []
    if tf.executing_eagerly():
        for source_var, dest_var in assignment_map.items():
          dest_var.assign(source_var)
      source_model.restore(model_dir=model_dir, checkpoint=checkpoint)
      for source_layer, dest_layer in assignment_map.items():
        dest_vars = dest_layer.trainable_variables
        dest_layer.set_weights(source_layer.get_weights())
    else:
        self._assign_ops = []
        for source_var, dest_var in assignment_map.items():
          assign_op = dest_var.assign(source_var)
          self._assign_ops.append(assign_op)
          self.session.run(assign_op)
      source_model.restore(
          model_dir=model_dir, checkpoint=checkpoint, session=self.session)
      with self.session.as_default():
        for source_layer, dest_layer in assignment_map.items():
          dest_vars = dest_layer.trainable_variables
          dest_layer.set_weights(source_layer.get_weights())

        if hasattr(self, '_initialized_vars'):
          self._initialized_vars.update(set(assignment_map.values()))
        else:
          self._initialized_vars = set(assignment_map.values())
    self._initialized_vars.update(set(dest_vars))


class _StandardLoss(object):
+75 −20
Original line number Diff line number Diff line
@@ -22,12 +22,12 @@ class MLP(dc.models.KerasModel):

  def _build_graph(self):
    inputs = Input(dtype=tf.float32, shape=(self.feature_dim,), name="Input")
    out1 = Dense(units=self.hidden_layer_size, activation=tf.nn.relu)(inputs)
    out1 = Dense(units=self.hidden_layer_size, activation='relu')(inputs)

    final = Dense(units=self.n_tasks)(out1)
    final = Dense(units=self.n_tasks, activation='sigmoid')(out1)
    outputs = [final]
    output_types = ['prediction']
    loss = L2Loss()
    loss = dc.models.losses.BinaryCrossEntropy()

    model = tf.keras.Model(inputs=[inputs], outputs=outputs)
    return model, loss, output_types
@@ -38,27 +38,29 @@ class TestPretrained(unittest.TestCase):
  def setUp(self):
    model_dir = "./MLP/"
    self.feature_dim = 2
    self.hidden_layer_size = 2
    data_points = 100
    self.hidden_layer_size = 10
    data_points = 10

    X = np.random.randn(data_points, self.feature_dim)
    y = np.random.randn(data_points)
    y = (X[:, 0] > X[:, 1]).astype(np.float32)

    dataset = dc.data.NumpyDataset(X, y)
    self.dataset = dc.data.NumpyDataset(X, y)

    model = MLP(
        hidden_layer_size=self.hidden_layer_size,
        feature_dim=self.feature_dim,
        model_dir=model_dir,
        batch_size=10)
    model.fit(dataset, nb_epoch=100)
    model.fit(self.dataset, nb_epoch=1000)
    predictions = np.squeeze(model.predict_on_batch(self.dataset.X))
    np.testing.assert_array_almost_equal(self.dataset.y, np.round(predictions))

  def test_load_pretrained(self):
  def test_load_from_pretrained_graph_mode(self):
    """Tests loading pretrained model in graph mode."""
    source_model = MLP(
        model_dir="./MLP/",
        feature_dim=self.feature_dim,
        hidden_layer_size=self.hidden_layer_size)
    source_model.restore()

    dest_model = MLP(
        feature_dim=self.feature_dim,
@@ -66,17 +68,70 @@ class TestPretrained(unittest.TestCase):
        n_tasks=10)

    assignment_map = dict()
    dest_variables = dest_model.model.trainable_variables[:
                                                          -2]  #Excluding the last weight and bias
    dest_layers = dest_model.model.layers[:-1]

    for idx, variable in enumerate(dest_variables):
      source_variable = source_model.model.trainable_variables[idx]
      assignment_map[source_variable] = variable
    for idx, dest_layer in enumerate(dest_layers):
      source_layer = source_model.model.layers[idx]
      assignment_map[source_layer] = dest_layer

    dest_model.load_pretrained(assignment_map=assignment_map)
    dest_model.load_from_pretrained(
        source_model=source_model,
        assignment_map=assignment_map,
        include_top=False)

    for var_old, var_new in assignment_map.items():
      val_old = dest_model.session.run(var_old)
      val_new = dest_model.session.run(var_new)
    for source_layer, dest_layer in assignment_map.items():
      for var_old, var_new in zip(source_layer.trainable_variables,
                                  dest_layer.trainable_variables):
        # Need to fix this by running session ops everytime

      np.testing.assert_array_almost_equal(val_old, val_new)
        np.testing.assert_array_almost_equal(
            var_old.eval(session=dest_model.session),
            var_new.eval(session=dest_model.session))

  def test_load_from_pretrained_eager(self):
    """Tests loading pretrained model in eager execution mode."""
    with context.eager_mode():
      source_model = MLP(
          model_dir="./MLP/",
          feature_dim=self.feature_dim,
          hidden_layer_size=self.hidden_layer_size)

      dest_model = MLP(
          feature_dim=self.feature_dim,
          hidden_layer_size=self.hidden_layer_size,
          n_tasks=10)

      assignment_map = dict()
      dest_layers = dest_model.model.layers[:-1]

      for idx, dest_layer in enumerate(dest_layers):
        source_layer = source_model.model.layers[idx]
        assignment_map[source_layer] = dest_layer

      dest_model.load_from_pretrained(
          source_model=source_model, assignment_map=assignment_map)

      for source_layer, dest_layer in assignment_map.items():
        for var_old, var_new in zip(source_layer.trainable_variables,
                                    dest_layer.trainable_variables):
          np.testing.assert_array_almost_equal(var_old.numpy(), var_new.numpy())

  def test_restore_equivalency(self):
    source_model = MLP(
        model_dir="./MLP/",
        feature_dim=self.feature_dim,
        hidden_layer_size=self.hidden_layer_size)

    dest_model = MLP(
        feature_dim=self.feature_dim, hidden_layer_size=self.hidden_layer_size)

    dest_model.load_from_pretrained(
        source_model=source_model, assignment_map=None, include_top=True)

    dest_model.fit(self.dataset, nb_epoch=1)
    predictions = np.squeeze(dest_model.predict_on_batch(self.dataset.X))

    # print(tf.train.load_variable("./MLP", 'model/layer_with_weights-0/kernel/.ATTRIBUTES/VARIABLE_VALUE'))
    # print(tf.train.load_variable("./MLP", 'model/layer_with_weights-1/kernel/.ATTRIBUTES/VARIABLE_VALUE'))

    np.testing.assert_array_almost_equal(self.dataset.y, np.round(predictions))