Commit 32301db3 authored by Milosz Grabski's avatar Milosz Grabski
Browse files

pytest fix, change inputs in critic

parent e29b3ac2
Loading
Loading
Loading
Loading
+10 −1
Original line number Diff line number Diff line
@@ -190,6 +190,13 @@ class BasicMolGANModel(WGAN):
        shape=(self.vertices, self.vertices, self.edges))
    node_tensor = layers.Input(shape=(self.vertices, self.nodes))

    # this is actually not used, added due to generator
    # output/discriminator input mismatch
    # will be removed in the future release

    e_argmax = layers.Input(shape=(self.vertices, self.vertices))
    n_argmax = layers.Input(shape=(self.vertices))

    graph = MolGANEncoderLayer(
        units=[(128, 64), 128],
        dropout_rate=self.dropout_rate,
@@ -200,7 +207,9 @@ class BasicMolGANModel(WGAN):
    dense = layers.Dropout(self.dropout_rate)(dense)
    output = layers.Dense(units=1)(dense)

    return keras.Model(inputs=[adjacency_tensor, node_tensor], outputs=[output])
    return keras.Model(
        inputs=[adjacency_tensor, node_tensor, e_argmax, n_argmax],
        outputs=[output])

  def predict_gan_generator(self,
                            batch_size: int = 1,