Commit 016bdfeb authored by cfperez's avatar cfperez Committed by GitHub
Browse files

Fix asserts statements

Cannot use `()` in asserts. Yields a SyntaxWarning because it will always return True. The arguments must be on the same line or end with a `\` before the newline.
parent 336d8167
Loading
Loading
Loading
Loading
+7 −7
Original line number Diff line number Diff line
@@ -133,7 +133,7 @@ def graph_gather(atoms, membership_placeholder, batch_size):
  """

  # WARNING: Does not work for Batch Size 1! If batch_size = 1, then use reduce_sum!
  assert (batch_size > 1, "graph_gather requires batches larger than 1")
  assert batch_size > 1, "graph_gather requires batches larger than 1"

  # Obtain the partitions for each of the molecules
  activated_par = tf.dynamic_partition(atoms, membership_placeholder,
@@ -277,8 +277,8 @@ class GraphConv(Layer):
  def get_output_shape_for(self, input_shape):
    """Output tensor shape produced by this layer."""
    atom_features_shape = input_shape[0]
    assert (len(atom_features_shape) == 2,
            "MolConv only takes 2 dimensional tensors for x")
    assert len(atom_features_shape) == 2, \
            "MolConv only takes 2 dimensional tensors for x"
    n_atoms = atom_features_shape[0]
    return (n_atoms, self.nb_filter)

@@ -365,8 +365,8 @@ class GraphGather(Layer):
    atom_features_shape = input_shape[0]
    membership_shape = input_shape[2]

    assert (len(atom_features_shape) == 2,
            "GraphGather only takes 2 dimensional tensors")
    assert len(atom_features_shape) == 2, \
            "GraphGather only takes 2 dimensional tensors"
    n_feat = atom_features_shape[1]

    return (self.batch_size, n_feat)
@@ -436,8 +436,8 @@ class GraphPool(Layer):
    # Extract nodes
    atom_features_shape = input_shape[0]

    assert (len(atom_features_shape) == 2,
            "GraphPool only takes 2 dimensional tensors")
    assert len(atom_features_shape) == 2, \
            "GraphPool only takes 2 dimensional tensors"
    return atom_features_shape

  def call(self, x, mask=None):