Unverified Commit f8ded19b authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2484 from VIGNESHinZONE/util

Adding essential Molecular Utils
parents 4041e83b 676d3b4c
Loading
Loading
Loading
Loading
+56 −0
Original line number Diff line number Diff line
@@ -35,6 +35,8 @@ DEFAULT_RING_SIZE_SET = [3, 4, 5, 6, 7, 8]
DEFAULT_BOND_TYPE_SET = ["SINGLE", "DOUBLE", "TRIPLE", "AROMATIC"]
DEFAULT_BOND_STEREO_SET = ["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"]
DEFAULT_GRAPH_DISTANCE_SET = [1, 2, 3, 4, 5, 6, 7]
DEFAULT_ATOM_IMPLICIT_VALENCE_SET = [0, 1, 2, 3, 4, 5, 6]
DEFAULT_ATOM_EXPLICIT_VALENCE_SET = [1, 2, 3, 4, 5, 6]


class _ChemicalFeaturesFactory:
@@ -356,6 +358,60 @@ def get_atom_total_degree_one_hot(
                        include_unknown_set)


def get_atom_implicit_valence_one_hot(
    atom: RDKitAtom,
    allowable_set: List[int] = DEFAULT_ATOM_IMPLICIT_VALENCE_SET,
    include_unknown_set: bool = True) -> List[float]:
  """Get an one-hot feature of implicit valence of an atom.

  Parameters
  ---------
  atom: rdkit.Chem.rdchem.Atom
    RDKit atom object
  allowable_set: List[int]
    Atom implicit valence to consider. The default set is `[0, 1, ..., 6]`
  include_unknown_set: bool, default True
    If true, the index of all types not in `allowable_set` is `len(allowable_set)`.

  Returns
  -------
  List[float]
    A one-hot vector of implicit valence an atom has.
    If `include_unknown_set` is False, the length is `len(allowable_set)`.
    If `include_unknown_set` is True, the length is `len(allowable_set) + 1`.

  """
  return one_hot_encode(atom.GetImplicitValence(), allowable_set,
                        include_unknown_set)


def get_atom_explicit_valence_one_hot(
    atom: RDKitAtom,
    allowable_set: List[int] = DEFAULT_ATOM_EXPLICIT_VALENCE_SET,
    include_unknown_set: bool = True) -> List[float]:
  """Get an one-hot feature of explicit valence of an atom.

  Parameters
  ---------
  atom: rdkit.Chem.rdchem.Atom
    RDKit atom object
  allowable_set: List[int]
    Atom explicit valence to consider. The default set is `[1, ..., 6]`
  include_unknown_set: bool, default True
    If true, the index of all types not in `allowable_set` is `len(allowable_set)`.

  Returns
  -------
  List[float]
    A one-hot vector of explicit valence an atom has.
    If `include_unknown_set` is False, the length is `len(allowable_set)`.
    If `include_unknown_set` is True, the length is `len(allowable_set) + 1`.

  """
  return one_hot_encode(atom.GetExplicitValence(), allowable_set,
                        include_unknown_set)


#################################################################
# bond (edge) featurization
#################################################################
+68 −0
Original line number Diff line number Diff line
@@ -626,3 +626,71 @@ def compute_ring_normal(mol, ring_indices):
  v2 = points[2] - points[0]
  normal = np.cross(v1, v2)
  return normal


def compute_all_pairs_shortest_path(
    mol) -> Dict[Tuple[int, int], Tuple[int, int]]:
  """Computes the All pair shortest between every pair of nodes
  in terms of Rdkit Atom indexes.

  Parameters:
  -----------
  mol: rdkit.rdchem.Mol
    Molecule containing a ring

  Returns:
  --------
  paths_dict: Dict representing every atom-atom pair as key in Rdkit index
  and value as the shortest path between each atom pair in terms of Atom index.
  """
  try:
    from rdkit import Chem
  except:
    raise ImportError("This class requires RDkit installed")
  n_atoms = mol.GetNumAtoms()
  paths_dict = {(i, j): Chem.rdmolops.GetShortestPath(mol, i, j)
                for i in range(n_atoms) for j in range(n_atoms) if i < j}
  return paths_dict


