Commit 536f97d4 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Cleaning up and adding tests

parent 260581b8
Loading
Loading
Loading
Loading
+116 −121
Original line number Diff line number Diff line
@@ -58,10 +58,10 @@ class ChargeVoxelizer(ComplexFeaturizer):
  """

  def __init__(self,
               cutoff=4.5,
               box_width=16.0,
               voxel_width=1.0,
               reduce_to_contacts=True):
               cutoff: float = 4.5,
               box_width: float = 16.0,
               voxel_width: float = 1.0,
               reduce_to_contacts: bool = True):
    """
    Parameters
    ----------
@@ -81,7 +81,7 @@ class ChargeVoxelizer(ComplexFeaturizer):
    self.voxel_width = voxel_width
    self.reduce_to_contacts = reduce_to_contacts

  def _featurize(self, mol_pdb: str, protein_pdb: str):
  def _featurize(self, mol_pdb: str, protein_pdb: str) -> np.ndarray:
    """
    Compute featurization for a single mol/protein complex

@@ -117,7 +117,7 @@ class ChargeVoxelizer(ComplexFeaturizer):
          sum([
              voxelize(
                  convert_atom_to_voxel,
                  hash_function=hash_ecfp_pair,
                  hash_function=None,
                  coordinates=xyz,
                  box_width=self.box_width,
                  voxel_width=self.voxel_width,
@@ -147,10 +147,10 @@ class SaltBridgeVoxelizer(ComplexFeaturizer):
  """

  def __init__(self,
               cutoff=5.0,
               box_width=16.0,
               voxel_width=1.0,
               reduce_to_contacts=True):
               cutoff: float = 5.0,
               box_width: float = 16.0,
               voxel_width: float = 1.0,
               reduce_to_contacts: bool = True):
    """
    Parameters
    ----------
@@ -171,14 +171,16 @@ class SaltBridgeVoxelizer(ComplexFeaturizer):
    self.voxel_width = voxel_width
    self.reduce_to_contacts = reduce_to_contacts

  def _featurize(self, mol_pdb: str, protein_pdb: str):
  def _featurize(self, mol_pdb: str, protein_pdb: str) -> np.ndarray:
    """
    Compute featurization for a single mol/protein complex

    Parameters
    ----------
    molecular_complex: Object
      Some representation of a molecular complex.
    mol_pdb: str
      Filename for ligand molecule
    protein_pdb: str
      Filename for protein molecule
    """
    molecular_complex = (mol_pdb, protein_pdb)
    try:
@@ -202,15 +204,17 @@ class SaltBridgeVoxelizer(ComplexFeaturizer):
      xyzs = [frag1_xyz, frag2_xyz]
      rdks = [frag1[1], frag2[1]]
      pairwise_features.append(
          sum([
              voxelize(
                  convert_atom_pair_to_voxel,
              self.box_width,
              self.voxel_width,
              None,
              xyzs,
                  hash_function=None,
                  coordinates=xyz,
                  box_width=self.box_width,
                  voxel_width=self.voxel_width,
                  feature_list=compute_salt_bridges(
                      frag1[1], frag2[1], distances, cutoff=self.cutoff),
              nb_channel=1))
                  nb_channel=1) for xyz in xyzs
          ]))
    # Features are of shape (voxels_per_edge, voxels_per_edge, voxels_per_edge, 1) so we should concatenate on the last axis.
    return np.concatenate(pairwise_features, axis=-1)

@@ -226,20 +230,19 @@ class CationPiVoxelizer(ComplexFeaturizer):

  Let `voxels_per_edge = int(box_width/voxel_width)`.  Creates a
  tensor output of shape `(voxels_per_edge, voxels_per_edge,
  voxels_per_edge, 1)` for each macromolecular the number of cation-pi
  interactions at each voxel.
  voxels_per_edge, 1)` for each macromolecular complex that counts the
  number of cation-pi interactions at each voxel.
  """

  def __init__(self,
               distance_cutoff=6.5,
               angle_cutoff=30.0,
               box_width=16.0,
               voxel_width=1.0):
    #reduce_to_contacts=True):
               cutoff: float = 6.5,
               angle_cutoff: float = 30.0,
               box_width: float = 16.0,
               voxel_width: float = 1.0):
    """
    Parameters
    ----------
    distance_cutoff: float, optional (default 6.5)
    cutoff: float, optional (default 6.5)
      The distance in angstroms within which atoms must be to
      be considered for a cation-pi interaction between them.
    angle_cutoff: float, optional (default 30.0)
@@ -251,16 +254,13 @@ class CationPiVoxelizer(ComplexFeaturizer):
      is centered on a ligand centroid.
    voxel_width: float, optional (default 1.0)
      Size of a 3D voxel in a grid.
    #reduce_to_contacts: bool, optional
    #  If True, reduce the atoms in the complex to those near a contact
    #  region.
    """
    self.distance_cutoff = distance_cutoff
    self.cutoff = cutoff
    self.angle_cutoff = angle_cutoff
    self.box_width = box_width
    self.voxel_width = voxel_width

  def _featurize(self, mol_pdb: str, protein_pdb: str):
  def _featurize(self, mol_pdb: str, protein_pdb: str) -> np.ndarray:
    """
    Compute featurization for a single mol/protein complex

@@ -281,7 +281,7 @@ class CationPiVoxelizer(ComplexFeaturizer):
      return None
    pairwise_features = []
    # We compute pairwise contact fingerprints
    centroid = compute_contact_centroid(fragments, cutoff=self.distance_cutoff)
    centroid = compute_contact_centroid(fragments, cutoff=self.cutoff)
    for (frag1_ind, frag2_ind) in itertools.combinations(
        range(len(fragments)), 2):
      frag1, frag2 = fragments[frag1_ind], fragments[frag2_ind]
@@ -294,17 +294,17 @@ class CationPiVoxelizer(ComplexFeaturizer):
          sum([
              voxelize(
                  convert_atom_to_voxel,
                  self.box_width,
                  self.voxel_width,
                  None,
                  xyz,
                  hash_function=None,
                  box_width=self.box_width,
                  voxel_width=self.voxel_width,
                  coordinates=xyz,
                  feature_dict=cation_pi_dict,
                  nb_channel=1) for xyz, cation_pi_dict in zip(
                      xyzs,
                      compute_binding_pocket_cation_pi(
                          frag1[1],
                          frag2[1],
                          dist_cutoff=self.distance_cutoff,
                          dist_cutoff=self.cutoff,
                          angle_cutoff=self.angle_cutoff,
                      ))
          ]))
@@ -323,21 +323,21 @@ class PiStackVoxelizer(ComplexFeaturizer):

  Let `voxels_per_edge = int(box_width/voxel_width)`.  Creates a
  tensor output of shape `(voxels_per_edge, voxels_per_edge,
  voxels_per_edge, 2)` for each macromolecular Each voxel has 2
  fields, with the first tracking the number of pi-pi parallel
  voxels_per_edge, 2)` for each macromolecular complex. Each voxel has
  2 fields, with the first tracking the number of pi-pi parallel
  interactions, and the second tracking the number of pi-T
  interactions.
  """

  def __init__(self,
               distance_cutoff=4.4,
               angle_cutoff=30.0,
               box_width=16.0,
               voxel_width=1.0):
               cutoff: float = 4.4,
               angle_cutoff: float = 30.0,
               box_width: float = 16.0,
               voxel_width: float = 1.0):
    """
    Parameters
    ----------
    distance_cutoff: float, optional (default 4.4)
    cutoff: float, optional (default 4.4)
      The distance in angstroms within which atoms must be to
      be considered for a cation-pi interaction between them.
    angle_cutoff: float, optional (default 30.0)
@@ -350,12 +350,12 @@ class PiStackVoxelizer(ComplexFeaturizer):
    voxel_width: float, optional (default 1.0)
      Size of a 3D voxel in a grid.
    """
    self.distance_cutoff = distance_cutoff
    self.cutoff = cutoff
    self.angle_cutoff = angle_cutoff
    self.box_width = box_width
    self.voxel_width = voxel_width

  def _featurize(self, mol_pdb: str, protein_pdb: str):
  def _featurize(self, mol_pdb: str, protein_pdb: str) -> np.ndarray:
    """
    Compute featurization for a single mol/protein complex

@@ -376,7 +376,7 @@ class PiStackVoxelizer(ComplexFeaturizer):
      return None
    pairwise_features = []
    # We compute pairwise contact fingerprints
    centroid = compute_contact_centroid(fragments, cutoff=self.distance_cutoff)
    centroid = compute_contact_centroid(fragments, cutoff=self.cutoff)
    for (frag1_ind, frag2_ind) in itertools.combinations(
        range(len(fragments)), 2):
      frag1, frag2 = fragments[frag1_ind], fragments[frag2_ind]
@@ -385,48 +385,38 @@ class PiStackVoxelizer(ComplexFeaturizer):
      frag2_xyz = subtract_centroid(frag2[0], centroid)
      xyzs = [frag1_xyz, frag2_xyz]
      rdks = [frag1[1], frag2[1]]
      #(lig_xyz, lig_rdk), (prot_xyz, prot_rdk) = mol, protein
      #distances = compute_pairwise_distances(prot_xyz, lig_xyz)
      protein_pi_t, protein_pi_parallel, ligand_pi_t, ligand_pi_parallel = (
          compute_pi_stack(
              frag1[1],
              frag2[1],
              distances,
              dist_cutoff=self.distance_cutoff,
              dist_cutoff=self.cutoff,
              angle_cutoff=self.angle_cutoff))
      pi_parallel_tensor = voxelize(
          convert_atom_to_voxel,
          self.box_width,
          self.voxel_width,
          None,
          frag1_xyz,
          feature_dict=protein_pi_parallel,
          nb_channel=1)
      pi_parallel_tensor += voxelize(
      pi_parallel_tensor = sum([
          voxelize(
              convert_atom_to_voxel,
          self.box_width,
          self.voxel_width,
          None,
          frag2_xyz,
          feature_dict=ligand_pi_parallel,
              hash_function=None,
              box_width=self.box_width,
              voxel_width=self.voxel_width,
              coordinates=xyz,
              feature_dict=feature_dict,
              nb_channel=1)
          for (xyz, feature_dict
              ) in zip(xyzs, [ligand_pi_parallel, protein_pi_parallel])
      ])

      pi_t_tensor = voxelize(
      pi_t_tensor = sum([
          voxelize(
              convert_atom_to_voxel,
          self.box_width,
          self.voxel_width,
          None,
          frag1_xyz,
              hash_function=None,
              box_width=self.box_width,
              voxel_width=self.voxel_width,
              coordinates=frag1_xyz,
              feature_dict=protein_pi_t,
              nb_channel=1)
      pi_t_tensor += voxelize(
          convert_atom_to_voxel,
          self.box_width,
          self.voxel_width,
          None,
          frag2_xyz,
          feature_dict=ligand_pi_t,
          nb_channel=1)
          for (xyz, feature_dict) in zip(xyzs, [ligand_pi_t, protein_pi_t])
      ])

      pairwise_features.append(
          np.concatenate([pi_parallel_tensor, pi_t_tensor], axis=-1))
    # Features are of shape (voxels_per_edge, voxels_per_edge, voxels_per_edge, 2) so we should concatenate on the last axis.
@@ -437,19 +427,19 @@ class HydrogenBondCounter(ComplexFeaturizer):
  """Counts hydrogen bonds between atoms in macromolecular complexes.

  Given a macromolecular complex made up of multiple
  constitutent molecules, count the number hydrogen bonds
  constitutent molecules, count the number of hydrogen bonds
  between atoms in the macromolecular complex.

  Creates a scalar output of shape `(3,)` (assuming the default value
  ofor `distance_bins` with 3 bins) for each macromolecular that
  computes the total number of hydrogen bonds.
  ofor `distance_bins` with 3 bins) for each macromolecular complex
  that computes the total number of hydrogen bonds.
  """

  def __init__(self,
               cutoff=4.5,
               distance_bins=None,
               angle_cutoffs=None,
               reduce_to_contacts=True):
               cutoff: float = 4.5,
               distance_bins: List[Tuple] = None,
               angle_cutoffs: List[float] = None,
               reduce_to_contacts: bool = True):
    """
    Parameters
    ----------
