Commit 9c51193d authored by peastman's avatar peastman
Browse files

Converted GAN to KerasModel

parent 4e382ee8
Loading
Loading
Loading
Loading
+22 −17
Original line number Diff line number Diff line
@@ -240,7 +240,7 @@ class KerasModel(Model):
    except ValueError:
      # The loss doesn't depend on any variables.
      self._train_op = 0
    self._train_op_for_vars = {}
    self._custom_train_op = {}
    if self.tensorboard:
      self._summary_ops = tf.summary.scalar('loss', self._loss_tensor)
      self._summary_writer = tf.summary.FileWriter(self.model_dir)
@@ -379,22 +379,27 @@ class KerasModel(Model):
        # In graph mode we execute the training op.

        if train_op is None:
          if loss is not None:
          if loss is None and variables is None:
            train_op = self._train_op
          else:
            if variables is None:
              op_key = (None, loss)
            else:
              op_key = tuple(variables) + (loss,)
            if op_key not in self._custom_train_op:
              if loss is None:
                loss_tensor = self._loss_tensor
              else:
                loss_tensor = loss(
                    [self._output_tensors[i] for i in self._loss_outputs],
                    self._label_placeholders, self._weights_placeholders)
            train_op = self._tf_optimizer.minimize(
                loss_tensor, global_step=self._global_step, var_list=variables)
          elif variables is None:
            train_op = self._train_op
              if variables is None:
                vars = self.model.trainable_variables
              else:
            var_key = tuple(variables)
            if var_key not in self._train_op_for_vars:
              self._train_op_for_vars[var_key] = self._tf_optimizer.minimize(
                  self._loss_tensor,
                  global_step=self._global_step,
                  var_list=variables)
            train_op = self._train_op_for_vars[var_key]
                vars = variables
              self._custom_train_op[op_key] = self._tf_optimizer.minimize(
                  loss_tensor, global_step=self._global_step, var_list=vars)
            train_op = self._custom_train_op[op_key]
        fetches = [train_op, self._loss_tensor, self._global_step]
        if should_log:
          fetches.append(self._summary_ops)
+182 −238

File changed.

Preview size limit exceeded, changes collapsed.

+49 −31

File changed.

Preview size limit exceeded, changes collapsed.