Unverified Commit 874c02d0 authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #1643 from VIGS25/transfer-learning-api

First version of pretrained loading
parents a8cadbb5 30669439
Loading
Loading
Loading
Loading
+135 −2
Original line number Diff line number Diff line
@@ -976,7 +976,7 @@ class KerasModel(Model):
      model_dir = self.model_dir
    return tf.train.get_checkpoint_state(model_dir).all_model_checkpoint_paths

  def restore(self, checkpoint=None, model_dir=None):
  def restore(self, checkpoint=None, model_dir=None, session=None):
    """Reload the values of all variables from a checkpoint file.

    Parameters
@@ -987,6 +987,8 @@ class KerasModel(Model):
      list of all available checkpoints.
    model_dir: str, default None
      Directory to restore checkpoint from. If None, use self.model_dir.
    session: tf.Session(), default None
      Session to run restore ops under. If None, self.session is used.
    """
    self._ensure_built()
    if model_dir is None:
@@ -998,7 +1000,9 @@ class KerasModel(Model):
    if tf.executing_eagerly():
      self._checkpoint.restore(checkpoint)
    else:
      self._checkpoint.restore(checkpoint).run_restore_ops(self.session)
      if session is None:
        session = self.session
      self._checkpoint.restore(checkpoint).run_restore_ops(session)

  def get_global_step(self):
    """Get the number of steps of fitting that have been performed."""
@@ -1006,6 +1010,135 @@ class KerasModel(Model):
      return int(self._global_step)
    return self._global_step.eval(session=self.session)

  def _create_assignment_map(self, source_model, include_top=True, **kwargs):
    """
    Creates a default assignment map between variables of source and current model.
    This is used only when a custom assignment map is missing. This assumes the
    model is made of different layers followed by a dense layer for mapping to
    output tasks. include_top is used to control whether or not the final dense
    layer is used. The default assignment map is useful in cases where the type
    of task is different (classification vs regression) and/or number of tasks.

    Parameters
    ----------
    source_model: dc.models.KerasModel
        Source model to copy variable values from.
    include_top: bool, default True
        if true, copies the last dense layer
    """
    assignment_map = {}
    source_vars = source_model.model.trainable_variables
    dest_vars = self.model.trainable_variables

    if not include_top:
      source_vars = source_vars[:-2]
      dest_vars = dest_vars[:-2]

    for source_var, dest_var in zip(source_vars, dest_vars):
      assignment_map[source_var] = dest_var

    return assignment_map

  def _create_value_map(self, source_model, **kwargs):
    """
    Creates a value map between variables in the source model and their
    current values. This is used only when a custom value map is missing, and
    assumes the restore method has been called under self.session.

    Parameters
    ----------
    source_model: dc.models.KerasModel
        Source model to create value map from
    """
    value_map = {}
    source_vars = source_model.model.trainable_variables

    if tf.executing_eagerly():
      for source_var in source_vars:
        value_map[source_var] = source_var.numpy()
    else:
      for source_var in source_vars:
        # self.session is used because restore was called in the same session
        value_map[source_var] = source_var.eval(session=self.session)

    return value_map

  def load_from_pretrained(self,
                           source_model,
                           assignment_map=None,
                           value_map=None,
                           checkpoint=None,
                           model_dir=None,
                           include_top=True,
                           **kwargs):
    """Copies variable values from a pretrained model. `source_model` can either
    be a pretrained model or a model with the same architecture. `value_map`
    is a variable-value dictionary. If no `value_map` is provided, the variable
    values are restored to the `source_model` from a checkpoint and a default
    `value_map` is created. `assignment_map` is a dictionary mapping variables
    from the `source_model` to the current model. If no `assignment_map` is
    provided, one is made from scratch and assumes the model is composed of
    several different layers, with the final one being a dense layer. include_top
    is used to control whether or not the final dense layer is used. The default
    assignment map is useful in cases where the type of task is different
    (classification vs regression) and/or number of tasks in the setting.

    Parameters
    ----------
    source_model: dc.KerasModel, required
      source_model can either be the pretrained model or a dc.KerasModel with
      the same architecture as the pretrained model. It is used to restore from
      a checkpoint, if value_map is None and to create a default assignment map
      if assignment_map is None
    assignment_map: Dict, default None
      Dictionary mapping the source_model variables and current model variables
    value_map: Dict, default None
      Dictionary containing source_model trainable variables mapped to numpy
      arrays. If value_map is None, the values are restored and a default
      variable map is created using the restored values
    checkpoint: str, default None
      the path to the checkpoint file to load.  If this is None, the most recent
      checkpoint will be chosen automatically.  Call get_checkpoints() to get a
      list of all available checkpoints
    model_dir: str, default None
      Restore model from custom model directory if needed
    include_top: bool, default True
        if True, copies the weights and bias associated with the final dense
        layer. Used only when assignment map is None
    """

    self._ensure_built()
    if value_map is None:
      logger.info(
          "No value map provided. Creating default value map from restored model."
      )
      if tf.executing_eagerly():
        source_model.restore(model_dir=model_dir, checkpoint=checkpoint)
      else:
        source_model.restore(
            model_dir=model_dir, checkpoint=checkpoint, session=self.session)
      value_map = self._create_value_map(source_model=source_model)

    if assignment_map is None:
      logger.info("No assignment map provided. Creating custom assignment map.")
      assignment_map = self._create_assignment_map(
          source_model=source_model, include_top=include_top)

    if tf.executing_eagerly():
      for source_var, dest_var in assignment_map.items():
        assert source_var.shape == dest_var.shape
        dest_var.assign(value_map[source_var])

    else:
      with self.session.as_default():
        for source_var, dest_var in assignment_map.items():
          assert source_var.shape == dest_var.shape
          assign_op = dest_var.assign(value_map[source_var])
          self.session.run(assign_op)

    dest_vars = list(assignment_map.values())
    self._initialized_vars.update(set(dest_vars))


