Unverified Commit 494c8dee authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #981 from peastman/mixgan

Implemented MIX+GAN
parents ae754761 cee61b48
Loading
Loading
Loading
Loading
+120 −45
Original line number Diff line number Diff line
@@ -50,14 +50,33 @@ class GAN(TensorGraph):
  create_generator_loss()
  create_discriminator_loss()
  get_noise_batch()

  This class allows a GAN to have multiple generators and discriminators, a model
  known as MIX+GAN.  It is described in Arora et al., "Generalization and
  Equilibrium in Generative Adversarial Nets (GANs)" (https://arxiv.org/abs/1703.00573).
  This can lead to better models, and is especially useful for reducing mode
  collapse, since different generators can learn different parts of the
  distribution.  To use this technique, simply specify the number of generators
  and discriminators when calling the constructor.  You can then tell
  predict_gan_generator() which generator to use for predicting samples.
  """

  def __init__(self, **kwargs):
  def __init__(self, n_generators=1, n_discriminators=1, **kwargs):
    """Construct a GAN.

    This class accepts all the keyword arguments from TensorGraph.
    In addition to the parameters listed below, this class accepts all the
    keyword arguments from TensorGraph.

    Parameters
    ----------
    n_generators: int
      the number of generators to include
    n_discriminators: int
      the number of discriminators to include
    """
    super(GAN, self).__init__(use_queue=False, **kwargs)
    self.n_generators = n_generators
    self.n_discriminators = n_discriminators

    # Create the inputs.

@@ -69,40 +88,49 @@ class GAN(TensorGraph):
    for shape in self.get_conditional_input_shapes():
      self.conditional_inputs.append(layers.Feature(shape=shape))

    # Create the generator.
    # Create the generators.

    self.generator = self.create_generator(self.noise_input,
    self.generators = []
    for i in range(n_generators):
      generator = self.create_generator(self.noise_input,
                                        self.conditional_inputs)
    if not isinstance(self.generator, Sequence):
      if not isinstance(generator, Sequence):
        raise ValueError('create_generator() must return a list of Layers')
    if len(self.generator) != len(self.data_inputs):
      if len(generator) != len(self.data_inputs):
        raise ValueError(
            'The number of generator outputs must match the number of data inputs'
        )
    for g, d in zip(self.generator, self.data_inputs):
      for g, d in zip(generator, self.data_inputs):
        if g.shape != d.shape:
          raise ValueError(
              'The shapes of the generator outputs must match the shapes of the data inputs'
          )
    for g in self.generator:
      for g in generator:
        self.add_output(g)
      self.generators.append(generator)

    # Create the discriminator.
    # Create the discriminators.

    self.discrim_train = self.create_discriminator(self.data_inputs,
    self.discrim_train = []
    self.discrim_gen = []
    for i in range(n_discriminators):
      discrim_train = self.create_discriminator(self.data_inputs,
                                                self.conditional_inputs)
      self.discrim_train.append(discrim_train)

    # Make a copy of the discriminator that takes the generator's output as
      # Make a copy of the discriminator that takes each generator's output as
      # its input.

      for generator in self.generators:
        replacements = {}
    for g, d in zip(self.generator, self.data_inputs):
        for g, d in zip(generator, self.data_inputs):
          replacements[d] = g
        for c in self.conditional_inputs:
          replacements[c] = c
    self.discrim_gen = self.discrim_train.copy(replacements, shared=True)
        discrim_gen = discrim_train.copy(replacements, shared=True)
        self.discrim_gen.append(discrim_gen)

    # Make a list of all layers in the generator and discriminator.
    # Make a list of all layers in the generators and discriminators.

    def add_layers_to_set(layer, layers):
      if layer not in layers:
@@ -111,21 +139,63 @@ class GAN(TensorGraph):
          add_layers_to_set(i, layers)

    gen_layers = set()
    for layer in self.generator:
    for generator in self.generators:
      for layer in generator:
        add_layers_to_set(layer, gen_layers)
    discrim_layers = set()
    add_layers_to_set(self.discrim_train, discrim_layers)
    for discriminator in self.discrim_train:
      add_layers_to_set(discriminator, discrim_layers)
    discrim_layers -= gen_layers

    # Create submodels for training the generator and discriminator.
    # Compute the loss functions.

    gen_losses = [self.create_generator_loss(d) for d in self.discrim_gen]
    discrim_losses = []
    for i in range(n_discriminators):
      for j in range(n_generators):
        discrim_losses.append(
            self.create_discriminator_loss(
                self.discrim_train[i], self.discrim_gen[i * n_generators + j]))
    if n_generators == 1 and n_discriminators == 1:
      total_gen_loss = gen_losses[0]
      total_discrim_loss = discrim_losses[0]
    else:
      # Create learnable weights for the generators and discriminators.

      gen_alpha = layers.Variable(np.ones((1, n_generators)))
      gen_weights = layers.SoftMax(gen_alpha)
      discrim_alpha = layers.Variable(np.ones((1, n_discriminators)))
      discrim_weights = layers.SoftMax(discrim_alpha)

      # Compute the weighted errors

      weight_products = layers.Reshape(
          (n_generators * n_discriminators,),
          in_layers=layers.Reshape(
              (n_discriminators,
               1), in_layers=discrim_weights) * layers.Reshape(
                   (1, n_generators), in_layers=gen_weights))
      total_gen_loss = layers.WeightedError((layers.Stack(gen_losses, axis=0),
                                             weight_products))
      total_discrim_loss = layers.WeightedError((layers.Stack(
          discrim_losses, axis=0), weight_products))
      gen_layers.add(gen_alpha)
      discrim_layers.add(gen_alpha)
      discrim_layers.add(discrim_alpha)

      # Add an entropy term to the loss.

      entropy = -(
          layers.ReduceSum(layers.Log(gen_weights)) / n_generators +
          layers.ReduceSum(layers.Log(discrim_weights)) / n_discriminators)
      total_discrim_loss += entropy

    # Create submodels for training the generators and discriminators.

    gen_loss = self.create_generator_loss(self.discrim_gen)
    discrim_loss = self.create_discriminator_loss(self.discrim_train,
                                                  self.discrim_gen)
    self.generator_submodel = self.create_submodel(
        layers=gen_layers, loss=gen_loss)
        layers=gen_layers, loss=total_gen_loss)
    self.discriminator_submodel = self.create_submodel(
        layers=discrim_layers, loss=discrim_loss)
        layers=discrim_layers, loss=total_discrim_loss)

  def get_noise_input_shape(self):
    """Get the shape of the generator's noise input layer.
