Commit cee61b48 authored by peastman's avatar peastman
Browse files

yapf

parent 34a1b60d
Loading
Loading
Loading
Loading
+20 −7
Original line number Diff line number Diff line
@@ -75,6 +75,8 @@ class GAN(TensorGraph):
      the number of discriminators to include
    """
    super(GAN, self).__init__(use_queue=False, **kwargs)
    self.n_generators = n_generators
    self.n_discriminators = n_discriminators

    # Create the inputs.

@@ -151,8 +153,9 @@ class GAN(TensorGraph):
    discrim_losses = []
    for i in range(n_discriminators):
      for j in range(n_generators):
        discrim_losses.append(self.create_discriminator_loss(self.discrim_train[i],
                                                      self.discrim_gen[i*n_generators+j]))
        discrim_losses.append(
            self.create_discriminator_loss(
                self.discrim_train[i], self.discrim_gen[i * n_generators + j]))
    if n_generators == 1 and n_discriminators == 1:
      total_gen_loss = gen_losses[0]
      total_discrim_loss = discrim_losses[0]
@@ -166,16 +169,25 @@ class GAN(TensorGraph):

      # Compute the weighted errors

      weight_products = layers.Reshape((n_generators*n_discriminators,), in_layers=layers.Reshape((n_discriminators, 1), in_layers=discrim_weights) * layers.Reshape((1, n_generators), in_layers=gen_weights))
      total_gen_loss = layers.WeightedError((layers.Stack(gen_losses, axis=0), weight_products))
      total_discrim_loss = layers.WeightedError((layers.Stack(discrim_losses, axis=0), weight_products))
      weight_products = layers.Reshape(
          (n_generators * n_discriminators,),
          in_layers=layers.Reshape(
              (n_discriminators,
               1), in_layers=discrim_weights) * layers.Reshape(
                   (1, n_generators), in_layers=gen_weights))
      total_gen_loss = layers.WeightedError((layers.Stack(gen_losses, axis=0),
                                             weight_products))
      total_discrim_loss = layers.WeightedError((layers.Stack(
          discrim_losses, axis=0), weight_products))
      gen_layers.add(gen_alpha)
      discrim_layers.add(gen_alpha)
      discrim_layers.add(discrim_alpha)

      # Add an entropy term to the loss.

      entropy = -(layers.ReduceSum(layers.Log(gen_weights))/n_generators + layers.ReduceSum(layers.Log(discrim_weights))/n_discriminators)
      entropy = -(
          layers.ReduceSum(layers.Log(gen_weights)) / n_generators +
          layers.ReduceSum(layers.Log(discrim_weights)) / n_discriminators)
      total_discrim_loss += entropy

    # Create submodels for training the generators and discriminators.
@@ -464,7 +476,8 @@ class GAN(TensorGraph):
    batch[self.noise_input] = noise_input
    for layer, value in zip(self.conditional_inputs, conditional_inputs):
      batch[layer] = value
    return self.predict_on_generator([batch], outputs=self.generators[generator_index])
    return self.predict_on_generator(
        [batch], outputs=self.generators[generator_index])

  def _set_empty_inputs(self, feed_dict, layers):
    """Set entries in a feed dict corresponding to a batch size of 0."""
+2 −1
Original line number Diff line number Diff line
@@ -73,7 +73,8 @@ class TestGAN(unittest.TestCase):

    means = 10 * np.random.random([1000, 1])
    for i in range(2):
      values = gan.predict_gan_generator(conditional_inputs=[means], generator_index=i)
      values = gan.predict_gan_generator(
          conditional_inputs=[means], generator_index=i)
      deltas = values - means
      assert abs(np.mean(deltas)) < 1.0
      assert np.std(deltas) > 1.0