Commit 16ecf372 authored by miaecle's avatar miaecle
Browse files

little fix

parent b8fe666c
Loading
Loading
Loading
Loading
+9 −3
Original line number Diff line number Diff line
@@ -49,7 +49,12 @@ class TestMolGraphs(unittest.TestCase):
        # 0 atoms of degree 4
        # 0 atoms of degree 5
        # 0 atoms of degree 6
        np.array([[0, 0], [0, 0], [0, 4], [0, 0], [0, 0], [0, 0], [0,0]]))
        # 0 atoms of degree 7
        # 0 atoms of degree 8
        # 0 atoms of degree 9
        # 0 atoms of degree 10
        np.array([[0, 0], [0, 0], [0, 4], [0, 0], [0, 0], [0, 0], [0, 0], 
                  [0, 0], [0, 0], [0, 0], [0, 0]]))

  def test_get_atom_features(self):
    """Test that the atom features are computed properly."""
@@ -168,10 +173,11 @@ class TestMolGraphs(unittest.TestCase):

    # Check that atoms are only connected to themselves.
    assert np.array_equal(
        deg_adj_lists[6], [[6, 6, 6, 6, 6, 6]])
        deg_adj_lists[10], [[10, 10, 10, 10, 10, 10, 10, 10, 10, 10]])
    assert np.array_equal(
        deg_adj_lists[1], [[1]])
    # Check that there's one atom of each degree.
    assert np.array_equal(
        null_mol.get_deg_slice(),
        [[0, 1], [1, 1], [2, 1], [3, 1], [4, 1], [5, 1], [6, 1]])
        [[0, 1], [1, 1], [2, 1], [3, 1], [4, 1], [5, 1], [6, 1],
         [7, 1], [8, 1], [9, 1], [10, 1]])
+1 −1
Original line number Diff line number Diff line
@@ -23,7 +23,7 @@ class TestGraphTopology(unittest.TestCase):
    n_atoms = 5
    n_feat = 10
    batch_size = 3
    max_deg = 6
    max_deg = 10
    min_deg = 0
    topology = GraphTopology(n_feat)