Commit eb02699b authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

PAGNTN featuriser

parent 9de39778
Loading
Loading
Loading
Loading
+135 −0
Original line number Diff line number Diff line
@@ -217,3 +217,138 @@ class MolGraphConvFeaturizer(MolecularFeaturizer):
        node_features=atom_features,
        edge_index=np.asarray([src, dest], dtype=int),
        edge_features=bond_features)


from deepchem.feat import MolecularFeaturizer
import numpy as np
from deepchem.feat.graph_data import GraphData


class PAGTNFeaturizer(MolecularFeaturizer):

  def __init__(self, max_length=5):

    try:
      from rdkit import Chem
      from dgllife.utils import ConcatFeaturizer
      from dgllife.utils import atom_formal_charge_one_hot, atom_type_one_hot, atom_degree_one_hot
      from dgllife.utils import atom_explicit_valence_one_hot, atom_implicit_valence_one_hot
      from dgllife.utils import atom_is_aromatic
      from functools import partial
      from dgllife.utils import one_hot_encoding
      from dgllife.utils import bond_type_one_hot, ConcatFeaturizer, bond_is_conjugated, bond_is_in_ring
    except:
      raise ImportError(
          "This class requires Rdkit & DGLLifeSci to be installed.")
    self.SYMBOLS = [
        'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe',
        'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd',
        'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In',
        'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'W', 'Ru', 'Nb', 'Re', 'Te', 'Rh',
        'Tc', 'Ba', 'Bi', 'Hf', 'Mo', 'U', 'Sm', 'Os', 'Ir', 'Ce', 'Gd', 'Ga',
        'Cs', '*', 'UNK'
    ]

    self.RING_TYPES = [(5, False), (5, True), (6, False), (6, True)]
    self.ordered_pair = lambda a, b: (a, b) if a < b else (b, a)
    self.bond_featurizer = ConcatFeaturizer(
        [bond_type_one_hot, bond_is_conjugated, bond_is_in_ring])
    self.max_length = 5

  def PAGTNAtomFeaturizer(atom):
    atom_type = get_atom_type_one_hot(atom, self.SYMBOLS, False)
    formal_charge = get_atom_formal_charge(atom)
    degree = get_atom_total_degree_one_hot(atom, 10, False)
    exp_valence = atom_explicit_valence_one_hot(atom, list(range(7)), False)
    imp_valence = get_atom_implicit_valence_one_hot(atom, list(range(6)), False)
    armoticity = get_atom_is_in_aromatic_one_hot(atom)
    atom_feat = np.concatenate([
        atom_type, formal_charge, degree, exp_valence, imp_valence, armoticity
    ])
    return atom_feat

  def bond_features(self, mol, path_atoms, ring_info):
    """Computes the edge features for a given pair of nodes.
        Parameters
        ----------
        mol : rdkit.Chem.rdchem.Mol
            RDKit molecule instance.
        path_atoms: tuple
            Shortest path between the given pair of nodes.
        ring_info: list
            Different rings that contain the pair of atoms
        """
    features = []
    path_bonds = []
    path_length = len(path_atoms)
    for path_idx in range(path_length - 1):
      bond = mol.GetBondBetweenAtoms(path_atoms[path_idx],
                                     path_atoms[path_idx + 1])
      if bond is None:
        import warnings
        warnings.warn('Valid idx of bonds must be passed')
      path_bonds.append(bond)

    for path_idx in range(self.max_length):
      if path_idx < len(path_bonds):
        features.append(self.bond_featurizer(path_bonds[path_idx]))
      else:
        features.append([0, 0, 0, 0, 0, 0])

    if path_length + 1 > self.max_length:
      path_length = self.max_length + 1
    position_feature = np.zeros(self.max_length + 2)
    position_feature[path_length] = 1
    features.append(position_feature)
    if ring_info:
      rfeat = [
          one_hot_encoding(r, allowable_set=self.RING_TYPES) for r in ring_info
      ]
      rfeat = [True] + np.any(rfeat, axis=0).tolist()
      features.append(rfeat)
    else:
      # This will return a boolean vector with all entries False
      features.append(
          [False] + one_hot_encoding(ring_info, allowable_set=self.RING_TYPES))
    return np.concatenate(features, axis=0)

  def PAGTNBondFeaturizer(self, mol):
    """Featurizes the input molecule.
        Parameters
        ----------
        mol : rdkit.Chem.rdchem.Mol
            RDKit molecule instance.
        Returns
        -------
        dict
            Mapping _edge_data_field to a float32 tensor of shape (N, M), where
            N is the number of atom pairs and M is the feature size depending on max_length.
        """

    n_atoms = mol.GetNumAtoms()
    # To get the shortest paths between two nodes.
    paths_dict = utils.compute_all_pairs_shortest_path(mol)
    # To get info if two nodes belong to the same ring.
    rings_dict = utils.compute_pairwise_ring_info(mol)
    # Featurizer
    feats = []
    src = []
    dest = []
    for i in range(n_atoms):
      for j in range(n_atoms):
        src.append(i)
        dest.append(j)

        if (i, j) not in paths_dict:
          feats.append(np.zeros(7 * self.max_length + 7))
          continue
        ring_info = rings_dict.get(self.ordered_pair(i, j), [])
        feats.append(self.bond_features(mol, paths_dict[(i, j)], ring_info))

    return np.array([src, dest], dtype=np.int), np.array(feats, dtype=np.float)

  def _featurize(self, mol):
    node_features = self.PAGTNAtomFeaturizer(mol)
    edge_index, edge_features = self.PAGTNBondFeaturizer(mol)
    graph = GraphData(node_features, edge_index, edge_features)
    return graph