Commit 888f747b authored by Kevin Shen's avatar Kevin Shen
Browse files

removed epoch logging + train metrics, added TorchModel integration

parent c3f0e078
Loading
Loading
Loading
Loading
+2 −24
Original line number Diff line number Diff line
@@ -2,7 +2,6 @@
Callback functions that can be invoked while fitting a KerasModel.
"""
import sys
import math


class ValidationCallback(object):
@@ -25,8 +24,7 @@ class ValidationCallback(object):
               output_file=sys.stdout,
               save_dir=None,
               save_metric=0,
               save_on_minimum=True,
               logging_strategy="step"):
               save_on_minimum=True):
    """Create a ValidationCallback.
    Parameters
    ----------
@@ -48,10 +46,6 @@ class ValidationCallback(object):
      if True, the best model is considered to be the one that minimizes the
      validation metric.  If False, the best model is considered to be the one
      that maximizes it.
    logging_strategy: str
      the logging strategy used for logging (step or epoch). If "step",
      logging interval will be the value provided for `interval`. If "epoch",
      then logging will happen at the end of every training epoch.
    """
    self.dataset = dataset
    self.interval = interval
@@ -61,12 +55,6 @@ class ValidationCallback(object):
    self.save_metric = save_metric
    self.save_on_minimum = save_on_minimum
    self._best_score = None
    if logging_strategy != "step" and logging_strategy != "epoch":
      print(
          "ValidationCallback: `logging_strategy` needs to be either 'step' or 'epoch'. Defaulting to 'step'."
      )
      logging_strategy = "step"
    self.logging_strategy = logging_strategy

  def __call__(self, model, step):
    """This is invoked by the KerasModel after every step of fitting.
@@ -77,15 +65,7 @@ class ValidationCallback(object):
    step: int
      the index of the training step that has just completed
    """

    # Check if we should log to Wandb on this iteration
    steps_per_epoch = math.ceil(len(model.dataset) / model.batch_size)
    should_log = False
    if (self.logging_strategy == "step" and step % self.interval == 0) or \
            (self.logging_strategy == "epoch" and step % steps_per_epoch == 0):
      should_log = True

    if should_log is False:
    if step % self.interval != 0:
      return
    scores = model.evaluate(self.dataset, self.metrics)
    message = 'Step %d validation:' % step
@@ -99,7 +79,6 @@ class ValidationCallback(object):
    if model.wandb:
      import wandb
      wandb.log(scores, step=step)

    if self.save_dir is not None:
      score = scores[self.metrics[self.save_metric].name]
      if not self.save_on_minimum:
@@ -107,7 +86,6 @@ class ValidationCallback(object):
      if self._best_score is None or score < self._best_score:
        model.save_checkpoint(model_dir=self.save_dir)
        self._best_score = score

    if model.wandb_logger is not None:
      # Log data to Wandb
      data = {'eval/' + k: v for k, v in scores.items()}
+7 −75
Original line number Diff line number Diff line
@@ -3,7 +3,6 @@ import tensorflow as tf
import time
import logging
import os
import math

try:
  from collections.abc import Sequence as SequenceCollection
@@ -136,7 +135,6 @@ class KerasModel(Model):
               tensorboard: bool = False,
               wandb: bool = False,
               log_frequency: int = 100,
               logging_strategy: Optional[str] = "step",
               wandb_logger: Optional[WandbLogger] = None,
               **kwargs) -> None:
    """Create a new KerasModel.
@@ -174,7 +172,7 @@ class KerasModel(Model):
      like a printout every 10 batch steps, you'd set
      `log_frequency=10` for example.
    wandb_logger: WandbLogger
      the Weights & Biases logger to log data and metrics
      the Weights & Biases logger object used to log data and metrics
    """
    super(KerasModel, self).__init__(model=model, model_dir=model_dir, **kwargs)
    if isinstance(loss, Loss):
@@ -191,7 +189,7 @@ class KerasModel(Model):
    # W&B flag support (DEPRECATED)
    if wandb:
      logger.warning(
          "'wandb' argument is deprecated. Please use wandb_logger instead. "
          "`wandb` argument is deprecated. Please use `wandb_logger` instead. "
          "This argument will be removed in a future release of DeepChem.")
    if wandb and not _has_wandb:
      logger.warning(
@@ -219,14 +217,6 @@ class KerasModel(Model):
    if self.wandb_logger is not None:
      self.wandb_logger.update_config(wandb_logger_config)

    # Check for valid logging strategy
    if logging_strategy != "step" and logging_strategy != "epoch":
      logger.warning(
          "Warning: `logging_strategy` needs to be either 'step' or 'epoch'. Defaulting to 'step'."
      )
      logging_strategy = "step"
    self.logging_strategy = logging_strategy

    # Backwards compatibility
    if "tensorboard_log_frequency" in kwargs:
      logger.warning(
@@ -316,8 +306,7 @@ class KerasModel(Model):
          variables: Optional[List[tf.Variable]] = None,
          loss: Optional[LossFn] = None,
          callbacks: Union[Callable, List[Callable]] = [],
          all_losses: Optional[List[float]] = None,
          metrics: Optional[List[Metric]] = None) -> float:
          all_losses: Optional[List[float]] = None) -> float:
    """Train this model on a dataset.

    Parameters
