Commit 1303c61c authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Removing deepdock files

parent eaee272b
Loading
Loading
Loading
Loading

deepchem/dock/__init__.py

deleted100644 → 0
+0 −11
Original line number Diff line number Diff line
"""
Imports all submodules 
"""
from deepchem.dock.pose_generation import PoseGenerator
from deepchem.dock.pose_generation import VinaPoseGenerator
from deepchem.dock.pose_scoring import PoseScorer
from deepchem.dock.pose_scoring import GridPoseScorer
from deepchem.dock.docking import Docker
from deepchem.dock.docking import VinaGridRFDocker
from deepchem.dock.binding_pocket import ConvexHullPocketFinder
from deepchem.dock.binding_pocket import RFConvexHullPocketFinder

deepchem/dock/binding_pocket.py

deleted100644 → 0
+0 −301
Original line number Diff line number Diff line
"""
Computes putative binding pockets on protein.
"""
__author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2017, Stanford University"
__license__ = "MIT"

import os
import logging
import tempfile
import numpy as np
from subprocess import call
from scipy.spatial import ConvexHull
from deepchem.feat.binding_pocket_features import BindingPocketFeaturizer
from deepchem.feat.fingerprints import CircularFingerprint
from deepchem.models.sklearn_models import SklearnModel
from deepchem.utils import rdkit_util

logger = logging.getLogger(__name__)


def extract_active_site(protein_file, ligand_file, cutoff=4):
  """Extracts a box for the active site."""
  protein_coords = rdkit_util.load_molecule(
      protein_file, add_hydrogens=False)[0]
  ligand_coords = rdkit_util.load_molecule(
      ligand_file, add_hydrogens=True, calc_charges=True)[0]
  num_ligand_atoms = len(ligand_coords)
  num_protein_atoms = len(protein_coords)
  pocket_inds = []
  pocket_atoms = set([])
  for lig_atom_ind in range(num_ligand_atoms):
    lig_atom = ligand_coords[lig_atom_ind]
    for protein_atom_ind in range(num_protein_atoms):
      protein_atom = protein_coords[protein_atom_ind]
      if np.linalg.norm(lig_atom - protein_atom) < cutoff:
        if protein_atom_ind not in pocket_atoms:
          pocket_atoms = pocket_atoms.union(set([protein_atom_ind]))
  # Should be an array of size (n_pocket_atoms, 3)
  pocket_atoms = list(pocket_atoms)
  n_pocket_atoms = len(pocket_atoms)
  pocket_coords = np.zeros((n_pocket_atoms, 3))
  for ind, pocket_ind in enumerate(pocket_atoms):
    pocket_coords[ind] = protein_coords[pocket_ind]

  x_min = int(np.floor(np.amin(pocket_coords[:, 0])))
  x_max = int(np.ceil(np.amax(pocket_coords[:, 0])))
  y_min = int(np.floor(np.amin(pocket_coords[:, 1])))
  y_max = int(np.ceil(np.amax(pocket_coords[:, 1])))
  z_min = int(np.floor(np.amin(pocket_coords[:, 2])))
  z_max = int(np.ceil(np.amax(pocket_coords[:, 2])))
  return (((x_min, x_max), (y_min, y_max), (z_min, z_max)), pocket_atoms,
          pocket_coords)


def compute_overlap(mapping, box1, box2):
  """Computes overlap between the two boxes.

  Overlap is defined as % atoms of box1 in box2. Note that
  overlap is not a symmetric measurement.
  """
  atom1 = set(mapping[box1])
  atom2 = set(mapping[box2])
  return len(atom1.intersection(atom2)) / float(len(atom1))


def get_all_boxes(coords, pad=5):
  """Get all pocket boxes for protein coords.

  We pad all boxes the prescribed number of angstroms.

  TODO(rbharath): It looks like this may perhaps be non-deterministic?
  """
  hull = ConvexHull(coords)
  boxes = []
  for triangle in hull.simplices:
    # coords[triangle, 0] gives the x-dimension of all triangle points
    # Take transpose to make sure rows correspond to atoms.
    points = np.array(
        [coords[triangle, 0], coords[triangle, 1], coords[triangle, 2]]).T
    # We voxelize so all grids have integral coordinates (convenience)
    x_min, x_max = np.amin(points[:, 0]), np.amax(points[:, 0])
    x_min, x_max = int(np.floor(x_min)) - pad, int(np.ceil(x_max)) + pad
    y_min, y_max = np.amin(points[:, 1]), np.amax(points[:, 1])
    y_min, y_max = int(np.floor(y_min)) - pad, int(np.ceil(y_max)) + pad
    z_min, z_max = np.amin(points[:, 2]), np.amax(points[:, 2])
    z_min, z_max = int(np.floor(z_min)) - pad, int(np.ceil(z_max)) + pad
    boxes.append(((x_min, x_max), (y_min, y_max), (z_min, z_max)))
  return boxes


