Commit 36036884 authored by atreyamaj's avatar atreyamaj
Browse files

Update

parent 3736c5fe
......@@ -30,7 +30,6 @@ class MATEncoding:
class MATFeaturizer(MolecularFeaturizer):
"""
This class is a featurizer for the Molecule Attention Transformer [1]_.
The featurizer accepts an RDKit Molecule, and a boolean (one_hot_formal_charge) as arguments.
The returned value is a numpy array which consists of molecular graph descriptions:
- Node Features
- Adjacency Matrix
......@@ -128,7 +127,7 @@ class MATFeaturizer(MolecularFeaturizer):
"""
return np.array([self.atom_features(atom) for atom in mol.GetAtoms()])
def add_dummy_node(
def _add_dummy_node(
self, node_features: np.ndarray, adj_matrix: np.ndarray,
dist_matrix: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
......@@ -167,7 +166,7 @@ class MATFeaturizer(MolecularFeaturizer):
return node_features, adj_matrix, dist_matrix
def pad_array(self, array: np.ndarray, shape: Any) -> np.ndarray:
def _pad_array(self, array: np.ndarray, shape: Any) -> np.ndarray:
"""
Pads an array to the desired shape.
......@@ -188,7 +187,7 @@ class MATFeaturizer(MolecularFeaturizer):
result[slices] = array
return result
def pad_sequence(self, sequence: np.ndarray) -> np.ndarray:
def _pad_sequence(self, sequence: np.ndarray) -> np.ndarray:
"""
Pads a given sequence using the pad_array function.
......@@ -204,7 +203,7 @@ class MATFeaturizer(MolecularFeaturizer):
"""
shapes = np.stack([np.array(t.shape) for t in sequence])
max_shape = tuple(np.max(shapes, axis=0))
return np.stack([self.pad_array(t, shape=max_shape) for t in sequence])
return np.stack([self._pad_array(t, shape=max_shape) for t in sequence])
def _featurize(self, datapoint: RDKitMol, **kwargs) -> np.ndarray:
"""
......@@ -236,11 +235,11 @@ class MATFeaturizer(MolecularFeaturizer):
adjacency_matrix = Chem.GetAdjacencyMatrix(datapoint)
distance_matrix = Chem.GetDistanceMatrix(datapoint)
node_features, adjacency_matrix, distance_matrix = self.add_dummy_node(
node_features, adjacency_matrix, distance_matrix = self._add_dummy_node(
node_features, adjacency_matrix, distance_matrix)
node_features = self.pad_sequence(node_features)
adjacency_matrix = self.pad_sequence(adjacency_matrix)
distance_matrix = self.pad_sequence(distance_matrix)
node_features = self._pad_sequence(node_features)
adjacency_matrix = self._pad_sequence(adjacency_matrix)
distance_matrix = self._pad_sequence(distance_matrix)
return MATEncoding(node_features, adjacency_matrix, distance_matrix)
......@@ -24,9 +24,9 @@ class TestMATFeaturizer(unittest.TestCase):
out = featurizer.featurize(self.mol)
assert (type(out) == np.ndarray)
assert (out.shape == (1,))
assert (out[0].node_features.shape == (1, 3, 36))
assert (out[0].adjacency_matrix.shape == (1, 3, 3))
assert (out[0].distance_matrix.shape == (1, 3, 3))
assert (out[0].node_features.shape == (3, 36))
assert (out[0].adjacency_matrix.shape == (3, 3))
assert (out[0].distance_matrix.shape == (3, 3))
expected_node_features = np.array([[[
1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment