Commit 8cf1dd33 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

testing voxelizers

parent 5ad01702
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -28,3 +28,6 @@ from deepchem.feat.smiles_featurizers import SmilesToSeq, SmilesToImage
from deepchem.feat.materials_featurizers import ElementPropertyFingerprint, SineCoulombMatrix, StructureGraphFeaturizer
from deepchem.feat.contact_fingerprints import ContactCircularFingerprint
from deepchem.feat.contact_fingerprints import ContactCircularVoxelizer
from deepchem.feat.grid_featurizers import ChargeVoxelizer 
from deepchem.feat.grid_featurizers import SaltBridgeVoxelizer
from deepchem.feat.grid_featurizers import CationPiVoxelizer
+32 −4
Original line number Diff line number Diff line
@@ -11,6 +11,8 @@ from deepchem.utils.hash_utils import vectorize
from deepchem.utils.voxel_utils import voxelize 
from deepchem.utils.voxel_utils import convert_atom_to_voxel
from deepchem.utils.rdkit_util import compute_all_ecfp
from deepchem.utils.rdkit_util import compute_contact_centroid
from deepchem.utils.rdkit_util import subtract_centroid
from deepchem.utils.rdkit_util import compute_pairwise_distances
from deepchem.utils.rdkit_util import MoleculeLoadException

@@ -186,11 +188,37 @@ class ContactCircularVoxelizer(ComplexFeaturizer):
      return None
    pairwise_features = []
    # We compute pairwise contact fingerprints
    centroid = compute_contact_centroid(fragments, cutoff=self.cutoff)
    ############################################
    #print("centroid")
    #print(centroid)
    ############################################
    for (frag1, frag2) in itertools.combinations(fragments, 2):
      distances = compute_pairwise_distances(frag1[0], frag2[0])
      xyzs = [frag1[0], frag2[0]]
      ############################################
      #print("np.max(frag1[0])")
      #print(np.max(frag1[0]))
      #print("np.min(frag1[0])")
      #print(np.min(frag1[0]))
      ############################################
      frag1_xyz = subtract_centroid(frag1[0], centroid)
      frag2_xyz = subtract_centroid(frag2[0], centroid)
      xyzs = [frag1_xyz, frag2_xyz]
      #(lig_xyz, lig_rdk), (prot_xyz, prot_rdk) = mol, protein
      #distances = compute_pairwise_distances(prot_xyz, lig_xyz)
      ###########################################
      ##print("np.max(frag1[0])")
      ##print(np.max(frag1[0]))
      ##print("np.min(frag1[0])")
      ##print(np.min(frag1[0]))
      #print("np.max(frag1_xyz)")
      #print(np.max(frag1_xyz))
      #print("np.min(frag1_xyz)")
      #print(np.min(frag1_xyz))
      ###########################################
      # TODO(rbharath): I think the reason this isn't making errors is
      # that it's already computing contact map under the hood which
      # prunes out atoms outside the box
      pairwise_features.append(
          sum([
              voxelize(
@@ -211,9 +239,9 @@ class ContactCircularVoxelizer(ComplexFeaturizer):
                                            ecfp_degree=self.radius))
          ])
      )
    ############################################
    print("[feat.shape for feat in pairwise_features]")
    print([feat.shape for feat in pairwise_features])
    #############################################
    #print("[feat.shape for feat in pairwise_features]")
    #print([feat.shape for feat in pairwise_features])
    ############################################
    # Features are of shape (voxels_per_edge, voxels_per_edge, voxels_per_edge, num_feat) so we should concatenate on the last axis.
    return np.concatenate(pairwise_features, axis=-1)
+140 −60
Original line number Diff line number Diff line
"""
Compute various spatial fingerprints for macromolecular complexes.
"""
import itertools
import logging
import numpy as np
from deepchem.utils import rdkit_util 
from deepchem.utils.rdkit_util import get_partial_charge
from deepchem.feat import ComplexFeaturizer
from deepchem.utils.hash_utils import hash_ecfp_pair
@@ -14,6 +17,10 @@ from deepchem.utils.voxel_utils import convert_atom_pair_to_voxel
from deepchem.utils.rdkit_util import compute_pairwise_distances
from deepchem.utils.rdkit_util import compute_pi_stack
from deepchem.utils.rdkit_util import compute_hydrogen_bonds
from deepchem.utils.rdkit_util import MoleculeLoadException
from deepchem.utils.rdkit_util import compute_contact_centroid
from deepchem.utils.rdkit_util import subtract_centroid
from deepchem.utils.rdkit_util import reduce_molecular_complex_to_contacts

logger = logging.getLogger(__name__)

@@ -49,34 +56,58 @@ class ChargeVoxelizer(ComplexFeaturizer):
  complex that computes the effective charge at each voxel.
  """
  def __init__(self, 
               cutoff=4.5,
               box_width=16.0,
               voxel_width=1.0):
               voxel_width=1.0,
               reduce_to_contacts=True):
    """
    Parameters
    ----------
    cutoff: float (default 4.5)
      Distance cutoff in angstroms for molecules in complex.
    box_width: float, optional (default 16.0)
      Size of a box in which voxel features are calculated. Box
      is centered on a ligand centroid.
    voxel_width: float, optional (default 1.0)
      Size of a 3D voxel in a grid.
    reduce_to_contacts: bool, optional
      If True, reduce the atoms in the complex to those near a contact
      region.
    """
    self.cutoff = cutoff
    self.box_width = box_width
    self.voxel_width = voxel_width
    self.voxels_per_edge = int(self.box_width / self.voxel_width)
    self.reduce_to_contacts = reduce_to_contacts

  def _featurize_complex(self, mol, protein):
  def _featurize_complex(self, molecular_complex):
    """
    Compute featurization for a single mol/protein complex

    Parameters
    ----------
    mol: object
      Representation of the molecule
    protein: object
      Representation of the protein
    """
    (lig_xyz, lig_rdk), (prot_xyz, prot_rdk) = mol, protein
    return [
    molecular_complex: Object
      Some representation of a molecular complex.
    """
    try:
      fragments = rdkit_util.load_complex(molecular_complex, add_hydrogens=False)

    except MoleculeLoadException:
      logger.warning("This molecule cannot be loaded by Rdkit. Returning None")
      return None
    pairwise_features = []
    # We compute pairwise contact fingerprints
    centroid = compute_contact_centroid(fragments, cutoff=self.cutoff)
    if self.reduce_to_contacts:
      fragments = reduce_molecular_complex_to_contacts(fragments, self.cutoff)
    # We compute pairwise contact fingerprints
    for (frag1_ind, frag2_ind) in itertools.combinations(range(len(fragments)), 2):
      frag1, frag2 = fragments[frag1_ind], fragments[frag2_ind]
      frag1_xyz = subtract_centroid(frag1[0], centroid)
      frag2_xyz = subtract_centroid(frag2[0], centroid)
      xyzs = [frag1_xyz, frag2_xyz]
      rdks = [frag1[1], frag2[1]]
      pairwise_features.append(
        sum([
            voxelize(
                convert_atom_to_voxel,
@@ -88,9 +119,11 @@ class ChargeVoxelizer(ComplexFeaturizer):
                feature_dict=compute_charge_dictionary(mol),
                nb_channel=1,
                dtype="np.float16")
            for xyz, mol in ((prot_xyz, prot_rdk), (lig_xyz, lig_rdk))
            for xyz, mol in zip(xyzs, rdks)
        ])
    ]
      )
    # Features are of shape (voxels_per_edge, voxels_per_edge, voxels_per_edge, 1) so we should concatenate on the last axis.
    return np.concatenate(pairwise_features, axis=-1)

class SaltBridgeVoxelizer(ComplexFeaturizer):
  """Localize salt bridges between atoms in macromolecular complexes.
@@ -110,7 +143,8 @@ class SaltBridgeVoxelizer(ComplexFeaturizer):
  def __init__(self, 
               cutoff=5.0,
               box_width=16.0,
               voxel_width=1.0):
               voxel_width=1.0,
               reduce_to_contacts=True):
    """
    Parameters
    ----------
@@ -122,39 +156,61 @@ class SaltBridgeVoxelizer(ComplexFeaturizer):
      is centered on a ligand centroid.
    voxel_width: float, optional (default 1.0)
      Size of a 3D voxel in a grid.
    reduce_to_contacts: bool, optional
      If True, reduce the atoms in the complex to those near a contact
      region.
    """
    self.cutoff = cutoff
    self.box_width = box_width
    self.voxel_width = voxel_width
    self.voxels_per_edge = int(self.box_width / self.voxel_width)
    self.reduce_to_contacts = reduce_to_contacts

  def _featurize_complex(self, mol, protein):
  def _featurize_complex(self, molecular_complex):
    """
    Compute featurization for a single mol/protein complex

    Parameters
    ----------
    mol: object
      Representation of the molecule
    protein: object
      Representation of the protein
    """
    (lig_xyz, lig_rdk), (prot_xyz, prot_rdk) = mol, protein
    distances = compute_pairwise_distances(prot_xyz, lig_xyz)
    return [
    molecular_complex: Object
      Some representation of a molecular complex.
    """
    try:
      fragments = rdkit_util.load_complex(molecular_complex, add_hydrogens=False)

    except MoleculeLoadException:
      logger.warning("This molecule cannot be loaded by Rdkit. Returning None")
      return None
    pairwise_features = []
    # We compute pairwise contact fingerprints
    centroid = compute_contact_centroid(fragments, cutoff=self.cutoff)
    if self.reduce_to_contacts:
      fragments = reduce_molecular_complex_to_contacts(fragments, self.cutoff)
    #(lig_xyz, lig_rdk), (prot_xyz, prot_rdk) = mol, protein
    #distances = compute_pairwise_distances(prot_xyz, lig_xyz)
    for (frag1_ind, frag2_ind) in itertools.combinations(range(len(fragments)), 2):
      frag1, frag2 = fragments[frag1_ind], fragments[frag2_ind]
      distances = compute_pairwise_distances(frag1[0], frag2[0])
      frag1_xyz = subtract_centroid(frag1[0], centroid)
      frag2_xyz = subtract_centroid(frag2[0], centroid)
      xyzs = [frag1_xyz, frag2_xyz]
      rdks = [frag1[1], frag2[1]]
      pairwise_features.append( 
          voxelize(
              convert_atom_pair_to_voxel,
              self.voxels_per_edge,
              self.box_width,
              self.voxel_width,
            None, (prot_xyz, lig_xyz),
              None, xyzs,
              feature_list=compute_salt_bridges(
                prot_rdk,
                lig_rdk,
                  frag1[1],
                  frag2[1],
                  distances,
                  cutoff=self.cutoff),
              nb_channel=1)
    ]
      )
    # Features are of shape (voxels_per_edge, voxels_per_edge, voxels_per_edge, 1) so we should concatenate on the last axis.
    return np.concatenate(pairwise_features, axis=-1)

class CationPiVoxelizer(ComplexFeaturizer):
  """Localize cation-Pi interactions between atoms in macromolecular complexes.
@@ -173,7 +229,8 @@ class CationPiVoxelizer(ComplexFeaturizer):
               distance_cutoff=6.5,
               angle_cutoff=30.0,
               box_width=16.0,
               voxel_width=1.0):
               voxel_width=1.0,
               reduce_to_contacts=True):
    """
    Parameters
    ----------