def boxes_to_atoms(atom_coords, boxes):
  """Maps each box to a list of atoms in that box.

  TODO(rbharath): This does a num_atoms x num_boxes computations. Is
  there a reasonable heuristic we can use to speed this up?
  """
  mapping = {}
  for box_ind, box in enumerate(boxes):
    box_atoms = []
    (x_min, x_max), (y_min, y_max), (z_min, z_max) = box
    logger.info("Handing box %d/%d" % (box_ind, len(boxes)))
    for atom_ind in range(len(atom_coords)):
      atom = atom_coords[atom_ind]
      x_cont = x_min <= atom[0] and atom[0] <= x_max
      y_cont = y_min <= atom[1] and atom[1] <= y_max
      z_cont = z_min <= atom[2] and atom[2] <= z_max
      if x_cont and y_cont and z_cont:
        box_atoms.append(atom_ind)
    mapping[box] = box_atoms
  return mapping


def merge_boxes(box1, box2):
  """Merges two boxes."""
  (x_min1, x_max1), (y_min1, y_max1), (z_min1, z_max1) = box1
  (x_min2, x_max2), (y_min2, y_max2), (z_min2, z_max2) = box2
  x_min = min(x_min1, x_min2)
  y_min = min(y_min1, y_min2)
  z_min = min(z_min1, z_min2)
  x_max = max(x_max1, x_max2)
  y_max = max(y_max1, y_max2)
  z_max = max(z_max1, z_max2)
  return ((x_min, x_max), (y_min, y_max), (z_min, z_max))


def merge_overlapping_boxes(mapping, boxes, threshold=.8):
  """Merge boxes which have an overlap greater than threshold.

  TODO(rbharath): This merge code is terribly inelegant. It's also quadratic
  in number of boxes. It feels like there ought to be an elegant divide and
  conquer approach here. Figure out later...
  """
  num_boxes = len(boxes)
  outputs = []
  for i in range(num_boxes):
    box = boxes[0]
    new_boxes = []
    new_mapping = {}
    # If overlap of box with previously generated output boxes, return
    contained = False
    for output_box in outputs:
      # Carry forward mappings
      new_mapping[output_box] = mapping[output_box]
      if compute_overlap(mapping, box, output_box) == 1:
        contained = True
    if contained:
      continue
    # We know that box has at least one atom not in outputs
    unique_box = True
    for merge_box in boxes[1:]:
      overlap = compute_overlap(mapping, box, merge_box)
      if overlap < threshold:
        new_boxes.append(merge_box)
        new_mapping[merge_box] = mapping[merge_box]
      else:
        # Current box has been merged into box further down list.
        # No need to output current box
        unique_box = False
        merged = merge_boxes(box, merge_box)
        new_boxes.append(merged)
        new_mapping[merged] = list(
            set(mapping[box]).union(set(mapping[merge_box])))
    if unique_box:
      outputs.append(box)
      new_mapping[box] = mapping[box]
    boxes = new_boxes
    mapping = new_mapping
  return outputs, mapping


class BindingPocketFinder(object):
  """Abstract superclass for binding pocket detectors"""

  def find_pockets(self, protein_file, ligand_file):
    """Finds potential binding pockets in proteins."""
    raise NotImplementedError


