Commit 3859fcf9 authored by Vignesh's avatar Vignesh
Browse files

Extend load_from_pretrained to build inputs

parent 8edbeaa0
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
@@ -1012,6 +1012,7 @@ class KerasModel(Model):
                           checkpoint=None,
                           model_dir=None,
                           include_top=True,
                           inputs=None,
                           **kwargs):
    """Copies variable values from a pretrained model. `source_model` can either
    be a pretrained model or a model with the same architecture. `value_map`
@@ -1047,7 +1048,15 @@ class KerasModel(Model):
    include_top: bool, default True
        if True, copies the weights and bias associated with the final dense
        layer. Used only when assignment map is None
    inputs: List, input tensors for model
        if not None, then the weights are built for both the source and self. 
        This option is useful only for models that are built by 
        subclassing tf.keras.Model, and not using the functional API by tf.keras
    """
    if inputs is not None:
      # Ensure weights for both models are built.
      source_model.model(inputs)
      self.model(inputs)

    self._ensure_built()
    if value_map is None:
+58 −0
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense
from deepchem.models.losses import L2Loss
from deepchem.feat.mol_graphs import ConvMol


class MLP(dc.models.KerasModel):
@@ -77,6 +78,63 @@ class TestPretrained(unittest.TestCase):
      dest_val = dest_var.numpy()
      np.testing.assert_array_almost_equal(source_val, dest_val)

  def test_load_pretrained_subclassed_model(self):
    from rdkit import Chem
    bi_tasks = ['a', 'b']
    y = np.ones((3, 2))
    smiles = ['C', 'CC', 'CCC']
    mols = [Chem.MolFromSmiles(smile) for smile in smiles]
    featurizer = dc.feat.ConvMolFeaturizer()
    X = featurizer.featurize(mols)
    dataset = dc.data.NumpyDataset(X, y, ids=smiles)

    source_model = dc.models.GraphConvModel(
        n_tasks=len(bi_tasks),
        graph_conv_layers=[128, 128],
        dense_layer_size=512,
        dropout=0,
        mode='regression',
        learning_rate=0.001,
        batch_size=8,
        model_dir="model")
    source_model.fit(dataset)

    dest_model = dc.models.GraphConvModel(
        n_tasks=len(bi_tasks),
        graph_conv_layers=[128, 128],
        dense_layer_size=512,
        dropout=0,
        mode='regression',
        learning_rate=0.001,
        batch_size=8)

    X_b, y_b, w_b, ids_b = next(
        dataset.iterbatches(batch_size=8, deterministic=True, pad_batches=True))
    multiConvMol = ConvMol.agglomerate_mols(X_b)
    n_samples = np.array(X_b.shape[0])
    inputs = [
        multiConvMol.get_atom_features(), multiConvMol.deg_slice,
        np.array(multiConvMol.membership), n_samples
    ]
    for i in range(1, len(multiConvMol.get_deg_adjacency_lists())):
      inputs.append(multiConvMol.get_deg_adjacency_lists()[i])

    dest_model.load_from_pretrained(
        source_model=source_model,
        assignment_map=None,
        value_map=None,
        include_top=False,
        inputs=inputs)

    source_vars = source_model.model.trainable_variables[:-2]
    dest_vars = dest_model.model.trainable_variables[:-2]
    assert len(source_vars) == len(dest_vars)

    for source_var, dest_var in zip(*(source_vars, dest_vars)):
      source_val = source_var.numpy()
      dest_val = dest_var.numpy()
      np.testing.assert_array_almost_equal(source_val, dest_val)

  def test_restore_equivalency(self):
    """Test for restore based pretrained model loading."""
    source_model = MLP(