Commit 836b1b28 authored by peastman's avatar peastman
Browse files

Implemented WGAN

parent bb7166de
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -32,4 +32,4 @@ from deepchem.models.tensorgraph.models.graph_models import WeaveTensorGraph, DT
from deepchem.models.tensorgraph.models.symmetry_function_regression import BPSymmetryFunctionRegression, ANIRegression

from deepchem.models.tensorgraph.models.seqtoseq import SeqToSeq
from deepchem.models.tensorgraph.models.gan import GAN
from deepchem.models.tensorgraph.models.gan import GAN, WGAN
+104 −21
Original line number Diff line number Diff line
@@ -92,17 +92,14 @@ class GAN(TensorGraph):

    # Create the discriminator.

    discrim_data = []
    self._discrim_data = []
    for g, d in zip(self.generator, self.data_inputs):
      discrim_data.append(layers.Concat([g, d], axis=0))
    discrim_conditional = []
      self._discrim_data.append(layers.Concat([g, d], axis=0))
    self._discrim_conditional = []
    for n, c in zip(self.noise_conditional_inputs, self.conditional_inputs):
      discrim_conditional.append(layers.Concat([n, c], axis=0))
    self.discriminator = self.create_discriminator(discrim_data,
                                                   discrim_conditional)

    #if self.discriminator.shape != (None,):
    #  raise ValueError('Incorrect shape for discriminator output')
      self._discrim_conditional.append(layers.Concat([n, c], axis=0))
    self.discriminator = self.create_discriminator(self._discrim_data,
                                                   self._discrim_conditional)

    # Make a list of all layers in the generator and discriminator.

@@ -349,6 +346,7 @@ class GAN(TensorGraph):

        # Train the generator.

        if generator_steps > 0.0:
          gen_train_fraction += generator_steps
          while gen_train_fraction >= 1.0:
            feed_dict[self.noise_input] = self.get_noise_batch(self.batch_size)
@@ -365,8 +363,8 @@ class GAN(TensorGraph):

        if discrim_average_steps == checkpoint_interval:
          saver.save(self.session, self.save_file, global_step=self.global_step)
          discrim_loss = discrim_error / discrim_average_steps
          gen_loss = gen_error / gen_average_steps
          discrim_loss = discrim_error / max(1, discrim_average_steps)
          gen_loss = gen_error / max(1, gen_average_steps)
          print(
              'Ending global_step %d: generator average loss %g, discriminator average loss %g'
              % (self.global_step, gen_loss, discrim_loss))
@@ -431,3 +429,88 @@ class GAN(TensorGraph):
      shape = list(layer.shape)
      shape[0] = 0
      feed_dict[layer] = np.zeros(shape)


class WGAN(GAN):
  """Implements Wasserstein Generative Adversarial Networks.

  This class implements Wasserstein Generative Adversarial Networks (WGANs) as
  described in Arjovsky et al., "Wasserstein GAN" (https://arxiv.org/abs/1701.07875).
  A WGAN is conceptually rather different from a conventional GAN, but in
  practical terms very similar.  It reinterprets the discriminator (often called
  the "critic" in this context) as learning an approximation to the Earth Mover
  distance between the training and generated distributions.  The generator is
  then trained to minimize that distance.  In practice, this just means using
  slightly different loss functions for training the generator and discriminator.

  WGANs have theoretical advantages over conventional GANs, and they often work
  better in practice.  In addition, the discriminator's loss function can be
  directly interpreted as a measure of the quality of the model.  That is an
  advantage over conventional GANs, where the loss does not directly convey
  information about the quality of the model.

  The theory WGANs are based on requires the discriminator's gradient to be
  bounded.  The original paper achieved this by clipping its weights.  This
  class instead does it by adding a penalty term to the discriminator's loss, as
  described in https://arxiv.org/abs/1704.00028.  This is sometimes found to
  produce better results.

  There are a few other practical differences between GANs and WGANs.  In a
  conventional GAN, the discriminator's output must be between 0 and 1 so it can
  be interpreted as a probability.  In a WGAN, it should produce an unbounded
  output that can be interpreted as a distance.

  When training a WGAN, you also should usually use a smaller value for
  generator_steps.  Conventional GANs rely on keeping the generator and
  discriminator "in balance" with each other.  If the discriminator ever gets
  too good, it becomes impossible for the generator to fool it and training
  stalls.  WGANs do not have this problem, and in fact the better the
  discriminator is, the easier it is for the generator to improve.  It therefore
  usually works best to perform several training steps on the discriminator for
  each training step on the generator.
  """

  def __init__(self, gradient_penalty=10.0, **kwargs):
    """Construct a WGAN.

    In addition to the following, this class accepts all the keyword arguments
    from TensorGraph.

    Parameters
    ----------
    gradient_penalty: float
      the magnitude of the gradient penalty loss
    """
    super(WGAN, self).__init__(**kwargs)
    self.gradient_penalty = gradient_penalty

  def create_generator_loss(self, discrim_output, is_training):
    return layers.ReduceMean(discrim_output * (1 - is_training))

  def create_discriminator_loss(self, discrim_output, is_training):
    training_data_loss = discrim_output * is_training
    gen_data_loss = -discrim_output * (1 - is_training)
    gradient_penalty = GradientPenaltyLayer(discrim_output, self)
    return gradient_penalty + layers.ReduceMean(training_data_loss +
                                                gen_data_loss)