class ConvexHullPocketFinder(BindingPocketFinder):
  """Implementation that uses convex hull of protein to find pockets.

  Based on https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4112621/pdf/1472-6807-14-18.pdf
  """

  def __init__(self, pad=5):
    self.pad = pad

  def find_all_pockets(self, protein_file):
    """Find list of binding pockets on protein."""
    # protein_coords is (N, 3) tensor
    coords = rdkit_util.load_molecule(protein_file)[0]
    return get_all_boxes(coords, self.pad)

  def find_pockets(self, protein_file, ligand_file):
    """Find list of suitable binding pockets on protein."""
    protein_coords = rdkit_util.load_molecule(
        protein_file, add_hydrogens=False, calc_charges=False)[0]
    ligand_coords = rdkit_util.load_molecule(
        ligand_file, add_hydrogens=False, calc_charges=False)[0]
    boxes = get_all_boxes(protein_coords, self.pad)
    mapping = boxes_to_atoms(protein_coords, boxes)
    pockets, pocket_atoms_map = merge_overlapping_boxes(mapping, boxes)
    pocket_coords = []
    for pocket in pockets:
      atoms = pocket_atoms_map[pocket]
      coords = np.zeros((len(atoms), 3))
      for ind, atom in enumerate(atoms):
        coords[ind] = protein_coords[atom]
      pocket_coords.append(coords)
    return pockets, pocket_atoms_map, pocket_coords


class RFConvexHullPocketFinder(BindingPocketFinder):
  """Uses pre-trained RF model + ConvexHulPocketFinder to select pockets."""

  def __init__(self, pad=5):
    self.pad = pad
    self.convex_finder = ConvexHullPocketFinder(pad)

    # Load binding pocket model
    self.base_dir = tempfile.mkdtemp()
    logger.info("About to download trained model.")
    # TODO(rbharath): Shift refined to full once trained.
    call((
        "wget -nv -c http://deepchem.io.s3-website-us-west-1.amazonaws.com/trained_models/pocket_random_refined_RF.tar.gz"
    ).split())
    call(("tar -zxvf pocket_random_refined_RF.tar.gz").split())
    call(("mv pocket_random_refined_RF %s" % (self.base_dir)).split())
    self.model_dir = os.path.join(self.base_dir, "pocket_random_refined_RF")

    # Fit model on dataset
    self.model = SklearnModel(model_dir=self.model_dir)
    self.model.reload()

    # Create featurizers
    self.pocket_featurizer = BindingPocketFeaturizer()
    self.ligand_featurizer = CircularFingerprint(size=1024)

  def find_pockets(self, protein_file, ligand_file):
    """Compute features for a given complex

    TODO(rbharath): This has a log of code overlap with
    compute_binding_pocket_features in
    examples/binding_pockets/binding_pocket_datasets.py. Find way to refactor
    to avoid code duplication.
    """
    # if not ligand_file.endswith(".sdf"):
    #   raise ValueError("Only .sdf ligand files can be featurized.")
    # ligand_basename = os.path.basename(ligand_file).split(".")[0]
    # ligand_mol2 = os.path.join(
    #     self.base_dir, ligand_basename + ".mol2")
    #
    # # Write mol2 file for ligand
    # obConversion = ob.OBConversion()
    # conv_out = obConversion.SetInAndOutFormats(str("sdf"), str("mol2"))
    # ob_mol = ob.OBMol()
    # obConversion.ReadFile(ob_mol, str(ligand_file))
    # obConversion.WriteFile(ob_mol, str(ligand_mol2))
    #
    # # Featurize ligand
    # mol = Chem.MolFromMol2File(str(ligand_mol2), removeHs=False)
    # if mol is None:
    #   return None, None
    # # Default for CircularFingerprint
    # n_ligand_features = 1024
    # ligand_features = self.ligand_featurizer.featurize([mol])
    #
    # # Featurize pocket
    # pockets, pocket_atoms_map, pocket_coords = self.convex_finder.find_pockets(
    #     protein_file, ligand_file)
    # n_pockets = len(pockets)
    # n_pocket_features = BindingPocketFeaturizer.n_features
    #
    # features = np.zeros((n_pockets, n_pocket_features+n_ligand_features))
    # pocket_features = self.pocket_featurizer.featurize(
    #     protein_file, pockets, pocket_atoms_map, pocket_coords)
    # # Note broadcast operation
    # features[:, :n_pocket_features] = pocket_features
    # features[:, n_pocket_features:] = ligand_features
    # dataset = NumpyDataset(X=features)
    # pocket_preds = self.model.predict(dataset)
    # pocket_pred_proba = np.squeeze(self.model.predict_proba(dataset))
    #
    # # Find pockets which are active
    # active_pockets = []
    # active_pocket_atoms_map = {}
    # active_pocket_coords = []
    # for pocket_ind in range(len(pockets)):
    #   #################################################### DEBUG
    #   # TODO(rbharath): For now, using a weak cutoff. Fix later.
    #   #if pocket_preds[pocket_ind] == 1:
    #   if pocket_pred_proba[pocket_ind][1] > .15:
    #   #################################################### DEBUG
    #     pocket = pockets[pocket_ind]
    #     active_pockets.append(pocket)
    #     active_pocket_atoms_map[pocket] = pocket_atoms_map[pocket]
    #     active_pocket_coords.append(pocket_coords[pocket_ind])
    # return active_pockets, active_pocket_atoms_map, active_pocket_coords
    # # TODO(LESWING)
    raise ValueError("Karl Implement")

