Unverified Commit 7e745b93 authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #2246 from peastman/tensorboard

Fixed writing to tensorboard from callback
parents af0c1a69 dd63de23
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -78,7 +78,8 @@ class ValidationCallback(object):
    print(message, file=self.output_file)
    if model.tensorboard:
      for key in scores:
        model._log_value_to_tensorboard(tag=key, simple_value=scores[key])
        model._log_scalar_to_tensorboard(key, scores[key],
                                         model.get_global_step())
    if model.wandb:
      import wandb
      wandb.log(scores, step=step)
+6 −2
Original line number Diff line number Diff line
@@ -431,8 +431,7 @@ class KerasModel(Model):
      for c in callbacks:
        c(self, current_step)
      if self.tensorboard and should_log:
        with self._summary_writer.as_default():
          tf.summary.scalar('loss', batch_loss, current_step)
        self._log_scalar_to_tensorboard('loss', batch_loss, current_step)
      if self.wandb and should_log:
        wandb.log({'loss': batch_loss}, step=current_step)

@@ -1075,6 +1074,11 @@ class KerasModel(Model):
    """Get the number of steps of fitting that have been performed."""
    return int(self._global_step)

  def _log_scalar_to_tensorboard(self, name: str, value: Any, step: int):
    """Log a scalar value to Tensorboard."""
    with self._summary_writer.as_default():
      tf.summary.scalar(name, value, step)

  def _create_assignment_map(self,
                             source_model: "KerasModel",
                             include_top: bool = True,
+5 −1
Original line number Diff line number Diff line
@@ -402,7 +402,7 @@ class TorchModel(Model):
      for c in callbacks:
        c(self, current_step)
      if self.tensorboard and should_log:
        self._summary_writer.add_scalar('loss', batch_loss, current_step)
        self._log_scalar_to_tensorboard('loss', batch_loss, current_step)
      if self.wandb and should_log:
        wandb.log({'loss': batch_loss}, step=current_step)

@@ -983,6 +983,10 @@ class TorchModel(Model):
    """Get the number of steps of fitting that have been performed."""
    return self._global_step

  def _log_scalar_to_tensorboard(self, name: str, value: Any, step: int):
    """Log a scalar value to Tensorboard."""
    self._summary_writer.add_scalar(name, value, step)

  def _create_assignment_map(self,
                             source_model: "TorchModel",
                             include_top: bool = True,