@@ -351,16 +340,11 @@ class KerasModel(Model):
      If specified, all logged losses are appended into this list. Note that
      you can call `fit()` repeatedly with the same list and losses will
      continue to be appended.
    metrics: Optional[List[Metric]], optional (default None)
      metrics to compute on the dataset used during training. If None,
      no metrics and scores will be computed and only training loss will be logged.

    Returns
    -------
    The average loss over the most recent checkpoint interval
   """
    self.dataset = dataset
    self.metrics = metrics
    return self.fit_generator(
        self.default_generator(dataset,
                               epochs=nb_epoch,
@@ -439,9 +423,6 @@ class KerasModel(Model):

    # Main training loop.

    # Calculate the number of steps in a training epoch
    steps_per_epoch = math.ceil(len(self.dataset) / self.batch_size)

    for batch in generator:
      self._create_training_ops(batch)
      if restore:
@@ -474,17 +455,6 @@ class KerasModel(Model):
        avg_loss = 0.0
        averaged_batches = 0

      # Calculate epoch number, sample count, and metrics
      epoch_num = self._get_epoch_num(current_step)
      sample_count = self._get_sample_count(current_step)

      # Decide whether to calculate metrics at this current step
      scores = None
      if self.metrics is not None and self.metrics:
        if (self.logging_strategy == "step" and current_step % self.log_frequency == 0) or \
           (self.logging_strategy == "epoch" and current_step % steps_per_epoch == 0):
          scores = self.evaluate(self.dataset, self.metrics)

      if checkpoint_interval > 0 and current_step % checkpoint_interval == checkpoint_interval - 1:
        manager.save()
      for c in callbacks:
@@ -494,25 +464,19 @@ class KerasModel(Model):
      # Wandb flag support (DEPRECATED)
      if self.wandb and should_log:
        wandb.log({'loss': batch_loss}, step=current_step)

      if self.wandb_logger is not None:
      if (self.wandb_logger is not None) and should_log:
        all_data = dict({
            'train/epoch': epoch_num,
            'train/sample_count': sample_count,
            'train/loss': batch_loss
        })
        if scores is not None:
          scores = {'train/' + k: v for k, v in scores.items()}
          all_data.update(scores)
        self.wandb_logger.log_data(all_data, step=current_step)

    if self.wandb:
      wandb.finish()

    # Close WandbLogger
    if self.wandb_logger is not None:
      self.wandb_logger.finish()

    if self.wandb:
      wandb.finish()

    # Report final results.
    if averaged_batches > 0:
      avg_loss = float(avg_loss) / averaged_batches
@@ -1159,38 +1123,6 @@ class KerasModel(Model):
    """Get the number of steps of fitting that have been performed."""
    return int(self._global_step)

  def _get_epoch_num(self, step):
    """Get the epoch number corresponding to current step.

    Parameters
    ----------
    step: int
    the current step during training

    Returns
    -------
    the current step's epoch number (does not have to be an int)
    """
    dataset_size = len(self.dataset)
    steps_per_epoch = math.ceil(dataset_size / self.batch_size)
    epoch_num = step / steps_per_epoch
    return epoch_num

  def _get_sample_count(self, step):
    """Get the number of samples seen during training at step.

    Parameters
    ----------
    step: int
    the current step during training

    Returns
    -------
    the number of samples seen by the model by the current step
    """
    sample_count = step * self.batch_size
    return sample_count

  def _log_scalar_to_tensorboard(self, name: str, value: Any, step: int):
    """Log a scalar value to Tensorboard."""
    with self._summary_writer.as_default():
+5 −7
Original line number Diff line number Diff line
@@ -4,9 +4,6 @@ import deepchem as dc
import numpy as np
import tensorflow as tf

from deepchem.models import ValidationCallback


def test_overfit_graph_model():
  """Test fitting a KerasModel defined as a graph."""
  n_data_points = 10
@@ -305,16 +302,17 @@ def test_wandblogger():
  tasks, datasets, transformers = dc.molnet.load_delaney(featurizer='ECFP',
                                                         splitter='random')
  train_dataset, valid_dataset, test_dataset = datasets
  metric = dc.metrics.Metric(dc.metrics.pearson_r2_score)
  wandblogger = dc.models.WandbLogger(anonymous="allow", save_run_history=True)

  keras_model = tf.keras.Sequential(
      [tf.keras.layers.Dense(10, activation='relu'),
       tf.keras.layers.Dense(1)])
  metric = dc.metrics.Metric(dc.metrics.pearson_r2_score)
  wandblogger = dc.models.WandbLogger(anonymous="allow", save_run_history=True)
  model = dc.models.KerasModel(keras_model,
                               dc.models.losses.L2Loss(),
                               wandb_logger=wandblogger)
  vc = ValidationCallback(valid_dataset, 1, [metric])
  model.fit(train_dataset, nb_epoch=10, metrics=[metric], callbacks=[vc])
  vc = dc.models.ValidationCallback(valid_dataset, 1, [metric])
  model.fit(train_dataset, nb_epoch=10, callbacks=[vc])

  run_data = wandblogger.run_history._data
  valid_score = model.evaluate(valid_dataset, [metric], transformers)
+27 −0
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@ import os
import unittest
import deepchem as dc
import numpy as np
import math

try:
  import torch
@@ -333,6 +334,32 @@ def test_tensorboard():
  file_size = os.stat(event_file).st_size
  assert file_size > 0

@unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
def test_wandblogger():
  """Test logging to Weights & Biases."""
  # Load dataset and Models
  tasks, datasets, transformers = dc.molnet.load_delaney(featurizer='ECFP',
                                                         splitter='random')
  train_dataset, valid_dataset, test_dataset = datasets
  metric = dc.metrics.Metric(dc.metrics.pearson_r2_score)
  wandblogger = dc.models.WandbLogger(anonymous="allow", save_run_history=True)

  pytorch_model = torch.nn.Sequential(
        torch.nn.Linear(1024, 1000),
        torch.nn.Dropout(p=0.5),
        torch.nn.Linear(1000, 1))
  model = dc.models.TorchModel(pytorch_model,
                               dc.models.losses.L2Loss(),
                               wandb_logger=wandblogger)
  vc = dc.models.ValidationCallback(valid_dataset, 1, [metric])
  model.fit(train_dataset, nb_epoch=10, callbacks=[vc])

  run_data = wandblogger.run_history._data
  valid_score = model.evaluate(valid_dataset, [metric], transformers)

  assert math.isclose(valid_score["pearson_r2_score"],
                      run_data['eval/pearson_r2_score'],
                      abs_tol=0.0005)

@unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
def test_fit_variables():
+34 −0
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@ from deepchem.utils.evaluate import GeneratorEvaluator

from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from deepchem.utils.typing import ArrayLike, LossFn, OneOrMany
from deepchem.models.wandblogger import WandbLogger

try:
  import wandb
@@ -118,6 +119,7 @@ class TorchModel(Model):
               wandb: bool = False,
               log_frequency: int = 100,
               device: Optional[torch.device] = None,
               wandb_logger: Optional[WandbLogger] = None,
               **kwargs) -> None:
    """Create a new TorchModel.

