Commit 2bf077b7 authored by Atreya Majumdar's avatar Atreya Majumdar Committed by atreyamaj
Browse files

Fixed shapes, layers work in sequence

parent f292889d
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -243,4 +243,8 @@ class MATFeaturizer(MolecularFeaturizer):
    adjacency_matrix = self.pad_sequence(adjacency_matrix)
    distance_matrix = self.pad_sequence(distance_matrix)

    node_features = np.expand_dims(node_features, 0)
    adjacency_matrix = np.expand_dims(adjacency_matrix, 0)
    distance_matrix = np.expand_dims(distance_matrix, 0)

    return MATEncoding(node_features, adjacency_matrix, distance_matrix)
+10 −10
Original line number Diff line number Diff line
@@ -24,10 +24,10 @@ class TestMATFeaturizer(unittest.TestCase):
    out = featurizer.featurize(self.mol)
    assert (type(out) == np.ndarray)
    assert (out.shape == (1,))
    assert (out[0].node_features.shape == (3, 36))
    assert (out[0].adjacency_matrix.shape == (3, 3))
    assert (out[0].distance_matrix.shape == (3, 3))
    expected_node_features = np.array([[
    assert (out[0].node_features.shape == (1, 3, 36))
    assert (out[0].adjacency_matrix.shape == (1, 3, 3))
    assert (out[0].distance_matrix.shape == (1, 3, 3))
    expected_node_features = np.array([[[
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
    ], [
@@ -36,12 +36,12 @@ class TestMATFeaturizer(unittest.TestCase):
    ], [
        0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.
    ]])
    expected_adjacency_matrix = np.array([[0., 0., 0.], [0., 0., 1.],
                                          [0., 1., 0.]])
    expected_distance_matrix = np.array([[1.e+06, 1.e+06,
    ]]])
    expected_adjacency_matrix = np.array([[[0., 0., 0.], [0., 0., 1.],
                                           [0., 1., 0.]]])
    expected_distance_matrix = np.array([[[1.e+06, 1.e+06,
                                           1.e+06], [1.e+06, 0.e+00, 1.e+00],
                                         [1.e+06, 1.e+00, 0.e+00]])
                                          [1.e+06, 1.e+00, 0.e+00]]])
    assert (np.array_equal(out[0].node_features, expected_node_features))
    assert (np.array_equal(out[0].adjacency_matrix, expected_adjacency_matrix))
    assert (np.array_equal(out[0].distance_matrix, expected_distance_matrix))