Unverified Commit a1db2ca4 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1024 from lilleswing/move-checkpoints

Move around saved models
parents 2e3a120a 523a810e
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -187,7 +187,8 @@ class TensorGraph(Model):
        if submodel.loss is not None:
          loss = submodel.loss
      if checkpoint_interval > 0:
        saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
        saver = tf.train.Saver(
            max_to_keep=max_checkpoints_to_keep, save_relative_paths=True)
      if restore:
        self.restore()
      avg_loss, n_averaged_batches = 0.0, 0.0
@@ -788,7 +789,8 @@ class TensorGraph(Model):
      var_names = set([x for x in reader.get_variable_to_shape_map()])
      var_map = {
          x.op.name: x
          for x in tf.global_variables() if x.op.name in var_names
          for x in tf.global_variables()
          if x.op.name in var_names
      }
      saver = tf.train.Saver(var_list=var_map)
      saver.restore(self.session, checkpoint)
+10 −4
Original line number Diff line number Diff line
import os
import tempfile
import unittest

import numpy as np
import os
from nose.tools import assert_true, nottest
from flaky import flaky
import tensorflow as tf
from flaky import flaky
from nose.tools import assert_true
import shutil

import deepchem as dc
from deepchem.data import NumpyDataset
@@ -238,7 +240,11 @@ class TestTensorGraph(unittest.TestCase):
    prediction = np.squeeze(tg.predict_on_batch(X))
    tg.save()

    tg1 = TensorGraph.load_from_dir(tg.model_dir)
    dirpath = tempfile.mkdtemp()
    shutil.rmtree(dirpath)
    shutil.move(tg.model_dir, dirpath)

    tg1 = TensorGraph.load_from_dir(dirpath)
    prediction2 = np.squeeze(tg1.predict_on_batch(X))
    assert_true(np.all(np.isclose(prediction, prediction2, atol=0.01)))