Commit 071f29ab authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Pocket detector final draft

parent 9c4582ae
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -12,4 +12,4 @@ from deepchem.dock.pose_scoring import GridPoseScorer
from deepchem.dock.docking import Docker
from deepchem.dock.docking import VinaGridRFDocker
from deepchem.dock.docking import VinaGridDNNDocker
from deepchem.dock.binding_pocket import ConvexHullRFPocketFinder
from deepchem.dock.binding_pocket import ConvexHullPocketFinder
+155 −18
Original line number Diff line number Diff line
@@ -19,38 +19,175 @@ from deepchem.feat.atomic_coordinates import AtomicCoordinates
from deepchem.feat.grid_featurizer import load_molecule
from subprocess import call

class BindingPocketFinder(object):
  """Abstract superclass for binding pocket detection"""
def extract_active_site(protein_file, ligand_file, cutoff=4):
  """Extracts a box for the active site."""
  protein_coords = load_molecule(protein_file)[0]
  ligand_coords = load_molecule(ligand_file)[0]
  num_ligand_atoms = len(ligand_coords)
  num_protein_atoms = len(protein_coords)
  pocket_inds = []
  pocket_atoms = set([])
  for lig_atom_ind in range(num_ligand_atoms):
    lig_atom = ligand_coords[lig_atom_ind]
    for protein_atom_ind in range(num_protein_atoms):
      protein_atom = protein_coords[protein_atom_ind]
      if np.linalg.norm(lig_atom - protein_atom) < cutoff:
        if protein_atom_ind not in pocket_atoms:
          pocket_atoms = pocket_atoms.union(set([protein_atom_ind]))
  # Should be an array of size (n_pocket_atoms, 3)
  pocket_atoms = list(pocket_atoms)
  n_pocket_atoms = len(pocket_atoms)
  pocket_coords = np.zeros((n_pocket_atoms, 3))
  for ind, pocket_ind in enumerate(pocket_atoms):
    pocket_coords[ind] = protein_coords[pocket_ind]

  def find_pockets(self, protein_file):
    """Finds potential binding pockets in proteins."""
    raise NotImplementedError
  x_min = int(np.floor(np.amin(pocket_coords[:, 0])))
  x_max = int(np.ceil(np.amax(pocket_coords[:, 0])))
  y_min = int(np.floor(np.amin(pocket_coords[:, 1])))
  y_max = int(np.ceil(np.amax(pocket_coords[:, 1])))
  z_min = int(np.floor(np.amin(pocket_coords[:, 2])))
  z_max = int(np.ceil(np.amax(pocket_coords[:, 2])))
  return (((x_min, x_max), (y_min, y_max), (z_min, z_max)),
          pocket_atoms, pocket_coords)

