Unverified Commit ffad1a75 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1997 from deepchem/fragment_util

Add Molecular Fragment Utilities
parents eb3ab25e 7ca8f488
Loading
Loading
Loading
Loading
+253 −3
Original line number Diff line number Diff line
"""A collection of utilities for dealing with Molecular Fragments"""
import itertools
import numpy as np
from typing import List, Optional, Any
from deepchem.utils.geometry_utils import compute_pairwise_distances
from deepchem.utils.rdkit_util import compute_charges


def get_contact_atom_indices(fragments, cutoff=4.5):
def get_partial_charge(atom):
  """Get partial charge of a given atom (rdkit Atom object)
  
  Parameters
  ----------
  atom: rdkit atom or `AtomShim` object
    Either an rdkit atom or `AtomShim`

  Note
  ----
  This function requires RDKit to be installed.

  Examples
  --------
  >>> from rdkit import Chem
  >>> mol = Chem.MolFromSmiles("CC")
  >>> atom = mol.GetAtoms()[0]
  >>> get_partial_charge(atom)
  0
  """
  from rdkit import Chem
  if isinstance(atom, Chem.Atom):
    try:
      value = atom.GetProp(str("_GasteigerCharge"))
      if value == '-nan':
        return 0
      return float(value)
    except KeyError:
      return 0
  else:
    return atom.GetPartialCharge()


class MolecularFragment(object):
  """A class that represents a fragment of a molecule.

  It's often convenient to represent a fragment of a molecule. For
  example, if two molecules form a molecular complex, it may be useful
  to create two fragments which represent the subsets of each molecule
  that's close to the other molecule (in the contact region).

  Ideally, we'd be able to do this in RDKit direct, but manipulating
  molecular fragments doesn't seem to be supported functionality. 

  Examples
  --------
  >>> import numpy as np
  >>> from rdkit import Chem
  >>> mol = Chem.MolFromSmiles("C")
  >>> coords = np.array([[0.0, 0.0, 0.0]])
  >>> atom = mol.GetAtoms()[0]
  >>> fragment = MolecularFragment([atom], coords)
  """

  def __init__(self, atoms, coords):
    """Initialize this object.

    Parameters
    ----------
    atoms: list
      Each entry in this list should be an RdkitAtom
    coords: np.ndarray
      Array of locations for atoms of shape `(N, 3)` where `N ==
      len(atoms)`.
    """
    if not isinstance(coords, np.ndarray):
      raise ValueError("Coords must be a numpy array of shape (N, 3)")
    if coords.shape != (len(atoms), 3):
      raise ValueError(
          "Coords must be a numpy array of shape `(N, 3)` where `N == len(atoms)`."
      )
    self.atoms = [
        AtomShim(x.GetAtomicNum(), get_partial_charge(x), coords[ind])
        for ind, x in enumerate(atoms)
    ]
    self.coords = coords

  def GetAtoms(self):
    """Returns the list of atoms

    Returns
    -------
    list of atoms in this fragment.
    """
    return self.atoms

  def GetCoords(self):
    """Returns 3D coordinates for this fragment as numpy array.

    Returns
    -------
    Numpy array of shape `(N, 3)` with coordinates for this fragment.
    Here `N == len(self.GetAtoms())`.
    """
    return self.coords


class AtomShim(object):
  """This is a shim object wrapping an atom.

  We use this class instead of raw RDKit atoms since manipulating a
  large number of rdkit Atoms seems to result in segfaults. Wrapping
  the basic information in an AtomShim seems to avoid issues.
  """

  def __init__(self, atomic_num: int, partial_charge: float,
               atom_coords: np.ndarray):
    """Initialize this object

    Parameters
    ----------
    atomic_num: int
      Atomic number for this atom.
    partial_charge: float
      The partial Gasteiger charge for this atom
    atom_coords: np.ndarray
      Of shape (3,) with the coordinates of this atom
    """
    self.atomic_num = atomic_num
    self.partial_charge = partial_charge
    self.coords = atom_coords

  def GetAtomicNum(self) -> int:
    """Returns atomic number for this atom.

    Returns
    -------
    Atomic number for this atom.
    """
    return self.atomic_num

  def GetPartialCharge(self) -> float:
    """Returns partial charge for this atom.

    Returns
    -------
    Partial Gasteiger charge for this atom.
    """
    return self.partial_charge

  def GetCoords(self) -> np.ndarray:
    """Returns 3D coordinates for this atom as numpy array.

    Returns
    -------
    Numpy array of shape `(3,)` with coordinates for this atom.
    """
    return self.coords


