Commit 36036884 authored by atreyamaj's avatar atreyamaj
Browse files

Update

parent 3736c5fe
...@@ -30,7 +30,6 @@ class MATEncoding: ...@@ -30,7 +30,6 @@ class MATEncoding:
class MATFeaturizer(MolecularFeaturizer): class MATFeaturizer(MolecularFeaturizer):
""" """
This class is a featurizer for the Molecule Attention Transformer [1]_. This class is a featurizer for the Molecule Attention Transformer [1]_.
The featurizer accepts an RDKit Molecule, and a boolean (one_hot_formal_charge) as arguments.
The returned value is a numpy array which consists of molecular graph descriptions: The returned value is a numpy array which consists of molecular graph descriptions:
- Node Features - Node Features
- Adjacency Matrix - Adjacency Matrix
...@@ -128,7 +127,7 @@ class MATFeaturizer(MolecularFeaturizer): ...@@ -128,7 +127,7 @@ class MATFeaturizer(MolecularFeaturizer):
""" """
return np.array([self.atom_features(atom) for atom in mol.GetAtoms()]) return np.array([self.atom_features(atom) for atom in mol.GetAtoms()])
def add_dummy_node( def _add_dummy_node(
self, node_features: np.ndarray, adj_matrix: np.ndarray, self, node_features: np.ndarray, adj_matrix: np.ndarray,
dist_matrix: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: dist_matrix: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
""" """
...@@ -167,7 +166,7 @@ class MATFeaturizer(MolecularFeaturizer): ...@@ -167,7 +166,7 @@ class MATFeaturizer(MolecularFeaturizer):
return node_features, adj_matrix, dist_matrix return node_features, adj_matrix, dist_matrix
def pad_array(self, array: np.ndarray, shape: Any) -> np.ndarray: def _pad_array(self, array: np.ndarray, shape: Any) -> np.ndarray:
""" """
Pads an array to the desired shape. Pads an array to the desired shape.
...@@ -188,7 +187,7 @@ class MATFeaturizer(MolecularFeaturizer): ...@@ -188,7 +187,7 @@ class MATFeaturizer(MolecularFeaturizer):
result[slices] = array result[slices] = array
return result return result
def pad_sequence(self, sequence: np.ndarray) -> np.ndarray: def _pad_sequence(self, sequence: np.ndarray) -> np.ndarray:
""" """
Pads a given sequence using the pad_array function. Pads a given sequence using the pad_array function.
...@@ -204,7 +203,7 @@ class MATFeaturizer(MolecularFeaturizer): ...@@ -204,7 +203,7 @@ class MATFeaturizer(MolecularFeaturizer):
""" """
shapes = np.stack([np.array(t.shape) for t in sequence]) shapes = np.stack([np.array(t.shape) for t in sequence])
max_shape = tuple(np.max(shapes, axis=0)) max_shape = tuple(np.max(shapes, axis=0))
return np.stack([self.pad_array(t, shape=max_shape) for t in sequence]) return np.stack([self._pad_array(t, shape=max_shape) for t in sequence])
def _featurize(self, datapoint: RDKitMol, **kwargs) -> np.ndarray: def _featurize(self, datapoint: RDKitMol, **kwargs) -> np.ndarray:
""" """
...@@ -236,11 +235,11 @@ class MATFeaturizer(MolecularFeaturizer): ...@@ -236,11 +235,11 @@ class MATFeaturizer(MolecularFeaturizer):
adjacency_matrix = Chem.GetAdjacencyMatrix(datapoint) adjacency_matrix = Chem.GetAdjacencyMatrix(datapoint)
distance_matrix = Chem.GetDistanceMatrix(datapoint) distance_matrix = Chem.GetDistanceMatrix(datapoint)
node_features, adjacency_matrix, distance_matrix = self.add_dummy_node( node_features, adjacency_matrix, distance_matrix = self._add_dummy_node(
node_features, adjacency_matrix, distance_matrix) node_features, adjacency_matrix, distance_matrix)
node_features = self.pad_sequence(node_features) node_features = self._pad_sequence(node_features)
adjacency_matrix = self.pad_sequence(adjacency_matrix) adjacency_matrix = self._pad_sequence(adjacency_matrix)
distance_matrix = self.pad_sequence(distance_matrix) distance_matrix = self._pad_sequence(distance_matrix)
return MATEncoding(node_features, adjacency_matrix, distance_matrix) return MATEncoding(node_features, adjacency_matrix, distance_matrix)
...@@ -24,9 +24,9 @@ class TestMATFeaturizer(unittest.TestCase): ...@@ -24,9 +24,9 @@ class TestMATFeaturizer(unittest.TestCase):
out = featurizer.featurize(self.mol) out = featurizer.featurize(self.mol)
assert (type(out) == np.ndarray) assert (type(out) == np.ndarray)
assert (out.shape == (1,)) assert (out.shape == (1,))
assert (out[0].node_features.shape == (1, 3, 36)) assert (out[0].node_features.shape == (3, 36))
assert (out[0].adjacency_matrix.shape == (1, 3, 3)) assert (out[0].adjacency_matrix.shape == (3, 3))
assert (out[0].distance_matrix.shape == (1, 3, 3)) assert (out[0].distance_matrix.shape == (3, 3))
expected_node_features = np.array([[[ expected_node_features = np.array([[[
1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0. 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment