Commit 16de6587 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Adding RF powered pocket selector

parent afe44833
Loading
Loading
Loading
Loading
+9 −4
Original line number Diff line number Diff line
@@ -225,13 +225,18 @@ class Dataset(object):
class NumpyDataset(Dataset):
  """A Dataset defined by in-memory numpy arrays."""

  def __init__(self, X, y, w=None, ids=None):
  def __init__(self, X, y=None, w=None, ids=None):
    n_samples = len(X)
    # The -1 indicates that y will be reshaped to have length -1
    if n_samples > 0:
      if y is not None:
        y = np.reshape(y, (n_samples, -1))
        if w is not None:
          w = np.reshape(w, (n_samples, -1))
      else:
        # Set labels to be zero, with zero weights
        y = np.zeros((n_samples, 1))
        w = np.zeros_like(y)
    n_tasks = y.shape[1]
    if ids is None:
      ids = np.arange(n_samples)
+1 −0
Original line number Diff line number Diff line
@@ -13,3 +13,4 @@ from deepchem.dock.docking import Docker
from deepchem.dock.docking import VinaGridRFDocker
from deepchem.dock.docking import VinaGridDNNDocker
from deepchem.dock.binding_pocket import ConvexHullPocketFinder
from deepchem.dock.binding_pocket import RFConvexHullPocketFinder
+114 −10
Original line number Diff line number Diff line
@@ -9,15 +9,20 @@ __author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2017, Stanford University"
__license__ = "GPL"

import numpy as np
import os
import pybel
import tempfile
import numpy as np
import openbabel as ob
from rdkit import Chem
from subprocess import call
from scipy.spatial import ConvexHull
from deepchem.feat import hydrogenate_and_compute_partial_charges
from deepchem.feat.atomic_coordinates import AtomicCoordinates
from deepchem.feat.grid_featurizer import load_molecule
from subprocess import call
from deepchem.feat.binding_pocket_features import BindingPocketFeaturizer 
from deepchem.feat.fingerprints import CircularFingerprint 
from deepchem.models.sklearn_models import SklearnModel
from deepchem.data.datasets import NumpyDataset

def extract_active_site(protein_file, ligand_file, cutoff=4):
  """Extracts a box for the active site."""
@@ -185,19 +190,118 @@ class ConvexHullPocketFinder(BindingPocketFinder):
  def find_pockets(self, protein_file, ligand_file):
    """Find list of suitable binding pockets on protein."""
    protein_coords = load_molecule(protein_file, add_hydrogens=False)[0]
    ######################################################################### DEBUG
    #print("protein_coords")
    #print(protein_coords)
    ######################################################################### DEBUG
    ligand_coords = load_molecule(ligand_file, add_hydrogens=False)[0]
    boxes = get_all_boxes(protein_coords, self.pad)
    mapping = boxes_to_atoms(protein_coords, boxes)
    pockets, pocket_atoms = merge_overlapping_boxes(mapping, boxes)
    pockets, pocket_atoms_map = merge_overlapping_boxes(mapping, boxes)
    pocket_coords = []
    for pocket in pockets:
      atoms = pocket_atoms[pocket]
      atoms = pocket_atoms_map[pocket]
      coords = np.zeros((len(atoms), 3))
      for ind, atom in enumerate(atoms):
        coords[ind] = protein_coords[atom]
      pocket_coords.append(coords)
    return pockets, pocket_atoms, pocket_coords
    return pockets, pocket_atoms_map, pocket_coords