@@ -187,6 +189,26 @@ class TorchModel(Model):
      )
    self.wandb = wandb and _has_wandb

    self.wandb_logger = wandb_logger

    # Setup and initialize W&B logging
    if (self.wandb_logger is not None) and (not self.wandb_logger.initialized):
      self.wandb_logger.setup()

    # Update config with KerasModel params
    wandb_logger_config = dict(loss=loss,
                               output_types=output_types,
                               batch_size=batch_size,
                               model_dir=model_dir,
                               learning_rate=learning_rate,
                               optimizer=optimizer,
                               tensorboard=tensorboard,
                               log_frequency=log_frequency)
    wandb_logger_config.update(**kwargs)

    if self.wandb_logger is not None:
      self.wandb_logger.update_config(wandb_logger_config)

    self.log_frequency = log_frequency
    if self.tensorboard:
      self._summary_writer = torch.utils.tensorboard.SummaryWriter(
@@ -406,6 +428,18 @@ class TorchModel(Model):
        self._log_scalar_to_tensorboard('loss', batch_loss, current_step)
      if self.wandb and should_log:
        wandb.log({'loss': batch_loss}, step=current_step)
      if (self.wandb_logger is not None) and should_log:
        all_data = dict({
          'train/loss': batch_loss
        })
        self.wandb_logger.log_data(all_data, step=current_step)

    if self.wandb:
      wandb.finish()

    # Close WandbLogger
    if self.wandb_logger is not None:
      self.wandb_logger.finish()

    # Report final results.
    if averaged_batches > 0:
Loading