@@ -479,15 +469,18 @@ class HydrogenBondCounter(ComplexFeaturizer):
      self.angle_cutoffs = angle_cutoffs
    self.reduce_to_contacts = reduce_to_contacts

  def _featurize_complex(self, molecular_complex):
  def _featurize(self, mol_pdb: str, protein_pdb: str) -> np.ndarray:
    """
    Compute featurization for a single mol/protein complex

    Parameters
    ----------
    molecular_complex: Object
      Some representation of a molecular complex.
    mol_pdb: str
      Filename for ligand molecule
    protein_pdb: str
      Filename for protein molecule
    """
    molecular_complex = (mol_pdb, protein_pdb)
    try:
      fragments = rdkit_utils.load_complex(
          molecular_complex, add_hydrogens=False)
@@ -509,8 +502,6 @@ class HydrogenBondCounter(ComplexFeaturizer):
      frag2_xyz = subtract_centroid(frag2[0], centroid)
      xyzs = [frag1_xyz, frag2_xyz]
      rdks = [frag1[1], frag2[1]]
      #(lig_xyz, lig_rdk), (prot_xyz, prot_rdk) = mol, protein
      #distances = compute_pairwise_distances(prot_xyz, lig_xyz)
      pairwise_features.append(
          np.concatenate(
              [
@@ -538,17 +529,17 @@ class HydrogenBondVoxelizer(ComplexFeaturizer):
  Let `voxels_per_edge = int(box_width/voxel_width)`.  Creates a
  tensor output of shape `(voxels_per_edge, voxels_per_edge,
  voxels_per_edge, 3)` (assuming the default for `distance_bins` which
  has 3 bins) for each macromolecular the number of hydrogen bonds at
  each voxel.
  has 3 bins) for each macromolecular complex that counts the number
  of hydrogen bonds at each voxel.
  """

  def __init__(self,
               cutoff=4.5,
               distance_bins=None,
               angle_cutoffs=None,
               box_width=16.0,
               voxel_width=1.0,
               reduce_to_contacts=True):
               cutoff: float = 4.5,
               distance_bins: List[Tuple] = None,
               angle_cutoffs: List[float] = None,
               box_width: float = 16.0,
               voxel_width: float = 1.0,
               reduce_to_contacts: bool = True):
    """
    Parameters
    ----------
@@ -585,15 +576,18 @@ class HydrogenBondVoxelizer(ComplexFeaturizer):
    self.voxel_width = voxel_width
    self.reduce_to_contacts = reduce_to_contacts

  def _featurize_complex(self, molecular_complex):
  def _featurize(self, mol_pdb: str, protein_pdb: str) -> np.ndarray:
    """
    Compute featurization for a single mol/protein complex

    Parameters
    ----------
    molecular_complex: Object
      Some representation of a molecular complex.
    mol_pdb: str
      Filename for ligand molecule
    protein_pdb: str
      Filename for protein molecule
    """
    molecular_complex = (mol_pdb, protein_pdb)
    try:
      fragments = rdkit_utils.load_complex(
          molecular_complex, add_hydrogens=False)
@@ -617,15 +611,16 @@ class HydrogenBondVoxelizer(ComplexFeaturizer):
      pairwise_features.append(
          np.concatenate(
              [
                  sum([
                      voxelize(
                          convert_atom_pair_to_voxel,
                      self.box_width,
                      self.voxel_width,
                      #None, (prot_xyz, lig_xyz),
                      None,
                      xyzs,
                          hash_function=None,
                          box_width=self.box_width,
                          voxel_width=self.voxel_width,
                          coordinates=xyz,
                          feature_list=hbond_list,
                      nb_channel=1) for hbond_list in compute_hydrogen_bonds(
                          nb_channel=1) for xyz in xyzs
                  ]) for hbond_list in compute_hydrogen_bonds(
                      frag1, frag2, distances, self.distance_bins,
                      self.angle_cutoffs)
              ],
+60 −5
Original line number Diff line number Diff line
@@ -15,23 +15,78 @@ def test_charge_voxelizer():
  voxelizer = dc.feat.ChargeVoxelizer(
      cutoff=cutoff, box_width=box_width, voxel_width=voxel_width)
  features, failures = voxelizer.featurize([ligand_file], [protein_file])
  # TODO: Add shape test


def test_salt_bridge_voxelizer():
  pass
  current_dir = os.path.dirname(os.path.realpath(__file__))
  protein_file = os.path.join(current_dir, 'data',
                              '3ws9_protein_fixer_rdkit.pdb')
  ligand_file = os.path.join(current_dir, 'data', '3ws9_ligand.sdf')

  cutoff = 4.5
  box_width = 16
  voxel_width = 1.0
  voxelizer = dc.feat.SaltBridgeVoxelizer(
      cutoff=cutoff, box_width=box_width, voxel_width=voxel_width)
  features, failures = voxelizer.featurize([ligand_file], [protein_file])
  # TODO: Add shape test


def test_cation_pi_voxelizer():
  pass
  current_dir = os.path.dirname(os.path.realpath(__file__))
  protein_file = os.path.join(current_dir, 'data',
                              '3ws9_protein_fixer_rdkit.pdb')
  ligand_file = os.path.join(current_dir, 'data', '3ws9_ligand.sdf')

  cutoff = 4.5
  box_width = 16
  voxel_width = 1.0
  voxelizer = dc.feat.CationPiVoxelizer(
      cutoff=cutoff, box_width=box_width, voxel_width=voxel_width)
  features, failures = voxelizer.featurize([ligand_file], [protein_file])
  # TODO: Add shape test


def test_pi_stack_voxelizer():
  pass
  current_dir = os.path.dirname(os.path.realpath(__file__))
  protein_file = os.path.join(current_dir, 'data',
                              '3ws9_protein_fixer_rdkit.pdb')
  ligand_file = os.path.join(current_dir, 'data', '3ws9_ligand.sdf')

  cutoff = 4.5
  box_width = 16
  voxel_width = 1.0
  voxelizer = dc.feat.PiStackVoxelizer(
      cutoff=cutoff, box_width=box_width, voxel_width=voxel_width)
  features, failures = voxelizer.featurize([ligand_file], [protein_file])
  # TODO: Add shape test


# TODO: This is failing, something about the hydrogen bond counting?
def test_hydrogen_bond_counter():
  pass
  current_dir = os.path.dirname(os.path.realpath(__file__))
  protein_file = os.path.join(current_dir, 'data',
                              '3ws9_protein_fixer_rdkit.pdb')
  ligand_file = os.path.join(current_dir, 'data', '3ws9_ligand.sdf')

  cutoff = 4.5
  featurizer = dc.feat.HydrogenBondCounter(cutoff=cutoff)
  features, failures = featurizer.featurize([ligand_file], [protein_file])
  # TODO: Add shape test


# TODO: This is failing, something about the hydrogen bond counting?
def test_hydrogen_bond_voxelizer():
  pass
  current_dir = os.path.dirname(os.path.realpath(__file__))
  protein_file = os.path.join(current_dir, 'data',
                              '3ws9_protein_fixer_rdkit.pdb')
  ligand_file = os.path.join(current_dir, 'data', '3ws9_ligand.sdf')

  cutoff = 4.5
  box_width = 16
  voxel_width = 1.0
  voxelizer = dc.feat.HydrogenBondVoxelizer(
      cutoff=cutoff, box_width=box_width, voxel_width=voxel_width)
  features, failures = voxelizer.featurize([ligand_file], [protein_file])
  # TODO: Add shape test
+27 −22
Original line number Diff line number Diff line
@@ -61,9 +61,9 @@ def is_hydrogen_bond(frag1,
  Parameters
  ----------
  frag1: tuple
    Tuple of (coords, rdkit mol / MolecularFragment
    Tuple of (coords, rdkit mol / MolecularFragment)
  frag2: tuple
    Tuple of (coords, rdkit mol / MolecularFragment
    Tuple of (coords, rdkit mol / MolecularFragment)
  contact: Tuple
    Tuple of indices for (atom_i, atom_j) contact. 
  hbond_distance_cutoff: float, optional
@@ -277,7 +277,7 @@ def compute_pi_stack(mol1,
    mol1: rdkit.rdchem.Mol
      First molecule.
    mol2: rdkit.rdchem.Mol
      First molecule.
      Second molecule.
    pairwise_distances: np.ndarray (optional)
      Array of pairwise interatomic distances (Angstroms)
    dist_cutoff: float
@@ -390,16 +390,16 @@ def is_pi_t(ring1_center,
  return False


def is_pi_parallel(ring1_center,
                   ring1_normal,
                   ring2_center,
                   ring2_normal,
                   dist_cutoff=8.0,
                   angle_cutoff=30.0):
def is_pi_parallel(ring1_center: np.ndarray,
                   ring1_normal: np.ndarray,
                   ring2_center: np.ndarray,
                   ring2_normal: np.ndarray,
                   dist_cutoff: float = 8.0,
                   angle_cutoff: float = 30.0) -> bool:
  """Check if two aromatic rings form a parallel pi-pi contact.

  Parameters:
  -----------
  Parameters
  ----------
  ring1_center, ring2_center: np.ndarray
    Positions of centers of the two rings. Can be computed with the
    compute_ring_center function.
@@ -411,6 +411,11 @@ def is_pi_parallel(ring1_center,
  angle_cutoff: float
    Angle cutoff. Max allowed deviation from the ideal (0deg) angle between
    the rings (in degrees).

  Returns
  -------
  bool
    True if two aromatic rings form a parallel pi-pi.
  """

  dist = np.linalg.norm(ring1_center - ring2_center)
+2 −12
Original line number Diff line number Diff line
@@ -503,7 +503,7 @@ def compute_ring_center(mol, ring_indices):


def get_contact_atom_indices(fragments: List, cutoff: float = 4.5) -> List:
  """Compute that atoms close to contact region.
  """Compute the atoms close to contact region.

  Molecular complexes can get very large. This can make it unwieldy to
  compute functions on them. To improve memory usage, it can be very
@@ -528,7 +528,7 @@ def get_contact_atom_indices(fragments: List, cutoff: float = 4.5) -> List:
  is a list of atom indices from that molecule which should be kept, in
  sorted order.
  """
  # indices to atoms to keep
  # indices of atoms to keep
  keep_inds: List[Set] = [set([]) for _ in fragments]
  for (ind1, ind2) in itertools.combinations(range(len(fragments)), 2):
    frag1, frag2 = fragments[ind1], fragments[ind2]
@@ -545,16 +545,6 @@ def get_contact_atom_indices(fragments: List, cutoff: float = 4.5) -> List:
  keep_ind_lists = [sorted(list(keep)) for keep in keep_inds]
  return keep_ind_lists

  # Now extract atoms
  #atoms_to_keep = []
  #for i, frag_keep_inds in enumerate(keep_inds):
  #  frag = fragments[i]
  #  mol = frag[1]
  #  atoms = mol.GetAtoms()
  #  frag_keep = [atoms[keep_ind] for keep_ind in frag_keep_inds]
  #  atoms_to_keep.append(frag_keep)
  #return atoms_to_keep


def get_mol_subset(coords, mol, atom_indices_to_keep):
  """Strip a subset of the atoms in this molecule