class ConvexHullRFPocketFinder(BindingPocketFinder):
  """Implementation that uses convex hull of protein to find pockets.
def compute_overlap(mapping, box1, box2):
  """Computes overlap between the two boxes.

  Based on https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4112621/pdf/1472-6807-14-18.pdf
  Overlap is defined as % atoms of box1 in box2. Note that
  overlap is not a symmetric measurement.
  """
  def __init__(self):
    pass
  atom1 = set(mapping[box1])
  atom2 = set(mapping[box2])
  return len(atom1.intersection(atom2))/float(len(atom1))

  def find_all_pockets(self, protein_file):
    """Find list of binding pockets on protein."""
    # protein_coords is (N, 3) tensor
    coords = load_molecule(protein_file)[0]
def get_all_boxes(coords, pad=5):
  """Get all pocket boxes for protein coords.

  We pad all boxes the prescribed number of angstroms.

  TODO(rbharath): It looks like this may perhaps be non-deterministic?
  """
  hull = ConvexHull(coords)
    faces = []
  boxes = []
  for triangle in hull.simplices:
    # coords[triangle, 0] gives the x-dimension of all triangle points
    # Take transpose to make sure rows correspond to atoms.
    points = np.array(
        [coords[triangle, 0], coords[triangle, 1], coords[triangle, 2]]).T
    # We voxelize so all grids have integral coordinates (convenience)
    x_min, x_max = np.amin(points[:, 0]), np.amax(points[:, 0])
    x_min, x_max = int(np.floor(x_min))-pad, int(np.ceil(x_max))+pad
    y_min, y_max = np.amin(points[:, 1]), np.amax(points[:, 1])
    y_min, y_max = int(np.floor(y_min))-pad, int(np.ceil(y_max))+pad
    z_min, z_max = np.amin(points[:, 2]), np.amax(points[:, 2])
      faces.append([(x_min, x_max), (y_min, y_max), (z_min, z_max)])
    return faces
    z_min, z_max = int(np.floor(z_min))-pad, int(np.ceil(z_max))+pad
    boxes.append(((x_min, x_max), (y_min, y_max), (z_min, z_max)))
  return boxes 

def boxes_to_atoms(atom_coords, boxes):
  """Maps each box to a list of atoms in that box.

  TODO(rbharath): This does a num_atoms x num_boxes computations. Is
  there a reasonable heuristic we can use to speed this up?
  """
  mapping = {}
  for box_ind, box in enumerate(boxes):
    box_atoms = []
    (x_min, x_max), (y_min, y_max), (z_min, z_max) = box
    print("Handing box %d/%d" % (box_ind, len(boxes)))
    for atom_ind in range(len(atom_coords)):
      atom = atom_coords[atom_ind] 
      x_cont = x_min <= atom[0] and atom[0] <= x_max
      y_cont = y_min <= atom[1] and atom[1] <= y_max
      z_cont = z_min <= atom[2] and atom[2] <= z_max
      if x_cont and y_cont and z_cont:
        box_atoms.append(atom_ind)
    mapping[box] = box_atoms
  return mapping

def merge_boxes(box1, box2):
  """Merges two boxes."""
  (x_min1, x_max1), (y_min1, y_max1), (z_min1, z_max1) = box1 
  (x_min2, x_max2), (y_min2, y_max2), (z_min2, z_max2) = box2 
  x_min = min(x_min1, x_min2)
  y_min = min(y_min1, y_min2)
  z_min = min(z_min1, z_min2)
  x_max = max(x_max1, x_max2)
  y_max = max(y_max1, y_max2)
  z_max = max(z_max1, z_max2)
  return ((x_min, x_max), (y_min, y_max), (z_min, z_max))

def merge_overlapping_boxes(mapping, boxes, threshold=.8):
  """Merge boxes which have an overlap greater than threshold.

  TODO(rbharath): This merge code is terribly inelegant. It's also quadratic
  in number of boxes. It feels like there ought to be an elegant divide and
  conquer approach here. Figure out later...
  """
  num_boxes = len(boxes)
  outputs = []
  for i in range(num_boxes):
    box = boxes[0]
    new_boxes = []
    new_mapping = {}
    # If overlap of box with previously generated output boxes, return
    contained = False
    for output_box in outputs:
      # Carry forward mappings
      new_mapping[output_box] = mapping[output_box]
      if compute_overlap(mapping, box, output_box) == 1:
        contained = True
    if contained:
      continue
    # We know that box has at least one atom not in outputs
    unique_box = True
    for merge_box in boxes[1:]:
      overlap = compute_overlap(mapping, box, merge_box)
      if overlap < threshold:
        new_boxes.append(merge_box)
        new_mapping[merge_box] = mapping[merge_box]
      else:
        # Current box has been merged into box further down list.
        # No need to output current box
        unique_box = False
        merged = merge_boxes(box, merge_box)
        new_boxes.append(merged)
        new_mapping[merged] = list(
            set(mapping[box]).union(set(mapping[merge_box])))
    if unique_box:
      outputs.append(box)
      new_mapping[box] = mapping[box]
    boxes = new_boxes
    mapping = new_mapping
  return outputs, mapping


class BindingPocketFinder(object):
  """Abstract superclass for binding pocket detectors"""

  def find_pockets(self, protein_file):
    """Finds potential binding pockets in proteins."""
    raise NotImplementedError

class ConvexHullPocketFinder(BindingPocketFinder):
  """Implementation that uses convex hull of protein to find pockets.

  Based on https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4112621/pdf/1472-6807-14-18.pdf
  """
  def __init__(self, pad=5):
    self.pad = pad

  def find_all_pockets(self, protein_file):
    """Find list of binding pockets on protein."""
    # protein_coords is (N, 3) tensor
    coords = load_molecule(protein_file)[0]
    return get_all_boxes(coords, self.pad)

  def find_pockets(self, protein_file, ligand_file):
    """Find list of suitable binding pockets on protein."""
    return self.find_all_pockets(protein_file)
    protein_coords = load_molecule(protein_file)[0]
    ligand_coords = load_molecule(ligand_file)[0]
    boxes = get_all_boxes(protein_coords, self.pad)
    mapping = boxes_to_atoms(protein_coords, boxes)
    merged_boxes, mapping = merge_overlapping_boxes(mapping, boxes)
    return merged_boxes, mapping
+119 −11
Original line number Diff line number Diff line
@@ -22,33 +22,141 @@ class TestPoseGeneration(unittest.TestCase):
  """

  def test_convex_rf_init(self):
    """Tests that ConvexHullRFPocketFinder can be initialized."""
    finder = dc.dock.ConvexHullRFPocketFinder()
    """Tests that ConvexHullPocketFinder can be initialized."""
    finder = dc.dock.ConvexHullPocketFinder()

  def test_convex_rf_find_all_pockets(self):
  def test_get_all_boxes(self):
    """Tests that binding pockets are detected."""
    current_dir = os.path.dirname(os.path.realpath(__file__))
    protein_file = os.path.join(current_dir, "1jld_protein.pdb")
    ligand_file = os.path.join(current_dir, "1jld_ligand.sdf")
    coords = dc.feat.grid_featurizer.load_molecule(protein_file)[0]

    finder = dc.dock.ConvexHullRFPocketFinder()

    all_pockets = finder.find_all_pockets(protein_file)
    assert isinstance(all_pockets, list)
    # Pocket is of form [(x_min, x_max), (y_min, y_max), (z_min, z_max)]
    for pocket in all_pockets:
    boxes = dc.dock.binding_pocket.get_all_boxes(coords)
    assert isinstance(boxes, list)
    # Pocket is of form ((x_min, x_max), (y_min, y_max), (z_min, z_max))
    for pocket in boxes:
      assert len(pocket) == 3
      assert len(pocket[0]) == 2
      assert len(pocket[1]) == 2
      assert len(pocket[2]) == 2
      (x_min, x_max), (y_min, y_max), (z_min, z_max) = pocket
      assert x_min < x_max
      assert y_min < y_max
      assert z_min < z_max

  def test_boxes_to_atoms(self):
    """Test that mapping of protein atoms to boxes is meaningful."""
    current_dir = os.path.dirname(os.path.realpath(__file__))
    protein_file = os.path.join(current_dir, "1jld_protein.pdb")
    ligand_file = os.path.join(current_dir, "1jld_ligand.sdf")
    coords = dc.feat.grid_featurizer.load_molecule(protein_file)[0]
    boxes = dc.dock.binding_pocket.get_all_boxes(coords)

    mapping = dc.dock.binding_pocket.boxes_to_atoms(coords, boxes)
    assert isinstance(mapping, dict)
    for box, box_atoms in mapping.iteritems():
      (x_min, x_max), (y_min, y_max), (z_min, z_max) = box
      for atom_ind in box_atoms:
        atom = coords[atom_ind]
        assert x_min <= atom[0] and atom[0] <= x_max
        assert y_min <= atom[1] and atom[1] <= y_max
        assert z_min <= atom[2] and atom[2] <= z_max

  def test_compute_overlap(self):
    """Tests that overlap between boxes is computed correctly."""
    # box1 contained in box2
    box1 = ((1, 2), (1, 2), (1, 2))
    box2 = ((1, 3), (1, 3), (1, 3))
    mapping = {box1: [1, 2, 3, 4], box2: [1, 2, 3, 4, 5]}
    # box1 in box2, so complete overlap
    np.testing.assert_almost_equal(
        dc.dock.binding_pocket.compute_overlap(mapping, box1, box2), 1)
    # 4/5 atoms in box2 in box1, so 80 % overlap
    np.testing.assert_almost_equal(
        dc.dock.binding_pocket.compute_overlap(mapping, box2, box1), .8)

  def test_merge_overlapping_boxes(self):
    """Tests that overlapping boxes are merged."""
    # box2 contains box1
    box1 = ((1, 2), (1, 2), (1, 2))
    box2 = ((1, 3), (1, 3), (1, 3))
    mapping = {box1: [1, 2, 3, 4], box2: [1, 2, 3, 4, 5]}
    boxes = [box1, box2]
    merged_boxes, _ = dc.dock.binding_pocket.merge_overlapping_boxes(
        mapping, boxes)
    print("merged_boxes")
    print(merged_boxes)
    assert len(merged_boxes) == 1
    assert merged_boxes[0] == ((1, 3), (1, 3), (1, 3))

    # box1 contains box2
    box1 = ((1, 3), (1, 3), (1, 3))
    box2 = ((1, 2), (1, 2), (1, 2))
    mapping = {box1: [1, 2, 3, 4, 5, 6], box2: [1, 2, 3, 4]}
    boxes = [box1, box2]
    merged_boxes, _ = dc.dock.binding_pocket.merge_overlapping_boxes(
        mapping, boxes)
    print("merged_boxes")
    print(merged_boxes)
    assert len(merged_boxes) == 1
    assert merged_boxes[0] == ((1, 3), (1, 3), (1, 3))

    # box1 contains box2, box3
    box1 = ((1, 3), (1, 3), (1, 3))
    box2 = ((1, 2), (1, 2), (1, 2))
    box3 = ((1, 2.5), (1, 2.5), (1, 2.5))
    mapping = {box1: [1, 2, 3, 4, 5, 6], box2: [1, 2, 3, 4],
               box3: [1, 2, 3, 4, 5]}
    merged_boxes, _ = dc.dock.binding_pocket.merge_overlapping_boxes(
        mapping, boxes)
    print("merged_boxes")
    print(merged_boxes)
    assert len(merged_boxes) == 1
    assert merged_boxes[0] == ((1, 3), (1, 3), (1, 3))

  def test_convex_rf_find_pockets(self):
    """Test that some pockets are filtered out."""
    current_dir = os.path.dirname(os.path.realpath(__file__))
    protein_file = os.path.join(current_dir, "1jld_protein.pdb")
    ligand_file = os.path.join(current_dir, "1jld_ligand.sdf")

    finder = dc.dock.ConvexHullRFPocketFinder()
    finder = dc.dock.ConvexHullPocketFinder()

    all_pockets = finder.find_all_pockets(protein_file)
    pockets = finder.find_pockets(protein_file)
    pockets = finder.find_pockets(protein_file, ligand_file)

    assert len(pockets) < len(all_pockets)

  def test_extract_active_site(self):
    """Test that computed pockets have strong overlap with true binding pocket."""
    current_dir = os.path.dirname(os.path.realpath(__file__))
    protein_file = os.path.join(current_dir, "1jld_protein.pdb")
    ligand_file = os.path.join(current_dir, "1jld_ligand.sdf")

    active_site_box, active_site_atoms, active_site_coords = (
        dc.dock.binding_pocket.extract_active_site(
            protein_file, ligand_file))
    print("active_site_box")
    print(active_site_box)
    print("len(active_site_atoms)")
    print(len(active_site_atoms))

    finder = dc.dock.ConvexHullPocketFinder()
    pockets, pocket_atoms = finder.find_pockets(protein_file, ligand_file)

    # Add active site to dict
    print("active_site_box")
    print(active_site_box)
    pocket_atoms[active_site_box] = active_site_atoms
    overlapping_pocket = False
    for pocket in pockets:
      print("pocket")
      print(pocket)
      overlap = dc.dock.binding_pocket.compute_overlap(
          pocket_atoms, active_site_box, pocket)
      if overlap > .5:
        overlapping_pocket = True
      print("Overlap for pocket is %f" % overlap)
    assert overlapping_pocket
    
+0 −1
Original line number Diff line number Diff line
@@ -20,7 +20,6 @@ from functools import partial
from deepchem.feat import ComplexFeaturizer
from deepchem.utils.save import log


"""
http://stackoverflow.com/questions/38987/how-can-i-merge-two-python-dictionaries-in-a-single-expression
"""