Commit 8d0c6bc9 authored by Kevin Shen's avatar Kevin Shen
Browse files

minor updates to code: comments, guardrails, warnings

minor fixes

wandb guardrails test keras model
parent 89d567ad
Loading
Loading
Loading
Loading
+1 −4
Original line number Diff line number Diff line
@@ -80,9 +80,6 @@ class ValidationCallback(object):
      for key in scores:
        model._log_scalar_to_tensorboard(key, scores[key],
                                         model.get_global_step())
    if model.wandb:
      import wandb
      wandb.log(scores, step=step)
    if self.save_dir is not None:
      score = scores[self.metrics[self.save_metric].name]
      if not self.save_on_minimum:
@@ -90,7 +87,7 @@ class ValidationCallback(object):
      if self._best_score is None or score < self._best_score:
        model.save_checkpoint(model_dir=self.save_dir)
        self._best_score = score
    if model.wandb_logger is not None:
    if model.wandb or (model.wandb_logger is not None):
      # Log data to Wandb
      data = {'eval/' + k: v for k, v in scores.items()}
      model.wandb_logger.log_data(data, step)
+8 −0
Original line number Diff line number Diff line
@@ -3,6 +3,13 @@ import math
import deepchem as dc
import numpy as np
import tensorflow as tf
import unittest

try:
  import wandb
  has_wandb = True
except:
  has_wandb = False


def test_overfit_graph_model():
@@ -297,6 +304,7 @@ def test_tensorboard():
  assert file_size > 0


@unittest.skipIf(not has_wandb, 'Wandb is not installed')
def test_wandblogger():
  """Test logging to Weights & Biases."""
  # Load dataset and Models
+9 −2
Original line number Diff line number Diff line
@@ -162,6 +162,8 @@ class TorchModel(Model):
    regularization_loss: Callable, optional
      a function that takes no arguments, and returns an extra contribution to add
      to the loss function
    wandb_logger: WandbLogger
      the Weights & Biases logger object used to log data and metrics
    """
    super(TorchModel, self).__init__(model=model, model_dir=model_dir, **kwargs)
    if isinstance(loss, Loss):
@@ -187,6 +189,10 @@ class TorchModel(Model):
    self.model = model.to(device)

    # W&B logging
    if wandb:
      logger.warning(
          "`wandb` argument is deprecated. Please use `wandb_logger` instead. "
          "This argument will be removed in a future release of DeepChem.")
    if wandb and not _has_wandb:
      logger.warning(
          "You set wandb to True but W&B is not installed. To use wandb logging, "
@@ -212,7 +218,8 @@ class TorchModel(Model):
        learning_rate=learning_rate,
        optimizer=optimizer,
        tensorboard=tensorboard,
        log_frequency=log_frequency)
        log_frequency=log_frequency,
        regularization_loss=regularization_loss)
    wandb_logger_config.update(**kwargs)

    if self.wandb_logger is not None: