Commit 5168ec9c authored by peastman's avatar peastman
Browse files

Converted WGAN to work with TF2

parent ff189a79
Loading
Loading
Loading
Loading
+43 −21
Original line number Diff line number Diff line
@@ -115,15 +115,12 @@ class GAN(KerasModel):
      discriminator = self.create_discriminator()
      self.discriminators.append(discriminator)
      discrim_train_outputs.append(
          discriminator(
              _list_or_tensor(self.data_input_layers +
                              self.conditional_input_layers)))
          self._call_discriminator(discriminator, self.data_input_layers, True))
      for gen_output in generator_outputs:
        if isinstance(gen_output, tf.Tensor):
          gen_output = [gen_output]
        discrim_gen_outputs.append(
            discriminator(
                _list_or_tensor(gen_output + self.conditional_input_layers)))
            self._call_discriminator(discriminator, gen_output, False))
      self.discrim_variables += discriminator.trainable_variables

    # Compute the loss functions.
@@ -182,6 +179,15 @@ class GAN(KerasModel):
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    super(GAN, self).__init__(model, self.gen_loss_fn, **kwargs)

  def _call_discriminator(self, discriminator, inputs, train):
    """Invoke the discriminator on a set of inputs.

    This is a separate method so WGAN can override it and also return the
    gradient penalty.
    """
    return discriminator(
        _list_or_tensor(inputs + self.conditional_input_layers))

  def get_noise_input_shape(self):
    """Get the shape of the generator's noise input layer.

@@ -494,24 +500,37 @@ class WGAN(GAN):
    self.gradient_penalty = gradient_penalty
    super(WGAN, self).__init__(**kwargs)

  def _call_discriminator(self, discriminator, inputs, train):
    if train:
      penalty = GradientPenaltyLayer(self, discriminator)
      return penalty(inputs, self.conditional_input_layers)
    return discriminator(
        _list_or_tensor(inputs + self.conditional_input_layers))

  def create_generator_loss(self, discrim_output):
    return Lambda(lambda x: tf.reduce_mean(x))(discrim_output)

  def create_discriminator_loss(self, discrim_output_train, discrim_output_gen):
    gradient_penalty = GradientPenaltyLayer(self)(discrim_output_train)
    return Lambda(lambda x: x[0] + tf.reduce_mean(x[1] - x[2]))(
        [gradient_penalty, discrim_output_train, discrim_output_gen])
    return Lambda(lambda x: tf.reduce_mean(x[0] - x[1]))(
        [discrim_output_train[0], discrim_output_gen]) + discrim_output_train[1]


class GradientPenaltyLayer(Layer):
  """Implements the gradient penalty loss term for WGANs."""

  def __init__(self, gan, **kwargs):
  def __init__(self, gan, discriminator, **kwargs):
    super(GradientPenaltyLayer, self).__init__(**kwargs)
    self.gan = gan

  def call(self, inputs):
    gradients = tf.gradients(inputs, self.gan.data_input_layers)
    self.discriminator = discriminator

  def call(self, inputs, conditional_inputs):
    with tf.GradientTape() as tape:
      for layer in inputs:
        tape.watch(layer)
      output = self.discriminator(_list_or_tensor(inputs + conditional_inputs))
    gradients = tape.gradient(output, inputs)
    gradients = [g for g in gradients if g is not None]
    if len(gradients) > 0:
      norm2 = 0.0
      for g in gradients:
        g2 = tf.square(g)
@@ -520,4 +539,7 @@ class GradientPenaltyLayer(Layer):
          g2 = tf.reduce_sum(g2, axis=list(range(1, dims)))
        norm2 += g2
      penalty = tf.square(tf.sqrt(norm2) - 1.0)
    return self.gan.gradient_penalty * tf.reduce_mean(penalty)
      penalty = self.gan.gradient_penalty * tf.reduce_mean(penalty)
    else:
      penalty = 0.0
    return [output, penalty]