Commit 971f981b authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Fix

parent 90962639
Loading
Loading
Loading
Loading
+6 −5
Original line number Diff line number Diff line
@@ -104,7 +104,7 @@ class KerasModel(Model):
               learning_rate=0.001,
               optimizer=None,
               tensorboard=False,
               tensorboard_log_frequency=100,
               log_frequency=100,
               **kwargs):
    """Create a new KerasModel.

@@ -130,8 +130,9 @@ class KerasModel(Model):
      ignored.
    tensorboard: bool
      whether to log progress to TensorBoard during training
    tensorboard_log_frequency: int
      the frequency at which to log data to TensorBoard, measured in batches
    log_frequency: int
      The frequency at which to log data. Data is logged using `logging` by
      default. If `tensorboard` is set, data is also logged to TensorBoard.
    """
    super(KerasModel, self).__init__(
        model_instance=model, model_dir=model_dir, **kwargs)
@@ -146,7 +147,7 @@ class KerasModel(Model):
    else:
      self.optimizer = optimizer
    self.tensorboard = tensorboard
    self.tensorboard_log_frequency = tensorboard_log_frequency
    self.log_frequency = log_frequency
    if self.tensorboard:
      self._summary_writer = tf.summary.create_file_writer(self.model_dir)
    if output_types is None:
@@ -348,7 +349,7 @@ class KerasModel(Model):

      # Report progress and write checkpoints.
      averaged_batches += 1
      should_log = (current_step % self.tensorboard_log_frequency == 0)
      should_log = (current_step % self.log_frequency == 0)
      if should_log:
        avg_loss = float(avg_loss) / averaged_batches
        logger.info(
+1 −1
Original line number Diff line number Diff line
@@ -235,7 +235,7 @@ class TestKerasModel(unittest.TestCase):
        keras_model,
        dc.models.losses.CategoricalCrossEntropy(),
        tensorboard=True,
        tensorboard_log_frequency=1)
        log_frequency=1)
    model.fit(dataset, nb_epoch=10)
    files_in_dir = os.listdir(model.model_dir)
    event_file = list(filter(lambda x: x.startswith("events"), files_in_dir))