Commit 8ffc36fe authored by marta-sd's avatar marta-sd
Browse files

use kwargs to specify custom bins and cutoffs

parent 9ee36e9c
Loading
Loading
Loading
Loading
+50 −23
Original line number Diff line number Diff line
@@ -524,7 +524,7 @@ def is_cation_pi(cation_position,
  return False


def compute_cation_pi(mol1, mol2, charge_tolerance=0.01):
def compute_cation_pi(mol1, mol2, charge_tolerance=0.01, **kwargs):
  """Finds aromatic rings in mo1 interacting with cations in mol2"""
  mol1_pi = Counter()
  mol2_cation = Counter()
@@ -541,15 +541,15 @@ def compute_cation_pi(mol1, mol2, charge_tolerance=0.01):
      for atom in mol2.GetAtoms():
        if atom.GetFormalCharge() > 1.0 - charge_tolerance:
          cation_position = np.array(conformer.GetAtomPosition(atom.GetIdx()))
          if is_cation_pi(cation_position, ring_center, ring_normal):
          if is_cation_pi(cation_position, ring_center, ring_normal, **kwargs):
            mol1_pi.update(ring)
            mol2_cation.update([atom.GetIndex()])
  return mol1_pi, mol2_cation


def compute_binding_pocket_cation_pi(protein, ligand):
  protein_pi, ligand_cation = compute_cation_pi(protein, ligand)
  ligand_pi, protein_cation = compute_cation_pi(ligand, protein)
def compute_binding_pocket_cation_pi(protein, ligand, **kwargs):
  protein_pi, ligand_cation = compute_cation_pi(protein, ligand, **kwargs)
  ligand_pi, protein_cation = compute_cation_pi(ligand, protein, **kwargs)

  protein_cation_pi = Counter()
  protein_cation_pi.update(protein_pi)
@@ -585,11 +585,15 @@ def is_salt_bridge(atom_i, atom_j):
  return False


def compute_salt_bridges(protein_xyz, protein, ligand_xyz, ligand,
                         pairwise_distances):
def compute_salt_bridges(protein_xyz,
                         protein,
                         ligand_xyz,
                         ligand,
                         pairwise_distances,
                         cutoff=5.0):
  salt_bridge_contacts = []

  contacts = np.nonzero(pairwise_distances < 5.0)
  contacts = np.nonzero(pairwise_distances < cutoff)
  contacts = zip(contacts[0], contacts[1])
  for contact in contacts:
    protein_atom = protein.GetAtoms()[int(contact[0])]
@@ -754,13 +758,28 @@ class RdkitGridFeaturizer(ComplexFeaturizer):

    self.ligand_only = ligand_only

    self.hbond_dist_bins = [(2.2, 2.5), (2.5, 3.2), (3.2, 4.0)]
    self.hbond_angle_cutoffs = [5, 50, 90]
    self.contact_bins = [(0, 2.0), (2.0, 3.0), (3.0, 4.5)]
    # default values
    self.cutoffs = {
        'hbond_dist_bins': [(2.2, 2.5), (2.5, 3.2), (3.2, 4.0)],
        'hbond_angle_cutoffs': [5, 50, 90],
        'splif_contact_bins': [(0, 2.0), (2.0, 3.0), (3.0, 4.5)],
        'ecfp_cutoff': 4.5,
        'sybyl_cutoff': 7.0,
        'salt_bridges_cutoff': 5.0,
        'pi_stack_dist_cutoff': 4.4,
        'pi_stack_angle_cutoff': 30.0,
        'cation_pi_dist_cutoff': 6.5,
        'cation_pi_angle_cutoff': 30.0,
    }

    # update with cutoffs specified by the user
    for arg, value in kwargs.items():
      if arg in self.cutoffs:
        self.cutoffs[arg] = value

    self.box_width = float(box_width)
    self.voxel_width = float(voxel_width)
    self.voxels_per_edge = self.box_width / self.voxel_width
    self.voxels_per_edge = int(self.box_width / self.voxel_width)

    self.sybyl_types = [
        "C3", "C2", "C1", "Cac", "Car", "N3", "N3+", "Npl", "N2", "N1", "Ng+",
@@ -788,7 +807,7 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
                lig_xyz,
                lig_rdk,
                distances,
                cutoff=4.5,
                cutoff=self.cutoffs['ecfp_cutoff'],
                ecfp_degree=self.ecfp_degree)],

        'splif_hashed': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
