Unverified Commit 085a3e53 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2652 from atreyamaj/featfix

Fixes to MATFeaturizer and Tests
parents 8b73e2e4 23682b9a
Loading
Loading
Loading
Loading
+158 −25
Original line number Diff line number Diff line
@@ -2,12 +2,34 @@ from deepchem.feat.base_classes import MolecularFeaturizer
from deepchem.utils.molecule_feature_utils import one_hot_encode
from deepchem.utils.typing import RDKitMol, RDKitAtom
import numpy as np
from typing import Tuple, Any
from dataclasses import dataclass


@dataclass
class MATEncoding:
  """
  Dataclass specific to the Molecular Attention Transformer [1]_.

  This dataclass class wraps around three different matrices for a given molecule: Node Features, Adjacency Matrix, and the Distance Matrix.

  Parameters
  ----------
  node_features: np.ndarray
    Node Features matrix for the molecule. For MAT, derived from the construct_node_features_matrix function.
  adjacency_matrix: np.ndarray
    Adjacency matrix for the molecule. Derived from rdkit.Chem.rdmolops.GetAdjacencyMatrix
  distance_matrix: np.ndarray
    Distance matrix for the molecule. Derived from rdkit.Chem.rdmolops.GetDistanceMatrix
  """
  node_features: np.ndarray
  adjacency_matrix: np.ndarray
  distance_matrix: np.ndarray


