Commit b6d4b18e authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

changes

parent aa8095e1
Loading
Loading
Loading
Loading
+167 −1
Original line number Diff line number Diff line
@@ -3,6 +3,163 @@ import itertools
import numpy as np
from deepchem.utils.geometry_utils import compute_pairwise_distances

def get_partial_charge(atom):
  """Get partial charge of a given atom (rdkit Atom object)
  
  Parameters
  ----------
  atom: rdkit atom or `AtomShim` object
    Either an rdkit atom or `AtomShim`
  """
  from rdkit import Chem
  if isinstance(atom, Chem.Atom):
    try:
      value = atom.GetProp(str("_GasteigerCharge"))
      if value == '-nan':
        return 0
      return float(value)
    except KeyError:
      return 0
  else:
    return atom.GetPartialCharge()


class MolecularFragment(object):
  """A class that represents a fragment of a molecule.

  It's often convenient to represent a fragment of a molecule. For
  example, if two molecules form a molecular complex, it may be useful
  to create two fragments which represent the subsets of each molecule
  that's close to the other molecule (in the contact region).

  Ideally, we'd be able to do this in RDKit direct, but manipulating
  molecular fragments doesn't seem to be supported functionality. 
  """

  def __init__(self, atoms):
    """Initialize this object.

    Parameters
    ----------
    atoms: list
      Each entry in this list should be an RdkitAtom
    """
    self.atoms = [AtomShim(x.GetAtomicNum(), get_partial_charge(x)) for x in atoms]
    #self.atoms = atoms 

  def GetAtoms(self):
    """Returns the list of atoms

    Returns
    -------
    list of atoms in this fragment.
    """
    return self.atoms

class AtomShim(object):
  """This is a shim object wrapping an atom.

  We use this class instead of raw RDKit atoms since manipulating a
  large number of rdkit Atoms seems to result in segfaults. Wrapping
  the basic information in an AtomShim seems to avoid issues.
  """

  def __init__(self, atomic_num, partial_charge):
    """Initialize this object

    Parameters
    ----------
    atomic_num: int
      Atomic number for this atom.
    partial_charge: float
      The partial Gasteiger charge for this atom
    """
    self.atomic_num = atomic_num
    self.partial_charge = partial_charge

  def GetAtomicNum(self):
    return self.atomic_num

  def GetPartialCharge(self):
    return self.partial_charge

def merge_molecular_fragments(molecules):
  """Helper method to merge two molecular fragments.

  Parameters
  ----------
  molecules: list
    List of `MolecularFragment` objects. 

  Returns
  -------
  merged: `MolecularFragment`
  """
  if len(molecules) == 0:
    return None
  if len(molecules) == 1:
    return molecules[0]
  else:
    all_atoms = []
    for mol_frag in molecules:
      all_atoms += mol_frag.GetAtoms()
    return MolecularFragment(all_atoms)

def get_mol_subset(coords, mol, atom_indices_to_keep):
  """Strip a subset of the atoms in this molecule

  Parameters
  ----------
  coords: Numpy ndarray
    Must be of shape (N, 3) and correspond to coordinates of mol.
  mol: Rdkit mol or `MolecularFragment`
    The molecule to strip
  atom_indices_to_keep: list
    List of the indices of the atoms to keep. Each index is a unique
    number between `[0, N)`.

  Returns
  -------
  A tuple of (coords, mol_frag) where coords is a Numpy array of
  coordinates with hydrogen coordinates. mol_frag is a
  `MolecularFragment`. 
  """
  from rdkit import Chem
  indexes_to_keep = []
  atoms_to_keep = []
  #####################################################
  # Compute partial charges on molecule if rdkit
  if isinstance(mol, Chem.Mol):
    compute_charges(mol)
  #####################################################
  atoms = list(mol.GetAtoms())
  for index in atom_indices_to_keep:
    indexes_to_keep.append(index)
    atoms_to_keep.append(atoms[index])
  mol_frag = MolecularFragment(atoms_to_keep)
  coords = coords[indexes_to_keep]
  return coords, mol_frag

def strip_hydrogens(coords, mol):
  """Strip the hydrogens from input molecule

  Parameters
  ----------
  coords: Numpy ndarray
    Must be of shape (N, 3) and correspond to coordinates of mol.
  mol: Rdkit mol or `MolecularFragment`
    The molecule to strip

  Returns
  -------
  A tuple of (coords, mol_frag) where coords is a Numpy array of
  coordinates with hydrogen coordinates. mol_frag is a
  `MolecularFragment`. 
  """
  mol_atoms = mol.GetAtoms()
  atomic_numbers = [atom.GetAtomicNum() for atom in mol_atoms]
  atom_indices_to_keep = [ind for (ind, atomic_number) in enumerate(atomic_numbers) if (atomic_number != 1)]
  return get_mol_subset(coords, mol, atom_indices_to_keep)

def get_contact_atom_indices(fragments, cutoff=4.5):
  """Compute that atoms close to contact region.
@@ -47,6 +204,15 @@ def get_contact_atom_indices(fragments, cutoff=4.5):
  keep_inds = [sorted(list(keep)) for keep in keep_inds]
  return keep_inds

  # Now extract atoms
  #atoms_to_keep = []
  #for i, frag_keep_inds in enumerate(keep_inds):
  #  frag = fragments[i]
  #  mol = frag[1]
  #  atoms = mol.GetAtoms()
  #  frag_keep = [atoms[keep_ind] for keep_ind in frag_keep_inds]
  #  atoms_to_keep.append(frag_keep)
  #return atoms_to_keep
  
def reduce_molecular_complex_to_contacts(fragments, cutoff=4.5):
  """Reduce a molecular complex to only those atoms near a contact.
+0 −3
Original line number Diff line number Diff line
@@ -112,7 +112,6 @@ def is_angle_within_cutoff(vector_i, vector_j, angle_cutoff):
  angle = angle_between(vector_i, vector_j) * 180. / np.pi
  return (angle > (180 - angle_cutoff) and angle < (180. + angle_cutoff))


def compute_centroid(coordinates):
  """Compute the (x,y,z) centroid of provided coordinates

@@ -124,7 +123,6 @@ def compute_centroid(coordinates):
  centroid = np.mean(coordinates, axis=0)
  return (centroid)


def subtract_centroid(xyz, centroid):
  """Subtracts centroid from each coordinate.

@@ -143,7 +141,6 @@ def subtract_centroid(xyz, centroid):
  xyz -= np.transpose(centroid)
  return (xyz)


