Commit bac28d46 authored by Kevin Shen's avatar Kevin Shen
Browse files

created wandb branch

parent bb4364ab
Loading
Loading
Loading
Loading
+33 −0
Original line number Diff line number Diff line
@@ -90,3 +90,36 @@ class ValidationCallback(object):
      if self._best_score is None or score < self._best_score:
        model.save_checkpoint(model_dir=self.save_dir)
        self._best_score = score

class WandbCallback(object):
  """
  Weights & Biases Logger
  """

  def __init__(self,
              **kwargs):
    try:
      import wandb
    except ImportError:
      raise ImportError(
        'You want to use `wandb` logger which is not installed yet,'
        ' install it with `pip install wandb`.'
      )




  def __call__(self, model, step):
    """This is invoked by the KerasModel after every step of fitting.

    Parameters
    ----------
    model: KerasModel
      the model that is being trained
    step: int
      the index of the training step that has just completed
    """