Commit 960f3b91 authored by Boris Dayma's avatar Boris Dayma
Browse files

style(keras_model.py): use yapf

parent 74e65b4e
Loading
Loading
Loading
Loading
+31 −35
Original line number Diff line number Diff line
@@ -158,9 +158,8 @@ class KerasModel(Model):
      like a printout every 10 batch steps, you'd set
      `log_frequency=10` for example.
    """
    super(KerasModel, self).__init__(model_instance=model,
                                     model_dir=model_dir,
                                     **kwargs)
    super(KerasModel, self).__init__(
        model_instance=model, model_dir=model_dir, **kwargs)
    self.model = model
    if isinstance(loss, Loss):
      self._loss_fn = _StandardLoss(model, loss)
@@ -225,8 +224,8 @@ class KerasModel(Model):
    self._built = True
    self._global_step = tf.Variable(0, trainable=False)
    self._tf_optimizer = self.optimizer._create_optimizer(self._global_step)
    self._checkpoint = tf.train.Checkpoint(optimizer=self._tf_optimizer,
                                           model=self.model)
    self._checkpoint = tf.train.Checkpoint(
        optimizer=self._tf_optimizer, model=self.model)

  def _create_inputs(self, example_inputs):
    """The first time this is called, create tensors representing the inputs and outputs."""
@@ -300,11 +299,10 @@ class KerasModel(Model):
      every step.  This can be used to perform validation, logging, etc.
   """
    return self.fit_generator(
        self.default_generator(dataset,
                               epochs=nb_epoch,
                               deterministic=deterministic),
        max_checkpoints_to_keep, checkpoint_interval, restore, variables, loss,
        callbacks)
        self.default_generator(
            dataset, epochs=nb_epoch,
            deterministic=deterministic), max_checkpoints_to_keep,
        checkpoint_interval, restore, variables, loss, callbacks)

  def fit_generator(self,
                    generator,
@@ -394,8 +392,8 @@ class KerasModel(Model):
      should_log = (current_step % self.log_frequency == 0)
      if should_log:
        avg_loss = float(avg_loss) / averaged_batches
        logger.info('Ending global_step %d: Average loss %g' %
                    (current_step, avg_loss))
        logger.info(
            'Ending global_step %d: Average loss %g' % (current_step, avg_loss))
        avg_loss = 0.0
        averaged_batches = 0

@@ -412,8 +410,8 @@ class KerasModel(Model):
    # Report final results.
    if averaged_batches > 0:
      avg_loss = float(avg_loss) / averaged_batches
      logger.info('Ending global_step %d: Average loss %g' %
                  (current_step, avg_loss))
      logger.info(
          'Ending global_step %d: Average loss %g' % (current_step, avg_loss))

    if checkpoint_interval > 0:
      manager.save()
@@ -730,10 +728,10 @@ class KerasModel(Model):
    a NumPy array of the model produces a single output, or a list of arrays
    if it produces multiple outputs
    """
    generator = self.default_generator(dataset,
                                       mode='predict',
                                       pad_batches=False)
    return self.predict_on_generator(generator,
    generator = self.default_generator(
        dataset, mode='predict', pad_batches=False)
    return self.predict_on_generator(
        generator,
        transformers=transformers,
        outputs=outputs,
        output_types=output_types)
@@ -754,9 +752,8 @@ class KerasModel(Model):
    a NumPy array of the embeddings model produces, or a list
    of arrays if it produces multiple embeddings
    """
    generator = self.default_generator(dataset,
                                       mode='predict',
                                       pad_batches=False)
    generator = self.default_generator(
        dataset, mode='predict', pad_batches=False)
    return self._predict(generator, [], None, False, ['embedding'])

  def predict_uncertainty(self, dataset, masks=50):
@@ -787,9 +784,8 @@ class KerasModel(Model):
    sum_sq_pred = []
    sum_var = []
    for i in range(masks):
      generator = self.default_generator(dataset,
                                         mode='uncertainty',
                                         pad_batches=False)
      generator = self.default_generator(
          dataset, mode='uncertainty', pad_batches=False)
      results = self._predict(generator, [], None, True, None)
      if len(sum_pred) == 0:
        for p, v in results:
@@ -870,8 +866,8 @@ class KerasModel(Model):
    # Use a GradientTape to compute gradients.

    X = tf.constant(X[0])
    with tf.GradientTape(persistent=True,
                         watch_accessed_variables=False) as tape:
    with tf.GradientTape(
        persistent=True, watch_accessed_variables=False) as tape:
      tape.watch(X)
      outputs = self._compute_model(X)
      if isinstance(outputs, tf.Tensor):
@@ -948,8 +944,8 @@ class KerasModel(Model):
    ([inputs], [outputs], [weights])
    """
    for epoch in range(epochs):
      for (X_b, y_b, w_b,
           ids_b) in dataset.iterbatches(batch_size=self.batch_size,
      for (X_b, y_b, w_b, ids_b) in dataset.iterbatches(
          batch_size=self.batch_size,
          deterministic=deterministic,
          pad_batches=pad_batches):
        yield ([X_b], [y_b], [w_b])
@@ -1128,8 +1124,8 @@ class KerasModel(Model):

    if assignment_map is None:
      logger.info("No assignment map provided. Creating custom assignment map.")
      assignment_map = self._create_assignment_map(source_model=source_model,
                                                   include_top=include_top)
      assignment_map = self._create_assignment_map(
          source_model=source_model, include_top=include_top)

    for source_var, dest_var in assignment_map.items():
      assert source_var.deref().shape == dest_var.shape