class MATFeaturizer(MolecularFeaturizer):
  """
  This class is a featurizer for the Molecule Attention Transformer [1]_.
  The featurizer accepts an RDKit Molecule, and a boolean (one_hot_formal_charge) as arguments.
  The returned value is a numpy array which consists of molecular graph descriptions:
    - Node Features
    - Adjacency Matrix
@@ -28,18 +50,37 @@ class MATFeaturizer(MolecularFeaturizer):
  This class requires RDKit to be installed.
  """

  def __init__(
      self,
      one_hot_formal_charge: bool = True,
  ):
  def __init__(self):
    pass

  def construct_mol(self, mol: RDKitMol) -> RDKitMol:
    """
    Processes an input RDKitMol further to be able to extract id-specific Conformers from it using mol.GetConformer().

    Parameters
    ----------
    one_hot_formal_charge: bool, default True
      If True, formal charges on atoms are one-hot encoded.
    mol: RDKitMol
      RDKit Mol object.

    Returns
    ----------
    mol: RDKitMol
      A processed RDKitMol objeect which is embedded, UFF Optimized and has Hydrogen atoms removed. If the former conditions are not met and there is a value error, then 2D Coordinates are computed instead.
    """
    try:
      from rdkit.Chem import AllChem
      from rdkit import Chem
    except ModuleNotFoundError:
      pass
    try:
      mol = Chem.AddHs(mol)
      AllChem.EmbedMolecule(mol, maxAttempts=5000)
      AllChem.UFFOptimizeMolecule(mol)
      mol = Chem.RemoveHs(mol)
    except ValueError:
      AllChem.Compute2DCoords(mol)

    self.one_hot_formal_charge = one_hot_formal_charge
    return mol

  def atom_features(self, atom: RDKitAtom) -> np.ndarray:
    """
@@ -55,7 +96,6 @@ class MATFeaturizer(MolecularFeaturizer):
    ----------
    Atom_features: ndarray
      Numpy array containing atom features.

    """
    attrib = []
    attrib += one_hot_encode(atom.GetAtomicNum(),
@@ -63,16 +103,108 @@ class MATFeaturizer(MolecularFeaturizer):
    attrib += one_hot_encode(len(atom.GetNeighbors()), [0, 1, 2, 3, 4, 5])
    attrib += one_hot_encode(atom.GetTotalNumHs(), [0, 1, 2, 3, 4])

    if self.one_hot_formal_charge:
      attrib += one_hot_encode(atom.GetFormalCharge(), [-1, 0, 1])
    else:
      attrib.append(atom.GetFormalCharge())
    attrib += one_hot_encode(atom.GetFormalCharge(),
                             [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5])

    attrib.append(atom.IsInRing())
    attrib.append(atom.GetIsAromatic())

    return np.array(attrib, dtype=np.float32)

  def construct_node_features_matrix(self, mol: RDKitMol) -> np.ndarray:
    """
    This function constructs a matrix of atom features for all atoms in a given molecule using the atom_features function.

    Parameters
    ----------
    mol: RDKitMol
      RDKit Mol object.

    Returns
    ----------
    Atom_features: ndarray
      Numpy array containing atom features.
    """
    return np.array([self.atom_features(atom) for atom in mol.GetAtoms()])

  def _add_dummy_node(
      self, node_features: np.ndarray, adj_matrix: np.ndarray,
      dist_matrix: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Adds a single dummy node to the molecule, which is consequently reflected in the Node Features Matrix, Adjacency Matrix and the Distance Matrix.

    Parameters
    ----------
    node_features: np.ndarray
      Node Features matrix for a given molecule.
    adjacency_matrix: np.ndarray
      Adjacency matrix for a given molecule.
    distance_matrix: np.ndarray
      Distance matrix for a given molecule.

    Returns
    ----------
    Atom_features: Tuple[np.ndarray, np.ndarray, np.ndarray]
      A tuple containing three numpy arrays: node_features, adjacency_matrix, distance_matrix.
    """

    if node_features is not None:
      m = np.zeros((node_features.shape[0] + 1, node_features.shape[1] + 1))
      m[1:, 1:] = node_features
      m[0, 0] = 1.0
      node_features = m

    if adj_matrix is not None:
      m = np.zeros((adj_matrix.shape[0] + 1, adj_matrix.shape[1] + 1))
      m[1:, 1:] = adj_matrix
      adj_matrix = m

    if dist_matrix is not None:
      m = np.full((dist_matrix.shape[0] + 1, dist_matrix.shape[1] + 1), 1e6)
      m[1:, 1:] = dist_matrix
      dist_matrix = m

    return node_features, adj_matrix, dist_matrix

  def _pad_array(self, array: np.ndarray, shape: Any) -> np.ndarray:
    """
    Pads an array to the desired shape.

    Parameters
    ----------
    array: np.ndarray
      Array to be padded.
    shape: int or Tuple
      Shape the array is padded to.

    Returns
    ----------
    array: np.ndarray
      Array padded to input shape.
    """
    result = np.zeros(shape=shape)
    slices = tuple(slice(s) for s in array.shape)
    result[slices] = array
    return result

  def _pad_sequence(self, sequence: np.ndarray) -> np.ndarray:
    """
    Pads a given sequence using the pad_array function.

    Parameters
    ----------
    sequence: np.ndarray
      Arrays in this sequence are padded to the largest shape in the sequence.

    Returns
    ----------
    array: np.ndarray
      Sequence with padded arrays.
    """
    shapes = np.stack([np.array(t.shape) for t in sequence])
    max_shape = tuple(np.max(shapes, axis=0))
    return np.stack([self._pad_array(t, shape=max_shape) for t in sequence])

  def _featurize(self, datapoint: RDKitMol, **kwargs) -> np.ndarray:
    """
    Featurize the molecule.
@@ -84,25 +216,26 @@ class MATFeaturizer(MolecularFeaturizer):

    Returns
    -------
    np.ndarray: A concatenated matrix consisting of node_features, adjacency_matrix and distance_matrix.
    MATEncoding: A MATEncoding dataclass instance consisting of processed node_features, adjacency_matrix and distance_matrix.
    """
    if 'mol' in kwargs:
      datapoint = kwargs.get("mol")
      raise DeprecationWarning(
          'Mol is being phased out as a parameter, please pass "datapoint" instead.'
      )

    try:
    from rdkit import Chem
    except:
      raise ImportError("This class requires RDKit to be installed.")

    node_features = np.array(
        [self.atom_features(atom) for atom in datapoint.GetAtoms()])
    adjacency_matrix = Chem.rdmolops.GetAdjacencyMatrix(datapoint)
    distance_matrix = Chem.rdmolops.GetDistanceMatrix(datapoint)
    datapoint = self.construct_mol(datapoint)

    result = np.concatenate(
        [node_features, adjacency_matrix, distance_matrix], axis=1)
    node_features = self.construct_node_features_matrix(datapoint)
    adjacency_matrix = Chem.GetAdjacencyMatrix(datapoint)
    distance_matrix = Chem.GetDistanceMatrix(datapoint)

    return result
    node_features, adjacency_matrix, distance_matrix = self._add_dummy_node(
        node_features, adjacency_matrix, distance_matrix)

    node_features = self._pad_sequence(node_features)
    adjacency_matrix = self._pad_sequence(adjacency_matrix)
    distance_matrix = self._pad_sequence(distance_matrix)

    return MATEncoding(node_features, adjacency_matrix, distance_matrix)
+19 −8
Original line number Diff line number Diff line
@@ -23,12 +23,23 @@ class TestMATFeaturizer(unittest.TestCase):
    featurizer = MATFeaturizer()
    out = featurizer.featurize(self.mol)
    assert (type(out) == np.ndarray)
    assert (out.shape == (1, 2, 31))
    correct_array = np.array([[[
        0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1.
    assert (out[0].node_features.shape == (3, 36))
    assert (out[0].adjacency_matrix.shape == (3, 3))
    assert (out[0].distance_matrix.shape == (3, 3))
    expected_node_features = np.array([[
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
    ], [
        0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0.
    ]]])
    assert (np.array_equal(out, correct_array))
        0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.
    ], [
        0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.
    ]])
    expected_adj_matrix = np.array([[0., 0., 0.], [0., 0., 1.], [0., 1., 0.]])
    expected_dist_matrix = np.array([[1.e+06, 1.e+06,
                                      1.e+06], [1.e+06, 0.e+00, 1.e+00],
                                     [1.e+06, 1.e+00, 0.e+00]])
    assert (np.array_equal(out[0].node_features, expected_node_features))
    assert (np.array_equal(out[0].adjacency_matrix, expected_adj_matrix))
    assert (np.array_equal(out[0].distance_matrix, expected_dist_matrix))