Commit 50c6a926 authored by Vignesh's avatar Vignesh
Browse files

First version of pretrained loading

parent 94809d39
Loading
Loading
Loading
Loading
+31 −0
Original line number Diff line number Diff line
@@ -953,6 +953,37 @@ class KerasModel(Model):
      return int(self._global_step)
    return self._global_step.eval(session=self.session)

  def load_pretrained(self,
                      assignment_map=None,
                      checkpoint=None,
                      model_dir=None):
    """Load from a pretrained model.

    Parameters
    ----------
    assignment_map: Dict, default None
      Dictionary containing variable mapping between source and current model
      variables
    """
    self._ensure_built()
    if assignment_map is None:
      self.restore(checkpoint=checkpoint, model_dir=model_dir)
    else:
      if tf.executing_eagerly():
        for source_var, dest_var in assignment_map.items():
          dest_var.assign(source_var)
      else:
        self._assign_ops = []
        for source_var, dest_var in assignment_map.items():
          assign_op = dest_var.assign(source_var)
          self._assign_ops.append(assign_op)
          self.session.run(assign_op)

        if hasattr(self, '_initialized_vars'):
          self._initialized_vars.update(set(assignment_map.values()))
        else:
          self._initialized_vars = set(assignment_map.values())


class _StandardLoss(object):
  """The implements the loss function for models that use a dc.models.losses.Loss."""
+82 −0
Original line number Diff line number Diff line
import os
import unittest
import deepchem as dc
import numpy as np
import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.keras.layers import Input, Dense
from deepchem.models.losses import L2Loss


class MLP(dc.models.KerasModel):

  def __init__(self, n_tasks=1, feature_dim=100, hidden_layer_size=64,
               **kwargs):
    self.feature_dim = feature_dim
    self.hidden_layer_size = hidden_layer_size
    self.n_tasks = n_tasks

    model, loss, output_types = self._build_graph()
    super(MLP, self).__init__(
        model=model, loss=loss, output_types=output_types, **kwargs)

  def _build_graph(self):
    inputs = Input(dtype=tf.float32, shape=(self.feature_dim,), name="Input")
    out1 = Dense(units=self.hidden_layer_size, activation=tf.nn.relu)(inputs)

    final = Dense(units=self.n_tasks)(out1)
    outputs = [final]
    output_types = ['prediction']
    loss = L2Loss()

    model = tf.keras.Model(inputs=[inputs], outputs=outputs)
    return model, loss, output_types


class TestPretrained(unittest.TestCase):

  def setUp(self):
    model_dir = "./MLP/"
    self.feature_dim = 2
    self.hidden_layer_size = 2
    data_points = 100

    X = np.random.randn(data_points, self.feature_dim)
    y = np.random.randn(data_points)

    dataset = dc.data.NumpyDataset(X, y)

    model = MLP(
        hidden_layer_size=self.hidden_layer_size,
        feature_dim=self.feature_dim,
        model_dir=model_dir,
        batch_size=10)
    model.fit(dataset, nb_epoch=100)

  def test_load_pretrained(self):
    source_model = MLP(
        model_dir="./MLP/",
        feature_dim=self.feature_dim,
        hidden_layer_size=self.hidden_layer_size)
    source_model.restore()

    dest_model = MLP(
        feature_dim=self.feature_dim,
        hidden_layer_size=self.hidden_layer_size,
        n_tasks=10)

    assignment_map = dict()
    dest_variables = dest_model.model.trainable_variables[:
                                                          -2]  #Excluding the last weight and bias

    for idx, variable in enumerate(dest_variables):
      source_variable = source_model.model.trainable_variables[idx]
      assignment_map[source_variable] = variable

    dest_model.load_pretrained(assignment_map=assignment_map)

    for var_old, var_new in assignment_map.items():
      val_old = dest_model.session.run(var_old)
      val_new = dest_model.session.run(var_new)

      np.testing.assert_array_almost_equal(val_old, val_new)