Commit 458164ea authored by Atreya Majumdar's avatar Atreya Majumdar
Browse files

Fixed MATFeaturizer, added tests

parent b88ef1e2
Loading
Loading
Loading
Loading
+15 −35
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ class MATFeaturizer(MolecularFeaturizer):
  """
  This class is a featurizer for the Molecule Attention Transformer [1]_.
  The featurizer accepts an RDKit Molecule, and 2 booleans (add_dummy_node and one_hot_formal_charge) as arguments.
  The returned value is a tuple which consists of molecular graph descriptions:
  The returned value is a numpy array which consists of molecular graph descriptions:
    - Node Features
    - Adjacency Matrix
    - Distance Matrix
@@ -60,7 +60,7 @@ class MATFeaturizer(MolecularFeaturizer):
    
    """
    attrib = []
    attrib += one_hot_encode(atom.GetAtomicNumber(),
    attrib += one_hot_encode(atom.GetAtomicNum(),
                             [5, 6, 7, 8, 9, 15, 16, 17, 35, 53, 999])
    attrib += one_hot_encode(len(atom.GetNeighbors()), [0, 1, 2, 3, 4, 5])
    attrib += one_hot_encode(atom.GetTotalNumHs(), [0, 1, 2, 3, 4])
@@ -86,38 +86,18 @@ class MATFeaturizer(MolecularFeaturizer):
    
    Returns
    -------
    tuple: (node_features, adjacency_matrix, distance_matrix)
    numpy.ndarray: (node_features, adjacency_matrix, distance_matrix)
    """

    node_features = np.array([
        self.atom_features(atom, self.one_hot_formal_charge)
        for atom in mol.getAtoms()
    ])

    adjacency_matrix = Chem.rdmolops.getAdjacencyMatrix(mol)

    conformer = mol.GetConformer()
    positional_matrix = np.array([[
        conformer.GetAtomPosition(k).x,
        conformer.GetAtomPosition(k).y,
        conformer.GetAtomPosition(k).z
    ] for k in range(mol.GetNumAtoms())])
    distance_matrix = pairwise_distances(positional_matrix)

    if self.add_dummy_node:
      m = np.zeros((node_features.shape[0] + 1, node_features.shape[1] + 1))
      m[1:, 1:] = node_features
      m[0, 0] = 1.0
      node_features = m

      m = np.zeros((adjacency_matrix.shape[0] + 1,
                    adjacency_matrix.shape[1] + 1))
      m[1:, 1:] = adjacency_matrix
      adjacency_matrix = m

      m = np.full((distance_matrix.shape[0] + 1, distance_matrix.shape[1] + 1),
                  1e6)
      m[1:, 1:] = distance_matrix
      distance_matrix = m
    node_features = np.array(
        [self.atom_features(atom) for atom in mol.GetAtoms()])

    adjacency_matrix = Chem.rdmolops.GetAdjacencyMatrix(mol)

    distance_matrix = Chem.rdmolops.GetDistanceMatrix(mol)

    adjacency_matrix.resize(node_features.shape)
    distance_matrix.resize(node_features.shape)

    return node_features, adjacency_matrix, distance_matrix
    
 No newline at end of file
+22 −3
Original line number Diff line number Diff line
import unittest
from deepchem.feat import material_featurizers
from deepchem.feat import MATFeaturizer
import numpy as np


class TestMATFeaturizer(unittest.TestCase):
  """
    Test MATFeaturizer.
    """

  def setUp(self):
    """
        Set up tests.
        """
    from rdkit import Chem
    smiles = 'CC(C)CC(=O)'
    self.mol = Chem.MolFromSmiles(smiles)

  def test_mat_featurizer(self):
    """
        Test featurizer.py
        """
    featurizer = MATFeaturizer()
    out = featurizer.featurize(self.mol)
    assert (type(out) == np.ndarray)
    assert (out.shape == (1, 3, 6, 27))