def merge_molecular_fragments(
    molecules: List[MolecularFragment]) -> Optional[MolecularFragment]:
  """Helper method to merge two molecular fragments.

  Parameters
  ----------
  molecules: list
    List of `MolecularFragment` objects. 

  Returns
  -------
  Returns a merged `MolecularFragment`
  """
  if len(molecules) == 0:
    return None
  if len(molecules) == 1:
    return molecules[0]
  else:
    all_atoms = []
    all_coords = []
    for mol_frag in molecules:
      all_atoms += mol_frag.GetAtoms()
      all_coords.append(mol_frag.GetCoords())
    all_coords = np.concatenate(all_coords)
    return MolecularFragment(all_atoms, all_coords)


def get_mol_subset(coords: np.ndarray, mol,
                   atom_indices_to_keep: List[int]) -> MolecularFragment:
  """Strip a subset of the atoms in this molecule

  Parameters
  ----------
  coords: Numpy ndarray
    Must be of shape (N, 3) and correspond to coordinates of mol.
  mol: Rdkit mol or `MolecularFragment`
    The molecule to strip
  atom_indices_to_keep: list
    List of the indices of the atoms to keep. Each index is a unique
    number between `[0, N)`.

  Returns
  -------
  Returns a `MolecularFragment` that summarizes the subset to be returned.

  Note
  ----
  This function requires RDKit to be installed.
  """
  try:
    from rdkit import Chem
  except ModuleNotFoundError:
    raise ValueError("This function requires RDKit to be installed.")
  indexes_to_keep = []
  atoms_to_keep = []
  # Compute partial charges on molecule if rdkit
  if isinstance(mol, Chem.Mol):
    compute_charges(mol)
  atoms = list(mol.GetAtoms())
  for index in atom_indices_to_keep:
    indexes_to_keep.append(index)
    atoms_to_keep.append(atoms[index])
  coords = coords[indexes_to_keep]
  mol_frag = MolecularFragment(atoms_to_keep, coords)
  return mol_frag


def strip_hydrogens(coords: np.ndarray, mol) -> MolecularFragment:
  """Strip the hydrogens from input molecule

  Parameters
  ----------
  coords: Numpy ndarray
    Must be of shape (N, 3) and correspond to coordinates of mol.
  mol: Rdkit mol or `MolecularFragment`
    The molecule to strip

  Returns
  -------
  A tuple of (coords, mol_frag) where coords is a Numpy array of
  coordinates with hydrogen coordinates. mol_frag is a
  `MolecularFragment`. 

  Note
  ----
  This function requires RDKit to be installed.
  """
  mol_atoms = mol.GetAtoms()
  atomic_numbers = [atom.GetAtomicNum() for atom in mol_atoms]
  atom_indices_to_keep = [
      ind for (ind, atomic_number) in enumerate(atomic_numbers)
      if (atomic_number != 1)
  ]
  return get_mol_subset(coords, mol, atom_indices_to_keep)


def get_contact_atom_indices(fragments: List[Any],
                             cutoff: float = 4.5) -> List[Any]:
  """Compute that atoms close to contact region.

  Molecular complexes can get very large. This can make it unwieldy to
@@ -31,7 +280,7 @@ def get_contact_atom_indices(fragments, cutoff=4.5):
  sorted order.
  """
  # indices to atoms to keep
  keep_inds = [set([]) for _ in fragments]
  keep_inds: List[Any] = [set([]) for _ in fragments]
  for (ind1, ind2) in itertools.combinations(range(len(fragments)), 2):
    frag1, frag2 = fragments[ind1], fragments[ind2]
    pairwise_distances = compute_pairwise_distances(frag1[0], frag2[0])
