Commit 29466145 authored by Milosz Grabski's avatar Milosz Grabski
Browse files

updated docstrings

parent 7f73791e
Loading
Loading
Loading
Loading
+23 −1
Original line number Diff line number Diff line
@@ -437,6 +437,12 @@ class GraphConvolutionLayer(tf.keras.layers.Layer):
    training: bool
      Should this layer be run in training mode.
      Typically decided by main model, influences things like dropout.

    Returns
    --------
    tuple(tf.Tensor,tf.Tensor,tf.Tensor)
      First and second are original input tensors
      Third is the result of convolution
    """

    ic = len(inputs)
@@ -531,6 +537,11 @@ class GraphAggregationLayer(tf.keras.layers.Layer):
    training: bool
      Should this layer be run in training mode.
      Typically decided by main model, influences things like dropout.

    Returns
    --------
    aggregation tensor: tf.Tensor
      Result of aggregation function on input convolution tensor.
    """

    i = self.d1(inputs)
@@ -621,6 +632,11 @@ class MultiGraphConvolutionLayer(tf.keras.layers.Layer):
    training: bool
      Should this layer be run in training mode.
      Typically decided by main model, influences things like dropout.

    Returns
    --------
    convolution tensor: tf.Tensor
      Result of input tensors going through convolution a number of times.
    """

    adjacency_tensor = inputs[0]
@@ -714,6 +730,12 @@ class GraphEncoderLayer(tf.keras.layers.Layer):
    training: bool
      Should this layer be run in training mode.
      Typically decided by main model, influences things like dropout.

    Returns
    --------
    encoder tensor: tf.Tensor
      Tensor that been through number of convolutions followed
      by aggregation.
    """

    output = self.multi_graph_convolution_layer(inputs)
+1 −1
Original line number Diff line number Diff line
@@ -67,7 +67,7 @@ class test_molgan_layers(unittest.TestCase):
    assert layer.edges == 5
    assert layer.dropout_rate == 0.0

  def test_graph_encoder_later(self):
  def test_graph_encoder_layer(self):
    vertices = 9
    nodes = 5
    edges = 5