Commit c5c9728e authored by peastman's avatar peastman
Browse files

GAN duplicates discriminator instead of using is_training flags

parent ab7bc3a7
Loading
Loading
Loading
Loading
+40 −61
Original line number Diff line number Diff line
@@ -66,16 +66,13 @@ class GAN(TensorGraph):
    for shape in self.get_data_input_shapes():
      self.data_inputs.append(layers.Feature(shape=shape))
    self.conditional_inputs = []
    self.noise_conditional_inputs = []
    for shape in self.get_conditional_input_shapes():
      self.conditional_inputs.append(layers.Feature(shape=shape))
      self.noise_conditional_inputs.append(layers.Feature(shape=shape))
    self.is_training = layers.Weights(shape=(None, 1))

    # Create the generator.

    self.generator = self.create_generator(self.noise_input,
                                           self.noise_conditional_inputs)
                                           self.conditional_inputs)
    if not isinstance(self.generator, Sequence):
      raise ValueError('create_generator() must return a list of Layers')
    if len(self.generator) != len(self.data_inputs):
@@ -92,14 +89,18 @@ class GAN(TensorGraph):

    # Create the discriminator.

    self._discrim_data = []
    self.discrim_train = self.create_discriminator(self.data_inputs,
                                                   self.conditional_inputs)

    # Make a copy of the discriminator that takes the generator's output as
    # its input.

    replacements = {}
    for g, d in zip(self.generator, self.data_inputs):
      self._discrim_data.append(layers.Concat([g, d], axis=0))
    self._discrim_conditional = []
    for n, c in zip(self.noise_conditional_inputs, self.conditional_inputs):
      self._discrim_conditional.append(layers.Concat([n, c], axis=0))
    self.discriminator = self.create_discriminator(self._discrim_data,
                                                   self._discrim_conditional)
      replacements[d] = g
    for c in self.conditional_inputs:
      replacements[c] = c
    self.discrim_gen = self.discrim_train.copy(replacements, shared=True)

    # Make a list of all layers in the generator and discriminator.

