Unverified Commit 42fbae1d authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1625 from peastman/gan

Converted GAN to KerasModel
parents 7ca3a111 c545d3ab
Loading
Loading
Loading
Loading
+22 −17
Original line number Diff line number Diff line
@@ -240,7 +240,7 @@ class KerasModel(Model):
    except ValueError:
      # The loss doesn't depend on any variables.
      self._train_op = 0
    self._train_op_for_vars = {}
    self._custom_train_op = {}
    if self.tensorboard:
      self._summary_ops = tf.summary.scalar('loss', self._loss_tensor)
      self._summary_writer = tf.summary.FileWriter(self.model_dir)
@@ -379,22 +379,27 @@ class KerasModel(Model):
        # In graph mode we execute the training op.

        if train_op is None:
          if loss is not None:
          if loss is None and variables is None:
            train_op = self._train_op
          else:
            if variables is None:
              op_key = (None, loss)
            else:
              op_key = tuple(variables) + (loss,)
            if op_key not in self._custom_train_op:
              if loss is None:
                loss_tensor = self._loss_tensor
              else:
                loss_tensor = loss(
                    [self._output_tensors[i] for i in self._loss_outputs],
                    self._label_placeholders, self._weights_placeholders)
            train_op = self._tf_optimizer.minimize(
                loss_tensor, global_step=self._global_step, var_list=variables)
          elif variables is None:
            train_op = self._train_op
              if variables is None:
                vars = self.model.trainable_variables
              else:
            var_key = tuple(variables)
            if var_key not in self._train_op_for_vars:
              self._train_op_for_vars[var_key] = self._tf_optimizer.minimize(
                  self._loss_tensor,
                  global_step=self._global_step,
                  var_list=variables)
            train_op = self._train_op_for_vars[var_key]
                vars = variables
              self._custom_train_op[op_key] = self._tf_optimizer.minimize(
                  loss_tensor, global_step=self._global_step, var_list=vars)
            train_op = self._custom_train_op[op_key]
        fetches = [train_op, self._loss_tensor, self._global_step]
        if should_log:
          fetches.append(self._summary_ops)
+191 −237

File changed.

Preview size limit exceeded, changes collapsed.

+1 −1
Original line number Diff line number Diff line
@@ -247,7 +247,7 @@ class ProgressiveMultitaskRegressor(KerasModel):
        dataset, epochs=nb_epoch, deterministic=deterministic)
    variables = []
    for layer in self._task_layers[task]:
      variables.append(layer.trainable_variables)
      variables += layer.trainable_variables
    loss = TaskLoss(self.model, self.create_loss(), task)
    self.fit_generator(
        generator,
+49 −31

File changed.

Preview size limit exceeded, changes collapsed.

+49 −35

File changed.

Preview size limit exceeded, changes collapsed.