Commit 34a1b60d authored by peastman's avatar peastman
Browse files

Added documentation on MIX+GAN

parent 7a179a64
Loading
Loading
Loading
Loading
+18 −1
Original line number Diff line number Diff line
@@ -50,12 +50,29 @@ class GAN(TensorGraph):
  create_generator_loss()
  create_discriminator_loss()
  get_noise_batch()

  This class allows a GAN to have multiple generators and discriminators, a model
  known as MIX+GAN.  It is described in Arora et al., "Generalization and
  Equilibrium in Generative Adversarial Nets (GANs)" (https://arxiv.org/abs/1703.00573).
  This can lead to better models, and is especially useful for reducing mode
  collapse, since different generators can learn different parts of the
  distribution.  To use this technique, simply specify the number of generators
  and discriminators when calling the constructor.  You can then tell
  predict_gan_generator() which generator to use for predicting samples.
  """

  def __init__(self, n_generators=1, n_discriminators=1, **kwargs):
    """Construct a GAN.

    This class accepts all the keyword arguments from TensorGraph.
    In addition to the parameters listed below, this class accepts all the
    keyword arguments from TensorGraph.

    Parameters
    ----------
    n_generators: int
      the number of generators to include
    n_discriminators: int
      the number of discriminators to include
    """
    super(GAN, self).__init__(use_queue=False, **kwargs)