@@ -113,14 +114,14 @@ class GAN(TensorGraph):
    for layer in self.generator:
      add_layers_to_set(layer, gen_layers)
    discrim_layers = set()
    add_layers_to_set(self.discriminator, discrim_layers)
    add_layers_to_set(self.discrim_train, discrim_layers)
    discrim_layers -= gen_layers

    # Create submodels for training the generator and discriminator.

    gen_loss = self.create_generator_loss(self.discriminator, self.is_training)
    discrim_loss = self.create_discriminator_loss(self.discriminator,
                                                  self.is_training)
    gen_loss = self.create_generator_loss(self.discrim_gen)
    discrim_loss = self.create_discriminator_loss(self.discrim_train,
                                                  self.discrim_gen)
    self.generator_submodel = self.create_submodel(
        layers=gen_layers, loss=gen_loss)
    self.discriminator_submodel = self.create_submodel(
@@ -221,7 +222,7 @@ class GAN(TensorGraph):
    """
    raise NotImplementedError("Subclasses must implement this.")

  def create_generator_loss(self, discrim_output, is_training):
  def create_generator_loss(self, discrim_output):
    """Create the loss function for the generator.

    The default implementation is appropriate for most cases.  Subclasses can
@@ -230,21 +231,17 @@ class GAN(TensorGraph):
    Parameters
    ----------
    discrim_output: Layer
      the discriminator's output layer, which computes the probability that each
      sample is training data.
    is_training: Layer
      outputs a set of flags indicating whether each sample is actually training
      data (1) or generated data (0).
      the output from the discriminator on a batch of generated data.  This is
      its estimate of the probability that each sample is training data.

    Returns
    -------
    A Layer object that outputs the loss function to use for optimizing the
    generator.
    """
    prob = discrim_output + 1e-10
    return -layers.ReduceMean(layers.Log(prob) * (1 - is_training))
    return -layers.ReduceMean(layers.Log(discrim_output + 1e-10))

  def create_discriminator_loss(self, discrim_output, is_training):
  def create_discriminator_loss(self, discrim_output_train, discrim_output_gen):
    """Create the loss function for the discriminator.

    The default implementation is appropriate for most cases.  Subclasses can
@@ -252,20 +249,20 @@ class GAN(TensorGraph):

    Parameters
    ----------
    discrim_output: Layer
      the discriminator's output layer, which computes the probability that each
      sample is training data.
    is_training: Layer
      outputs a set of flags indicating whether each sample is actually training
      data (1) or generated data (0).
    discrim_output_train: Layer
      the output from the discriminator on a batch of generated data.  This is
      its estimate of the probability that each sample is training data.
    discrim_output_gen: Layer
      the output from the discriminator on a batch of training data.  This is
      its estimate of the probability that each sample is training data.

    Returns
    -------
    A Layer object that outputs the loss function to use for optimizing the
    discriminator.
    """
    training_data_loss = layers.Log(discrim_output + 1e-10) * is_training
    gen_data_loss = layers.Log(1 - discrim_output + 1e-10) * (1 - is_training)
    training_data_loss = layers.Log(discrim_output_train + 1e-10)
    gen_data_loss = layers.Log(1 - discrim_output_gen + 1e-10)
    return -layers.ReduceMean(training_data_loss + gen_data_loss)

  def fit_gan(self,
@@ -317,26 +314,10 @@ class GAN(TensorGraph):

        global_step = self.global_step

        # Train the discriminator on training data.
        # Train the discriminator.

        feed_dict = dict(feed_dict)
        feed_dict[self.noise_input] = self.get_noise_batch(0)
        feed_dict[self.is_training] = np.ones((self.batch_size, 1))
        self._set_empty_inputs(feed_dict, self.noise_conditional_inputs)
        discrim_error += self.fit_generator(
            [feed_dict],
            submodel=self.discriminator_submodel,
            checkpoint_interval=0)
        self.global_step = global_step

        # Train the discriminator on generated data.

        feed_dict[self.noise_input] = self.get_noise_batch(self.batch_size)
        feed_dict[self.is_training] = np.zeros((self.batch_size, 1))
        for n, c in zip(self.noise_conditional_inputs, self.conditional_inputs):
          feed_dict[n] = feed_dict[c]
        self._set_empty_inputs(feed_dict, self.data_inputs)
        self._set_empty_inputs(feed_dict, self.conditional_inputs)
        discrim_error += self.fit_generator(
            [feed_dict],
            submodel=self.discriminator_submodel,
@@ -419,7 +400,7 @@ class GAN(TensorGraph):
      noise_input = self.get_noise_batch(batch_size)
    batch = {}
    batch[self.noise_input] = noise_input
    for layer, value in zip(self.noise_conditional_inputs, conditional_inputs):
    for layer, value in zip(self.conditional_inputs, conditional_inputs):
      batch[layer] = value
    return self.predict_on_generator([batch])

@@ -484,26 +465,24 @@ class WGAN(GAN):
    super(WGAN, self).__init__(**kwargs)
    self.gradient_penalty = gradient_penalty

  def create_generator_loss(self, discrim_output, is_training):
    return layers.ReduceMean(discrim_output * (1 - is_training))
  def create_generator_loss(self, discrim_output):
    return layers.ReduceMean(discrim_output)

  def create_discriminator_loss(self, discrim_output, is_training):
    training_data_loss = discrim_output * is_training
    gen_data_loss = -discrim_output * (1 - is_training)
    gradient_penalty = GradientPenaltyLayer(discrim_output, self)
    return gradient_penalty + layers.ReduceMean(training_data_loss +
                                                gen_data_loss)
  def create_discriminator_loss(self, discrim_output_train, discrim_output_gen):
    gradient_penalty = GradientPenaltyLayer(discrim_output_train, self)
    return gradient_penalty + layers.ReduceMean(discrim_output_train -
                                                discrim_output_gen)


class GradientPenaltyLayer(layers.Layer):
  """Implements the gradient penalty loss term for WGANs."""

  def __init__(self, discrim_output, gan):
    super(GradientPenaltyLayer, self).__init__(discrim_output)
  def __init__(self, discrim_output_train, gan):
    super(GradientPenaltyLayer, self).__init__([discrim_output_train])
    self.gan = gan

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    gradients = tf.gradients(self.in_layers[0], self.gan._discrim_data)
    gradients = tf.gradients(self.in_layers[0], self.gan.data_inputs)
    norm2 = 0.0
    for g in gradients:
      g2 = tf.square(g)