deepchem/dock/docking.py

deleted100644 → 0
+0 −119
Original line number Diff line number Diff line
"""
Docks protein-ligand pairs
"""
__author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "MIT"

import logging
import numpy as np
import os
import tempfile
from deepchem.data import DiskDataset
from deepchem.models import SklearnModel
from deepchem.models import MultitaskRegressor
from deepchem.dock.pose_scoring import GridPoseScorer
from deepchem.dock.pose_generation import VinaPoseGenerator
from sklearn.ensemble import RandomForestRegressor
from subprocess import call

logger = logging.getLogger(__name__)


class Docker(object):
  """Abstract Class specifying API for Docking."""

  def dock(self,
           protein_file,
           ligand_file,
           centroid=None,
           box_dims=None,
           dry_run=False):
    raise NotImplementedError


class VinaGridRFDocker(Docker):
  """Vina pose-generation, RF-models on grid-featurization of complexes."""

  def __init__(self, exhaustiveness=10, detect_pockets=False):
    """Builds model."""
    self.base_dir = tempfile.mkdtemp()
    logger.info("About to download trained model.")
    call((
        "wget -nv -c http://deepchem.io.s3-website-us-west-1.amazonaws.com/trained_models/random_full_RF.tar.gz"
    ).split())
    call(("tar -zxvf random_full_RF.tar.gz").split())
    call(("mv random_full_RF %s" % (self.base_dir)).split())
    self.model_dir = os.path.join(self.base_dir, "random_full_RF")

    # Fit model on dataset
    model = SklearnModel(model_dir=self.model_dir)
    model.reload()

    self.pose_scorer = GridPoseScorer(model, feat="grid")
    self.pose_generator = VinaPoseGenerator(
        exhaustiveness=exhaustiveness, detect_pockets=detect_pockets)

  def dock(self,
           protein_file,
           ligand_file,
           centroid=None,
           box_dims=None,
           dry_run=False):
    """Docks using Vina and RF."""
    protein_docked, ligand_docked = self.pose_generator.generate_poses(
        protein_file, ligand_file, centroid, box_dims, dry_run)
    if not dry_run:
      score = self.pose_scorer.score(protein_docked, ligand_docked)
    else:
      score = np.zeros((1,))
    return (score, (protein_docked, ligand_docked))


'''
class VinaGridDNNDocker(object):
  """Vina pose-generation, DNN-models on grid-featurization of complexes."""

  def __init__(self, exhaustiveness=10, detect_pockets=False):
    """Builds model."""
    self.base_dir = tempfile.mkdtemp()
    logger.info("About to download trained model.")
    call((
        "wget -nv -c http://deepchem.io.s3-website-us-west-1.amazonaws.com/trained_models/random_full_DNN.tar.gz"
    ).split())
    call(("tar -zxvf random_full_DNN.tar.gz").split())
    call(("mv random_full_DNN %s" % (self.base_dir)).split())
    self.model_dir = os.path.join(self.base_dir, "random_full_DNN")

    # Fit model on dataset
    pdbbind_tasks = ["-logKd/Ki"]
    n_features = 2052
    model = MultitaskRegressor(
        len(pdbbind_tasks),
        n_features,
        dropouts=[.25],
        learning_rate=0.0003,
        weight_init_stddevs=[.1],
        batch_size=64,
        model_dir=self.model_dir)
    model.reload()

    self.pose_scorer = GridPoseScorer(model, feat="grid")
    self.pose_generator = VinaPoseGenerator(
        exhaustiveness=exhaustiveness, detect_pockets=detect_pockets)

  def dock(self,
           protein_file,
           ligand_file,
           centroid=None,
           box_dims=None,
           dry_run=False):
    """Docks using Vina and DNNs."""
    protein_docked, ligand_docked = self.pose_generator.generate_poses(
        protein_file, ligand_file, centroid, box_dims, dry_run)
    if not dry_run:
      score = self.pose_scorer.score(protein_docked, ligand_docked)
    else:
      score = np.zeros((1,))
    return (score, (protein_docked, ligand_docked))
'''

