Commit 65cd15b9 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by Bharath Ramsundar
Browse files

Debugging splif fingerprints/voxelizer

parent d0b20b17
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -39,6 +39,8 @@ from deepchem.feat.complex_featurizers import NeighborListComplexAtomicCoordinat
from deepchem.feat.complex_featurizers import ComplexNeighborListFragmentAtomicCoordinates
from deepchem.feat.complex_featurizers import ContactCircularFingerprint
from deepchem.feat.complex_featurizers import ContactCircularVoxelizer
from deepchem.feat.complex_featurizers import SplifFingerprint
from deepchem.feat.complex_featurizers import SplifVoxelizer
from deepchem.feat.complex_featurizers import ChargeVoxelizer
from deepchem.feat.complex_featurizers import SaltBridgeVoxelizer
from deepchem.feat.complex_featurizers import CationPiVoxelizer
+2 −0
Original line number Diff line number Diff line
@@ -14,3 +14,5 @@ from deepchem.feat.complex_featurizers.grid_featurizers import CationPiVoxelizer
from deepchem.feat.complex_featurizers.grid_featurizers import PiStackVoxelizer
from deepchem.feat.complex_featurizers.grid_featurizers import HydrogenBondVoxelizer
from deepchem.feat.complex_featurizers.grid_featurizers import HydrogenBondCounter
from deepchem.feat.complex_featurizers.splif_fingerprints import SplifFingerprint
from deepchem.feat.complex_featurizers.splif_fingerprints import SplifVoxelizer
+55 −59
Original line number Diff line number Diff line
@@ -2,26 +2,34 @@
SPLIF Fingerprints for molecular complexes.
"""
import logging
import itertools
import numpy as np
from deepchem.utils.hash_utils import hash_ecfp_pair
from deepchem.utils.rdkit_util import compute_all_ecfp
from deepchem.utils.rdkit_utils import load_complex
from deepchem.utils.rdkit_utils import compute_all_ecfp
from deepchem.utils.rdkit_utils import MoleculeLoadException
from deepchem.utils.rdkit_utils import compute_contact_centroid
from deepchem.utils.rdkit_utils import reduce_molecular_complex_to_contacts
from deepchem.feat import ComplexFeaturizer
from deepchem.utils.hash_utils import vectorize
from deepchem.utils.voxel_utils import voxelize
from deepchem.utils.voxel_utils import convert_atom_to_voxel
from deepchem.utils.voxel_utils import convert_atom_pair_to_voxel
from deepchem.utils.geometry_utils import compute_pairwise_distances
from deepchem.utils.geometry_utils import subtract_centroid

from typing import Tuple, Dict, List

logger = logging.getLogger(__name__)

SPLIF_CONTACT_BINS = [(0, 2.0), (2.0, 3.0), (3.0, 4.5)]


def compute_splif_features_in_range(frag1,
                                    frag2,
                                    pairwise_distances,
                                    contact_bin,
                                    ecfp_degree=2):
def compute_splif_features_in_range(frag1: Tuple,
                                    frag2: Tuple,
                                    pairwise_distances: np.ndarray,
                                    contact_bin: List,
                                    ecfp_degree: int = 2) -> Dict:
  """Computes SPLIF features for close atoms in molecular complexes.

  Finds all frag1 atoms that are > contact_bin[0] and <