def compute_pairwise_distances(first_xyz, second_xyz):
  """Computes pairwise distances between two molecules.

+445 −0
Original line number Diff line number Diff line
"""The functions in these utilities check that noncovalent interactions happen"""

def is_salt_bridge(atom_i, atom_j):
  """Check if two atoms have correct charges to form a salt bridge"""
  if np.abs(2.0 - np.abs(
      get_partial_charge(atom_i) - get_partial_charge(atom_j))) < 0.01:
    return True
  return False

def compute_salt_bridges(first,
                         second,
                         pairwise_distances,
                         cutoff=5.0):
  """Find salt bridge contacts between two molecules. 

  Parameters:
  -----------
  first: rdkit.rdchem.Mol
    Interacting molecules
  second: rdkit.rdchem.Mol
    Interacting molecules
  pairwise_distances: np.ndarray
    Array of pairwise interatomic distances between molecule atoms (Angstroms)
  cutoff: float
    Cutoff distance for contact consideration

  Returns:
  --------
  salt_bridge_contacts: list of tuples
    List of contacts. Tuple (i, j) indicates that atom i from
    first molecule interacts with atom j from second.
  """

  salt_bridge_contacts = []
  contacts = np.nonzero(pairwise_distances < cutoff)
  contacts = zip(contacts[0], contacts[1])
  for contact in contacts:
    first_atom = first.GetAtoms()[int(contact[0])]
    second_atom = second.GetAtoms()[int(contact[1])]
    if is_salt_bridge(first_atom, second_atom):
      salt_bridge_contacts.append(contact)
  return salt_bridge_contacts


def is_hydrogen_bond(frag1,
                     frag2,
                     contact,
                     hbond_distance_cutoff=4.0,
                     hbond_angle_cutoff=40.0):
  """
  Determine if a pair of atoms (contact = frag1_atom_index,
  frag2_atom_index) between two molecules represents a hydrogen
  bond. Returns a boolean result.

  Parameters
  ----------
  frag1: tuple
    Tuple of (coords, rdkit mol / MolecularFragment
  frag2: tuple
    Tuple of (coords, rdkit mol / MolecularFragment
  contact: Tuple
    Tuple of indices for (atom_i, atom_j) contact. 
  hbond_distance_cutoff: float, optional
    Distance cutoff for hbond. 
  hbond_angle_cutoff: float, optional
    Angle deviance cutoff for hbond 
  """
  frag1_xyz, frag2_xyz = frag1[0], frag2[0]
  frag1_mol, frag2_mol = frag1[1], frag2[1]
  frag1_atom_xyz = frag1_xyz[int(contact[0])]
  frag2_atom_xyz = frag2_xyz[int(contact[1])]
  frag1_atom = frag1_mol.GetAtoms()[int(contact[0])]
  frag2_atom = frag2_mol.GetAtoms()[int(contact[1])]

  # Nitrogen has atomic number 7, and oxygen 8.
  if ((frag2_atom.GetAtomicNum() == 7 or frag2_atom.GetAtomicNum() == 8) and (frag1_atom.GetAtomicNum() == 7 or frag1_atom.GetAtomicNum() == 8)):
    hydrogens = []

    for i, atom in enumerate(frag2_mol.GetAtoms()):
      # If atom is a hydrogen
      if atom.GetAtomicNum() == 1:
        atom_xyz = frag2_xyz[i]
        dist = np.linalg.norm(atom_xyz-frag2_atom_xyz)
        # O-H distance is 0.96 A, N-H is 1.01 A. See http://www.science.uwaterloo.ca/~cchieh/cact/c120/bondel.html
        if dist < 1.3:
          hydrogens.append(atom_xyz)
    
    for j, atom in enumerate(frag1_mol.GetAtoms()):
      # If atom is a hydrogen
      if atom.GetAtomicNum() == 1:
        atom_xyz = frag1_xyz[i]
        dist = np.linalg.norm(atom_xyz-frag1_atom_xyz)
        # O-H distance is 0.96 A, N-H is 1.01 A. See http://www.science.uwaterloo.ca/~cchieh/cact/c120/bondel.html
        if dist < 1.3:
          hydrogens.append(atom_xyz)

    for hydrogen_xyz in hydrogens:
      hydrogen_to_frag2 = frag2_atom_xyz - hydrogen_xyz
      hydrogen_to_frag1 = frag1_atom_xyz - hydrogen_xyz
      return is_angle_within_cutoff(hydrogen_to_frag2, hydrogen_to_frag1, hbond_angle_cutoff)
  return False

def compute_hbonds_in_range(frag1, frag2,
                            pairwise_distances, hbond_dist_bin,
                            hbond_angle_cutoff):
  """
  Find all pairs of (frag1_index_i, frag2_index_j) that hydrogen bond
  given a distance bin and an angle cutoff.

  Parameters
  ----------
  frag1: tuple
    Tuple of (coords, rdkit mol / MolecularFragment
  frag2: tuple
    Tuple of (coords, rdkit mol / MolecularFragment
  pairwise_distances:
    Matrix of shape `(N, M)` with pairwise distances between frag1/frag2.
  hbond_dist_bin: tuple
    Tuple of floats `(min_dist, max_dist)` in angstroms. 
  hbond_angle_cutoffs: list[float]
    List of angles of deviances allowed for hbonds
  """

  contacts = np.nonzero((pairwise_distances > hbond_dist_bin[0]) &
                        (pairwise_distances < hbond_dist_bin[1]))
  contacts = zip(contacts[0], contacts[1])
  hydrogen_bond_contacts = []
  for contact in contacts:
    if is_hydrogen_bond(frag1, frag2, contact,
                        hbond_angle_cutoff):
      hydrogen_bond_contacts.append(contact)
  return hydrogen_bond_contacts

def compute_hydrogen_bonds(frag1, frag2,
                           pairwise_distances, hbond_dist_bins,
                           hbond_angle_cutoffs):
  """Computes hydrogen bonds between proteins and ligands.

  Returns a list of sublists. Each sublist is a series of tuples
  of (protein_index_i, ligand_index_j) that represent a hydrogen
  bond. Each sublist represents a different type of hydrogen
  bond.

  Parameters
  ----------
  frag1: tuple
    Tuple of (coords, rdkit mol / MolecularFragment
  frag2: tuple
    Tuple of (coords, rdkit mol / MolecularFragment
  pairwise_distances:
    Matrix of shape `(N, M)` with pairwise distances between frag1/frag2.
  hbond_dist_bins: list[tuple]
    List of tuples of hbond distance ranges.
  hbond_angle_cutoffs: list[float]
    List of angles of deviances allowed for hbonds
  """

  hbond_contacts = []
  for i, hbond_dist_bin in enumerate(hbond_dist_bins):
    hbond_angle_cutoff = hbond_angle_cutoffs[i]
    hbond_contacts.append(
        compute_hbonds_in_range(frag1, frag2,
                                pairwise_distances, hbond_dist_bin,
                                hbond_angle_cutoff))
  return (hbond_contacts)

def compute_cation_pi(mol1, mol2, charge_tolerance=0.01, **kwargs):
  """Finds aromatic rings in mo1 and cations in mol2 that interact with each other.

  Parameters:
  -----------
  mol1: rdkit.rdchem.Mol
    Molecule to look for interacting rings
  mol2: rdkit.rdchem.Mol
    Molecule to look for interacting cations
  charge_tolerance: float
    Atom is considered a cation if its formal charge is greater
    than 1 - charge_tolerance
  **kwargs:
    Arguments that are passed to is_cation_pi function

  Returns:
  --------
  mol1_pi: dict
    Dictionary that maps atom indices (from mol1) to the number of cations
    (in mol2) they interact with
  mol2_cation: dict
    Dictionary that maps atom indices (from mol2) to the number of aromatic
    atoms (in mol1) they interact with
  """
  mol1_pi = Counter()
  mol2_cation = Counter()
  conformer = mol2.GetConformer()

  aromatic_atoms = set(atom.GetIdx() for atom in mol1.GetAromaticAtoms())
  from rdkit import Chem
  rings = [list(r) for r in Chem.GetSymmSSSR(mol1)]

  for ring in rings:
    # if ring from mol1 is aromatic
    if set(ring).issubset(aromatic_atoms):
      ring_center = compute_ring_center(mol1, ring)
      ring_normal = compute_ring_normal(mol1, ring)

      for atom in mol2.GetAtoms():
        # ...and atom from mol2 is a cation
        if atom.GetFormalCharge() > 1.0 - charge_tolerance:
          cation_position = np.array(conformer.GetAtomPosition(atom.GetIdx()))
          # if angle and distance are correct
          if is_cation_pi(cation_position, ring_center, ring_normal, **kwargs):
            # count atoms forming a contact
            mol1_pi.update(ring)
            mol2_cation.update([atom.GetIndex()])
  return mol1_pi, mol2_cation

def is_cation_pi(cation_position,
                 ring_center,
                 ring_normal,
                 dist_cutoff=6.5,
                 angle_cutoff=30.0):
  """Check if a cation and an aromatic ring form contact.

  Parameters:
  -----------
  ring_center: np.ndarray
    Positions of ring center. Can be computed with the compute_ring_center
    function.
  ring_normal: np.ndarray
    Normal of ring. Can be computed with the compute_ring_normal function.
  dist_cutoff: float
    Distance cutoff. Max allowed distance between ring center
    and cation (in Angstroms).
  angle_cutoff: float
    Angle cutoff. Max allowed deviation from the ideal (0deg)
    angle between ring normal and vector pointing from ring
    center to cation (in degrees).
  """
  cation_to_ring_vec = cation_position - ring_center
  dist = np.linalg.norm(cation_to_ring_vec)
  angle = angle_between(cation_to_ring_vec, ring_normal) * 180. / np.pi
  if ((angle < angle_cutoff or angle > 180.0 - angle_cutoff) and
      (dist < dist_cutoff)):
    return True
  return False

def compute_pi_stack(mol1,
                     mol2,
                     pairwise_distances=None,
                     dist_cutoff=4.4,
                     angle_cutoff=30.):
  """Find aromatic rings in both molecules that form pi-pi contacts.
  For each atom in the contact, count number of atoms in the other molecule
  that form this contact.

  Pseudocode:

  for each aromatic ring in mol1:
    for each aromatic ring in mol2:
      compute distance between centers
      compute angle between normals
      if it counts as parallel pi-pi:
        count interacting atoms
      if it counts as pi-T:
        count interacting atoms

  Parameters:
  -----------
    mol1: rdkit.rdchem.Mol
      First molecule.
    mol2: rdkit.rdchem.Mol
      First molecule.
    pairwise_distances: np.ndarray (optional)
      Array of pairwise interatomic distances (Angstroms)
    dist_cutoff: float
      Distance cutoff. Max allowed distance between the ring center (Angstroms).
    angle_cutoff: float
      Angle cutoff. Max allowed deviation from the ideal angle between rings.

  Returns:
  --------
    mol1_pi_t, mol1_pi_parallel, mol2_pi_t, mol2_pi_parallel: dict
      Dictionaries mapping atom indices to number of atoms they interact with.
      Separate dictionary is created for each type of pi stacking (parallel and
      T-shaped) and each molecule (mol1 and mol2).
  """

  mol1_pi_parallel = Counter()
  mol1_pi_t = Counter()
  mol2_pi_parallel = Counter()
  mol2_pi_t = Counter()

  mol1_aromatic_rings = []
  mol2_aromatic_rings = []
  from rdkit import Chem
  for mol, ring_list in ((mol1, mol1_aromatic_rings),
                         (mol2, mol2_aromatic_rings)):
    aromatic_atoms = {atom.GetIdx() for atom in mol.GetAromaticAtoms()}
    for ring in Chem.GetSymmSSSR(mol):
      # if ring is aromatic
      if set(ring).issubset(aromatic_atoms):
        # save its indices, center, and normal
        ring_center = compute_ring_center(mol, ring)
        ring_normal = compute_ring_normal(mol, ring)
        ring_list.append((ring, ring_center, ring_normal))

  # remember mol1-mol2 pairs we already counted
  counted_pairs_parallel = set()
  counted_pairs_t = set()
  for prot_ring, prot_ring_center, prot_ring_normal in mol1_aromatic_rings:
    for lig_ring, lig_ring_center, lig_ring_normal in mol2_aromatic_rings:
      if is_pi_parallel(
          prot_ring_center,
          prot_ring_normal,
          lig_ring_center,
          lig_ring_normal,
          angle_cutoff=angle_cutoff,
          dist_cutoff=dist_cutoff):
        prot_to_update = set()
        lig_to_update = set()
        for prot_atom_idx in prot_ring:
          for lig_atom_idx in lig_ring:
            if (prot_atom_idx, lig_atom_idx) not in counted_pairs_parallel:
              # if this pair is new, count atoms forming a contact
              prot_to_update.add(prot_atom_idx)
              lig_to_update.add(lig_atom_idx)
              counted_pairs_parallel.add((prot_atom_idx, lig_atom_idx))

        mol1_pi_parallel.update(prot_to_update)
        mol2_pi_parallel.update(lig_to_update)

      if is_pi_t(
          prot_ring_center,
          prot_ring_normal,
          lig_ring_center,
          lig_ring_normal,
          angle_cutoff=angle_cutoff,
          dist_cutoff=dist_cutoff):
        prot_to_update = set()
        lig_to_update = set()
        for prot_atom_idx in prot_ring:
          for lig_atom_idx in lig_ring:
            if (prot_atom_idx, lig_atom_idx) not in counted_pairs_t:
              # if this pair is new, count atoms forming a contact
              prot_to_update.add(prot_atom_idx)
              lig_to_update.add(lig_atom_idx)
              counted_pairs_t.add((prot_atom_idx, lig_atom_idx))

        mol1_pi_t.update(prot_to_update)
        mol2_pi_t.update(lig_to_update)

  return (mol1_pi_t, mol1_pi_parallel, mol2_pi_t, mol2_pi_parallel)

def is_pi_t(ring1_center,
            ring1_normal,
            ring2_center,
            ring2_normal,
            dist_cutoff=5.5,
            angle_cutoff=30.0):
  """Check if two aromatic rings form a T-shaped pi-pi contact.

  Parameters:
  -----------
  ring1_center, ring2_center: np.ndarray
    Positions of centers of the two rings. Can be computed with the
    compute_ring_center function.
  ring1_normal, ring2_normal: np.ndarray
    Normals of the two rings. Can be computed with the compute_ring_normal
    function.
  dist_cutoff: float
    Distance cutoff. Max allowed distance between the ring center (Angstroms).
  angle_cutoff: float
    Angle cutoff. Max allowed deviation from the ideal (90deg) angle between
    the rings (in degrees).
  """
  dist = np.linalg.norm(ring1_center - ring2_center)
  angle = angle_between(ring1_normal, ring2_normal) * 180 / np.pi
  if ((90.0 - angle_cutoff < angle < 90.0 + angle_cutoff) and
      dist < dist_cutoff):
    return True
  return False

def is_pi_parallel(ring1_center,
                   ring1_normal,
                   ring2_center,
                   ring2_normal,
                   dist_cutoff=8.0,
                   angle_cutoff=30.0):
  """Check if two aromatic rings form a parallel pi-pi contact.

  Parameters:
  -----------
    ring1_center, ring2_center: np.ndarray
      Positions of centers of the two rings. Can be computed with the
      compute_ring_center function.
    ring1_normal, ring2_normal: np.ndarray
      Normals of the two rings. Can be computed with the compute_ring_normal
      function.
    dist_cutoff: float
      Distance cutoff. Max allowed distance between the ring center (Angstroms).
    angle_cutoff: float
      Angle cutoff. Max allowed deviation from the ideal (0deg) angle between
      the rings (in degrees).
  """

  dist = np.linalg.norm(ring1_center - ring2_center)
  angle = angle_between(ring1_normal, ring2_normal) * 180 / np.pi
  if ((angle < angle_cutoff or angle > 180.0 - angle_cutoff) and
      dist < dist_cutoff):
    return True
  return False

def compute_binding_pocket_cation_pi(mol1, mol2, **kwargs):
  """Finds cation-pi interactions between mol1 and mol2.

  Parameters:
  -----------
  mol1: rdkit.rdchem.Mol
    Interacting molecules
  mol2: rdkit.rdchem.Mol
    Interacting molecules
  **kwargs:
    Arguments that are passed to compute_cation_pi function

  Returns:
  --------
  mol1_cation_pi, mol2_cation_pi: dict
    Dictionaries that maps atom indices to the number of cations/aromatic
    atoms they interact with
  """
  # find interacting rings from mol1 and cations from mol2 
  mol1_pi, mol2_cation = compute_cation_pi(mol1, mol2, **kwargs)
  # find interacting cations from mol1 and rings from mol2 
  mol2_pi, mol1_cation = compute_cation_pi(mol2, mol1, **kwargs)

  # merge counters
  mol1_cation_pi = Counter()
  mol1_cation_pi.update(mol1_pi)
  mol1_cation_pi.update(mol1_cation)

  mol2_cation_pi = Counter()
  mol2_cation_pi.update(mol2_pi)
  mol2_cation_pi.update(mol2_cation)

  return mol1_cation_pi, mol2_cation_pi
+244 −0

File added.

Preview size limit exceeded, changes collapsed.

+1 −520

File changed.

Preview size limit exceeded, changes collapsed.

Loading