Commit c200dc48 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

changes

parent 53366e7d
Loading
Loading
Loading
Loading
+222 −0
Original line number Diff line number Diff line
@@ -4,6 +4,228 @@ import numpy as np
from deepchem.utils.geometry_utils import compute_pairwise_distances


def get_partial_charge(atom):
  """Get partial charge of a given atom (rdkit Atom object)
  
  Parameters
  ----------
  atom: rdkit atom or `AtomShim` object
    Either an rdkit atom or `AtomShim`
  """
  from rdkit import Chem
  if isinstance(atom, Chem.Atom):
    try:
      value = atom.GetProp(str("_GasteigerCharge"))
      if value == '-nan':
        return 0
      return float(value)
    except KeyError:
      return 0
  else:
    return atom.GetPartialCharge()


class MolecularFragment(object):
  """A class that represents a fragment of a molecule.

  It's often convenient to represent a fragment of a molecule. For
  example, if two molecules form a molecular complex, it may be useful
  to create two fragments which represent the subsets of each molecule
  that's close to the other molecule (in the contact region).

  Ideally, we'd be able to do this in RDKit direct, but manipulating
  molecular fragments doesn't seem to be supported functionality. 
  """

  def __init__(self, atoms, coords):
    """Initialize this object.

    Parameters
    ----------
    atoms: list
      Each entry in this list should be an RdkitAtom
    coords: np.ndarray
      Array of locations for atoms of shape `(N, 3)` where `N ==
      len(atoms)`.
    """
    if not isinstance(coords, np.ndarray):
      raise ValueError("Coords must be a numpy array of shape (N, 3)")
    if coords.shape != (len(atoms), 3):
      raise ValueError(
          "Coords must be a numpy array of shape `(N, 3)` where `N == len(atoms)`."
      )
    self.atoms = [
        AtomShim(x.GetAtomicNum(), get_partial_charge(x), coords[ind])
        for ind, x in enumerate(atoms)
    ]
    self.coords = coords

  def GetAtoms(self):
    """Returns the list of atoms

    Returns
    -------
    list of atoms in this fragment.
    """
    return self.atoms

  def GetCoords(self):
    """Returns 3D coordinates for this fragment as numpy array.

    Returns
    -------
    Numpy array of shape `(N, 3)` with coordinates for this fragment.
    Here `N == len(self.GetAtoms())`.
    """
    return self.coords


class AtomShim(object):
  """This is a shim object wrapping an atom.

  We use this class instead of raw RDKit atoms since manipulating a
  large number of rdkit Atoms seems to result in segfaults. Wrapping
  the basic information in an AtomShim seems to avoid issues.
  """

  def __init__(self, atomic_num, partial_charge, atom_coords):
    """Initialize this object

    Parameters
    ----------
    atomic_num: int
      Atomic number for this atom.
    partial_charge: float
      The partial Gasteiger charge for this atom
    atom_coords: np.ndarray
      Of shape (3,) with the coordinates of this atom
    """
    self.atomic_num = atomic_num
    self.partial_charge = partial_charge
    self.coords = atom_coords

  def GetAtomicNum(self):
    """Returns atomic number for this atom.

    Returns
    -------
    Atomic number fo this atom.
    """
    return self.atomic_num

  def GetPartialCharge(self):
    """Returns partial charge for this atom.

    Returns
    -------
    Partial Gasteiger charge for this atom.
    """
    return self.partial_charge

  def GetCoords(self):
    """Returns 3D coordinates for this atom as numpy array.

    Returns
    -------
    Numpy array of shape `(3,)` with coordinates for this atom.
    """
    return self.coords


def merge_molecular_fragments(molecules):
  """Helper method to merge two molecular fragments.

  Parameters
  ----------
  molecules: list
    List of `MolecularFragment` objects. 

  Returns
  -------
  Returns a merged `MolecularFragment`
  """
  if len(molecules) == 0:
    return None
  if len(molecules) == 1:
    return molecules[0]
  else:
    all_atoms = []
    all_coords = []
    for mol_frag in molecules:
      all_atoms += mol_frag.GetAtoms()
      all_coords.append(mol_frag.GetCoords())
    all_coords = np.concatenate(all_coords)
    return MolecularFragment(all_atoms, all_coords)