@@ -33,11 +41,11 @@ def compute_splif_features_in_range(frag1,
  Parameters
  ----------
  frag1: Tuple
    A tuple of (coords, mol) returned by `rdkit_util.load_molecule`.
    A tuple of (coords, mol) returned by `load_molecule`.
  frag2: Tuple
    A tuple of (coords, mol) returned by `rdkit_util.load_molecule`.
    A tuple of (coords, mol) returned by `load_molecule`.
  contact_bins: np.ndarray
    TODO 
    Ranges of pair distances which are placed in separate bins.
  pairwise_distances: np.ndarray
    Array of pairwise fragment-fragment distances (Angstroms)
  ecfp_degree: int
@@ -49,13 +57,13 @@ def compute_splif_features_in_range(frag1,
  contacts = zip(contacts[0], contacts[1])

  frag1_ecfp_dict = compute_all_ecfp(
      frag1, indices=frag1_atoms, degree=ecfp_degree)
  frag2_ecfp_dict = compute_all_ecfp(frag2, degree=ecfp_degree)
      frag1[1], indices=frag1_atoms, degree=ecfp_degree)
  frag2_ecfp_dict = compute_all_ecfp(frag2[1], degree=ecfp_degree)
  splif_dict = {
      contact: (frag1_ecfp_dict[contact[0]], frag2_ecfp_dict[contact[1]])
      for contact in contacts
  }
  return (splif_dict)
  return splif_dict


def featurize_splif(frag1, frag2, contact_bins, pairwise_distances,
@@ -70,15 +78,15 @@ def featurize_splif(frag1, frag2, contact_bins, pairwise_distances,
  Parameters
  ----------
  frag1: Tuple
    A tuple of (coords, mol) returned by `rdkit_util.load_molecule`.
    A tuple of (coords, mol) returned by `load_molecule`.
  frag2: Tuple
    A tuple of (coords, mol) returned by `rdkit_util.load_molecule`.
    A tuple of (coords, mol) returned by `load_molecule`.
  contact_bins: np.ndarray
    TODO 
    Ranges of pair distances which are placed in separate bins.
  pairwise_distances: np.ndarray
    Array of pairwise fragment-fragment distances (Angstroms)
  ecfp_degree: int
    ECFP radius
    ECFP radius, the graph distance at which fragments are computed.

  Returns
  -------
@@ -91,7 +99,7 @@ def featurize_splif(frag1, frag2, contact_bins, pairwise_distances,
        compute_splif_features_in_range(frag1, frag2, pairwise_distances,
                                        contact_bin, ecfp_degree))

  return (splif_dicts)
  return splif_dicts


class SplifFingerprint(ComplexFeaturizer):
@@ -118,7 +126,7 @@ class SplifFingerprint(ComplexFeaturizer):
  for direct contacts instead of the entire contact region.

  For a macromolecular complex, returns a vector of shape
  `(2*size,)`
  `(len(contact_bins)*size,)`
  """

  def __init__(self, contact_bins=None, radius=2, size=8):
@@ -140,18 +148,20 @@ class SplifFingerprint(ComplexFeaturizer):
    self.size = size
    self.radius = radius

  def _featurize_complex(self, molecular_complex):
  def _featurize(self, mol_pdb: str, complex_pdb: str):
    """
    Compute featurization for a molecular complex

    Parameters
    ----------
    molecular_complex: Object
      Some representation of a molecular complex.
    mol_pdb: str
      Filename for ligand molecule
    complex_pdb: str
      Filename for protein molecule
    """
    molecular_complex = (mol_pdb, complex_pdb)
    try:
      fragments = rdkit_util.load_complex(
          molecular_complex, add_hydrogens=False)
      fragments = load_complex(molecular_complex, add_hydrogens=False)

    except MoleculeLoadException:
      logger.warning("This molecule cannot be loaded by Rdkit. Returning None")
@@ -161,15 +171,13 @@ class SplifFingerprint(ComplexFeaturizer):
    for (frag1, frag2) in itertools.combinations(fragments, 2):
      # Get coordinates
      distances = compute_pairwise_distances(frag1[0], frag2[0])
      #(lig_xyz, lig_rdk), (prot_xyz, prot_rdk) = mol, protein
      #distances = compute_pairwise_distances(prot_xyz, lig_xyz)
      vectors = [
          vectorize(hash_ecfp_pair, feature_dict=splif_dict,
                    size=self.size) for splif_dict in featurize_splif(
                        prot_xyz, prot_rdk, lig_xyz, lig_rdk, self.contact_bins,
                        distances, self.radius)
                        frag1, frag2, self.contact_bins, distances, self.radius)
      ]
      pairwse_features += vector
      pairwise_features += vectors
    pairwise_features = np.concatenate(pairwise_features)
    return pairwise_features

@@ -196,15 +204,17 @@ class SplifVoxelizer(ComplexFeaturizer):
  """

  def __init__(self,
               contact_bins=None,
               radius=2,
               size=8,
               box_width=16.0,
               voxel_width=1.0,
               reduce_to_contacts=True):
               cutoff: float = 4.5,
               contact_bins: List = None,
               radius: int = 2,
               size: int = 8,
               box_width: float = 16.0,
               voxel_width: float = 1.0):
    """
    Parameters
    ----------
    cutoff: float (default 4.5)
      Distance cutoff in angstroms for molecules in complex.
    contact_bins: list[tuple] 
      List of contact bins. If not specified is set to default
      `[(0, 2.0), (2.0, 3.0), (3.0, 4.5)]`.
@@ -217,10 +227,8 @@ class SplifVoxelizer(ComplexFeaturizer):
      is centered on a ligand centroid.
    voxel_width: float, optional (default 1.0)
      Size of a 3D voxel in a grid.
    reduce_to_contacts: bool, optional
      If True, reduce the atoms in the complex to those near a contact
      region.
    """
    self.cutoff = cutoff
    if contact_bins is None:
      self.contact_bins = SPLIF_CONTACT_BINS
    else:
@@ -230,29 +238,21 @@ class SplifVoxelizer(ComplexFeaturizer):
    self.box_width = box_width
    self.voxel_width = voxel_width
    self.voxels_per_edge = int(self.box_width / self.voxel_width)
    self.reduce_to_contacts = reduce_to_contacts

  def _featurize_complex(self, molecular_complex):
  def _featurize(self, mol_pdb: str, complex_pdb: str):
    """
    Compute featurization for a single mol/protein complex

    TODO(rbharath): This is very not ergonomic. I'd much prefer
    returning an vector instead of a list of two vectors. In
    addition, there's a question of efficiency.
    RdkitGridFeaturizer caches rotated versions etc internally.
    To make things work out of box, we are accepting that
    kludgey input. This needs to be cleaned up before full
    merge.
    Compute featurization for a molecular complex

    Parameters
    ----------
    molecular_complex: Object
      A representation of a molecular complex, produced by
      `rdkit_util.load_complex`.
    mol_pdb: str
      Filename for ligand molecule
    complex_pdb: str
      Filename for protein molecule
    """
    molecular_complex = (mol_pdb, complex_pdb)
    try:
      fragments = rdkit_util.load_complex(
          molecular_complex, add_hydrogens=False)
      fragments = load_complex(molecular_complex, add_hydrogens=False)

    except MoleculeLoadException:
      logger.warning("This molecule cannot be loaded by Rdkit. Returning None")
@@ -260,15 +260,11 @@ class SplifVoxelizer(ComplexFeaturizer):
    pairwise_features = []
    # We compute pairwise contact fingerprints
    centroid = compute_contact_centroid(fragments, cutoff=self.cutoff)
    if self.reduce_to_contacts:
      fragments = reduce_molecular_complex_to_contacts(fragments, self.cutoff)
    for (frag1, frag2) in itertools.combinations(fragments, 2):
      distances = compute_pairwise_distances(frag1[0], frag2[0])
      frag1_xyz = subtract_centroid(frag1[0], centroid)
      frag2_xyz = subtract_centroid(frag2[0], centroid)
      xyzs = [frag1_xyz, frag2_xyz]
      #(lig_xyz, lig_rdk), (prot_xyz, prot_rdk) = mol, protein
      #distances = compute_pairwise_distances(prot_xyz, lig_xyz)
      pairwise_features.append(
          np.concatenate(
              [
@@ -279,9 +275,9 @@ class SplifVoxelizer(ComplexFeaturizer):
                      hash_ecfp_pair,
                      xyzs,
                      feature_dict=splif_dict,
                      nb_channel=self.size) for splif_dict in featurize_splif(
                          prot_xyz, prot_rdk, lig_xyz, lig_rdk,
                          self.contact_bins, distances, self.radius)
                      nb_channel=self.size)
                  for splif_dict in featurize_splif(
                      frag1, frag2, self.contact_bins, distances, self.radius)
              ],
              axis=-1))
    # Features are of shape (voxels_per_edge, voxels_per_edge, voxels_per_edge, 1) so we should concatenate on the last axis.