class _StandardLoss(object):
  """The implements the loss function for models that use a dc.models.losses.Loss."""
+116 −0
Original line number Diff line number Diff line
import os
import unittest
import deepchem as dc
import numpy as np
import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.keras.layers import Input, Dense
from deepchem.models.losses import L2Loss


class MLP(dc.models.KerasModel):

  def __init__(self, n_tasks=1, feature_dim=100, hidden_layer_size=64,
               **kwargs):
    self.feature_dim = feature_dim
    self.hidden_layer_size = hidden_layer_size
    self.n_tasks = n_tasks

    model, loss, output_types = self._build_graph()
    super(MLP, self).__init__(
        model=model, loss=loss, output_types=output_types, **kwargs)

  def _build_graph(self):
    inputs = Input(dtype=tf.float32, shape=(self.feature_dim,), name="Input")
    out1 = Dense(units=self.hidden_layer_size, activation='relu')(inputs)

    final = Dense(units=self.n_tasks, activation='sigmoid')(out1)
    outputs = [final]
    output_types = ['prediction']
    loss = dc.models.losses.BinaryCrossEntropy()

    model = tf.keras.Model(inputs=[inputs], outputs=outputs)
    return model, loss, output_types


class TestPretrained(unittest.TestCase):

  def setUp(self):
    self.feature_dim = 2
    self.hidden_layer_size = 10
    data_points = 10

    X = np.random.randn(data_points, self.feature_dim)
    y = (X[:, 0] > X[:, 1]).astype(np.float32)

    self.dataset = dc.data.NumpyDataset(X, y)

  def test_load_from_pretrained_graph_mode(self):
    """Tests loading pretrained model in graph mode."""
    source_model = MLP(
        hidden_layer_size=self.hidden_layer_size,
        feature_dim=self.feature_dim,
        batch_size=10)

    source_model.fit(self.dataset, nb_epoch=1000, checkpoint_interval=0)

    dest_model = MLP(
        feature_dim=self.feature_dim,
        hidden_layer_size=self.hidden_layer_size,
        n_tasks=10)

    assignment_map = dict()
    value_map = dict()
    dest_vars = dest_model.model.trainable_variables[:-2]

    for idx, dest_var in enumerate(dest_vars):
      source_var = source_model.model.trainable_variables[idx]
      assignment_map[source_var] = dest_var
      if tf.executing_eagerly():
        value_map[source_var] = source_var.numpy()
      else:
        value_map[source_var] = source_var.eval(session=source_model.session)

    dest_model.load_from_pretrained(
        source_model=source_model,
        assignment_map=assignment_map,
        value_map=value_map)

    for source_var, dest_var in assignment_map.items():
      if tf.executing_eagerly():
        source_val = source_var.numpy()
        dest_val = dest_var.numpy()
      else:
        source_val = source_var.eval(session=source_model.session)
        dest_val = dest_var.eval(session=dest_model.session)
      np.testing.assert_array_almost_equal(source_val, dest_val)

  def test_load_from_pretrained_eager_mode(self):
    """Tests loading pretrained model in eager execution mode."""
    with context.eager_mode():
      self.test_load_from_pretrained_graph_mode()

  def test_restore_equivalency_graph_mode(self):
    """Test for restore based pretrained model loading in graph mode."""
    source_model = MLP(
        feature_dim=self.feature_dim, hidden_layer_size=self.hidden_layer_size)

    source_model.fit(self.dataset, nb_epoch=1000)

    dest_model = MLP(
        feature_dim=self.feature_dim, hidden_layer_size=self.hidden_layer_size)

    dest_model.load_from_pretrained(
        source_model=source_model,
        assignment_map=None,
        value_map=None,
        model_dir=None,
        include_top=True)

    predictions = np.squeeze(dest_model.predict_on_batch(self.dataset.X))
    np.testing.assert_array_almost_equal(self.dataset.y, np.round(predictions))

  def test_restore_equivalency_eager_mode(self):
    """Test for restore based pretrained model loading in eager mode."""
    with context.eager_mode():
      self.test_restore_equivalency_graph_mode()