Commit a34adf28 authored by Vignesh's avatar Vignesh
Browse files

Edits to logging and saving for KerasModel

parent 9df90305
Loading
Loading
Loading
Loading
+46 −21
Original line number Diff line number Diff line
import numpy as np
import tensorflow as tf
import time
import logging
import os

logger = logging.getLogger(__name__)

from deepchem.data import NumpyDataset
from deepchem.models.losses import Loss
@@ -348,7 +352,6 @@ class KerasModel(Model):
      inputs, labels, weights = self._prepare_batch(batch)
      self._tensorboard_step += 1
      should_log = (
          self.tensorboard and
          self._tensorboard_step % self.tensorboard_log_frequency == 0)
      if tf.executing_eagerly():

@@ -401,7 +404,7 @@ class KerasModel(Model):
                  loss_tensor, global_step=self._global_step, var_list=vars)
            train_op = self._custom_train_op[op_key]
        fetches = [train_op, self._loss_tensor, self._global_step]
        if should_log:
        if self.tensorboard and should_log:
          fetches.append(self._summary_ops)
        feed_dict = dict(zip(self._input_placeholders, inputs))
        feed_dict.update(dict(zip(self._label_placeholders, labels)))
@@ -409,33 +412,36 @@ class KerasModel(Model):
        fetched_values = self.session.run(fetches, feed_dict=feed_dict)
        avg_loss += fetched_values[1]
        current_step = fetched_values[2]
        if should_log:

        if self.tensorboard and should_log:
          self._summary_writer.reopen()
          self._summary_writer.add_summary(
              fetched_values[3], global_step=current_step)
          self._summary_writer.close()

      # Report progress and write checkpoints.

      averaged_batches += 1
      if checkpoint_interval > 0 and current_step % checkpoint_interval == checkpoint_interval - 1:
        self._exec_with_session(lambda: manager.save())
      if should_log:
        avg_loss = float(avg_loss) / averaged_batches
        print(
        logger.info(
            'Ending global_step %d: Average loss %g' % (current_step, avg_loss))
        avg_loss = 0.0
        averaged_batches = 0

    # Report final results.
      if checkpoint_interval > 0 and current_step % checkpoint_interval == checkpoint_interval - 1:
        self._exec_with_session(lambda: manager.save())

    if checkpoint_interval > 0:
    # Report final results.
    if averaged_batches > 0:
      avg_loss = float(avg_loss) / averaged_batches
        print(
      logger.info(
          'Ending global_step %d: Average loss %g' % (current_step, avg_loss))

    if checkpoint_interval > 0:
      self._exec_with_session(lambda: manager.save())

    time2 = time.time()
      print("TIMING: model fitting took %0.3f s" % (time2 - time1))
    logger.info("TIMING: model fitting took %0.3f s" % (time2 - time1))
    return avg_loss

  def fit_on_batch(self, X, y, w, variables=None, loss=None):
@@ -898,7 +904,7 @@ class KerasModel(Model):
          pad_batches=pad_batches):
        yield ([X_b], [y_b], [w_b])

  def save_checkpoint(self, max_checkpoints_to_keep=5):
  def save_checkpoint(self, max_checkpoints_to_keep=5, model_dir=None):
    """Save a checkpoint to disk.

    Usually you do not need to call this method, since fit() saves checkpoints
@@ -909,9 +915,15 @@ class KerasModel(Model):
    ----------
    max_checkpoints_to_keep: int
      the maximum number of checkpoints to keep.  Older checkpoints are discarded.
    model_dir: str, default None
      Model directory to save checkpoint to. If None, revert to self.model_dir
    """
    self._ensure_built()
    manager = tf.train.CheckpointManager(self._checkpoint, self.model_dir,
    if model_dir is None:
      model_dir = self.model_dir
    if not os.path.exists(model_dir):
      os.makedirs(model_dir)
    manager = tf.train.CheckpointManager(self._checkpoint, model_dir,
                                         max_checkpoints_to_keep)
    self._exec_with_session(lambda: manager.save())

@@ -922,12 +934,20 @@ class KerasModel(Model):
      with self.session.as_default():
        f()

  def get_checkpoints(self):
    """Get a list of all available checkpoint files."""
    return tf.train.get_checkpoint_state(
        self.model_dir).all_model_checkpoint_paths
  def get_checkpoints(self, model_dir=None):
    """Get a list of all available checkpoint files.

  def restore(self, checkpoint=None):
    Parameters
    ----------
    model_dir: str, default None
      Directory to get list of checkpoints from. Reverts to self.model_dir if None

    """
    if model_dir is None:
      model_dir = self.model_dir
    return tf.train.get_checkpoint_state(model_dir).all_model_checkpoint_paths

  def restore(self, checkpoint=None, restore_from=None):
    """Reload the values of all variables from a checkpoint file.

    Parameters
@@ -936,10 +956,14 @@ class KerasModel(Model):
      the path to the checkpoint file to load.  If this is None, the most recent
      checkpoint will be chosen automatically.  Call get_checkpoints() to get a
      list of all available checkpoints.
    restore_from: str, default None
      Directory to restore checkpoint from. If None, use self.model_dir.
    """
    self._ensure_built()
    if restore_from is None:
      restore_from = self.model_dir
    if checkpoint is None:
      checkpoint = tf.train.latest_checkpoint(self.model_dir)
      checkpoint = tf.train.latest_checkpoint(restore_from)
    if checkpoint is None:
      raise ValueError('No checkpoint found')
    if tf.executing_eagerly():
@@ -975,5 +999,6 @@ class _StandardLoss(object):
        shape = w.shape
      shape = tuple(-1 if x is None else x for x in shape)
      w = tf.reshape(w, shape + (1,) * (len(losses.shape) - len(w.shape)))

    loss = losses * w
    return tf.reduce_mean(loss) + sum(self.model.losses)