def compute_pairwise_ring_info(mol):
  """ Computes all atom-atom pair belong to same ring with
  its ring size and its aromaticity.

  Parameters:
  -----------
  mol: rdkit.rdchem.Mol
    Molecule containing a ring

  Returns:
  --------
  rings_dict: Key consisting of all node-node pair sharing the same ring
  and value as a tuple of size of ring and its aromaticity.
  """
  try:
    from rdkit import Chem
  except:
    raise ImportError("This class requires RDkit installed")
  rings_dict = {}

  def ordered_pair(a, b):
    return (a, b) if a < b else (b, a)

  ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
  for ring in ssr:
    ring_sz = len(ring)
    is_aromatic = True
    for atom_idx in ring:
      if not mol.GetAtoms()[atom_idx].GetIsAromatic():
        is_aromatic = False
        break
    for ring_idx, atom_idx in enumerate(ring):
      for other_idx in ring[ring_idx:]:
        atom_pair = ordered_pair(atom_idx, other_idx)
        if atom_pair not in rings_dict:
          rings_dict[atom_pair] = [(ring_sz, is_aromatic)]
        else:
          if (ring_sz, is_aromatic) not in rings_dict[atom_pair]:
            rings_dict[atom_pair].append((ring_sz, is_aromatic))

  return rings_dict
+22 −0
Original line number Diff line number Diff line
@@ -11,6 +11,8 @@ from deepchem.utils.molecule_feature_utils import get_atom_chirality_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_formal_charge
from deepchem.utils.molecule_feature_utils import get_atom_partial_charge
from deepchem.utils.molecule_feature_utils import get_atom_total_degree_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_implicit_valence_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_explicit_valence_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_type_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_is_in_same_ring_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_is_conjugated_one_hot
@@ -140,6 +142,26 @@ class TestGraphConvUtils(unittest.TestCase):
    one_hot = get_atom_total_degree_one_hot(atoms[3])
    assert one_hot == [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]

  def test_get_atom_implicit_valence_one_hot(self):
    atoms = self.mol.GetAtoms()
    assert atoms[0].GetSymbol() == "C"
    one_hot = get_atom_implicit_valence_one_hot(atoms[0])
    assert one_hot == [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]

    assert atoms[3].GetSymbol() == "O"
    one_hot = get_atom_implicit_valence_one_hot(atoms[3])
    assert one_hot == [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

  def test_get_atom_explicit_valence_one_hot(self):
    atoms = self.mol.GetAtoms()
    assert atoms[0].GetSymbol() == "C"
    one_hot = get_atom_explicit_valence_one_hot(atoms[0])
    assert one_hot == [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

    assert atoms[3].GetSymbol() == "O"
    one_hot = get_atom_explicit_valence_one_hot(atoms[3])
    assert one_hot == [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]

  def test_get_bond_type_one_hot(self):
    bonds = self.mol.GetBonds()
    one_hot = get_bond_type_one_hot(bonds[0])
+25 −0
Original line number Diff line number Diff line
@@ -162,3 +162,28 @@ class TestRdkitUtil(unittest.TestCase):

  def test_strip_hydrogens(self):
    pass

  def test_all_shortest_pairs(self):
    from rdkit import Chem
    mol = Chem.MolFromSmiles("CN=C=O")
    valid_dict = {
        (0, 1): (0, 1),
        (0, 2): (0, 1, 2),
        (0, 3): (0, 1, 2, 3),
        (1, 2): (1, 2),
        (1, 3): (1, 2, 3),
        (2, 3): (2, 3)
    }
    assert rdkit_utils.compute_all_pairs_shortest_path(mol) == valid_dict

  def test_pairwise_ring_info(self):
    from rdkit import Chem
    mol = Chem.MolFromSmiles("c1ccccc1")
    predict_dict = rdkit_utils.compute_pairwise_ring_info(mol)
    assert all(pair == [(6, True)] for pair in predict_dict.values())
    mol = Chem.MolFromSmiles("c1c2ccccc2ccc1")
    predict_dict = rdkit_utils.compute_pairwise_ring_info(mol)
    assert all(pair == [(6, True)] for pair in predict_dict.values())
    mol = Chem.MolFromSmiles("CN=C=O")
    predict_dict = rdkit_utils.compute_pairwise_ring_info(mol)
    assert not predict_dict