Commit 006929eb authored by Boris Dayma's avatar Boris Dayma
Browse files

feat(wandb): add logging through Weights & Biases

parent c9eaf1b6
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -81,6 +81,8 @@ class ValidationCallback(object):
    if model.tensorboard:
      for key in scores:
        model._log_value_to_tensorboard(tag=key, simple_value=scores[key])
    if model.wandb:
      wandb.log(scores)
    if self.save_dir is not None:
      score = scores[self.metrics[self.save_metric].name]
      if not self.save_on_minimum:
+24 −0
Original line number Diff line number Diff line
@@ -17,6 +17,16 @@ from deepchem.models.optimizers import Adam
from deepchem.trans import undo_transforms
from deepchem.utils.evaluate import GeneratorEvaluator

try:
  import wandb
  wandb.ensure_configured()
  if wandb.api.api_key is None:
    _has_wandb = False
    wandb.termwarn("W&B installed but not logged in.  Run `wandb login` or set the WANDB_API_KEY env variable.")
  else:
    _has_wandb = False if os.getenv("WANDB_DISABLED") else True
except (ImportError, AttributeError):
  _has_wandb = False

class KerasModel(Model):
  """This is a DeepChem model implemented by a Keras model.
@@ -104,6 +114,7 @@ class KerasModel(Model):
               learning_rate=0.001,
               optimizer=None,
               tensorboard=False,
               wandb=False,
               log_frequency=100,
               **kwargs):
    """Create a new KerasModel.
@@ -130,6 +141,8 @@ class KerasModel(Model):
      ignored.
    tensorboard: bool
      whether to log progress to TensorBoard during training
    wandb: bool
      whether to log progress to Weights & Biases during training
    log_frequency: int
      The frequency at which to log data. Data is logged using
      `logging` by default. If `tensorboard` is set, data is also
@@ -151,6 +164,15 @@ class KerasModel(Model):
    else:
      self.optimizer = optimizer
    self.tensorboard = tensorboard

    # W&B logging
    if wandb and not _has_wandb:
      logger.warning(
        "You set wandb to True but W&B is not installed. To use wandb logging, "
        "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
      )
    self.wandb = wandb and _has_wandb
    
    # Backwards compatibility
    if "tensorboard_log_frequency" in kwargs:
      logger.warning(
@@ -375,6 +397,8 @@ class KerasModel(Model):
      if self.tensorboard and should_log:
        with self._summary_writer.as_default():
          tf.summary.scalar('loss', batch_loss, current_step)
      if self.wandb and should_log:
        wandb.log({'loss': batch_loss}, step=current_step)

    # Report final results.
    if averaged_batches > 0: