Commit 77ece2a4 authored by marta-sd's avatar marta-sd
Browse files

deal with py2/py3 integers + change formatting in tests

parent d3c6a575
Loading
Loading
Loading
Loading
+42 −50
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@
Test rdkit_grid_featurizer module.
"""
import os
from six import integer_types
import unittest

import numpy as np
@@ -91,6 +92,7 @@ class TestHelperFunctions(unittest.TestCase):
    coords2 = np.random.rand(n2, 3)

    distance = rgf.compute_pairwise_distances(coords1, coords2)
    self.assertEqual(distance.shape, (n1, n2))
    self.assertTrue((distance >= 0).all())
    # random coords between 0 and 1, so the max possible distance in sqrt(2)
    self.assertTrue((distance <= 2.0**0.5).all())
@@ -114,7 +116,7 @@ class TestHelperFunctions(unittest.TestCase):
      for _ in range(10):
        string = random_string(10)
        string_hash = rgf.hash_ecfp(string, power)
        self.assertIsInstance(string_hash, int)
        self.assertIsInstance(string_hash, integer_types)
        self.assertLess(string_hash, 2**power)
        self.assertGreaterEqual(string_hash, 0)

@@ -124,7 +126,7 @@ class TestHelperFunctions(unittest.TestCase):
        string1 = random_string(10)
        string2 = random_string(10)
        pair_hash = rgf.hash_ecfp_pair((string1, string2), power)
        self.assertIsInstance(pair_hash, int)
        self.assertIsInstance(pair_hash, integer_types)
        self.assertLess(pair_hash, 2**power)
        self.assertGreaterEqual(pair_hash, 0)

@@ -150,23 +152,17 @@ class TestHelperFunctions(unittest.TestCase):
  def test_featurize_binding_pocket_ecfp(self):
    prot_xyz, prot_rdk = rgf.load_molecule(self.protein_file)
    lig_xyz, lig_rdk = rgf.load_molecule(self.ligand_file)
      distance = rgf.compute_pairwise_distances(protein_xyz=prot_xyz,
                                                ligand_xyz=lig_xyz)
    distance = rgf.compute_pairwise_distances(
        protein_xyz=prot_xyz, ligand_xyz=lig_xyz)

    # check if results are the same if we provide precomputed distances
    prot_dict, lig_dict = rgf.featurize_binding_pocket_ecfp(
        prot_xyz,
        prot_rdk,
        lig_xyz,
        lig_rdk,
      )
        lig_rdk,)
    prot_dict_dist, lig_dict_dist = rgf.featurize_binding_pocket_ecfp(
        prot_xyz,
        prot_rdk,
        lig_xyz,
        lig_rdk,
        pairwise_distances=distance
      )
        prot_xyz, prot_rdk, lig_xyz, lig_rdk, pairwise_distances=distance)
    # ...but first check if we actually got two dicts
    self.assertIsInstance(prot_dict, dict)
    self.assertIsInstance(lig_dict, dict)
@@ -180,15 +176,13 @@ class TestHelperFunctions(unittest.TestCase):
        prot_rdk,
        lig_xyz,
        lig_rdk,
        cutoff=2.0,
      )
        cutoff=2.0,)
    prot_dict_d6, lig_dict_d6 = rgf.featurize_binding_pocket_ecfp(
        prot_xyz,
        prot_rdk,
        lig_xyz,
        lig_rdk,
        cutoff=6.0,
      )
        cutoff=6.0,)
    self.assertLess(len(prot_dict_d2), len(prot_dict))
    # ligands are typically small so all atoms might be present
    self.assertLessEqual(len(lig_dict_d2), len(lig_dict))
@@ -201,8 +195,7 @@ class TestHelperFunctions(unittest.TestCase):
        prot_rdk,
        lig_xyz,
        lig_rdk,
        ecfp_degree=3,
      )
        ecfp_degree=3,)
    self.assertNotEqual(prot_dict_e3, prot_dict)
    self.assertNotEqual(lig_dict_e3, lig_dict)

@@ -211,16 +204,15 @@ class TestHelperFunctions(unittest.TestCase):
    lig_xyz, lig_rdk = rgf.load_molecule(self.ligand_file)
    prot_num_atoms = prot_rdk.GetNumAtoms()
    lig_num_atoms = lig_rdk.GetNumAtoms()
    distance = rgf.compute_pairwise_distances(protein_xyz=prot_xyz,
                                              ligand_xyz=lig_xyz)
    distance = rgf.compute_pairwise_distances(
        protein_xyz=prot_xyz, ligand_xyz=lig_xyz)

    for bins in ((0, 2), (2, 3)):
      splif_dict = rgf.compute_splif_features_in_range(
          prot_rdk,
          lig_rdk,
          distance,
        bins,
      )
          bins,)

      self.assertIsInstance(splif_dict, dict)
      for (prot_idx, lig_idx), ecfp_pair in splif_dict.items():