class GradientPenaltyLayer(layers.Layer):
  """Implements the gradient penalty loss term for WGANs."""

  def __init__(self, discrim_output, gan):
    super(GradientPenaltyLayer, self).__init__(discrim_output)
    self.gan = gan

  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    gradients = tf.gradients(self.in_layers[0], self.gan._discrim_data)
    norm2 = 0.0
    for g in gradients:
      g2 = tf.square(g)
      dims = len(g.shape)
      if dims > 1:
        g2 = tf.reduce_sum(g2, axis=list(range(1, dims)))
      norm2 += g2
    penalty = tf.square(tf.sqrt(norm2) - 1.0)
    self.out_tensor = self.gan.gradient_penalty * tf.reduce_mean(penalty)
    return self.out_tensor
+57 −15
Original line number Diff line number Diff line
@@ -5,12 +5,26 @@ import unittest
from deepchem.models.tensorgraph import layers


def generate_batch(batch_size):
  """Draw training data from a Gaussian distribution, where the mean  is a conditional input."""
  means = 10 * np.random.random([batch_size, 1])
  values = np.random.normal(means, scale=2.0)
  return means, values


def generate_data(gan, batches, batch_size):
  for i in range(batches):
    means, values = generate_batch(batch_size)
    batch = {gan.data_inputs[0]: values, gan.conditional_inputs[0]: means}
    yield batch


class TestGAN(unittest.TestCase):

  def test_cgan(self):
    """Test fitting a conditional GAN."""

    class CGAN(dc.models.GAN):
    class ExampleGAN(dc.models.GAN):

      def get_noise_input_shape(self):
        return (None, 2)
@@ -30,29 +44,57 @@ class TestGAN(unittest.TestCase):
        dense = layers.Dense(10, in_layers=discrim_in, activation_fn=tf.nn.relu)
        return layers.Dense(1, in_layers=dense, activation_fn=tf.sigmoid)

    gan = CGAN(learning_rate=0.003)
    gan = ExampleGAN(learning_rate=0.003)
    gan.fit_gan(
        generate_data(gan, 5000, 100),
        generator_steps=0.5,
        checkpoint_interval=0)

    # The training data is drawn from a Gaussian distribution, where the mean
    # is a conditional input.
    # See if it has done a plausible job of learning the distribution.

    def generate_batch(batch_size):
      means = 10 * np.random.random([batch_size, 1])
      values = np.random.normal(means, scale=2.0)
      return means, values
    means = 10 * np.random.random([1000, 1])
    values = gan.predict_gan_generator(conditional_inputs=[means])
    deltas = values - means
    assert abs(np.mean(deltas)) < 1.0
    assert np.std(deltas) > 1.0

    def generate_data(batches, batch_size):
      for i in range(batches):
        means, values = generate_batch(batch_size)
        batch = {gan.data_inputs[0]: values, gan.conditional_inputs[0]: means}
        yield batch
  def test_wgan(self):
    """Test fitting a conditional WGAN."""

    class ExampleWGAN(dc.models.WGAN):

      def get_noise_input_shape(self):
        return (None, 2)

      def get_data_input_shapes(self):
        return [(None, 1)]

      def get_conditional_input_shapes(self):
        return [(None, 1)]

      def create_generator(self, noise_input, conditional_inputs):
        gen_in = layers.Concat([noise_input] + conditional_inputs)
        return [layers.Dense(1, in_layers=gen_in)]

      def create_discriminator(self, data_inputs, conditional_inputs):
        discrim_in = layers.Concat(data_inputs + conditional_inputs)
        dense = layers.Dense(10, in_layers=discrim_in, activation_fn=tf.nn.relu)
        return layers.Dense(1, in_layers=dense)

    # We have to set the gradient penalty very small because the generator's
    # output is only a single number, so the default penalty would constrain
    # it far too much.

    gan = ExampleWGAN(learning_rate=0.003, gradient_penalty=0.1)
    gan.fit_gan(
        generate_data(5000, 100), generator_steps=0.5, checkpoint_interval=0)
        generate_data(gan, 10000, 100),
        generator_steps=0.1,
        checkpoint_interval=0)

    # See if it has done a plausible job of learning the distribution.

    means = 10 * np.random.random([1000, 1])
    values = gan.predict_gan_generator(conditional_inputs=[means])
    deltas = values - means
    assert np.mean(deltas) < 1.0
    assert abs(np.mean(deltas)) < 1.0
    assert np.std(deltas) > 1.0