def get_mol_subset(coords, mol, atom_indices_to_keep):
  """Strip a subset of the atoms in this molecule

  Parameters
  ----------
  coords: Numpy ndarray
    Must be of shape (N, 3) and correspond to coordinates of mol.
  mol: Rdkit mol or `MolecularFragment`
    The molecule to strip
  atom_indices_to_keep: list
    List of the indices of the atoms to keep. Each index is a unique
    number between `[0, N)`.

  Returns
  -------
  Returns a `MolecularFragment` that summarizes the subset to be returned.

  Note
  ----
  This function requires RDKit to be installed.
  """
  try:
    from rdkit import Chem
  except ModuleNotFoundError:
    raise ValueError("This function requires RDKit to be installed.")
  indexes_to_keep = []
  atoms_to_keep = []
  # Compute partial charges on molecule if rdkit
  if isinstance(mol, Chem.Mol):
    compute_charges(mol)
  atoms = list(mol.GetAtoms())
  for index in atom_indices_to_keep:
    indexes_to_keep.append(index)
    atoms_to_keep.append(atoms[index])
  coords = coords[indexes_to_keep]
  mol_frag = MolecularFragment(atoms_to_keep, coords)
  return mol_frag


def strip_hydrogens(coords, mol):
  """Strip the hydrogens from input molecule

  Parameters
  ----------
  coords: Numpy ndarray
    Must be of shape (N, 3) and correspond to coordinates of mol.
  mol: Rdkit mol or `MolecularFragment`
    The molecule to strip

  Returns
  -------
  A tuple of (coords, mol_frag) where coords is a Numpy array of
  coordinates with hydrogen coordinates. mol_frag is a
  `MolecularFragment`. 

  Note
  ----
  This function requires RDKit to be installed.
  """
  mol_atoms = mol.GetAtoms()
  atomic_numbers = [atom.GetAtomicNum() for atom in mol_atoms]
  atom_indices_to_keep = [
      ind for (ind, atomic_number) in enumerate(atomic_numbers)
      if (atomic_number != 1)
  ]
  return get_mol_subset(coords, mol, atom_indices_to_keep)


def get_contact_atom_indices(fragments, cutoff=4.5):
  """Compute that atoms close to contact region.

+15 −0
Original line number Diff line number Diff line
@@ -2,6 +2,8 @@ import os
import unittest
from deepchem.utils import rdkit_util
from deepchem.utils.fragment_util import get_contact_atom_indices
from deepchem.utils.fragment_util import merge_molecular_fragments
from deepchem.utils.fragment_util import MolecularFragment


class TestFragmentUtil(unittest.TestCase):
@@ -18,3 +20,16 @@ class TestFragmentUtil(unittest.TestCase):
    complexes = rdkit_util.load_complex([self.protein_file, self.ligand_file])
    contact_indices = get_contact_atom_indices(complexes)
    assert len(contact_indices) == 2

  def test_create_molecular_fragment(self):
    mol_xyz, mol_rdk = rdkit_util.load_molecule(self.ligand_file)
    fragment = MolecularFragment(mol_rdk.GetAtoms(), mol_xyz)
    assert len(mol_rdk.GetAtoms()) == len(fragment.GetAtoms())
    assert (fragment.GetCoords() == mol_xyz).all()

  def test_merge_molecular_fragments(self):
    mol_xyz, mol_rdk = rdkit_util.load_molecule(self.ligand_file)
    fragment1 = MolecularFragment(mol_rdk.GetAtoms(), mol_xyz)
    fragment2 = MolecularFragment(mol_rdk.GetAtoms(), mol_xyz)
    joint = merge_molecular_fragments([fragment1, fragment2])
    assert len(mol_rdk.GetAtoms())*2 == len(joint.GetAtoms())
+8 −0
Original line number Diff line number Diff line
@@ -88,6 +88,14 @@ Molecular Utilities

.. autofunction:: deepchem.utils.rdkit_util.write_molecule

Molecular Fragment Utilities
----------------------------

It's often convenient to manipulate subsets of a molecule. The :code:`MolecularFragment` class aids in such manipulations.

.. autoclass:: deepchem.utils.MolecularFragment
  :members:

Coordinate Box Utilities
------------------------