Commit a65fad8c authored by Peter Eastman's avatar Peter Eastman
Browse files

Workaround for bug in TF 1.14

parent 2b709f7b
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -137,10 +137,11 @@ class GAN(KerasModel):
      # Create learnable weights for the generators and discriminators.

      gen_alpha = layers.Variable(np.ones((1, n_generators)), dtype=tf.float32)
      gen_weights = Softmax()(gen_alpha([]))
      # We pass an input to the Variable layer to work around a bug in TF 1.14.
      gen_weights = Softmax()(gen_alpha([self.noise_input]))
      discrim_alpha = layers.Variable(
          np.ones((1, n_discriminators)), dtype=tf.float32)
      discrim_weights = Softmax()(discrim_alpha([]))
      discrim_weights = Softmax()(discrim_alpha([self.noise_input]))

      # Compute the weighted errors