Commit 36036884 authored by atreyamaj's avatar atreyamaj
Browse files

Update

parent 3736c5fe
Loading
Loading
Loading
Loading
+8 −9
Original line number Diff line number Diff line
@@ -30,7 +30,6 @@ class MATEncoding:
class MATFeaturizer(MolecularFeaturizer):
  """
  This class is a featurizer for the Molecule Attention Transformer [1]_.
  The featurizer accepts an RDKit Molecule, and a boolean (one_hot_formal_charge) as arguments.
  The returned value is a numpy array which consists of molecular graph descriptions:
    - Node Features
    - Adjacency Matrix
@@ -128,7 +127,7 @@ class MATFeaturizer(MolecularFeaturizer):
    """
    return np.array([self.atom_features(atom) for atom in mol.GetAtoms()])

  def add_dummy_node(
  def _add_dummy_node(
      self, node_features: np.ndarray, adj_matrix: np.ndarray,
      dist_matrix: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
@@ -167,7 +166,7 @@ class MATFeaturizer(MolecularFeaturizer):

    return node_features, adj_matrix, dist_matrix

  def pad_array(self, array: np.ndarray, shape: Any) -> np.ndarray:
  def _pad_array(self, array: np.ndarray, shape: Any) -> np.ndarray:
    """
    Pads an array to the desired shape.

@@ -188,7 +187,7 @@ class MATFeaturizer(MolecularFeaturizer):
    result[slices] = array
    return result

  def pad_sequence(self, sequence: np.ndarray) -> np.ndarray:
  def _pad_sequence(self, sequence: np.ndarray) -> np.ndarray:
    """
    Pads a given sequence using the pad_array function.

@@ -204,7 +203,7 @@ class MATFeaturizer(MolecularFeaturizer):
    """
    shapes = np.stack([np.array(t.shape) for t in sequence])
    max_shape = tuple(np.max(shapes, axis=0))
    return np.stack([self.pad_array(t, shape=max_shape) for t in sequence])
    return np.stack([self._pad_array(t, shape=max_shape) for t in sequence])

  def _featurize(self, datapoint: RDKitMol, **kwargs) -> np.ndarray:
    """
@@ -236,11 +235,11 @@ class MATFeaturizer(MolecularFeaturizer):
    adjacency_matrix = Chem.GetAdjacencyMatrix(datapoint)
    distance_matrix = Chem.GetDistanceMatrix(datapoint)

    node_features, adjacency_matrix, distance_matrix = self.add_dummy_node(
    node_features, adjacency_matrix, distance_matrix = self._add_dummy_node(
        node_features, adjacency_matrix, distance_matrix)

    node_features = self.pad_sequence(node_features)
    adjacency_matrix = self.pad_sequence(adjacency_matrix)
    distance_matrix = self.pad_sequence(distance_matrix)
    node_features = self._pad_sequence(node_features)
    adjacency_matrix = self._pad_sequence(adjacency_matrix)
    distance_matrix = self._pad_sequence(distance_matrix)

    return MATEncoding(node_features, adjacency_matrix, distance_matrix)
+3 −3
Original line number Diff line number Diff line
@@ -24,9 +24,9 @@ class TestMATFeaturizer(unittest.TestCase):
    out = featurizer.featurize(self.mol)
    assert (type(out) == np.ndarray)
    assert (out.shape == (1,))
    assert (out[0].node_features.shape == (1, 3, 36))
    assert (out[0].adjacency_matrix.shape == (1, 3, 3))
    assert (out[0].distance_matrix.shape == (1, 3, 3))
    assert (out[0].node_features.shape == (3, 36))
    assert (out[0].adjacency_matrix.shape == (3, 3))
    assert (out[0].distance_matrix.shape == (3, 3))
    expected_node_features = np.array([[[
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.