Commit 55d1ff20 authored by Peter Eastman's avatar Peter Eastman
Browse files

Updated GAN example notebook

parent 9c51193d
Loading
Loading
Loading
Loading
+20 −10
Original line number Diff line number Diff line
@@ -95,9 +95,10 @@ class GAN(KerasModel):
    for i in range(n_generators):
      generator = self.create_generator()
      self.generators.append(generator)
      self.gen_variables += generator.trainable_variables
      generator_outputs.append(
          generator([self.noise_input] + self.conditional_inputs))
          generator(
              _list_or_tensor([self.noise_input] + self.conditional_inputs)))
      self.gen_variables += generator.trainable_variables

    # Create the discriminators.

@@ -108,14 +109,16 @@ class GAN(KerasModel):
    for i in range(n_discriminators):
      discriminator = self.create_discriminator()
      self.discriminators.append(discriminator)
      self.discrim_variables += discriminator.trainable_variables
      discrim_train_outputs.append(
          discriminator(self.data_inputs + self.conditional_inputs))
          discriminator(
              _list_or_tensor(self.data_inputs + self.conditional_inputs)))
      for gen_output in generator_outputs:
        if isinstance(gen_output, tf.Tensor):
          gen_output = [gen_output]
        discrim_gen_outputs.append(
            discriminator(gen_output + self.conditional_inputs))
            discriminator(
                _list_or_tensor(gen_output + self.conditional_inputs)))
      self.discrim_variables += discriminator.trainable_variables

    # Compute the loss functions.

@@ -316,8 +319,8 @@ class GAN(KerasModel):
    gen_average_steps = 0
    time1 = time.time()
    if checkpoint_interval > 0:
      manager = tf.train.CheckpointManager(
          self._get_tf('Checkpoint'), self.model_dir, max_checkpoints_to_keep)
      manager = tf.train.CheckpointManager(self._checkpoint, self.model_dir,
                                           max_checkpoints_to_keep)
    for feed_dict in batches:
      # Every call to fit_generator() will increment global_step, but we only
      # want it to get incremented once for the entire batch, so record the
@@ -367,7 +370,7 @@ class GAN(KerasModel):
        gen_loss = gen_error / max(1, gen_average_steps)
        print(
            'Ending global_step %d: generator average loss %g, discriminator average loss %g'
            % (self.global_step, gen_loss, discrim_loss))
            % (global_step, gen_loss, discrim_loss))
        discrim_error = 0.0
        gen_error = 0.0
        discrim_average_steps = 0
@@ -381,7 +384,7 @@ class GAN(KerasModel):
        gen_loss = gen_error / gen_average_steps
        print(
            'Ending global_step %d: generator average loss %g, discriminator average loss %g'
            % (self.global_step, gen_loss, discrim_loss))
            % (global_step, gen_loss, discrim_loss))
      self._exec_with_session(lambda: manager.save())
      time2 = time.time()
      print("TIMING: model fitting took %0.3f s" % (time2 - time1))
@@ -424,7 +427,8 @@ class GAN(KerasModel):
    inputs = [noise_input]
    inputs += conditional_inputs
    inputs = [i.astype(np.float32) for i in inputs]
    pred = self.generators[generator_index](inputs, training=False)
    pred = self.generators[generator_index](
        _list_or_tensor(inputs), training=False)
    if tf.executing_eagerly():
      pred = pred.numpy()
    else:
@@ -432,6 +436,12 @@ class GAN(KerasModel):
    return pred


def _list_or_tensor(inputs):
  if len(inputs) == 1:
    return inputs[0]
  return inputs


class WGAN(GAN):
  """Implements Wasserstein Generative Adversarial Networks.

+49 −35

File changed.

Preview size limit exceeded, changes collapsed.