deepchem/dock/pose_generation.py

deleted100644 → 0
+0 −165
Original line number Diff line number Diff line
"""
Generates protein-ligand docked poses using Autodock Vina.
"""
from deepchem.utils import mol_xyz_util

__author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "MIT"

import logging
import numpy as np
import os
import tempfile
from subprocess import call
from deepchem.feat import hydrogenate_and_compute_partial_charges
from deepchem.dock.binding_pocket import RFConvexHullPocketFinder
from deepchem.utils import rdkit_util

logger = logging.getLogger(__name__)


class PoseGenerator(object):
  """Abstract superclass for all pose-generation routines."""

  def generate_poses(self, protein_file, ligand_file, out_dir=None):
    """Generates the docked complex and outputs files for docked complex."""
    raise NotImplementedError


def write_conf(receptor_filename,
               ligand_filename,
               centroid,
               box_dims,
               conf_filename,
               exhaustiveness=None):
  """Writes Vina configuration file to disk."""
  with open(conf_filename, "w") as f:
    f.write("receptor = %s\n" % receptor_filename)
    f.write("ligand = %s\n\n" % ligand_filename)

    f.write("center_x = %f\n" % centroid[0])
    f.write("center_y = %f\n" % centroid[1])
    f.write("center_z = %f\n\n" % centroid[2])

    f.write("size_x = %f\n" % box_dims[0])
    f.write("size_y = %f\n" % box_dims[1])
    f.write("size_z = %f\n\n" % box_dims[2])

    if exhaustiveness is not None:
      f.write("exhaustiveness = %d\n" % exhaustiveness)


