Commit 0efd1f69 authored by Kevin Shen's avatar Kevin Shen
Browse files

wandb initialization

parent bac28d46
Loading
Loading
Loading
Loading
+1 −34
Original line number Diff line number Diff line
@@ -90,36 +90,3 @@ class ValidationCallback(object):
      if self._best_score is None or score < self._best_score:
        model.save_checkpoint(model_dir=self.save_dir)
        self._best_score = score
 No newline at end of file

class WandbCallback(object):
  """
  Weights & Biases Logger
  """

  def __init__(self,
              **kwargs):
    try:
      import wandb
    except ImportError:
      raise ImportError(
        'You want to use `wandb` logger which is not installed yet,'
        ' install it with `pip install wandb`.'
      )




  def __call__(self, model, step):
    """This is invoked by the KerasModel after every step of fitting.

    Parameters
    ----------
    model: KerasModel
      the model that is being trained
    step: int
      the index of the training step that has just completed
    """



+34 −6
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@ import tensorflow as tf
import time
import logging
import os

try:
  from collections.abc import Sequence as SequenceCollection
except:
@@ -13,11 +14,13 @@ from deepchem.metrics import Metric
from deepchem.models.losses import Loss
from deepchem.models.models import Model
from deepchem.models.optimizers import Adam, Optimizer, LearningRateSchedule
from deepchem.models.callbacks import ValidationCallback
from deepchem.trans import Transformer, undo_transforms
from deepchem.utils.evaluate import GeneratorEvaluator

from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from deepchem.utils.typing import ArrayLike, LossFn, OneOrMany
from deepchem.models.wandblogger import WandbLogger

try:
  import wandb
@@ -131,7 +134,7 @@ class KerasModel(Model):
               learning_rate: Union[float, LearningRateSchedule] = 0.001,
               optimizer: Optional[Optimizer] = None,
               tensorboard: bool = False,
               wandb: bool = False,
               wandb: Optional[WandbLogger] = None,
               log_frequency: int = 100,
               **kwargs) -> None:
    """Create a new KerasModel.
@@ -158,7 +161,7 @@ class KerasModel(Model):
      ignored.
    tensorboard: bool
      whether to log progress to TensorBoard during training
    wandb: bool
    wandb: bool (MODIFY)
      whether to log progress to Weights & Biases during training
    log_frequency: int
      The frequency at which to log data. Data is logged using
@@ -182,12 +185,28 @@ class KerasModel(Model):
    self.tensorboard = tensorboard

    # W&B logging
    if wandb and not _has_wandb:
    if (wandb is not None) and not _has_wandb:
      logger.warning(
          "You set wandb to True but W&B is not installed. To use wandb logging, "
          "You are using a wandb logger but W&B is not installed. To use wandb logging, "
          "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
      )
    self.wandb = wandb and _has_wandb
    self.wandb_logger = wandb
    #Need to save model??
    #wandb.save_model(...) #save the tf.keras.Model

    #update config with KerasModel Params
    run_params = dict(
      model=model,
      loss=loss,
      output_types=output_types,
      batch_size=batch_size,
      model_dir=model_dir,
      learning_rate=learning_rate,
      optimizer=optimizer,
      tensorboard=tensorboard,
      log_frequency=log_frequency
    )
    self.wandb_logger.wandb.config.update(run_params)

    # Backwards compatibility
    if "tensorboard_log_frequency" in kwargs:
@@ -394,6 +413,15 @@ class KerasModel(Model):

    # Main training loop.

    # Warn if both ValidationCallback and WandbLogger present
    if wandb is not None:
      for c in callbacks:
        if isinstance(c, ValidationCallback):
          logger.warning(
            "You are using both WandbLogger and ValidationCallback. WandbLogger is able to log validation metrics" 
            "so there is no need to have a ValidationCallback. Logging validation metrics twice may take longer."
          )

    for batch in generator:
      self._create_training_ops(batch)
      if restore:
@@ -433,7 +461,7 @@ class KerasModel(Model):
      if self.tensorboard and should_log:
        self._log_scalar_to_tensorboard('loss', batch_loss, current_step)
      if self.wandb and should_log:
        wandb.log({'loss': batch_loss}, step=current_step)
        #wandb.log({'loss': batch_loss}, step=current_step)

    # Report final results.
    if averaged_batches > 0:
+84 −0
Original line number Diff line number Diff line
import torch
import tensorflow as tf
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from deepchem.data import Dataset, NumpyDataset
from deepchem.metrics import Metric

class WandbLogger(object):
    """
  Weights & Biases Logger
  """

    def __init__(self,
                 datasets: List[Dataset],
                 metrics: List[Metric],
                 log_loss: bool = True,
                 name: Optional[str] = None,
                 save_dir: Optional[str] = None,
                 offline: Optional[bool] = False,
                 id: Optional[str] = None,
                 anonymous: Optional[bool] = None,
                 version: Optional[str] = None,
                 project: Optional[str] = None,
                 log_model: Optional[bool] = False,
                 experiment=None,
                 prefix: Optional[str] = '',
                 **kwargs):
        try:
            import wandb
        except ImportError:
            raise ImportError(
                'You want to use `wandb` logger which is not installed yet,'
                ' install it with `pip install wandb`.'
            )

        if offline and log_model:
            # TODO: Different exception type?
            raise Exception(
                f'Providing log_model={log_model} and offline={offline} is an invalid configuration'
                ' since model checkpoints cannot be uploaded in offline mode.\n'
                'Hint: Set `offline=False` to log your model.'
            )
        self.base_model = None # will be set in KerasModel init
        self.datasets = datasets
        self.metrics = metrics
        self.log_loss = log_loss

        self.offline = offline
        self.log_model = log_model
        self.prefix = prefix
        self.experiment = experiment
        # set wandb init arguments
        anonymous_lut = {True: 'allow', False: None}
        self.wandb_init = dict(
            name=name,
            project=project,
            id=version or id,
            dir=save_dir,
            resume='allow',
            anonymous=anonymous_lut.get(anonymous, anonymous)
        )
        self.wandb_init.update(**kwargs)
        # extract parameters
        self.save_dir = self.wandb_init.get('dir')
        self.name = self.wandb_init.get('name')
        self.id = self.wandb_init.get('id')

        #Log the parameters of KerasModel

        self.wandb = wandb.init(**self.wandb_init) if wandb.run is None else wandb.run

    def save_model(self, model):
        #model is a tf.keras.Model
        return None

    def log_data(self, model, step):
        #model is a Deepchem KerasModel
        for dataset in self.datasets:
            scores = model.evaluate(dataset, self.metrics)
        self.wandb.log()


    def update_config(self):
        return None