+22 −2
Original line number Diff line number Diff line
import unittest
import os
import deepchem as dc


@@ -8,7 +9,26 @@ class TestSplifFingerprints(unittest.TestCase):
  def setUp(self):
    # TODO test more formats for ligand
    current_dir = os.path.dirname(os.path.realpath(__file__))
    self.protein_file = os.path.join(current_dir,
    self.protein_file = os.path.join(current_dir, 'data',
                                     '3ws9_protein_fixer_rdkit.pdb')
    self.ligand_file = os.path.join(current_dir, '3ws9_ligand.sdf')
    self.ligand_file = os.path.join(current_dir, 'data', '3ws9_ligand.sdf')
    self.complex_files = [(self.protein_file, self.ligand_file)]

  def test_splif_shape(self):
    size = 8
    featurizer = dc.feat.SplifFingerprint(size=size)
    features, failures = featurizer.featurize([self.ligand_file],
                                              [self.protein_file])
    assert features.shape == (1, 3 * size)

  def test_splif_voxels_shape(self):
    box_width = 48
    voxel_width = 2
    voxels_per_edge = box_width / voxel_width
    size = 8
    voxelizer = dc.feat.SplifVoxelizer(
        box_width=box_width, voxel_width=voxel_width, size=size)
    features, failures = voxelizer._featurize(self.ligand_file,
                                              self.protein_file)
    assert features.shape == (1, voxels_per_edge, voxels_per_edge,
                              voxels_per_edge, size)
+10 −0
Original line number Diff line number Diff line
@@ -119,6 +119,16 @@ class MolecularFragment(object):
    """
    return self.atoms

  def GetNumAtoms(self) -> int:
    """Returns the number of atoms

    Returns
    -------
    int
      Number of atoms in this fragment.
    """
    return len(self.atoms)

  def GetCoords(self) -> np.ndarray:
    """Returns 3D coordinates for this fragment as numpy array.