Commit a626d8f5 authored by Milosz Grabski's avatar Milosz Grabski
Browse files

MultiGraphConvolution example

parent 3b466cc2
Loading
Loading
Loading
Loading
+15 −0
Original line number Diff line number Diff line
@@ -607,6 +607,21 @@ class MultiGraphConvolutionLayer(tf.keras.layers.Layer):
  It simplifies the overall framework, but might be moved to
  GraphEncoderLayer in the future in order to reduce number of layers.

  Example
  --------
  vertices = 9
  nodes = 5
  edges = 5
  units = 128

  layer_1 = MultiGraphConvolutionLayer(units=(128,64))
  layer_2 = GraphAggregationLayer(units=128)
  adjacency_tensor= layers.Input(shape=(vertices, vertices, edges))
  node_tensor = layers.Input(shape=(vertices,nodes))
  hidden = layer_1([adjacency_tensor,node_tensor])
  output = layer_2(hidden)
  model = keras.Model(inputs=[adjacency_tensor,node_tensor], outputs=[output])

  References
  ----------
  .. [1] Nicola De Cao et al. "MolGAN: An implicit generative model