@@ -189,6 +246,9 @@ class CationPiVoxelizer(ComplexFeaturizer):
      is centered on a ligand centroid.
    voxel_width: float, optional (default 1.0)
      Size of a 3D voxel in a grid.
    reduce_to_contacts: bool, optional
      If True, reduce the atoms in the complex to those near a contact
      region.
    """
    self.distance_cutoff = distance_cutoff
    self.angle_cutoff = angle_cutoff
@@ -196,20 +256,38 @@ class CationPiVoxelizer(ComplexFeaturizer):
    self.voxel_width = voxel_width
    self.voxels_per_edge = int(self.box_width / self.voxel_width)

  def _featurize_complex(self, mol, protein):
  def _featurize_complex(self, molecular_complex):
    """
    Compute featurization for a single mol/protein complex

    Parameters
    ----------
    mol: object
      Representation of the molecule
    protein: object
      Representation of the protein
    """
    (lig_xyz, lig_rdk), (prot_xyz, prot_rdk) = mol, protein
    distances = compute_pairwise_distances(prot_xyz, lig_xyz)
    return [
    molecular_complex: Object
      Some representation of a molecular complex.
    """
    try:
      fragments = rdkit_util.load_complex(molecular_complex, add_hydrogens=False)

    except MoleculeLoadException:
      logger.warning("This molecule cannot be loaded by Rdkit. Returning None")
      return None
    pairwise_features = []
    # We compute pairwise contact fingerprints
    centroid = compute_contact_centroid(fragments, cutoff=self.cutoff)
    if self.reduce_to_contacts:
      fragments = reduce_molecular_complex_to_contacts(fragments, self.cutoff)
    #(lig_xyz, lig_rdk), (prot_xyz, prot_rdk) = mol, protein
    #distances = compute_pairwise_distances(prot_xyz, lig_xyz)
    for (frag1_ind, frag2_ind) in itertools.combinations(range(len(fragments)), 2):
      frag1, frag2 = fragments[frag1_ind], fragments[frag2_ind]
      distances = compute_pairwise_distances(frag1[0], frag2[0])
    #(lig_xyz, lig_rdk), (prot_xyz, prot_rdk) = mol, protein
    #distances = compute_pairwise_distances(prot_xyz, lig_xyz)
      frag1_xyz = subtract_centroid(frag1[0], centroid)
      frag2_xyz = subtract_centroid(frag2[0], centroid)
      xyzs = [frag1_xyz, frag2_xyz]
      rdks = [frag1[1], frag2[1]]
      pairwise_features.append(
          sum([
              voxelize(
                  convert_atom_to_voxel,
@@ -220,15 +298,17 @@ class CationPiVoxelizer(ComplexFeaturizer):
                  xyz,
                  feature_dict=cation_pi_dict,
                  nb_channel=1) for xyz, cation_pi_dict in zip(
                    (prot_xyz, lig_xyz),
                      xyzs,
                      compute_binding_pocket_cation_pi(
                        prot_rdk,
                        lig_rdk,
                          frag1[1],
                          frag2[1],
                          dist_cutoff=self.distance_cutoff,
                          angle_cutoff=self.angle_cutoff,
                      ))
          ])
    ]
      )
    # Features are of shape (voxels_per_edge, voxels_per_edge, voxels_per_edge, 1) so we should concatenate on the last axis.
    return np.concatenate(pairwise_features, axis=-1)

class PiStackVoxelizer(ComplexFeaturizer):
  """Localize Pi stacking interactions between atoms in macromolecular complexes.
+66 −18
Original line number Diff line number Diff line
@@ -483,6 +483,9 @@ class MolecularFragment(object):
  example, if two molecules form a molecular complex, it may be useful
  to create two fragments which represent the subsets of each molecule
  that's close to the other molecule (in the contact region).

  Ideally, we'd be able to do this in RDKit direct, but manipulating
  molecular fragments doesn't seem to be supported functionality. 
  """

  def __init__(self, atoms):
@@ -491,9 +494,10 @@ class MolecularFragment(object):
    Parameters
    ----------
    atoms: list
      Each entry in this list should be an rdkit Atom object.
      Each entry in this list should be an RdkitAtom
    """
    self.atoms = [AtomShim(x) for x in atoms]
    #self.atoms = [AtomShim(x) for x in atoms]
    self.atoms = [AtomShim(x.GetAtomicNum(), get_partial_charge(x)) for x in atoms]
    #self.atoms = atoms 

  def GetAtoms(self):
@@ -513,19 +517,25 @@ class AtomShim(object):
  the basic information in an AtomShim seems to avoid issues.
  """

  def __init__(self, atomic_num):
  def __init__(self, atomic_num, partial_charge):
    """Initialize this object

    Parameters
    ----------
    atomic_num: int
      Atomic number for this atom.
    partial_charge: float
      The partial Gasteiger charge for this atom
    """
    self.atomic_num = atomic_num
    self.partial_charge = partial_charge

  def GetAtomicNum(self):
    return self.atomic_num

  def GetPartialCharge(self):
    return self.partial_charge

