Commit be741b65 authored by Kevin Shen's avatar Kevin Shen
Browse files

removed deprecated wandb code, fixed styles and tests

parent 87aad9d2
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -6,10 +6,12 @@ import sys

class ValidationCallback(object):
  """Performs validation while training a KerasModel.

  This is a callback that can be passed to fit().  It periodically computes a
  set of metrics over a validation set and writes them to a file.  In addition,
  it can save the best model parameters found so far to a directory on disk,
  updating them every time it finds a new best validation score.

  If Tensorboard logging is enabled on the KerasModel, the metrics are also
  logged to Tensorboard.  This only happens when validation coincides with a
  step on which the model writes to the log.  You should therefore make sure
@@ -26,6 +28,7 @@ class ValidationCallback(object):
               save_metric=0,
               save_on_minimum=True):
    """Create a ValidationCallback.

    Parameters
    ----------
    dataset: dc.data.Dataset
@@ -58,6 +61,7 @@ class ValidationCallback(object):

  def __call__(self, model, step):
    """This is invoked by the KerasModel after every step of fitting.

    Parameters
    ----------
    model: KerasModel
+3 −6
Original line number Diff line number Diff line
@@ -198,6 +198,9 @@ class KerasModel(Model):
    self.wandb = wandb and _has_wandb

    self.wandb_logger = wandb_logger
    # If `wandb=True` and no logger is provided, initialize default logger
    if self.wandb and (self.wandb_logger is None):
      self.wandb_logger = WandbLogger()

    # Setup and initialize W&B logging
    if (self.wandb_logger is not None) and (not self.wandb_logger.initialized):
@@ -461,16 +464,10 @@ class KerasModel(Model):
        c(self, current_step)
      if self.tensorboard and should_log:
        self._log_scalar_to_tensorboard('loss', batch_loss, current_step)
      # Wandb flag support (DEPRECATED)
      if self.wandb and should_log:
        wandb.log({'loss': batch_loss}, step=current_step)
      if (self.wandb_logger is not None) and should_log:
        all_data = dict({'train/loss': batch_loss})
        self.wandb_logger.log_data(all_data, step=current_step)

    if self.wandb:
      wandb.finish()

    # Close WandbLogger
    if self.wandb_logger is not None:
      self.wandb_logger.finish()
+8 −1
Original line number Diff line number Diff line
@@ -11,6 +11,12 @@ try:
except:
  has_pytorch = False

try:
  import wandb
  has_wandb = True
except:
  has_wandb = False


@unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
def test_overfit_subclass_model():
@@ -335,7 +341,8 @@ def test_tensorboard():
  assert file_size > 0


@unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
@unittest.skipIf((not has_pytorch) or (not has_wandb),
                 'PyTorch and/or Wandb is not installed')
def test_wandblogger():
  """Test logging to Weights & Biases."""
  # Load dataset and Models
+3 −5
Original line number Diff line number Diff line
@@ -190,6 +190,9 @@ class TorchModel(Model):
    self.wandb = wandb and _has_wandb

    self.wandb_logger = wandb_logger
    # If `wandb=True` and no logger is provided, initialize default logger
    if self.wandb and (self.wandb_logger is None):
      self.wandb_logger = WandbLogger()

    # Setup and initialize W&B logging
    if (self.wandb_logger is not None) and (not self.wandb_logger.initialized):
@@ -427,15 +430,10 @@ class TorchModel(Model):
        c(self, current_step)
      if self.tensorboard and should_log:
        self._log_scalar_to_tensorboard('loss', batch_loss, current_step)
      if self.wandb and should_log:
        wandb.log({'loss': batch_loss}, step=current_step)
      if (self.wandb_logger is not None) and should_log:
        all_data = dict({'train/loss': batch_loss})
        self.wandb_logger.log_data(all_data, step=current_step)

    if self.wandb:
      wandb.finish()

    # Close WandbLogger
    if self.wandb_logger is not None:
      self.wandb_logger.finish()
+3 −1
Original line number Diff line number Diff line
@@ -33,7 +33,9 @@ class WandbLogger(object):
               anonymous: Optional[str] = "never",
               save_run_history: Optional[bool] = False,
               **kwargs):
    """Parameters
    """Creates a WandbLogger.

    Parameters
    ----------
    name: str
      a display name for the run in the W&B dashboard
Loading