Commit b09844ec authored by marta-sd's avatar marta-sd
Browse files

test modified RdkitGridFeaturizer._voxelize

parent 9b0afeee
Loading
Loading
Loading
Loading
+54 −0
Original line number Diff line number Diff line
@@ -312,3 +312,57 @@ class TestHelperFunctions(unittest.TestCase):
      for i in range(mol.GetNumAtoms()):
        self.assertIn(i, charge_dict)
        self.assertIsInstance(charge_dict[i], (float, int))


class TestRdkitGridFeaturizer(unittest.TestCase):
  """
  Test RdkitGridFeaturizer class defined in rdkit_grid_featurizer module.
  """

  def setUp(self):
    current_dir = os.path.dirname(os.path.realpath(__file__))
    package_dir = os.path.dirname(os.path.dirname(current_dir))
    self.protein_file = os.path.join(package_dir, 'dock', 'tests',
                                     '1jld_protein.pdb')
    self.ligand_file = os.path.join(package_dir, 'dock', 'tests',
                                    '1jld_ligand.sdf')

  def test_voxelize(self):
    prot_xyz, prot_rdk = rgf.load_molecule(self.protein_file)
    lig_xyz, lig_rdk = rgf.load_molecule(self.ligand_file)

    centroid = rgf.compute_centroid(lig_xyz)
    prot_xyz = rgf.subtract_centroid(prot_xyz, centroid)
    lig_xyz = rgf.subtract_centroid(lig_xyz, centroid)

    prot_ecfp_dict, lig_ecfp_dict = (rgf.featurize_binding_pocket_ecfp(
        prot_xyz, prot_rdk, lig_xyz, lig_rdk))

    box_w = 20
    f_power = 5

    rgf_featurizer = rgf.RdkitGridFeaturizer(
        box_width=box_w, ecfp_power=f_power)

    prot_tensor = rgf_featurizer._voxelize(
        rgf.convert_atom_to_voxel,
        rgf.hash_ecfp,
        prot_xyz,
        feature_dict=prot_ecfp_dict,
        channel_power=f_power)
    self.assertEqual(prot_tensor.shape, tuple([box_w] * 3 + [2**f_power]))
    all_features = prot_tensor.sum()
    # protein is too big for the box, some features should be missing
    self.assertGreater(all_features, 0)
    self.assertLess(all_features, prot_rdk.GetNumAtoms())

    lig_tensor = rgf_featurizer._voxelize(
        rgf.convert_atom_to_voxel,
        rgf.hash_ecfp,
        lig_xyz,
        feature_dict=lig_ecfp_dict,
        channel_power=f_power)
    self.assertEqual(lig_tensor.shape, tuple([box_w] * 3 + [2**f_power]))
    all_features = lig_tensor.sum()
    # whole ligand should fit in the box
    self.assertEqual(all_features, lig_rdk.GetNumAtoms())