@@ -48,7 +297,8 @@ def get_contact_atom_indices(fragments, cutoff=4.5):
  return keep_inds


def reduce_molecular_complex_to_contacts(fragments, cutoff=4.5):
def reduce_molecular_complex_to_contacts(fragments: List,
                                         cutoff: float = 4.5) -> List:
  """Reduce a molecular complex to only those atoms near a contact.

  Molecular complexes can get very large. This can make it unwieldy to
+42 −0
Original line number Diff line number Diff line
import os
import unittest
import numpy as np
from deepchem.utils import rdkit_util
from deepchem.utils.fragment_util import get_contact_atom_indices
from deepchem.utils.fragment_util import merge_molecular_fragments
from deepchem.utils.fragment_util import get_partial_charge
from deepchem.utils.fragment_util import strip_hydrogens
from deepchem.utils.fragment_util import MolecularFragment
from deepchem.utils.fragment_util import AtomShim


class TestFragmentUtil(unittest.TestCase):
@@ -18,3 +24,39 @@ class TestFragmentUtil(unittest.TestCase):
    complexes = rdkit_util.load_complex([self.protein_file, self.ligand_file])
    contact_indices = get_contact_atom_indices(complexes)
    assert len(contact_indices) == 2

  def test_create_molecular_fragment(self):
    mol_xyz, mol_rdk = rdkit_util.load_molecule(self.ligand_file)
    fragment = MolecularFragment(mol_rdk.GetAtoms(), mol_xyz)
    assert len(mol_rdk.GetAtoms()) == len(fragment.GetAtoms())
    assert (fragment.GetCoords() == mol_xyz).all()

  def test_strip_hydrogens(self):
    mol_xyz, mol_rdk = rdkit_util.load_molecule(self.ligand_file)
    fragment = MolecularFragment(mol_rdk.GetAtoms(), mol_xyz)

    # Test on RDKit
    frag = strip_hydrogens(mol_xyz, mol_rdk)

  def test_merge_molecular_fragments(self):
    mol_xyz, mol_rdk = rdkit_util.load_molecule(self.ligand_file)
    fragment1 = MolecularFragment(mol_rdk.GetAtoms(), mol_xyz)
    fragment2 = MolecularFragment(mol_rdk.GetAtoms(), mol_xyz)
    joint = merge_molecular_fragments([fragment1, fragment2])
    assert len(mol_rdk.GetAtoms()) * 2 == len(joint.GetAtoms())

  def test_get_partial_charge(self):
    from rdkit import Chem
    mol = Chem.MolFromSmiles("CC")
    atom = mol.GetAtoms()[0]
    partial_charge = get_partial_charge(atom)
    assert partial_charge == 0

  def test_atom_shim(self):
    atomic_num = 5
    partial_charge = 1
    atom_coords = np.array([0., 1., 2.])
    shim = AtomShim(atomic_num, partial_charge, atom_coords)
    assert shim.GetAtomicNum() == atomic_num
    assert shim.GetPartialCharge() == partial_charge
    assert (shim.GetCoords() == atom_coords).all()
+19 −0
Original line number Diff line number Diff line
@@ -88,6 +88,25 @@ Molecular Utilities

.. autofunction:: deepchem.utils.rdkit_util.write_molecule

Molecular Fragment Utilities
----------------------------

It's often convenient to manipulate subsets of a molecule. The :code:`MolecularFragment` class aids in such manipulations.

.. autoclass:: deepchem.utils.fragment_util.MolecularFragment
  :members:

.. autoclass:: deepchem.utils.fragment_util.AtomShim
  :members:

.. autofunction:: deepchem.utils.fragment_util.strip_hydrogens

.. autofunction:: deepchem.utils.fragment_util.merge_molecular_fragments

.. autofunction:: deepchem.utils.fragment_util.get_contact_atom_indices

.. autofunction:: deepchem.utils.fragment_util.reduce_molecular_complex_to_contacts

Coordinate Box Utilities
------------------------