def get_mol_subset(coords, mol, atom_indices_to_keep):
  """Strip a subset of the atoms in this molecule

@@ -533,7 +543,7 @@ def get_mol_subset(coords, mol, atom_indices_to_keep):
  ----------
  coords: Numpy ndarray
    Must be of shape (N, 3) and correspond to coordinates of mol.
  mol: Rdkit mol
  mol: Rdkit mol or `MolecularFragment`
    The molecule to strip
  atom_indices_to_keep: list
    List of the indices of the atoms to keep. Each index is a unique
@@ -545,15 +555,18 @@ def get_mol_subset(coords, mol, atom_indices_to_keep):
  coordinates with hydrogen coordinates. mol_frag is a
  `MolecularFragment`. 
  """

  from rdkit import Chem
  indexes_to_keep = []
  atoms_to_keep = []
  #####################################################
  # Compute partial charges on molecule if rdkit
  if isinstance(mol, Chem.Mol):
    compute_charges(mol)
  #####################################################
  atoms = list(mol.GetAtoms())
  for index in atom_indices_to_keep:
    indexes_to_keep.append(index)
    #atomic_numbers.append(atom.GetAtomicNum())
    atoms_to_keep.append(atoms[index])
  #mol = MolecularFragment(atomic_numbers)
  mol_frag = MolecularFragment(atoms_to_keep)
  coords = coords[indexes_to_keep]
  return coords, mol_frag
@@ -685,8 +698,17 @@ def compute_centroid(coordinates):
def subtract_centroid(xyz, centroid):
  """Subtracts centroid from each coordinate.

  Subtracts the centroid, a numpy array of dim 3, from all coordinates of all
  atoms in the molecule
  Subtracts the centroid, a numpy array of dim 3, from all coordinates
  of all atoms in the molecule

  Note that this update is made in place to the array it's applied to.

  Parameters
  ----------
  xyz: numpy array
    Of shape `(N, 3)`
  centroid: numpy array
    Of shape `(3,)`
  """
  xyz -= np.transpose(centroid)
  return (xyz)
@@ -758,16 +780,42 @@ def rotate_molecules(mol_coordinates_list):
  mol_coordinates_list: list
    Elements of list must be (N_atoms, 3) shaped arrays
  """
  from rdkit.Chem import rdmolops
  if len(molecules) == 0:
    return None
  elif len(molecules) == 1:
    return molecules[0]
  R = generate_random_rotation_matrix()
  rotated_coordinates_list = []

  for mol_coordinates in mol_coordinates_list:
    coordinates = deepcopy(mol_coordinates)
    rotated_coordinates = np.transpose(np.dot(R, np.transpose(coordinates)))
    rotated_coordinates_list.append(rotated_coordinates)

  return (rotated_coordinates_list)

def get_partial_charge(atom):
  """Get partial charge of a given atom (rdkit Atom object)
  
  Parameters
  ----------
  atom: rdkit atom or `AtomShim` object
    Either an rdkit atom or `AtomShim`
  """
  from rdkit import Chem
  if isinstance(atom, Chem.Atom):
    try:
      value = atom.GetProp(str("_GasteigerCharge"))
      if value == '-nan':
        return 0
      return float(value)
    except KeyError:
      return 0
  else:
    combined = molecules[0]
    for nextmol in molecules[1:]:
      combined = rdmolops.CombineMols(combined, nextmol)
    return combined
    return atom.GetPartialCharge()

def is_salt_bridge(atom_i, atom_j):
  """Check if two atoms have correct charges to form a salt bridge"""
  if np.abs(2.0 - np.abs(
      get_partial_charge(atom_i) - get_partial_charge(atom_j))) < 0.01:
    return True
  return False


def is_hydrogen_bond(protein_xyz,
+1 −1
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ from deepchem.molnet.load_function.pdbbind_datasets import get_pdbbind_molecular

complex_files = get_pdbbind_molecular_complex_files(subset="core", version="v2015", interactions="protein-ligand", load_binding_pocket=False)
n_complexes = len(complex_files)
n_featurize = n_complexes
n_featurize = 2
core_subset = complex_files[:n_featurize]

max_num_neighbors = 4
Loading