Unverified Commit b26a6014 authored by Kevin Shen's avatar Kevin Shen Committed by GitHub
Browse files

WandbLogger fixes: removed finish() from Model fit(), ValidationCallbacks fixes (#2586)

* removed finish from fit(), separate logging keys for different eval datasets

* yapf format fix

* fixed type annotations, wandb_logger existance checking in callback
parent 1a2d2e9f
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -93,7 +93,7 @@ class ValidationCallback(object):
      if self._best_score is None or score < self._best_score:
        model.save_checkpoint(model_dir=self.save_dir)
        self._best_score = score
    if model.wandb or (model.wandb_logger is not None):
    if model.wandb_logger is not None:
      # Log data to Wandb
      data = {'eval/' + k: v for k, v in scores.items()}
      model.wandb_logger.log_data(data, step)
      model.wandb_logger.log_data(data, step, dataset_id=id(self.dataset))
+0 −4
Original line number Diff line number Diff line
@@ -468,10 +468,6 @@ class KerasModel(Model):
        all_data = dict({'train/loss': batch_loss})
        self.wandb_logger.log_data(all_data, step=current_step)

    # Close WandbLogger
    if self.wandb_logger is not None:
      self.wandb_logger.finish()

    # Report final results.
    if averaged_batches > 0:
      avg_loss = float(avg_loss) / averaged_batches
+7 −3
Original line number Diff line number Diff line
@@ -335,15 +335,19 @@ def test_wandblogger():
       tf.keras.layers.Dense(1)])
  model = dc.models.KerasModel(
      keras_model, dc.models.losses.L2Loss(), wandb_logger=wandblogger)
  vc = dc.models.ValidationCallback(valid_dataset, 1, [metric])
  model.fit(train_dataset, nb_epoch=10, callbacks=[vc])
  vc_train = dc.models.ValidationCallback(train_dataset, 1, [metric])
  vc_valid = dc.models.ValidationCallback(valid_dataset, 1, [metric])
  model.fit(train_dataset, nb_epoch=10, callbacks=[vc_train, vc_valid])
  # call model.fit again to test multiple fit() calls
  model.fit(train_dataset, nb_epoch=10, callbacks=[vc_train, vc_valid])
  wandblogger.finish()

  run_data = wandblogger.run_history
  valid_score = model.evaluate(valid_dataset, [metric], transformers)

  assert math.isclose(
      valid_score["pearson_r2_score"],
      run_data['eval/pearson_r2_score'],
      run_data['eval/pearson_r2_score_(1)'],
      abs_tol=0.0005)


+7 −3
Original line number Diff line number Diff line
@@ -360,15 +360,19 @@ def test_wandblogger():
      torch.nn.Linear(1000, 1))
  model = dc.models.TorchModel(
      pytorch_model, dc.models.losses.L2Loss(), wandb_logger=wandblogger)
  vc = dc.models.ValidationCallback(valid_dataset, 1, [metric])
  model.fit(train_dataset, nb_epoch=10, callbacks=[vc])
  vc_train = dc.models.ValidationCallback(train_dataset, 1, [metric])
  vc_valid = dc.models.ValidationCallback(valid_dataset, 1, [metric])
  model.fit(train_dataset, nb_epoch=10, callbacks=[vc_train, vc_valid])
  # call model.fit again to test multiple fit() calls
  model.fit(train_dataset, nb_epoch=10, callbacks=[vc_train, vc_valid])
  wandblogger.finish()

  run_data = wandblogger.run_history
  valid_score = model.evaluate(valid_dataset, [metric], transformers)

  assert math.isclose(
      valid_score["pearson_r2_score"],
      run_data['eval/pearson_r2_score'],
      run_data['eval/pearson_r2_score_(1)'],
      abs_tol=0.0005)


+0 −4
Original line number Diff line number Diff line
@@ -452,10 +452,6 @@ class TorchModel(Model):
        all_data = dict({'train/loss': batch_loss})
        self.wandb_logger.log_data(all_data, step=current_step)

    # Close WandbLogger
    if self.wandb_logger is not None:
      self.wandb_logger.finish()

    # Report final results.
    if averaged_batches > 0:
      avg_loss = float(avg_loss) / averaged_batches
Loading