Unverified Commit a59823d5 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #879 from marta-sd/RGF_tests

Changes in rdkit_grid_featurizer
parents 457b25e0 49c70053
Loading
Loading
Loading
Loading
+5 −7
Original line number Diff line number Diff line
@@ -34,15 +34,13 @@ class GridPoseScorer(object):
    if feat == "grid":
      self.featurizer = RdkitGridFeaturizer(
          voxel_width=16.0,
          feature_types="voxel_combined",
          # TODO(rbharath, enf): Figure out why pi_stack is slow and cation_pi
          # causes segfaults.
          #voxel_feature_types=["ecfp", "splif", "hbond", "pi_stack", "cation_pi",
          #"salt_bridge"], ecfp_power=9, splif_power=9,
          voxel_feature_types=["ecfp", "splif", "hbond", "salt_bridge"],
          # TODO: add pi_stack and cation_pi to feature_types (it's not trivial
          # because they require sanitized molecules)
          # feature_types=["ecfp", "splif", "hbond", "pi_stack", "cation_pi",
          # "salt_bridge"],
          feature_types=["ecfp", "splif", "hbond", "salt_bridge"],
          ecfp_power=9,
          splif_power=9,
          parallel=True,
          flatten=True)
    else:
      raise ValueError("feat not defined.")
+782 −534

File changed.

Preview size limit exceeded, changes collapsed.

+117 −0
Original line number Diff line number Diff line
3ws9_ligand

Created by X-TOOL on Sat Nov 28 16:04:28 2015
 46 50  0  0  0  0  0  0  0  0999 V2000
   17.9520    7.9520   55.8870  C 0  0  0  2  0  3
   19.0490    8.5940   56.4540  C 0  0  0  2  0  3
   17.8800    6.5720   55.8140  C 0  0  0  2  0  3
   24.2190    5.8580   55.5190  C 0  0  0  2  0  3
   22.8760    6.1220   55.6870  C 0  0  0  2  0  3
   20.0970    7.8390   56.9560  C 0  0  0  2  0  3
   18.9340    5.8230   56.3200  C 0  0  0  2  0  3
   24.4600    4.6370   57.5640  C 0  0  0  2  0  3
   17.0690    1.9920   61.1550  C 0  0  0  2  0  3
   24.9970    5.1470   56.4080  C 0  0  0  1  0  3
   23.1230    4.8880   57.7500  C 0  0  0  1  0  3
   22.3640    5.6030   56.8470  C 0  0  0  1  0  3
   20.0370    6.4530   56.8870  C 0  0  0  1  0  3
   17.5790    2.4330   62.3580  C 0  0  0  1  0  3
   18.9920    2.8850   60.7480  C 0  0  0  1  0  3
   21.1470    4.9560   58.5550  C 0  0  0  1  0  3
   17.1390    2.4280   63.6860  C 0  0  0  2  0  3
   17.9060    2.9730   64.6400  C 0  0  0  2  0  3
   19.1960    3.5540   64.1810  C 0  0  0  2  0  3
   26.4440    4.9360   56.1160  C 0  0  0  4  0  4
   20.2170    3.4070   60.1380  C 0  0  0  3  0  4
   20.0030    4.7400   59.4610  C 0  0  0  3  0  4
   17.9710    2.2830   60.1590  N 0  0  0  1  0  2
   22.3580    4.4990   58.8010  N 0  0  0  1  0  2
   19.6510    3.5860   62.9680  N 0  0  0  1  0  2
   21.0990    5.6500   57.3710  N 0  0  0  1  0  3
   18.8000    3.0020   62.0940  N 0  0  0  1  0  3
   17.1354    8.5465   55.4936  H 0  0  0  1  0  1
   19.0834    9.6764   56.5026  H 0  0  0  1  0  1
   17.0185    6.0857   55.3709  H 0  0  0  1  0  1
   24.6977    6.2368   54.6232  H 0  0  0  1  0  1
   22.2847    6.6838   54.9729  H 0  0  0  1  0  1
   20.9576    8.3266   57.3994  H 0  0  0  1  0  1
   18.8971    4.7406   56.2729  H 0  0  0  1  0  1
   25.0500    4.0757   58.2795  H 0  0  0  1  0  1
   16.1099    1.4951   61.0167  H 0  0  0  1  0  1
   16.1785    1.9837   63.9429  H 0  0  0  1  0  1
   17.6038    2.9895   65.6861  H 0  0  0  1  0  1
   19.8268    4.0012   64.9478  H 0  0  0  1  0  1
   26.6971    5.4095   55.1559  H 0  0  0  1  0  1
   26.6532    3.8576   56.0593  H 0  0  0  1  0  1
   27.0491    5.3853   56.9172  H 0  0  0  1  0  1
   20.5740    2.6844   59.3894  H 0  0  0  1  0  1
   20.9781    3.5281   60.9229  H 0  0  0  1  0  1
   19.0648    4.7251   58.8869  H 0  0  0  1  0  1
   19.9604    5.5421   60.2125  H 0  0  0  1  0  1
  2  1  4  0  0  1
  3  1  4  0  0  1
  6  2  4  0  0  1
  7  3  4  0  0  1
  5  4  4  0  0  1
  4 10  4  0  0  1
 12  5  4  0  0  1
 13  6  4  0  0  1
 13  7  4  0  0  1
  8 10  4  0  0  1
 11  8  4  0  0  1
 14  9  2  0  0  1
 23  9  1  0  0  1
 10 20  1  0  0  2
 12 11  4  0  0  1
 24 11  1  0  0  1
 26 12  1  0  0  1
 26 13  1  0  0  2
 14 17  1  0  0  1
 27 14  1  0  0  1
 21 15  1  0  0  2
 15 23  2  0  0  1
 15 27  1  0  0  1
 22 16  1  0  0  2
 16 24  2  0  0  1
 16 26  1  0  0  1
 17 18  2  0  0  1
 18 19  1  0  0  1
 19 25  2  0  0  1
 21 22  1  0  0  2
 25 27  1  0  0  1
  1 28  1  0  0  2
  2 29  1  0  0  2
  3 30  1  0  0  2
  4 31  1  0  0  2
  5 32  1  0  0  2
  6 33  1  0  0  2
  7 34  1  0  0  2
  8 35  1  0  0  2
  9 36  1  0  0  2
 17 37  1  0  0  2
 18 38  1  0  0  2
 19 39  1  0  0  2
 20 40  1  0  0  2
 20 41  1  0  0  2
 20 42  1  0  0  2
 21 43  1  0  0  2
 21 44  1  0  0  2
 22 45  1  0  0  2
 22 46  1  0  0  2
M  END
> <MOLECULAR_FORMULA>
C22H19N5

> <MOLECULAR_WEIGHT>
353.3

> <NUM_HB_ATOMS>
5  

> <NUM_ROTOR>
3  

> <XLOGP2>
4.58 

$$$$
+4456 −0

File added.

Preview size limit exceeded, changes collapsed.

+332 −71
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ import unittest
import numpy as np
np.random.seed(123)

from rdkit.Chem import MolFromMolFile
from rdkit.Chem import MolFromSmiles
from rdkit.Chem.AllChem import Mol, ComputeGasteigerCharges

from deepchem.feat import rdkit_grid_featurizer as rgf
@@ -23,17 +23,15 @@ def random_string(length, chars=None):

class TestHelperFunctions(unittest.TestCase):
  """
  Test functions defined in rdkit_grid_featurizer module.
  Test helper functions defined in rdkit_grid_featurizer module.
  """

  def setUp(self):
    # TODO test more formats for ligand
    current_dir = os.path.dirname(os.path.realpath(__file__))
    package_dir = os.path.dirname(os.path.dirname(current_dir))
    self.protein_file = os.path.join(package_dir, 'dock', 'tests',
                                     '1jld_protein.pdb')
    self.ligand_file = os.path.join(package_dir, 'dock', 'tests',
                                    '1jld_ligand.sdf')
    self.protein_file = os.path.join(current_dir,
                                     '3ws9_protein_fixer_rdkit.pdb')
    self.ligand_file = os.path.join(current_dir, '3ws9_ligand.sdf')

  def test_get_ligand_filetype(self):

@@ -60,7 +58,7 @@ class TestHelperFunctions(unittest.TestCase):
        self.assertIsInstance(mol_rdk, Mol)
        self.assertEqual(mol_xyz.shape, (num_atoms, 3))

  def test_generate_random__unit_vector(self):
  def test_generate_random_unit_vector(self):
    for _ in range(100):
      u = rgf.generate_random__unit_vector()
      # 3D vector with unit length
@@ -98,6 +96,12 @@ class TestHelperFunctions(unittest.TestCase):
    # random coords between 0 and 1, so the max possible distance in sqrt(2)
    self.assertTrue((distance <= 2.0**0.5).all())

    # check if correct distance metric was used
    coords1 = np.array([[0, 0, 0], [1, 0, 0]])
    coords2 = np.array([[1, 0, 0], [2, 0, 0], [3, 0, 0]])
    distance = rgf.compute_pairwise_distances(coords1, coords2)
    self.assertTrue((distance == [[1, 2, 3], [0, 1, 2]]).all())

  def test_unit_vector(self):
    for _ in range(10):
      vector = np.random.rand(3)
@@ -111,6 +115,8 @@ class TestHelperFunctions(unittest.TestCase):
      angle = rgf.angle_between(v1, v2)
      self.assertLessEqual(angle, np.pi)
      self.assertGreaterEqual(angle, 0.0)
      self.assertAlmostEqual(rgf.angle_between(v1, v1), 0.0)
      self.assertAlmostEqual(rgf.angle_between(v1, -v1), np.pi)

  def test_hash_ecfp(self):
    for power in (2, 16, 64):
@@ -131,8 +137,208 @@ class TestHelperFunctions(unittest.TestCase):
        self.assertLess(pair_hash, 2**power)
        self.assertGreaterEqual(pair_hash, 0)

  def test_convert_atom_to_voxel(self):
    # 20 points with coords between -5 and 5, centered at 0
    coords_range = 10
    xyz = (np.random.rand(20, 3) - 0.5) * coords_range
    for idx in np.random.choice(20, 6):
      for box_width in (10, 20, 40):
        for voxel_width in (0.5, 1, 2):
          voxel = rgf.convert_atom_to_voxel(xyz, idx, box_width, voxel_width)
          self.assertIsInstance(voxel, list)
          self.assertEqual(len(voxel), 1)
          self.assertIsInstance(voxel[0], np.ndarray)
          self.assertEqual(voxel[0].shape, (3,))
          self.assertIs(voxel[0].dtype, np.dtype('int'))
          # indices are positive
          self.assertTrue((voxel[0] >= 0).all())
          # coordinates were properly translated and scaled
          self.assertTrue(
              (voxel[0] < (box_width + coords_range) / 2.0 / voxel_width).all())
          self.assertTrue(
              np.allclose(voxel[0],
                          np.floor((xyz[idx] + box_width / 2.0) / voxel_width)))

    # for coordinates outside of the box function should properly transform them
    # to indices and warn the user
    for args in ((np.array([[0, 1, 6]]), 0, 10, 1.0), (np.array([[0, 4, -6]]),
                                                       0, 10, 1.0)):
      # TODO check if function warns. There is assertWarns method in unittest,
      # but it is not implemented in 2.7 and buggy in 3.5 (issue 29620)
      voxel = rgf.convert_atom_to_voxel(*args)
      self.assertTrue(
          np.allclose(voxel[0], np.floor((args[0] + args[2] / 2.0) / args[3])))

  def test_convert_atom_pair_to_voxel(self):
    # 20 points with coords between -5 and 5, centered at 0
    coords_range = 10
    xyz1 = (np.random.rand(20, 3) - 0.5) * coords_range
    xyz2 = (np.random.rand(20, 3) - 0.5) * coords_range
    # 3 pairs of indices
    for idx1, idx2 in np.random.choice(20, (3, 2)):
      for box_width in (10, 20, 40):
        for voxel_width in (0.5, 1, 2):
          v1 = rgf.convert_atom_to_voxel(xyz1, idx1, box_width, voxel_width)
          v2 = rgf.convert_atom_to_voxel(xyz2, idx2, box_width, voxel_width)
          v_pair = rgf.convert_atom_pair_to_voxel((xyz1, xyz2), (idx1, idx2),
                                                  box_width, voxel_width)
          self.assertEqual(len(v_pair), 2)
          self.assertTrue((v1 == v_pair[0]).all())
          self.assertTrue((v2 == v_pair[1]).all())

  def test_compute_charge_dictionary(self):
    for fname in (self.ligand_file, self.protein_file):
      _, mol = rgf.load_molecule(fname)
      ComputeGasteigerCharges(mol)
      charge_dict = rgf.compute_charge_dictionary(mol)
      self.assertEqual(len(charge_dict), mol.GetNumAtoms())
      for i in range(mol.GetNumAtoms()):
        self.assertIn(i, charge_dict)
        self.assertIsInstance(charge_dict[i], (float, int))


class TestPiInteractions(unittest.TestCase):

  def setUp(self):
    current_dir = os.path.dirname(os.path.realpath(__file__))

    # simple flat ring
    self.cycle4 = MolFromSmiles('C1CCC1')
    self.cycle4.Compute2DCoords()

    # load and sanitize two real molecules
    _, self.prot = rgf.load_molecule(
        os.path.join(current_dir, '3ws9_protein_fixer_rdkit.pdb'),
        add_hydrogens=False,
        calc_charges=False,
        sanitize=True)

    _, self.lig = rgf.load_molecule(
        os.path.join(current_dir, '3ws9_ligand.sdf'),
        add_hydrogens=False,
        calc_charges=False,
        sanitize=True)

  def test_compute_ring_center(self):
    # FIXME might break with different version of rdkit
    self.assertTrue(
        np.allclose(rgf.compute_ring_center(self.cycle4, range(4)), 0))

  def test_compute_ring_normal(self):
    # FIXME might break with different version of rdkit
    normal = rgf.compute_ring_normal(self.cycle4, range(4))
    self.assertTrue(
        np.allclose(np.abs(normal / np.linalg.norm(normal)), [0, 0, 1]))

  def test_is_pi_parallel(self):
    ring1_center = np.array([0.0, 0.0, 0.0])
    ring2_center_true = np.array([4.0, 0.0, 0.0])
    ring2_center_false = np.array([10.0, 0.0, 0.0])
    ring1_normal_true = np.array([1.0, 0.0, 0.0])
    ring1_normal_false = np.array([0.0, 1.0, 0.0])

    for ring2_normal in (np.array([2.0, 0, 0]), np.array([-3.0, 0, 0])):
      # parallel normals
      self.assertTrue(
          rgf.is_pi_parallel(ring1_center, ring1_normal_true, ring2_center_true,
                             ring2_normal))
      # perpendicular normals
      self.assertFalse(
          rgf.is_pi_parallel(ring1_center, ring1_normal_false,
                             ring2_center_true, ring2_normal))
      # too far away
      self.assertFalse(
          rgf.is_pi_parallel(ring1_center, ring1_normal_true,
                             ring2_center_false, ring2_normal))

  def test_is_pi_t(self):
    ring1_center = np.array([0.0, 0.0, 0.0])
    ring2_center_true = np.array([4.0, 0.0, 0.0])
    ring2_center_false = np.array([10.0, 0.0, 0.0])
    ring1_normal_true = np.array([0.0, 1.0, 0.0])
    ring1_normal_false = np.array([1.0, 0.0, 0.0])

    for ring2_normal in (np.array([2.0, 0, 0]), np.array([-3.0, 0, 0])):
      # perpendicular normals
      self.assertTrue(
          rgf.is_pi_t(ring1_center, ring1_normal_true, ring2_center_true,
                      ring2_normal))
      # parallel normals
      self.assertFalse(
          rgf.is_pi_t(ring1_center, ring1_normal_false, ring2_center_true,
                      ring2_normal))
      # too far away
      self.assertFalse(
          rgf.is_pi_t(ring1_center, ring1_normal_true, ring2_center_false,
                      ring2_normal))

  def test_compute_pi_stack(self):
    # order of the molecules shouldn't matter
    dicts1 = rgf.compute_pi_stack(self.prot, self.lig)
    dicts2 = rgf.compute_pi_stack(self.lig, self.prot)
    for i, j in ((0, 2), (1, 3)):
      self.assertEqual(dicts1[i], dicts2[j])
      self.assertEqual(dicts1[j], dicts2[i])

    # with this criteria we should find both types of stacking
    for d in rgf.compute_pi_stack(
        self.lig, self.prot, dist_cutoff=7, angle_cutoff=40.):
      self.assertGreater(len(d), 0)

  def test_is_cation_pi(self):
    cation_position = np.array([[2.0, 0.0, 0.0]])
    ring_center_true = np.array([4.0, 0.0, 0.0])
    ring_center_false = np.array([10.0, 0.0, 0.0])
    ring_normal_true = np.array([1.0, 0.0, 0.0])
    ring_normal_false = np.array([0.0, 1.0, 0.0])

    # parallel normals
    self.assertTrue(
        rgf.is_cation_pi(cation_position, ring_center_true, ring_normal_true))
    # perpendicular normals
    self.assertFalse(
        rgf.is_cation_pi(cation_position, ring_center_true, ring_normal_false))
    # too far away
    self.assertFalse(
        rgf.is_cation_pi(cation_position, ring_center_false, ring_normal_true))

  def test_compute_cation_pi(self):
    # TODO find better example, currently dicts are empty
    dicts1 = rgf.compute_cation_pi(self.prot, self.lig)
    dicts2 = rgf.compute_cation_pi(self.lig, self.prot)

  def test_compute_binding_pocket_cation_pi(self):
    # TODO find better example, currently dicts are empty
    prot_dict, lig_dict = rgf.compute_binding_pocket_cation_pi(
        self.prot, self.lig)

    exp_prot_dict, exp_lig_dict = rgf.compute_cation_pi(self.prot, self.lig)
    add_lig, add_prot = rgf.compute_cation_pi(self.lig, self.prot)
    for exp_dict, to_add in ((exp_prot_dict, add_prot), (exp_lig_dict,
                                                         add_lig)):
      for atom_idx, count in to_add.items():
        if atom_idx not in exp_dict:
          exp_dict[atom_idx] = count
        else:
          exp_dict[atom_idx] += count

    self.assertEqual(prot_dict, exp_prot_dict)
    self.assertEqual(lig_dict, exp_lig_dict)


class TestFeaturizationFunctions(unittest.TestCase):
  """
  Test functions calculating features defined in rdkit_grid_featurizer module.
  """

  def setUp(self):
    current_dir = os.path.dirname(os.path.realpath(__file__))
    self.protein_file = os.path.join(current_dir,
                                     '3ws9_protein_fixer_rdkit.pdb')
    self.ligand_file = os.path.join(current_dir, '3ws9_ligand.sdf')

  def test_compute_all_ecfp(self):
    mol = MolFromMolFile(self.ligand_file)
    _, mol = rgf.load_molecule(self.ligand_file)
    num_atoms = mol.GetNumAtoms()
    for degree in range(1, 4):
      # TODO test if dict contains smiles
@@ -254,65 +460,6 @@ class TestHelperFunctions(unittest.TestCase):
    self.assertIsInstance(dicts, list)
    self.assertEqual(dicts, expected_dicts)

  def test_convert_atom_to_voxel(self):
    # 20 points with coords between -5 and 5, centered at 0
    coords_range = 10
    xyz = (np.random.rand(20, 3) - 0.5) * coords_range
    for idx in np.random.choice(20, 6):
      for box_width in (10, 20, 40):
        for voxel_width in (0.5, 1, 2):
          voxel = rgf.convert_atom_to_voxel(xyz, idx, box_width, voxel_width)
          self.assertIsInstance(voxel, list)
          self.assertEqual(len(voxel), 1)
          self.assertIsInstance(voxel[0], np.ndarray)
          self.assertEqual(voxel[0].shape, (3,))
          self.assertIs(voxel[0].dtype, np.dtype('int'))
          # indices are positive
          self.assertTrue((voxel[0] >= 0).all())
          # coordinates were properly translated and scaled
          self.assertTrue(
              (voxel[0] < (box_width + coords_range) / 2.0 / voxel_width).all())
          self.assertTrue(
              np.allclose(voxel[0],
                          np.floor((xyz[idx] + box_width / 2.0) / voxel_width)))

    # for coordinates outside of the box function should properly transform them
    # to indices and warn the user
    for args in ((np.array([[0, 1, 6]]), 0, 10, 1.0), (np.array([[0, 4, -6]]),
                                                       0, 10, 1.0)):
      # TODO check if function warns. There is assertWarns method in unittest,
      # but it is not implemented in 2.7 and buggy in 3.5 (issue 29620)
      voxel = rgf.convert_atom_to_voxel(*args)
      self.assertTrue(
          np.allclose(voxel[0], np.floor((args[0] + args[2] / 2.0) / args[3])))

  def test_convert_atom_pair_to_voxel(self):
    # 20 points with coords between -5 and 5, centered at 0
    coords_range = 10
    xyz1 = (np.random.rand(20, 3) - 0.5) * coords_range
    xyz2 = (np.random.rand(20, 3) - 0.5) * coords_range
    # 3 pairs of indices
    for idx1, idx2 in np.random.choice(20, (3, 2)):
      for box_width in (10, 20, 40):
        for voxel_width in (0.5, 1, 2):
          v1 = rgf.convert_atom_to_voxel(xyz1, idx1, box_width, voxel_width)
          v2 = rgf.convert_atom_to_voxel(xyz2, idx2, box_width, voxel_width)
          v_pair = rgf.convert_atom_pair_to_voxel((xyz1, xyz2), (idx1, idx2),
                                                  box_width, voxel_width)
          self.assertEqual(len(v_pair), 2)
          self.assertTrue((v1 == v_pair[0]).all())
          self.assertTrue((v2 == v_pair[1]).all())

  def test_compute_charge_dictionary(self):
    for fname in (self.ligand_file, self.protein_file):
      _, mol = rgf.load_molecule(fname)
      ComputeGasteigerCharges(mol)
      charge_dict = rgf.compute_charge_dictionary(mol)
      self.assertEqual(len(charge_dict), mol.GetNumAtoms())
      for i in range(mol.GetNumAtoms()):
        self.assertIn(i, charge_dict)
        self.assertIsInstance(charge_dict[i], (float, int))


class TestRdkitGridFeaturizer(unittest.TestCase):
  """
@@ -327,6 +474,116 @@ class TestRdkitGridFeaturizer(unittest.TestCase):
    self.ligand_file = os.path.join(package_dir, 'dock', 'tests',
                                    '1jld_ligand.sdf')

  def test_default_featurizer(self):
    # test if default parameters work
    featurizer = rgf.RdkitGridFeaturizer()
    self.assertIsInstance(featurizer, rgf.RdkitGridFeaturizer)
    feature_tensor = featurizer.featurize_complexes([self.ligand_file],
                                                    [self.protein_file])
    self.assertIsInstance(feature_tensor, np.ndarray)

  def test_example_featurizer(self):
    # check if use-case from examples works
    featurizer = rgf.RdkitGridFeaturizer(
        voxel_width=16.0,
        feature_types=['ecfp', 'splif', 'hbond', 'salt_bridge'],
        ecfp_power=9,
        splif_power=9,
        flatten=True)
    feature_tensor = featurizer.featurize_complexes([self.ligand_file],
                                                    [self.protein_file])
    self.assertIsInstance(feature_tensor, np.ndarray)

  def test_force_flatten(self):
    # test if input is flattened when flat features are used
    featurizer = rgf.RdkitGridFeaturizer(
        feature_types=['ecfp_hashed'], flatten=False)
    featurizer.flatten = True  # False should be ignored with ecfp_hashed
    feature_tensor = featurizer.featurize_complexes([self.ligand_file],
                                                    [self.protein_file])
    self.assertIsInstance(feature_tensor, np.ndarray)
    self.assertEqual(feature_tensor.shape, (1, 2 * 2**featurizer.ecfp_power))

  def test_combined(self):
    ecfp_power = 5
    splif_power = 5
    # test voxel features
    featurizer = rgf.RdkitGridFeaturizer(
        voxel_width=1.0,
        box_width=20.0,
        feature_types=['voxel_combined'],
        ecfp_power=ecfp_power,
        splif_power=splif_power,
        flatten=False,
        sanitize=True)
    feature_tensor = featurizer.featurize_complexes([self.ligand_file],
                                                    [self.protein_file])
    self.assertIsInstance(feature_tensor, np.ndarray)
    voxel_total_len = (
        2**ecfp_power +
        len(featurizer.cutoffs['splif_contact_bins']) * 2**splif_power +
        len(featurizer.cutoffs['hbond_dist_bins']) + 5)
    self.assertEqual(feature_tensor.shape, (1, 20, 20, 20, voxel_total_len))

    # test flat features
    featurizer = rgf.RdkitGridFeaturizer(
        voxel_width=1.0,
        feature_types=['flat_combined'],
        ecfp_power=ecfp_power,
        splif_power=splif_power,
        sanitize=True)
    feature_tensor = featurizer.featurize_complexes([self.ligand_file],
                                                    [self.protein_file])
    self.assertIsInstance(feature_tensor, np.ndarray)
    flat_total_len = (
        3 * 2**ecfp_power +
        len(featurizer.cutoffs['splif_contact_bins']) * 2**splif_power +
        len(featurizer.cutoffs['hbond_dist_bins']))
    self.assertEqual(feature_tensor.shape, (1, flat_total_len))

    # check if aromatic features are ignores if sanitize=False
    featurizer = rgf.RdkitGridFeaturizer(
        voxel_width=16.0,
        feature_types=['all_combined'],
        ecfp_power=ecfp_power,
        splif_power=splif_power,
        flatten=True,
        sanitize=False)

    self.assertTrue('pi_stack' not in featurizer.feature_types)
    self.assertTrue('cation_pi' not in featurizer.feature_types)
    feature_tensor = featurizer.featurize_complexes([self.ligand_file],
                                                    [self.protein_file])
    self.assertIsInstance(feature_tensor, np.ndarray)
    total_len = voxel_total_len + flat_total_len - 3 - 2**ecfp_power
    self.assertEqual(feature_tensor.shape, (1, total_len))

  def test_custom_cutoffs(self):
    custom_cutoffs = {
        'hbond_dist_bins': [(2., 3.), (3., 3.5)],
        'hbond_angle_cutoffs': [5, 90],
        'splif_contact_bins': [(0, 3.5), (3.5, 6.0)],
        'ecfp_cutoff': 5.0,
        'sybyl_cutoff': 3.0,
        'salt_bridges_cutoff': 4.0,
        'pi_stack_dist_cutoff': 5.0,
        'pi_stack_angle_cutoff': 15.0,
        'cation_pi_dist_cutoff': 5.5,
        'cation_pi_angle_cutoff': 20.0,
    }
    rgf_featurizer = rgf.RdkitGridFeaturizer(**custom_cutoffs)
    self.assertEqual(rgf_featurizer.cutoffs, custom_cutoffs)

  def test_rotations(self):
    featurizer = rgf.RdkitGridFeaturizer(
        nb_rotations=3,
        feature_types=['voxel_combined'],
        flatten=False,
        sanitize=True)
    feature_tensors = featurizer.featurize_complexes([self.ligand_file],
                                                     [self.protein_file])
    self.assertEqual(len(feature_tensors), 4)

  def test_voxelize(self):
    prot_xyz, prot_rdk = rgf.load_molecule(self.protein_file)
    lig_xyz, lig_rdk = rgf.load_molecule(self.ligand_file)
@@ -335,14 +592,18 @@ class TestRdkitGridFeaturizer(unittest.TestCase):
    prot_xyz = rgf.subtract_centroid(prot_xyz, centroid)
    lig_xyz = rgf.subtract_centroid(lig_xyz, centroid)

    prot_ecfp_dict, lig_ecfp_dict = (rgf.featurize_binding_pocket_ecfp(
        prot_xyz, prot_rdk, lig_xyz, lig_rdk))
    prot_ecfp_dict, lig_ecfp_dict = rgf.featurize_binding_pocket_ecfp(
        prot_xyz, prot_rdk, lig_xyz, lig_rdk)

    box_w = 20
    f_power = 5

    rgf_featurizer = rgf.RdkitGridFeaturizer(
        box_width=box_w, ecfp_power=f_power)
        box_width=box_w,
        ecfp_power=f_power,
        feature_types=['all_combined'],
        flatten=True,
        sanitize=True)

    prot_tensor = rgf_featurizer._voxelize(
        rgf.convert_atom_to_voxel,
Loading