Unverified Commit 7c514756 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2111 from peastman/gan

Fixes to GAN class
parents 6bdba2e0 7d925e7b
Loading
Loading
Loading
Loading
+6 −8
Original line number Diff line number Diff line
@@ -83,13 +83,11 @@ class GAN(KerasModel):
    self.data_input_layers = []
    for shape in self.get_data_input_shapes():
      self.data_input_layers.append(Input(shape=shape))
    self.data_inputs = [i.experimental_ref() for i in self.data_input_layers]
    self.data_inputs = [i.ref() for i in self.data_input_layers]
    self.conditional_input_layers = []
    for shape in self.get_conditional_input_shapes():
      self.conditional_input_layers.append(Input(shape=shape))
    self.conditional_inputs = [
        i.experimental_ref() for i in self.conditional_input_layers
    ]
    self.conditional_inputs = [i.ref() for i in self.conditional_input_layers]

    # Create the generators.

@@ -344,9 +342,9 @@ class GAN(KerasModel):

      inputs = [self.get_noise_batch(self.batch_size)]
      for input in self.data_input_layers:
        inputs.append(feed_dict[input.experimental_ref()])
        inputs.append(feed_dict[input.ref()])
      for input in self.conditional_input_layers:
        inputs.append(feed_dict[input.experimental_ref()])
        inputs.append(feed_dict[input.ref()])
      discrim_error += self.fit_generator(
          [(inputs, [], [])],
          variables=self.discrim_variables,
@@ -373,7 +371,7 @@ class GAN(KerasModel):
      # Write checkpoints and report progress.

      if discrim_average_steps == checkpoint_interval:
        self._exec_with_session(lambda: manager.save())
        manager.save()
        discrim_loss = discrim_error / max(1, discrim_average_steps)
        gen_loss = gen_error / max(1, gen_average_steps)
        print(
@@ -393,7 +391,7 @@ class GAN(KerasModel):
        print(
            'Ending global_step %d: generator average loss %g, discriminator average loss %g'
            % (global_step, gen_loss, discrim_loss))
      self._exec_with_session(lambda: manager.save())
      manager.save()
      time2 = time.time()
      print("TIMING: model fitting took %0.3f s" % (time2 - time1))

+1 −4
Original line number Diff line number Diff line
@@ -128,10 +128,7 @@ class TestGAN(unittest.TestCase):
    # it far too much.

    gan = ExampleWGAN(learning_rate=0.01, gradient_penalty=0.1)
    gan.fit_gan(
        generate_data(gan, 1000, 100),
        generator_steps=0.1,
        checkpoint_interval=0)
    gan.fit_gan(generate_data(gan, 1000, 100), generator_steps=0.1)

    # See if it has done a plausible job of learning the distribution.