Commit d0b20b17 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by Bharath Ramsundar
Browse files

Making progress on debugging tests

parent 03b787d3
Loading
Loading
Loading
Loading
+41 −30
Original line number Diff line number Diff line
@@ -7,6 +7,8 @@ import itertools
from deepchem.utils.hash_utils import hash_ecfp
from deepchem.feat import ComplexFeaturizer
from deepchem.utils import rdkit_utils
from deepchem.utils.rdkit_utils import load_complex
from deepchem.utils.rdkit_utils import load_molecule
from deepchem.utils.hash_utils import vectorize
from deepchem.utils.voxel_utils import voxelize
from deepchem.utils.voxel_utils import convert_atom_to_voxel
@@ -16,28 +18,35 @@ from deepchem.utils.rdkit_utils import MoleculeLoadException
from deepchem.utils.geometry_utils import compute_pairwise_distances
from deepchem.utils.geometry_utils import subtract_centroid

from typing import Tuple, Dict

logger = logging.getLogger(__name__)


def featurize_contacts_ecfp(frag1,
                            frag2,
                            pairwise_distances=None,
                            cutoff=4.5,
                            ecfp_degree=2):
def featurize_contacts_ecfp(
    frag1: Tuple,
    frag2: Tuple,
    pairwise_distances: np.ndarray = None,
    cutoff: float = 4.5,
    ecfp_degree: int = 2) -> Tuple[Dict[int, str], Dict[int, str]]:
  """Computes ECFP dicts for pairwise interaction between two molecular fragments.

  Parameters
  ----------
  frag1: Tuple
    A tuple of (coords, mol) returned by `rdkit_util.load_molecule`.
    A tuple of (coords, mol) returned by `load_molecule`.
  frag2: Tuple
    A tuple of (coords, mol) returned by `rdkit_util.load_molecule`.
    A tuple of (coords, mol) returned by `load_molecule`.
  pairwise_distances: np.ndarray
    Array of pairwise fragment-fragment distances (Angstroms)
  cutoff: float
    Cutoff distance for contact consideration
  ecfp_degree: int
    ECFP radius

  Returns
  -------
  Tuple of dictionaries of ECFP contact fragments
  """
  if pairwise_distances is None:
    pairwise_distances = compute_pairwise_distances(frag1[0], frag2[0])
@@ -71,7 +80,7 @@ class ContactCircularFingerprint(ComplexFeaturizer):
  `(2*size,)`
  """

  def __init__(self, cutoff=4.5, radius=2, size=8):
  def __init__(self, cutoff: float = 4.5, radius: int = 2, size: int = 8):
    """
    Parameters
    ----------
@@ -86,18 +95,19 @@ class ContactCircularFingerprint(ComplexFeaturizer):
    self.radius = radius
    self.size = size

  def _featurize_complex(self, molecular_complex):
  def _featurize(self, mol_pdb: str, complex_pdb: str):
    """
    Compute featurization for a molecular complex

    Parameters
    ----------
    molecular_complex: Object
      Some representation of a molecular complex.
    mol_pdb: str
      Filename for ligand molecule
    complex_pdb: str
      Filename for protein molecule
    """
    try:
      fragments = rdkit_util.load_complex(
          molecular_complex, add_hydrogens=False)
      fragments = load_complex((mol_pdb, complex_pdb), add_hydrogens=False)

    except MoleculeLoadException:
      logger.warning("This molecule cannot be loaded by Rdkit. Returning None")
@@ -141,12 +151,12 @@ class ContactCircularVoxelizer(ComplexFeaturizer):
  """

  def __init__(self,
               cutoff=4.5,
               radius=2,
               size=8,
               box_width=16.0,
               voxel_width=1.0,
               flatten=False):
               cutoff: float = 4.5,
               radius: int = 2,
               size: int = 8,
               box_width: float = 16.0,
               voxel_width: float = 1.0,
               flatten: bool = False):
    """
    Parameters
    ----------
@@ -173,19 +183,20 @@ class ContactCircularVoxelizer(ComplexFeaturizer):
    self.voxels_per_edge = int(self.box_width / self.voxel_width)
    self.flatten = flatten

  def _featurize_complex(self, molecular_complex):
  def _featurize(self, mol_pdb: str, complex_pdb: str):
    """
    Compute featurization for a single mol/protein complex
    Compute featurization for a molecular complex

    Parameters
    ----------
    molecular_complex: Object
      A representation of a molecular complex, produced by
      `rdkit_util.load_complex`.
    mol_pdb: str
      Filename for ligand molecule
    complex_pdb: str
      Filename for protein molecule
    """
    molecular_complex = (mol_pdb, complex_pdb)
    try:
      fragments = rdkit_util.load_complex(
          molecular_complex, add_hydrogens=False)
      fragments = load_complex(molecular_complex, add_hydrogens=False)

    except MoleculeLoadException:
      logger.warning("This molecule cannot be loaded by Rdkit. Returning None")
@@ -202,10 +213,10 @@ class ContactCircularVoxelizer(ComplexFeaturizer):
          sum([
              voxelize(
                  convert_atom_to_voxel,
                  self.box_width,
                  self.voxel_width,
                  hash_ecfp,
                  xyz,
                  self.box_width,
                  self.voxel_width,
                  feature_dict=ecfp_dict,
                  nb_channel=self.size) for xyz, ecfp_dict in zip(
                      xyzs,
+2 −2
Original line number Diff line number Diff line
@@ -9,9 +9,9 @@ class TestContactFeaturizers(unittest.TestCase):
  def setUp(self):
    # TODO test more formats for ligand
    current_dir = os.path.dirname(os.path.realpath(__file__))
    self.protein_file = os.path.join(current_dir,
    self.protein_file = os.path.join(current_dir, 'data',
                                     '3ws9_protein_fixer_rdkit.pdb')
    self.ligand_file = os.path.join(current_dir, '3ws9_ligand.sdf')
    self.ligand_file = os.path.join(current_dir, 'data', '3ws9_ligand.sdf')
    self.complex_files = [(self.protein_file, self.ligand_file)]

  def test_contact_fingerprint_shape(self):
+30 −1
Original line number Diff line number Diff line
@@ -5,7 +5,8 @@ from typing import List, Optional, Sequence, Set, Tuple, Union

from deepchem.utils.typing import RDKitAtom, RDKitMol
from deepchem.utils.geometry_utils import compute_pairwise_distances
from deepchem.utils.rdkit_utils import compute_charges

#from deepchem.utils.rdkit_utils import compute_charges


class AtomShim(object):
@@ -357,3 +358,31 @@ def reduce_molecular_complex_to_contacts(
    contact_frag = get_mol_subset(frag[0], frag[1], keep)
    reduced_complex.append(contact_frag)
  return reduced_complex


# TODO: This is duplicated! Clean up
def compute_charges(mol):
  """Attempt to compute Gasteiger Charges on Mol

  This also has the side effect of calculating charges on mol.  The
  mol passed into this function has to already have been sanitized

  Parameters
  ----------
  mol: rdkit molecule

  Returns
  -------
  No return since updates in place.

  Note
  ----
  This function requires RDKit to be installed.
  """
  from rdkit.Chem import AllChem
  try:
    # Updates charges in place
    AllChem.ComputeGasteigerCharges(mol)
  except Exception as e:
    logging.exception("Unable to compute charges for mol")
    raise MoleculeLoadException(e)
+189 −8
Original line number Diff line number Diff line
@@ -9,11 +9,16 @@ properties of molecules.

import os
import logging
import itertools
import numpy as np
from io import StringIO
from deepchem.utils.pdbqt_utils import pdbqt_to_pdb
from deepchem.utils.pdbqt_utils import convert_mol_to_pdbqt
from deepchem.utils.pdbqt_utils import convert_protein_to_pdbqt
from deepchem.utils.geometry_utils import compute_pairwise_distances
from deepchem.utils.fragment_utils import MolecularFragment
from typing import Any, List, Tuple, Set, Optional, Dict
from deepchem.utils.typing import OneOrMany, RDKitMol

logger = logging.getLogger(__name__)

@@ -168,10 +173,10 @@ def compute_charges(mol):
    raise MoleculeLoadException(e)


def load_complex(molecular_complex,
                 add_hydrogens=True,
                 calc_charges=True,
                 sanitize=True):
def load_complex(molecular_complex: OneOrMany[str],
                 add_hydrogens: bool = True,
                 calc_charges: bool = True,
                 sanitize: bool = True) -> List[Tuple]:
  """Loads a molecular complex.

  Given some representation of a molecular complex, returns a list of
@@ -372,12 +377,29 @@ def merge_molecules(molecules):
    return combined


def compute_all_ecfp(mol, indices=None, degree=2):
def compute_all_ecfp(mol: RDKitMol,
                     indices: Optional[Set[int]] = None,
                     degree: int = 2) -> Dict[int, str]:
  """Obtain molecular fragment for all atoms emanating outward to given degree.

  For each fragment, compute SMILES string (for now) and hash to
  an int. Return a dictionary mapping atom index to hashed
  SMILES.

  Parameters
  ----------
  mol: rdkit Molecule
    Molecule to compute ecfp fragments on
  indices: Optional[Set[int]]
    List of atom indices for molecule. Default is all indices. If
    specified will only compute fragments for specified atoms.
  degree: int
    Graph degree to use when computing ECFP fingerprints

  Parameters
  ----------
  

  """

  ecfp_dict = {}
@@ -393,13 +415,16 @@ def compute_all_ecfp(mol, indices=None, degree=2):
  return ecfp_dict


def compute_contact_centroid(molecular_complex, cutoff=4.5):
def compute_contact_centroid(molecular_complex: Any,
                             cutoff: float = 4.5) -> np.ndarray:
  """Computes the (x,y,z) centroid of the contact regions of this molecular complex.

  For a molecular complex, it's necessary for various featurizations
  that compute voxel grids to find a reasonable center for the
  voxelization. This function computes the centroid of all the contact
  atoms, defined as an atom that's within `cutoff` Angstroms of an
  atom from a different molecule.

  Parameters
  ----------
  molecular_complex: Object
@@ -415,14 +440,53 @@ def compute_contact_centroid(molecular_complex, cutoff=4.5):
  return (centroid)


def reduce_molecular_complex_to_contacts(fragments: List,
                                         cutoff: float = 4.5) -> List:
  """Reduce a molecular complex to only those atoms near a contact.

  Molecular complexes can get very large. This can make it unwieldy to
  compute functions on them. To improve memory usage, it can be very
  useful to trim out atoms that aren't close to contact regions. This
  function takes in a molecular complex and returns a new molecular
  complex representation that contains only contact atoms. The contact
  atoms are computed by calling `get_contact_atom_indices` under the
  hood.

  Parameters
  ----------
  fragments: List
    As returned by `rdkit_util.load_complex`, a list of tuples of
    `(coords, mol)` where `coords` is a `(N_atoms, 3)` array and `mol`
    is the rdkit molecule object.
  cutoff: float
    The cutoff distance in angstroms.

  Returns
  -------
  A list of length `len(molecular_complex)`. Each entry in this list
  is a tuple of `(coords, MolecularShim)`. The coords is stripped down
  to `(N_contact_atoms, 3)` where `N_contact_atoms` is the number of
  contact atoms for this complex. `MolecularShim` is used since it's
  tricky to make a RDKit sub-molecule. 
  """
  atoms_to_keep = get_contact_atom_indices(fragments, cutoff)
  reduced_complex = []
  for frag, keep in zip(fragments, atoms_to_keep):
    contact_frag = get_mol_subset(frag[0], frag[1], keep)
    reduced_complex.append(contact_frag)
  return reduced_complex


def compute_ring_center(mol, ring_indices):
  """Computes 3D coordinates of a center of a given ring.

  Parameters:
  -----------
  mol: rdkit.rdchem.Mol
    Molecule containing a ring
  ring_indices: array-like
    Indices of atoms forming a ring

  Returns:
  --------
  ring_centroid: np.ndarray
@@ -435,3 +499,120 @@ def compute_ring_center(mol, ring_indices):
    ring_xyz[i] = np.array(atom_position)
  ring_centroid = compute_centroid(ring_xyz)
  return ring_centroid


def get_contact_atom_indices(fragments: List, cutoff: float = 4.5) -> List:
  """Compute that atoms close to contact region.

  Molecular complexes can get very large. This can make it unwieldy to
  compute functions on them. To improve memory usage, it can be very
  useful to trim out atoms that aren't close to contact regions. This
  function computes pairwise distances between all pairs of molecules
  in the molecular complex. If an atom is within cutoff distance of
  any atom on another molecule in the complex, it is regarded as a
  contact atom. Otherwise it is trimmed.

  Parameters
  ----------
  fragments: List
    As returned by `rdkit_util.load_complex`, a list of tuples of
    `(coords, mol)` where `coords` is a `(N_atoms, 3)` array and `mol`
    is the rdkit molecule object.
  cutoff: float
    The cutoff distance in angstroms.

  Returns
  -------
  A list of length `len(molecular_complex)`. Each entry in this list
  is a list of atom indices from that molecule which should be kept, in
  sorted order.
  """
  # indices to atoms to keep
  keep_inds: List[Set] = [set([]) for _ in fragments]
  for (ind1, ind2) in itertools.combinations(range(len(fragments)), 2):
    frag1, frag2 = fragments[ind1], fragments[ind2]
    pairwise_distances = compute_pairwise_distances(frag1[0], frag2[0])
    # contacts is of form (x_coords, y_coords), a tuple of 2 lists
    contacts = np.nonzero((pairwise_distances < cutoff))
    # contacts[0] is the x_coords, that is the frag1 atoms that have
    # nonzero contact.
    frag1_atoms = set([int(c) for c in contacts[0].tolist()])
    # contacts[1] is the y_coords, the frag2 atoms with nonzero contacts
    frag2_atoms = set([int(c) for c in contacts[1].tolist()])
    keep_inds[ind1] = keep_inds[ind1].union(frag1_atoms)
    keep_inds[ind2] = keep_inds[ind2].union(frag2_atoms)
  keep_ind_lists = [sorted(list(keep)) for keep in keep_inds]
  return keep_ind_lists

  # Now extract atoms
  #atoms_to_keep = []
  #for i, frag_keep_inds in enumerate(keep_inds):
  #  frag = fragments[i]
  #  mol = frag[1]
  #  atoms = mol.GetAtoms()
  #  frag_keep = [atoms[keep_ind] for keep_ind in frag_keep_inds]
  #  atoms_to_keep.append(frag_keep)
  #return atoms_to_keep


def get_mol_subset(coords, mol, atom_indices_to_keep):
  """Strip a subset of the atoms in this molecule

  Parameters
  ----------
  coords: Numpy ndarray
    Must be of shape (N, 3) and correspond to coordinates of mol.
  mol: Rdkit mol or `MolecularFragment`
    The molecule to strip
  atom_indices_to_keep: list
    List of the indices of the atoms to keep. Each index is a unique
    number between `[0, N)`.

  Returns
  -------
  A tuple of (coords, mol_frag) where coords is a Numpy array of
  coordinates with hydrogen coordinates. mol_frag is a
  `MolecularFragment`. 
  """
  from rdkit import Chem
  indexes_to_keep = []
  atoms_to_keep = []
  #####################################################
  # Compute partial charges on molecule if rdkit
  if isinstance(mol, Chem.Mol):
    compute_charges(mol)
  #####################################################
  atoms = list(mol.GetAtoms())
  for index in atom_indices_to_keep:
    indexes_to_keep.append(index)
    atoms_to_keep.append(atoms[index])
  coords = coords[indexes_to_keep]
  mol_frag = MolecularFragment(atoms_to_keep, coords)
  return coords, mol_frag


def compute_ring_normal(mol, ring_indices):
  """Computes normal to a plane determined by a given ring.

  Parameters:
  -----------
  mol: rdkit.rdchem.Mol
    Molecule containing a ring
  ring_indices: array-like
    Indices of atoms forming a ring

  Returns:
  --------
  normal: np.ndarray
    Normal vector
  """
  conformer = mol.GetConformer()
  points = np.zeros((3, 3))
  for i, atom_idx in enumerate(ring_indices[:3]):
    atom_position = conformer.GetAtomPosition(atom_idx)
    points[i] = np.array(atom_position)

  v1 = points[1] - points[0]
  v2 = points[2] - points[0]
  normal = np.cross(v1, v2)
  return normal
+42 −36
Original line number Diff line number Diff line
import os
import unittest
from deepchem.utils.rdkit_utils import load_molecule
from deepchem.utils.rdkit_utils import compute_ring_center
from deepchem.utils.rdkit_utils import compute_ring_normal
from deepchem.utils.noncovalent_utils import is_pi_parallel
from deepchem.utils.noncovalent_utils import is_pi_t
from deepchem.utils.noncovalent_utils import compute_pi_stack
from deepchem.utils.noncovalent_utils import is_cation_pi
from deepchem.utils.noncovalent_utils import compute_cation_pi
from deepchem.utils.noncovalent_utils import compute_binding_pocket_cation_pi


class TestPiInteractions(unittest.TestCase):

  def setUp(self):
@@ -11,24 +24,24 @@ class TestPiInteractions(unittest.TestCase):
    Compute2DCoords(self.cycle4)

    # load and sanitize two real molecules
    _, self.prot = rdkit_util.load_molecule(
        os.path.join(current_dir, '../../feat/tests/3ws9_protein_fixer_rdkit.pdb'),
    _, self.prot = load_molecule(
        os.path.join(current_dir,
                     '../../feat/tests/3ws9_protein_fixer_rdkit.pdb'),
        add_hydrogens=False,
        calc_charges=False,
        sanitize=True)

    _, self.lig = rdkit_util.load_molecule(
    _, self.lig = load_molecule(
        os.path.join(current_dir, '../../feat/tests/3ws9_ligand.sdf'),
        add_hydrogens=False,
        calc_charges=False,
        sanitize=True)

  def test_compute_ring_center(self):
    self.assertTrue(
        np.allclose(rdkit_util.compute_ring_center(self.cycle4, range(4)), 0))
    self.assertTrue(np.allclose(compute_ring_center(self.cycle4, range(4)), 0))

  def test_compute_ring_normal(self):
    normal = rdkit_util.compute_ring_normal(self.cycle4, range(4))
    normal = compute_ring_normal(self.cycle4, range(4))
    self.assertTrue(
        np.allclose(np.abs(normal / np.linalg.norm(normal)), [0, 0, 1]))

@@ -42,21 +55,15 @@ class TestPiInteractions(unittest.TestCase):
    for ring2_normal in (np.array([2.0, 0, 0]), np.array([-3.0, 0, 0])):
      # parallel normals
      self.assertTrue(
          rdkit_util.is_pi_parallel(ring1_center,
                                    ring1_normal_true,
                                    ring2_center_true,
          is_pi_parallel(ring1_center, ring1_normal_true, ring2_center_true,
                         ring2_normal))
      # perpendicular normals
      self.assertFalse(
          rdkit_util.is_pi_parallel(ring1_center,
                                    ring1_normal_false,
                                    ring2_center_true,
          is_pi_parallel(ring1_center, ring1_normal_false, ring2_center_true,
                         ring2_normal))
      # too far away
      self.assertFalse(
          rdkit_util.is_pi_parallel(ring1_center,
                                    ring1_normal_true,
                                    ring2_center_false,
          is_pi_parallel(ring1_center, ring1_normal_true, ring2_center_false,
                         ring2_normal))

  def test_is_pi_t(self):
@@ -69,27 +76,27 @@ class TestPiInteractions(unittest.TestCase):
    for ring2_normal in (np.array([2.0, 0, 0]), np.array([-3.0, 0, 0])):
      # perpendicular normals
      self.assertTrue(
          rdkit_util.is_pi_t(ring1_center, ring1_normal_true, ring2_center_true,
          is_pi_t(ring1_center, ring1_normal_true, ring2_center_true,
                  ring2_normal))
      # parallel normals
      self.assertFalse(
          rdkit_util.is_pi_t(ring1_center, ring1_normal_false, ring2_center_true,
          is_pi_t(ring1_center, ring1_normal_false, ring2_center_true,
                  ring2_normal))
      # too far away
      self.assertFalse(
          rdkit_util.is_pi_t(ring1_center, ring1_normal_true, ring2_center_false,
          is_pi_t(ring1_center, ring1_normal_true, ring2_center_false,
                  ring2_normal))

  def test_compute_pi_stack(self):
    # order of the molecules shouldn't matter
    dicts1 = rdkit_util.compute_pi_stack(self.prot, self.lig)
    dicts2 = rdkit_util.compute_pi_stack(self.lig, self.prot)
    dicts1 = compute_pi_stack(self.prot, self.lig)
    dicts2 = compute_pi_stack(self.lig, self.prot)
    for i, j in ((0, 2), (1, 3)):
      self.assertEqual(dicts1[i], dicts2[j])
      self.assertEqual(dicts1[j], dicts2[i])

    # with this criteria we should find both types of stacking
    for d in rdkit_util.compute_pi_stack(
    for d in compute_pi_stack(
        self.lig, self.prot, dist_cutoff=7, angle_cutoff=40.):
      self.assertGreater(len(d), 0)

@@ -102,26 +109,25 @@ class TestPiInteractions(unittest.TestCase):

    # parallel normals
    self.assertTrue(
        rdkit_util.is_cation_pi(cation_position, ring_center_true, ring_normal_true))
        is_cation_pi(cation_position, ring_center_true, ring_normal_true))
    # perpendicular normals
    self.assertFalse(
        rdkit_util.is_cation_pi(cation_position, ring_center_true, ring_normal_false))
        is_cation_pi(cation_position, ring_center_true, ring_normal_false))
    # too far away
    self.assertFalse(
        rdkit_util.is_cation_pi(cation_position, ring_center_false, ring_normal_true))
        is_cation_pi(cation_position, ring_center_false, ring_normal_true))

  def test_compute_cation_pi(self):
    # TODO(rbharath): find better example, currently dicts are empty
    dicts1 = rdkit_util.compute_cation_pi(self.prot, self.lig)
    dicts2 = rdkit_util.compute_cation_pi(self.lig, self.prot)
    dicts1 = compute_cation_pi(self.prot, self.lig)
    dicts2 = compute_cation_pi(self.lig, self.prot)

  def test_compute_binding_pocket_cation_pi(self):
    # TODO find better example, currently dicts are empty
    prot_dict, lig_dict = rdkit_util.compute_binding_pocket_cation_pi(
        self.prot, self.lig)
    prot_dict, lig_dict = compute_binding_pocket_cation_pi(self.prot, self.lig)

    exp_prot_dict, exp_lig_dict = rdkit_util.compute_cation_pi(self.prot, self.lig)
    add_lig, add_prot = rdkit_util.compute_cation_pi(self.lig, self.prot)
    exp_prot_dict, exp_lig_dict = compute_cation_pi(self.prot, self.lig)
    add_lig, add_prot = compute_cation_pi(self.lig, self.prot)
    for exp_dict, to_add in ((exp_prot_dict, add_prot), (exp_lig_dict,
                                                         add_lig)):
      for atom_idx, count in to_add.items():