Commit 109d8ed2 authored by Milosz Grabski's avatar Milosz Grabski
Browse files

Update molgan.py

parent 2f4cd252
Loading
Loading
Loading
Loading
+8 −12
Original line number Diff line number Diff line
@@ -108,9 +108,10 @@ class BasicMolGANModel(WGAN):
    List
        List of shapes used as an input for distriminator.
    """
    return [(self.vertices, self.vertices, self.edges),
            (self.vertices, self.nodes), (self.vertices,
                                          self.vertices), (self.vertices)]
    return [
        (self.vertices, self.vertices, self.edges),
        (self.vertices, self.nodes),
    ]

  def create_generator(self) -> keras.Model:
    """
@@ -190,13 +191,6 @@ class BasicMolGANModel(WGAN):
        shape=(self.vertices, self.vertices, self.edges))
    node_tensor = layers.Input(shape=(self.vertices, self.nodes))

    # this is actually not used, added due to generator
    # output/discriminator input mismatch
    # will be removed in the future release

    e_argmax = layers.Input(shape=(self.vertices, self.vertices))
    n_argmax = layers.Input(shape=(self.vertices))

    graph = MolGANEncoderLayer(
        units=[(128, 64), 128],
        dropout_rate=self.dropout_rate,
@@ -208,8 +202,10 @@ class BasicMolGANModel(WGAN):
    output = layers.Dense(units=1)(dense)

    return keras.Model(
        inputs=[adjacency_tensor, node_tensor, e_argmax, n_argmax],
        outputs=[output])
        inputs=[
            adjacency_tensor,
            node_tensor,
        ], outputs=[output])

  def predict_gan_generator(self,
                            batch_size: int = 1,