Commit 4ee456d3 authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

Integerating few utils

parent c0629454
Loading
Loading
Loading
Loading
+16 −27
Original line number Diff line number Diff line
@@ -18,6 +18,10 @@ from deepchem.utils.molecule_feature_utils import get_bond_type_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_is_in_same_ring_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_is_conjugated_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_stereo_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_implicit_valence_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_explicit_valence_one_hot
from deepchem.utils.rdkit_utils import compute_all_pairs_shortest_path
from deepchem.utils.rdkit_utils import compute_pairwise_ring_info


def _construct_atom_feature(
@@ -219,27 +223,10 @@ class MolGraphConvFeaturizer(MolecularFeaturizer):
        edge_features=bond_features)


from deepchem.feat import MolecularFeaturizer
import numpy as np
from deepchem.feat.graph_data import GraphData


class PAGTNFeaturizer(MolecularFeaturizer):

  def __init__(self, max_length=5):

    try:
      from rdkit import Chem
      from dgllife.utils import ConcatFeaturizer
      from dgllife.utils import atom_formal_charge_one_hot, atom_type_one_hot, atom_degree_one_hot
      from dgllife.utils import atom_explicit_valence_one_hot, atom_implicit_valence_one_hot
      from dgllife.utils import atom_is_aromatic
      from functools import partial
      from dgllife.utils import one_hot_encoding
      from dgllife.utils import bond_type_one_hot, ConcatFeaturizer, bond_is_conjugated, bond_is_in_ring
    except:
      raise ImportError(
          "This class requires Rdkit & DGLLifeSci to be installed.")
    self.SYMBOLS = [
        'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe',
        'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd',
@@ -251,15 +238,13 @@ class PAGTNFeaturizer(MolecularFeaturizer):

    self.RING_TYPES = [(5, False), (5, True), (6, False), (6, True)]
    self.ordered_pair = lambda a, b: (a, b) if a < b else (b, a)
    self.bond_featurizer = ConcatFeaturizer(
        [bond_type_one_hot, bond_is_conjugated, bond_is_in_ring])
    self.max_length = 5

  def PAGTNAtomFeaturizer(atom):
  def PAGTNAtomFeaturizer(self, atom):
    atom_type = get_atom_type_one_hot(atom, self.SYMBOLS, False)
    formal_charge = get_atom_formal_charge(atom)
    degree = get_atom_total_degree_one_hot(atom, 10, False)
    exp_valence = atom_explicit_valence_one_hot(atom, list(range(7)), False)
    exp_valence = get_atom_explicit_valence_one_hot(atom, list(range(7)), False)
    imp_valence = get_atom_implicit_valence_one_hot(atom, list(range(6)), False)
    armoticity = get_atom_is_in_aromatic_one_hot(atom)
    atom_feat = np.concatenate([
@@ -291,7 +276,10 @@ class PAGTNFeaturizer(MolecularFeaturizer):

    for path_idx in range(self.max_length):
      if path_idx < len(path_bonds):
        features.append(self.bond_featurizer(path_bonds[path_idx]))
        bond_type = get_bond_type_one_hot(path_bonds[path_idx])
        conjugacy = get_bond_is_conjugated_one_hot(path_bonds[path_idx])
        ring_attach = get_bond_is_in_same_ring_one_hot(path_bonds[path_idx])
        features.append(np.concatenate([bond_type, conjugacy, ring_attach]))
      else:
        features.append([0, 0, 0, 0, 0, 0])

@@ -302,14 +290,15 @@ class PAGTNFeaturizer(MolecularFeaturizer):
    features.append(position_feature)
    if ring_info:
      rfeat = [
          one_hot_encoding(r, allowable_set=self.RING_TYPES) for r in ring_info
          get_atom_type_one_hot(r, allowable_set=self.RING_TYPES)
          for r in ring_info
      ]
      rfeat = [True] + np.any(rfeat, axis=0).tolist()
      features.append(rfeat)
    else:
      # This will return a boolean vector with all entries False
      features.append(
          [False] + one_hot_encoding(ring_info, allowable_set=self.RING_TYPES))
      features.append([False] + get_atom_type_one_hot(
          ring_info, allowable_set=self.RING_TYPES))
    return np.concatenate(features, axis=0)

  def PAGTNBondFeaturizer(self, mol):
@@ -327,9 +316,9 @@ class PAGTNFeaturizer(MolecularFeaturizer):

    n_atoms = mol.GetNumAtoms()
    # To get the shortest paths between two nodes.
    paths_dict = utils.compute_all_pairs_shortest_path(mol)
    paths_dict = compute_all_pairs_shortest_path(mol)
    # To get info if two nodes belong to the same ring.
    rings_dict = utils.compute_pairwise_ring_info(mol)
    rings_dict = compute_pairwise_ring_info(mol)
    # Featurizer
    feats = []
    src = []