@@ -370,7 +440,8 @@ class GAN(TensorGraph):
  def predict_gan_generator(self,
                            batch_size=1,
                            noise_input=None,
                            conditional_inputs=[]):
                            conditional_inputs=[],
                            generator_index=0):
    """Use the GAN to generate a batch of samples.

    Parameters
@@ -386,6 +457,9 @@ class GAN(TensorGraph):
    conditional_inputs: list of arrays
      the values to use for all conditional inputs.  This must be specified if
      the GAN has any conditional inputs.
    generator_index: int
      the index of the generator (between 0 and n_generators-1) to use for
      generating the samples.

    Returns
    -------
@@ -402,7 +476,8 @@ class GAN(TensorGraph):
    batch[self.noise_input] = noise_input
    for layer, value in zip(self.conditional_inputs, conditional_inputs):
      batch[layer] = value
    return self.predict_on_generator([batch])
    return self.predict_on_generator(
        [batch], outputs=self.generators[generator_index])

  def _set_empty_inputs(self, feed_dict, layers):
    """Set entries in a feed dict corresponding to a batch size of 0."""
+37 −17
Original line number Diff line number Diff line
@@ -20,11 +20,6 @@ def generate_data(gan, batches, batch_size):
    yield batch


class TestGAN(unittest.TestCase):

  def test_cgan(self):
    """Test fitting a conditional GAN."""

class ExampleGAN(dc.models.GAN):

  def get_noise_input_shape(self):
@@ -45,6 +40,12 @@ class TestGAN(unittest.TestCase):
    dense = layers.Dense(10, in_layers=discrim_in, activation_fn=tf.nn.relu)
    return layers.Dense(1, in_layers=dense, activation_fn=tf.sigmoid)


class TestGAN(unittest.TestCase):

  def test_cgan(self):
    """Test fitting a conditional GAN."""

    gan = ExampleGAN(learning_rate=0.003)
    gan.fit_gan(
        generate_data(gan, 5000, 100),
@@ -59,6 +60,25 @@ class TestGAN(unittest.TestCase):
    assert abs(np.mean(deltas)) < 1.0
    assert np.std(deltas) > 1.0

  def test_mix_gan(self):
    """Test a GAN with multiple generators and discriminators."""

    gan = ExampleGAN(n_generators=2, n_discriminators=2, learning_rate=0.003)
    gan.fit_gan(
        generate_data(gan, 5000, 100),
        generator_steps=0.5,
        checkpoint_interval=0)

    # See if it has done a plausible job of learning the distribution.

    means = 10 * np.random.random([1000, 1])
    for i in range(2):
      values = gan.predict_gan_generator(
          conditional_inputs=[means], generator_index=i)
      deltas = values - means
      assert abs(np.mean(deltas)) < 1.0
      assert np.std(deltas) > 1.0

  @flaky
  def test_wgan(self):
    """Test fitting a conditional WGAN."""