class RFConvexHullPocketFinder(BindingPocketFinder):
  """Uses pre-trained RF model + ConvexHulPocketFinder to select pockets."""

  def __init__(self, pad=5):
    self.pad = pad
    self.convex_finder = ConvexHullPocketFinder(pad)

    # Load binding pocket model
    self.base_dir = tempfile.mkdtemp()
    print("About to download trained model.")
    # TODO(rbharath): Shift refined to full once trained.
    call(("wget -c http://deepchem.io.s3-website-us-west-1.amazonaws.com/trained_models/pocket_random_refined_RF.tar.gz").split())
    call(("tar -zxvf pocket_random_refined_RF.tar.gz").split())
    call(("mv pocket_random_refined_RF %s" % (self.base_dir)).split())
    self.model_dir = os.path.join(self.base_dir, "pocket_random_refined_RF")

    # Fit model on dataset
    self.model = SklearnModel(model_dir=self.model_dir)
    self.model.reload()

    # Create featurizers
    self.pocket_featurizer = BindingPocketFeaturizer()
    self.ligand_featurizer = CircularFingerprint(size=1024)

  def find_pockets(self, protein_file, ligand_file):
    """Compute features for a given complex

    TODO(rbharath): This has a log of code overlap with
    compute_binding_pocket_features in
    examples/binding_pockets/binding_pocket_datasets.py. Find way to refactor
    to avoid code duplication.
    """
    ##################################################### DEBUG
    #print("ENTERING!!!!!!!!!!!!!!!!!!!!")
    ##################################################### DEBUG
    if not ligand_file.endswith(".sdf"):
      raise ValueError("Only .sdf ligand files can be featurized.")
    ligand_basename = os.path.basename(ligand_file).split(".")[0]
    ligand_mol2 = os.path.join(
        self.base_dir, ligand_basename + ".mol2")

    # Write mol2 file for ligand
    obConversion = ob.OBConversion()
    conv_out = obConversion.SetInAndOutFormats(str("sdf"), str("mol2"))
    ob_mol = ob.OBMol()
    obConversion.ReadFile(ob_mol, str(ligand_file))
    obConversion.WriteFile(ob_mol, str(ligand_mol2))
      
    # Featurize ligand
    mol = Chem.MolFromMol2File(str(ligand_mol2), removeHs=False)
    if mol is None:
      return None, None
    # Default for CircularFingerprint
    n_ligand_features = 1024
    ligand_features = self.ligand_featurizer.featurize([mol])

    # Featurize pocket
    pockets, pocket_atoms_map, pocket_coords = self.convex_finder.find_pockets(
        protein_file, ligand_file)
    n_pockets = len(pockets)
    n_pocket_features = BindingPocketFeaturizer.n_features

    features = np.zeros((n_pockets, n_pocket_features+n_ligand_features))
    pocket_features = self.pocket_featurizer.featurize(
        protein_file, pockets, pocket_atoms_map, pocket_coords)
    # Note broadcast operation
    features[:, :n_pocket_features] = pocket_features
    features[:, n_pocket_features:] = ligand_features
    dataset = NumpyDataset(X=features)
    pocket_preds = self.model.predict(dataset)
    ############################################################# DEBUG
    pocket_pred_proba = np.squeeze(self.model.predict_proba(dataset))
    #print("pockets")
    #print(pockets)
    #print("pocket_features")
    #print(pocket_features)
    #print("features")
    #print(features)
    #print("n_pockets")
    #print(n_pockets)
    #print("pocket_pred_proba")
    #print(pocket_pred_proba)
    ############################################################# DEBUG

    # Find pockets which are active
    active_pockets = []
    active_pocket_atoms_map = {}
    active_pocket_coords = []
    for pocket_ind in range(len(pockets)):
      #################################################### DEBUG
      # TODO(rbharath): FIX THIS! For now since models are broken, using a bogus
      # cutoff.
      #if pocket_preds[pocket_ind] == 1:
      #print("pocket_pred_proba.shape")
      #print(pocket_pred_proba.shape)
      if pocket_pred_proba[pocket_ind][1] > .15:
      #################################################### DEBUG
        pocket = pockets[pocket_ind]
        active_pockets.append(pocket)
        active_pocket_atoms_map[pocket] = pocket_atoms_map[pocket]
        active_pocket_coords.append(pocket_coords[pocket_ind])
    return active_pockets, active_pocket_atoms_map, active_pocket_coords
+26 −6
Original line number Diff line number Diff line
@@ -134,12 +134,6 @@ class TestBindingPocket(unittest.TestCase):
    print(protein.xyz.shape)
    print("n_protein_atoms")
    print(n_protein_atoms)
    ############################################################## DEBUG
    #from deepchem.feat.grid_featurizer import load_molecule
    #protein_coords = load_molecule(protein_file, add_hydrogens=False)[0]
    #print("protein_coords.shape")
    #print(protein_coords.shape)
    ############################################################## DEBUG
    for pocket in pockets:
      pocket_atoms = pocket_atoms_map[pocket]
      for atom in pocket_atoms:
@@ -149,6 +143,32 @@ class TestBindingPocket(unittest.TestCase):

    assert len(pockets) < len(all_pockets)

  def test_rf_convex_find_pockets(self):
    """Test that filter with pre-trained RF models works."""
    current_dir = os.path.dirname(os.path.realpath(__file__))
    protein_file = os.path.join(current_dir, "1jld_protein.pdb")
    ligand_file = os.path.join(current_dir, "1jld_ligand.sdf")

    protein = md.load(protein_file)

    finder = dc.dock.RFConvexHullPocketFinder()
    pockets, pocket_atoms_map, pocket_coords = finder.find_pockets(
        protein_file, ligand_file)
    # Test that every atom in pocket maps exists
    n_protein_atoms = protein.xyz.shape[1]
    print("protein.xyz.shape")
    print(protein.xyz.shape)
    print("n_protein_atoms")
    print(n_protein_atoms)
    print("len(pockets)")
    print(len(pockets))
    for pocket in pockets:
      pocket_atoms = pocket_atoms_map[pocket]
      for atom in pocket_atoms:
        # Check that the atoms is actually in protein
        assert atom >= 0
        assert atom < n_protein_atoms

  def test_extract_active_site(self):
    """Test that computed pockets have strong overlap with true binding pocket."""
    current_dir = os.path.dirname(os.path.realpath(__file__))
+0 −10
Original line number Diff line number Diff line
@@ -44,18 +44,8 @@ class BindingPocketFeaturizer(Featurizer):
    all_features = np.zeros((n_pockets, n_residues)) 
    for pocket_num, (pocket, coords) in enumerate(zip(pockets, pocket_coords)):
      pocket_atoms = pocket_atoms_map[pocket]
      ################################################ DEBUG
      #print("len(pocket_atoms)")
      #print(len(pocket_atoms))
      #print("protein.top")
      #print(protein.top)
      ################################################ DEBUG
      for ind, atom in enumerate(pocket_atoms):
        atom_name = str(protein.top.atom(atom))
        ################################################ DEBUG
        #print("ind, atom, atom_name")
        #print(ind, atom, atom_name)
        ################################################ DEBUG
        # atom_name is of format RESX-ATOMTYPE
        # where X is a 1 to 4 digit number
        residue = atom_name[:3]
Loading