Commit dd2a379c authored by atreyamaj's avatar atreyamaj
Browse files

Update tests

parent 36036884
Loading
Loading
Loading
Loading
+8 −10
Original line number Diff line number Diff line
@@ -23,11 +23,10 @@ class TestMATFeaturizer(unittest.TestCase):
    featurizer = MATFeaturizer()
    out = featurizer.featurize(self.mol)
    assert (type(out) == np.ndarray)
    assert (out.shape == (1,))
    assert (out[0].node_features.shape == (3, 36))
    assert (out[0].adjacency_matrix.shape == (3, 3))
    assert (out[0].distance_matrix.shape == (3, 3))
    expected_node_features = np.array([[[
    expected_node_features = np.array([[
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
    ], [
@@ -36,12 +35,11 @@ class TestMATFeaturizer(unittest.TestCase):
    ], [
        0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.
    ]]])
    expected_adjacency_matrix = np.array([[[0., 0., 0.], [0., 0., 1.],
                                           [0., 1., 0.]]])
    expected_distance_matrix = np.array([[[1.e+06, 1.e+06,
    ]])
    expected_adj_matrix = np.array([[0., 0., 0.], [0., 0., 1.], [0., 1., 0.]])
    expected_dist_matrix = np.array([[1.e+06, 1.e+06,
                                      1.e+06], [1.e+06, 0.e+00, 1.e+00],
                                          [1.e+06, 1.e+00, 0.e+00]]])
                                     [1.e+06, 1.e+00, 0.e+00]])
    assert (np.array_equal(out[0].node_features, expected_node_features))
    assert (np.array_equal(out[0].adjacency_matrix, expected_adjacency_matrix))
    assert (np.array_equal(out[0].distance_matrix, expected_distance_matrix))
    assert (np.array_equal(out[0].adjacency_matrix, expected_adj_matrix))
    assert (np.array_equal(out[0].distance_matrix, expected_dist_matrix))