class VinaPoseGenerator(PoseGenerator):
  """Uses Autodock Vina to generate binding poses."""

  def __init__(self, exhaustiveness=10, detect_pockets=True):
    """Initializes Vina Pose generation"""
    current_dir = os.path.dirname(os.path.realpath(__file__))
    self.vina_dir = os.path.join(current_dir, "autodock_vina_1_1_2_linux_x86")
    self.exhaustiveness = exhaustiveness
    self.detect_pockets = detect_pockets
    if self.detect_pockets:
      self.pocket_finder = RFConvexHullPocketFinder()
    if not os.path.exists(self.vina_dir):
      logger.info("Vina not available. Downloading")
      # TODO(rbharath): May want to move this file to S3 so we can ensure it's
      # always available.
      wget_cmd = "wget -nv -c -T 15 http://vina.scripps.edu/download/autodock_vina_1_1_2_linux_x86.tgz"
      call(wget_cmd.split())
      logger.info("Downloaded Vina. Extracting")
      download_cmd = "tar xzvf autodock_vina_1_1_2_linux_x86.tgz"
      call(download_cmd.split())
      logger.info("Moving to final location")
      mv_cmd = "mv autodock_vina_1_1_2_linux_x86 %s" % current_dir
      call(mv_cmd.split())
      logger.info("Cleanup: removing downloaded vina tar.gz")
      rm_cmd = "rm autodock_vina_1_1_2_linux_x86.tgz"
      call(rm_cmd.split())
    self.vina_cmd = os.path.join(self.vina_dir, "bin/vina")

  def generate_poses(self,
                     protein_file,
                     ligand_file,
                     centroid=None,
                     box_dims=None,
                     dry_run=False,
                     out_dir=None):
    """Generates the docked complex and outputs files for docked complex."""
    if out_dir is None:
      out_dir = tempfile.mkdtemp()

    # Prepare receptor
    receptor_name = os.path.basename(protein_file).split(".")[0]
    protein_hyd = os.path.join(out_dir, "%s.pdb" % receptor_name)
    protein_pdbqt = os.path.join(out_dir, "%s.pdbqt" % receptor_name)
    hydrogenate_and_compute_partial_charges(
        protein_file,
        "pdb",
        hyd_output=protein_hyd,
        pdbqt_output=protein_pdbqt,
        protein=True)
    # Get protein centroid and range
    # TODO(rbharath): Need to add some way to identify binding pocket, or this is
    # going to be extremely slow!
    if centroid is not None and box_dims is not None:
      protein_centroid = centroid
    else:
      if not self.detect_pockets:
        receptor_mol = rdkit_util.load_molecule(
            protein_hyd, calc_charges=False, add_hydrogens=False)
        protein_centroid = mol_xyz_util.get_molecule_centroid(receptor_mol[0])
        protein_range = mol_xyz_util.get_molecule_range(receptor_mol[0])
        box_dims = protein_range + 5.0
      else:
        logger.info("About to find putative binding pockets")
        pockets, pocket_atoms_maps, pocket_coords = self.pocket_finder.find_pockets(
            protein_file, ligand_file)
        # TODO(rbharath): Handle multiple pockets instead of arbitrarily selecting
        # first pocket.
        logger.info("Computing centroid and size of proposed pocket.")
        pocket_coord = pocket_coords[0]
        protein_centroid = np.mean(pocket_coord, axis=1)
        pocket = pockets[0]
        (x_min, x_max), (y_min, y_max), (z_min, z_max) = pocket
        x_box = (x_max - x_min) / 2.
        y_box = (y_max - y_min) / 2.
        z_box = (z_max - z_min) / 2.
        box_dims = (x_box, y_box, z_box)

    # Prepare receptor
    ligand_name = os.path.basename(ligand_file).split(".")[0]
    ligand_hyd = os.path.join(out_dir, "%s.pdb" % ligand_name)
    ligand_pdbqt = os.path.join(out_dir, "%s.pdbqt" % ligand_name)

    # TODO(rbharath): Generalize this so can support mol2 files as well.
    hydrogenate_and_compute_partial_charges(
        ligand_file,
        "sdf",
        hyd_output=ligand_hyd,
        pdbqt_output=ligand_pdbqt,
        protein=False)
    # Write Vina conf file
    conf_file = os.path.join(out_dir, "conf.txt")
    write_conf(
        protein_pdbqt,
        ligand_pdbqt,
        protein_centroid,
        box_dims,
        conf_file,
        exhaustiveness=self.exhaustiveness)

    # Define locations of log and output files
    log_file = os.path.join(out_dir, "%s_log.txt" % ligand_name)
    out_pdbqt = os.path.join(out_dir, "%s_docked.pdbqt" % ligand_name)
    # TODO(rbharath): Let user specify the number of poses required.
    if not dry_run:
      logger.info("About to call Vina")
      call(
          "%s --config %s --log %s --out %s" % (self.vina_cmd, conf_file,
                                                log_file, out_pdbqt),
          shell=True)
    # TODO(rbharath): Convert the output pdbqt to a pdb file.

    # Return docked files
    return protein_hyd, out_pdbqt

deepchem/dock/pose_scoring.py

deleted100644 → 0
+0 −50
Original line number Diff line number Diff line
"""
Scores protein-ligand poses using DeepChem.
"""
from deepchem.feat import RdkitGridFeaturizer

__author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "MIT"

import numpy as np
import os
import tempfile
from deepchem.data import NumpyDataset
from subprocess import call


class PoseScorer(object):
  """Abstract superclass for all scoring methods."""

  def score(self, protein_file, ligand_file):
    """Returns a score for a protein/ligand pair."""
    raise NotImplementedError


class GridPoseScorer(object):

  def __init__(self, model, feat="grid"):
    """Initializes a pose-scorer."""
    self.model = model
    if feat == "grid":
      self.featurizer = RdkitGridFeaturizer(
          voxel_width=16.0,
          # TODO: add pi_stack and cation_pi to feature_types (it's not trivial
          # because they require sanitized molecules)
          # feature_types=["ecfp", "splif", "hbond", "pi_stack", "cation_pi",
          # "salt_bridge"],
          feature_types=["ecfp", "splif", "hbond", "salt_bridge"],
          ecfp_power=9,
          splif_power=9,
          flatten=True)
    else:
      raise ValueError("feat not defined.")

  def score(self, protein_file, ligand_file):
    """Returns a score for a protein/ligand pair."""
    features, _ = self.featurizer.featurize_complexes([ligand_file],
                                                      [protein_file])
    dataset = NumpyDataset(X=features, y=None, w=None, ids=None)
    score = self.model.predict(dataset)
    return score
Loading