Commit ffd1189c authored by VIGNESHinZONE's avatar VIGNESHinZONE
Browse files

Additional Utils and bug fixes

parent 4ee456d3
Loading
Loading
Loading
Loading
+14 −9
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ import numpy as np
from deepchem.utils.typing import RDKitAtom, RDKitBond, RDKitMol
from deepchem.feat.graph_data import GraphData
from deepchem.feat.base_classes import MolecularFeaturizer
from deepchem.utils.molecule_feature_utils import one_hot_encode
from deepchem.utils.molecule_feature_utils import get_atom_type_one_hot
from deepchem.utils.molecule_feature_utils import construct_hydrogen_bonding_info
from deepchem.utils.molecule_feature_utils import get_atom_hydrogen_bonding_one_hot
@@ -18,6 +19,7 @@ from deepchem.utils.molecule_feature_utils import get_bond_type_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_is_in_same_ring_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_is_conjugated_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_stereo_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_formal_charge_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_implicit_valence_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_explicit_valence_one_hot
from deepchem.utils.rdkit_utils import compute_all_pairs_shortest_path
@@ -238,12 +240,13 @@ class PAGTNFeaturizer(MolecularFeaturizer):

    self.RING_TYPES = [(5, False), (5, True), (6, False), (6, True)]
    self.ordered_pair = lambda a, b: (a, b) if a < b else (b, a)
    self.max_length = 5
    self.max_length = max_length

  def PAGTNAtomFeaturizer(self, atom):
    atom_type = get_atom_type_one_hot(atom, self.SYMBOLS, False)
    formal_charge = get_atom_formal_charge(atom)
    degree = get_atom_total_degree_one_hot(atom, 10, False)
    formal_charge = get_atom_formal_charge_one_hot(
        atom, include_unknown_set=False)
    degree = get_atom_total_degree_one_hot(atom, list(range(11)), False)
    exp_valence = get_atom_explicit_valence_one_hot(atom, list(range(7)), False)
    imp_valence = get_atom_implicit_valence_one_hot(atom, list(range(6)), False)
    armoticity = get_atom_is_in_aromatic_one_hot(atom)
@@ -281,7 +284,8 @@ class PAGTNFeaturizer(MolecularFeaturizer):
        ring_attach = get_bond_is_in_same_ring_one_hot(path_bonds[path_idx])
        features.append(np.concatenate([bond_type, conjugacy, ring_attach]))
      else:
        features.append([0, 0, 0, 0, 0, 0])
        features.append(np.zeros(6))
        #features.append([0, 0, 0, 0, 0, 0])

    if path_length + 1 > self.max_length:
      path_length = self.max_length + 1
@@ -290,15 +294,14 @@ class PAGTNFeaturizer(MolecularFeaturizer):
    features.append(position_feature)
    if ring_info:
      rfeat = [
          get_atom_type_one_hot(r, allowable_set=self.RING_TYPES)
          for r in ring_info
          one_hot_encode(r, allowable_set=self.RING_TYPES) for r in ring_info
      ]
      rfeat = [True] + np.any(rfeat, axis=0).tolist()
      features.append(rfeat)
    else:
      # This will return a boolean vector with all entries False
      features.append([False] + get_atom_type_one_hot(
          ring_info, allowable_set=self.RING_TYPES))
      features.append([False] +
                      one_hot_encode(ring_info, allowable_set=self.RING_TYPES))
    return np.concatenate(features, axis=0)

  def PAGTNBondFeaturizer(self, mol):
@@ -337,7 +340,9 @@ class PAGTNFeaturizer(MolecularFeaturizer):
    return np.array([src, dest], dtype=np.int), np.array(feats, dtype=np.float)

  def _featurize(self, mol):
    node_features = self.PAGTNAtomFeaturizer(mol)
    node_features = np.asarray(
        [self.PAGTNAtomFeaturizer(atom) for atom in mol.GetAtoms()],
        dtype=np.float)
    edge_index, edge_features = self.PAGTNBondFeaturizer(mol)
    graph = GraphData(node_features, edge_index, edge_features)
    return graph
+26 −0
Original line number Diff line number Diff line
@@ -30,6 +30,7 @@ DEFAULT_ATOM_TYPE_SET = [
]
DEFAULT_HYBRIDIZATION_SET = ["SP", "SP2", "SP3"]
DEFAULT_TOTAL_NUM_Hs_SET = [0, 1, 2, 3, 4]
DEFAULT_FORMAL_CHARGE_SET = [-2, -1, 0, 1, 2]
DEFAULT_TOTAL_DEGREE_SET = [0, 1, 2, 3, 4, 5]
DEFAULT_RING_SIZE_SET = [3, 4, 5, 6, 7, 8]
DEFAULT_BOND_TYPE_SET = ["SINGLE", "DOUBLE", "TRIPLE", "AROMATIC"]
@@ -308,6 +309,31 @@ def get_atom_formal_charge(atom: RDKitAtom) -> List[float]:
  return [float(atom.GetFormalCharge())]


def get_atom_formal_charge_one_hot(
    atom: RDKitAtom,
    allowable_set: List[int] = DEFAULT_FORMAL_CHARGE_SET,
    include_unknown_set: bool = True) -> List[float]:
  """Get one hot encoding of formal charge of an atom.

  Parameters
  ---------
  atom: rdkit.Chem.rdchem.Atom
    RDKit atom object
  allowable_set: List[int]
    The degree to consider. The default set is `[-2, -1, ..., 2]`
  include_unknown_set: bool, default True
    If true, the index of all types not in `allowable_set` is `len(allowable_set)`.


  Returns
  -------
  List[float]
    A vector of the formal charge.
  """
  return one_hot_encode(atom.GetFormalCharge(), allowable_set,
                        include_unknown_set)


def get_atom_partial_charge(atom: RDKitAtom) -> List[float]:
  """Get a partial charge of an atom.

+8 −0
Original line number Diff line number Diff line
@@ -9,6 +9,8 @@ from deepchem.utils.molecule_feature_utils import get_atom_total_num_Hs_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_is_in_aromatic_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_chirality_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_formal_charge
from deepchem.utils.molecule_feature_utils import get_atom_formal_charge_one_hot

from deepchem.utils.molecule_feature_utils import get_atom_partial_charge
from deepchem.utils.molecule_feature_utils import get_atom_total_degree_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_implicit_valence_one_hot
@@ -119,6 +121,12 @@ class TestGraphConvUtils(unittest.TestCase):
    formal_charge = get_atom_formal_charge(atoms[0])
    assert formal_charge == [0.0]

  def test_get_atom_formal_charge_one_hot(self):
    atoms = self.mol.GetAtoms()
    assert atoms[0].GetSymbol() == "C"
    formal_charge = get_atom_formal_charge_one_hot(atoms[0])
    assert formal_charge == [0.0, 0.0, 1.0, 0.0, 0.0, 0.0]

  def test_get_atom_partial_charge(self):
    from rdkit.Chem import AllChem
    atoms = self.mol.GetAtoms()