@@ -801,7 +820,7 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
                prot_rdk,
                lig_xyz,
                lig_rdk,
                self.contact_bins,
                self.cutoffs['splif_contact_bins'],
                distances,
                self.ecfp_degree)],

@@ -816,13 +835,18 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
                lig_xyz,
                lig_rdk,
                distances,
                self.hbond_dist_bins,
                self.hbond_angle_cutoffs)]
                self.cutoffs['hbond_dist_bins'],
                self.cutoffs['hbond_angle_cutoffs'])]
    }

    def voxelize_pi_stack(prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances):
      protein_pi_t, protein_pi_parallel, ligand_pi_t, ligand_pi_parallel = (
          compute_pi_stack(prot_rdk, lig_rdk, distances))
          compute_pi_stack(
              prot_rdk,
              lig_rdk,
              distances,
              dist_cutoff=self.cutoffs['pi_stack_dist_cutoff'],
              angle_cutoff=self.cutoffs['pi_stack_angle_cutoff']))
      pi_parallel_tensor = self._voxelize(
          convert_atom_to_voxel,
          None,
@@ -865,7 +889,7 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
                    lig_xyz,
                    lig_rdk,
                    distances,
                    cutoff=4.5,
                    cutoff=self.cutoffs['ecfp_cutoff'],
                    ecfp_degree=self.ecfp_degree
                ))])],

@@ -881,7 +905,7 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
                prot_rdk,
                lig_xyz,
                lig_rdk,
                self.contact_bins,
                self.cutoffs['splif_contact_bins'],
                distances,
                self.ecfp_degree)],

@@ -899,7 +923,7 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
                    lig_xyz,
                    lig_rdk,
                    distances,
                    cutoff=7.0
                    cutoff=self.cutoffs['sybyl_cutoff']
                ))],

        'salt_bridge': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
@@ -912,7 +936,8 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
                    prot_rdk,
                    lig_xyz,
                    lig_rdk,
                    distances),
                    distances,
                    cutoff=self.cutoffs['salt_bridges_cutoff']),
                nb_channel=1
            )],

@@ -939,8 +964,8 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
                lig_xyz,
                lig_rdk,
                distances,
                self.hbond_dist_bins,
                self.hbond_angle_cutoffs)
                self.cutoffs['hbond_dist_bins'],
                self.cutoffs['hbond_angle_cutoffs'])
            ],
        'pi_stack': voxelize_pi_stack,

@@ -955,6 +980,8 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
                (prot_xyz, lig_xyz), compute_binding_pocket_cation_pi(
                    prot_rdk,
                    lig_rdk,
                    dist_cutoff=self.cutoffs['cation_pi_dist_cutoff'],
                    angle_cutoff=self.cutoffs['cation_pi_angle_cutoff'],
                ))])],
    }

+7 −3
Original line number Diff line number Diff line
@@ -487,8 +487,9 @@ class TestRdkitGridFeaturizer(unittest.TestCase):
    feature_tensor = featurizer.featurize_complexes([self.ligand_file],
                                                    [self.protein_file])
    self.assertIsInstance(feature_tensor, np.ndarray)
    total_len = (2**ecfp_power + len(featurizer.contact_bins) * 2**splif_power +
                 len(featurizer.hbond_dist_bins) + 4)
    total_len = (2**ecfp_power +
                 len(featurizer.cutoffs['splif_contact_bins']) * 2**splif_power
                 + len(featurizer.cutoffs['hbond_dist_bins']) + 4)
    self.assertEqual(feature_tensor.shape, (1, total_len))

  def test_voxelize(self):
@@ -506,7 +507,10 @@ class TestRdkitGridFeaturizer(unittest.TestCase):
    f_power = 5

    rgf_featurizer = rgf.RdkitGridFeaturizer(
        box_width=box_w, ecfp_power=f_power)
        box_width=box_w,
        ecfp_power=f_power,
        feature_types=['all_combined'],
        flatten=True)

    prot_tensor = rgf_featurizer._voxelize(
        rgf.convert_atom_to_voxel,