Commit 61a456ad authored by miaecle's avatar miaecle
Browse files

Merge remote-tracking branch 'remotes/origin/master'

parents fc8e4382 6693d1a7
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -56,7 +56,7 @@ git clone https://github.com/deepchem/deepchem.git # Clone deepchem source
cd deepchem
bash scripts/install_deepchem_conda.sh deepchem
source activate deepchem
pip install tensorflow-gpu==1.2.1                       # If you want GPU support
pip install tensorflow-gpu==1.3.0                      # If you want GPU support
python setup.py install                                 # Manual install
nosetests -v deepchem --nologcapture                    # Run tests
```
@@ -110,7 +110,7 @@ conda install -c deepchem -c rdkit -c conda-forge -c omnia deepchem=1.2.0
    contact your local sysadmin to work out a custom installation. If your
    version of Linux is recent, then the following command will work:
    ```
    pip install tensorflow-gpu==1.2.1
    pip install tensorflow-gpu==1.3.0
    ```

9. `deepchem`: Clone the `deepchem` github repo:
+7 −1
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@ from deepchem.utils.save import log
import tempfile
import time
import shutil
from multiprocessing.dummy import Pool

__author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
@@ -643,8 +644,12 @@ class DiskDataset(Dataset):
        shard_perm = np.random.permutation(num_shards)
      else:
        shard_perm = np.arange(num_shards)
      pool = Pool(1)
      next_shard = pool.apply_async(dataset.get_shard, (shard_perm[0],))
      for i in range(num_shards):
        X, y, w, ids = dataset.get_shard(shard_perm[i])
        X, y, w, ids = next_shard.get()
        if i < num_shards - 1:
          next_shard = pool.apply_async(dataset.get_shard, (shard_perm[i + 1],))
        n_samples = X.shape[0]
        # TODO(rbharath): This happens in tests sometimes, but don't understand why?
        # Handle edge case.
@@ -683,6 +688,7 @@ class DiskDataset(Dataset):
            (X_batch, y_batch, w_batch, ids_batch) = pad_batch(
                shard_batch_size, X_batch, y_batch, w_batch, ids_batch)
          yield (X_batch, y_batch, w_batch, ids_batch)
      pool.close()

    return iterate(self)

+8 −4
Original line number Diff line number Diff line
@@ -134,10 +134,12 @@ def get_atom_adj_matrices(mol,
  return (adj_matrix.astype(np.uint8), atom_matrix.astype(np.uint8))


def featurize_mol(mol, n_atom_types, max_n_atoms, max_valence):

def featurize_mol(mol, n_atom_types, max_n_atoms, max_valence,
                  num_atoms_feature):
  adj_matrix, atom_matrix = get_atom_adj_matrices(mol, n_atom_types,
                                                  max_n_atoms, max_valence)
  if num_atoms_feature:
    return ((adj_matrix, atom_matrix, mol.GetNumAtoms()))
  return ((adj_matrix, atom_matrix))


@@ -147,11 +149,13 @@ class AdjacencyFingerprint(Featurizer):
               n_atom_types=23,
               max_n_atoms=200,
               add_hydrogens=False,
               max_valence=4):
               max_valence=4,
               num_atoms_feature=False):
    self.n_atom_types = n_atom_types
    self.max_n_atoms = max_n_atoms
    self.add_hydrogens = add_hydrogens
    self.max_valence = max_valence
    self.num_atoms_feature = num_atoms_feature

  def featurize(self, rdkit_mols):
    featurized_mols = np.empty((len(rdkit_mols)), dtype=object)
@@ -160,6 +164,6 @@ class AdjacencyFingerprint(Featurizer):
      if self.add_hydrogens:
        mol = Chem.AddHs(mol)
      featurized_mol = featurize_mol(mol, self.n_atom_types, self.max_n_atoms,
                                     self.max_valence)
                                     self.max_valence, self.num_atoms_feature)
      featurized_mols[idx] = featurized_mol
    return (featurized_mols)
+36 −30
Original line number Diff line number Diff line
@@ -28,13 +28,13 @@ TODO(LESWING) add sanitization with rdkit upgrade to 2017.*
def get_ligand_filetype(ligand_filename):
  """Returns the filetype of ligand."""
  if ".mol2" in ligand_filename:
    return ".mol2"
    return "mol2"
  elif ".sdf" in ligand_filename:
    return "sdf"
  elif ".pdbqt" in ligand_filename:
    return ".pdbqt"
    return "pdbqt"
  elif ".pdb" in ligand_filename:
    return ".pdb"
    return "pdb"
  else:
    raise ValueError("Unrecognized_filename")

@@ -74,7 +74,8 @@ def generate_random__unit_vector():
  theta = np.random.uniform(low=0.0, high=2 * np.pi)
  z = np.random.uniform(low=-1.0, high=1.0)
  u = np.array(
      [np.sqrt(1 - z**2) * np.cos(theta), np.sqrt(1 - z**2) * np.sin(theta), z])
      [np.sqrt(1 - z**2) * np.cos(theta),
       np.sqrt(1 - z**2) * np.sin(theta), z])
  return (u)


@@ -142,8 +143,8 @@ def compute_pairwise_distances(protein_xyz, ligand_xyz):
  atom and the j"th ligand atom
  """

  pairwise_distances = np.zeros(
      (np.shape(protein_xyz)[0], np.shape(ligand_xyz)[0]))
  pairwise_distances = np.zeros((np.shape(protein_xyz)[0],
                                 np.shape(ligand_xyz)[0]))
  for j in range(0, np.shape(ligand_xyz)[0]):
    differences = protein_xyz - ligand_xyz[j, :]
    squared_differences = np.square(differences)
@@ -270,18 +271,17 @@ def featurize_binding_pocket_ecfp(protein_xyz,
  ----------
  protein_xyz: np.ndarray
    Of shape (N_protein_atoms, 3)
  protein: PDB object (TODO(rbharath): Correct?)
  protein: rdkit.rdchem.Mol
    Contains more metadata.
  ligand_xyz: np.ndarray
    Of shape (N_ligand_atoms, 3)
  ligand: PDB object (TODO(rbharath): Correct?)
  ligand: rdkit.rdchem.Mol
    Contains more metadata
  pairwise_distances: np.ndarray 
    Array of pairwise protein-ligand distances (Angstroms) 
  cutoff: float
    Cutoff distance for contact consideration.
  """
  features_dict = {}

  if pairwise_distances is None:
    pairwise_distances = compute_pairwise_distances(protein_xyz, ligand_xyz)
@@ -347,8 +347,8 @@ def compute_splif_features_in_range(protein,
  atoms.  Returns a dictionary mapping (protein_index_i, ligand_index_j) -->
  (protein_ecfp_i, ligand_ecfp_j)
  """
  contacts = np.nonzero((pairwise_distances > contact_bin[0]) & (
      pairwise_distances < contact_bin[1]))
  contacts = np.nonzero((pairwise_distances > contact_bin[0]) &
                        (pairwise_distances < contact_bin[1]))
  protein_atoms = set([int(c) for c in contacts[0].tolist()])
  contacts = zip(contacts[0], contacts[1])

@@ -509,8 +509,8 @@ def get_formal_charge(atom):


def is_salt_bridge(atom_i, atom_j):
  if np.abs(2.0 - np.abs(get_formal_charge(atom_i) - get_formal_charge(
      atom_j))) < 0.01:
  if np.abs(2.0 - np.abs(get_formal_charge(atom_i) - get_formal_charge(atom_j))
           ) < 0.01:
    return True
  else:
    return False
@@ -555,8 +555,8 @@ def compute_hbonds_in_range(protein, protein_xyz, ligand, ligand_xyz,
  a distance bin and an angle cutoff.
  """

  contacts = np.nonzero((pairwise_distances > hbond_dist_bin[0]) & (
      pairwise_distances < hbond_dist_bin[1]))
  contacts = np.nonzero((pairwise_distances > hbond_dist_bin[0]) &
                        (pairwise_distances < hbond_dist_bin[1]))
  protein_atoms = set([int(c) for c in contacts[0].tolist()])
  protein_ecfp_dict = compute_all_ecfp(
      protein, indices=protein_atoms, degree=ecfp_degree)
@@ -594,10 +594,15 @@ def convert_atom_to_voxel(molecule_xyz, atom_index, box_width, voxel_width):
  """
  Converts an atom to an i,j,k grid index.
  """
  coordinates = molecule_xyz[atom_index, :]
  from warnings import warn

  indices = np.floor(
      np.abs(molecule_xyz[atom_index, :] + np.array(
          [box_width, box_width, box_width]) / 2.0) / voxel_width).astype(int)
      (molecule_xyz[atom_index, :] + np.array([box_width, box_width, box_width]
                                             ) / 2.0) / voxel_width).astype(int)
  if ((indices < 0) | (indices >= box_width / voxel_width)).any():
    warn(
        'Coordinates are outside of the box (atom id = %s, coords xyz = %s, coords in box = %s'
        % (atom_index, molecule_xyz[atom_index], indices))
  return ([indices])


@@ -617,10 +622,15 @@ def convert_atom_pair_to_voxel(molecule_xyz_tuple, atom_index_pair, box_width,


def compute_charge_dictionary(molecule):
  """Computes partial charges for each atom."""
  """Create a dictionary with partial charges for each atom in the molecule.

  This function assumes that the charges for the molecule are already
  computed (it can be done with rdkit_util.compute_charges(molecule))
  """

  charge_dictionary = {}
  for i, atom in enumerate(ob.OBMolAtomIter(molecule)):
    charge_dictionary[i] = atom.GetPartialCharge()
  for i, atom in enumerate(molecule.GetAtoms()):
    charge_dictionary[i] = get_formal_charge(atom)
  return charge_dictionary


@@ -1090,7 +1100,7 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
                channel_power=None,
                nb_channel=16,
                dtype="np.int8"):
    # TODO(enf): make array index checking not a try-catch statement.

    if channel_power is not None:
      if channel_power == 0:
        nb_channel = 1
@@ -1110,22 +1120,18 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
      for key, features in feature_dict.items():
        voxels = get_voxels(coordinates, key, self.box_width, self.voxel_width)
        for voxel in voxels:
          try:
          if ((voxel >= 0) & (voxel < self.voxels_per_edge)).all():
            if hash_function is not None:
              feature_tensor[voxel[0], voxel[1], voxel[2], hash_function(
                  features, channel_power)] += 1.0
              feature_tensor[voxel[0], voxel[1], voxel[2],
                             hash_function(features, channel_power)] += 1.0
            else:
              feature_tensor[voxel[0], voxel[1], voxel[3], 0] += features
          except:
            continue
    elif feature_list is not None:
      for key in feature_list:
        voxels = get_voxels(coordinates, key, self.box_width, self.voxel_width)
        for voxel in voxels:
          try:
          if ((voxel >= 0) & (voxel < self.voxels_per_edge)).all():
            feature_tensor[voxel[0], voxel[1], voxel[2], 0] += 1.0
          except:
            continue

    return feature_tensor

+368 −0
Original line number Diff line number Diff line
"""
Test rdkit_grid_featurizer module.
"""
import os
from six import integer_types
import unittest

import numpy as np
np.random.seed(123)

from rdkit.Chem import MolFromMolFile
from rdkit.Chem.AllChem import Mol, ComputeGasteigerCharges

from deepchem.feat import rdkit_grid_featurizer as rgf


def random_string(length, chars=None):
  import string
  if chars is None:
    chars = list(string.ascii_letters + string.ascii_letters + '()[]+-.=#@/\\')
  return ''.join(np.random.choice(chars, length))


class TestHelperFunctions(unittest.TestCase):
  """
  Test functions defined in rdkit_grid_featurizer module.
  """

  def setUp(self):
    # TODO test more formats for ligand
    current_dir = os.path.dirname(os.path.realpath(__file__))
    package_dir = os.path.dirname(os.path.dirname(current_dir))
    self.protein_file = os.path.join(package_dir, 'dock', 'tests',
                                     '1jld_protein.pdb')
    self.ligand_file = os.path.join(package_dir, 'dock', 'tests',
                                    '1jld_ligand.sdf')

  def test_get_ligand_filetype(self):

    supported_extensions = ['mol2', 'sdf', 'pdb', 'pdbqt']
    # some users might try to read smiles with this function
    unsupported_extensions = ['smi', 'ism']

    for extension in supported_extensions:
      fname = 'molecule.%s' % extension
      self.assertEqual(rgf.get_ligand_filetype(fname), extension)

    for extension in unsupported_extensions:
      fname = 'molecule.%s' % extension
      self.assertRaises(ValueError, rgf.get_ligand_filetype, fname)

  def test_load_molecule(self):
    # adding hydrogens and charges is tested in dc.utils
    for add_hydrogens in (True, False):
      for calc_charges in (True, False):
        mol_xyz, mol_rdk = rgf.load_molecule(self.ligand_file, add_hydrogens,
                                             calc_charges)
        num_atoms = mol_rdk.GetNumAtoms()
        self.assertIsInstance(mol_xyz, np.ndarray)
        self.assertIsInstance(mol_rdk, Mol)
        self.assertEqual(mol_xyz.shape, (num_atoms, 3))

  def test_generate_random__unit_vector(self):
    for _ in range(100):
      u = rgf.generate_random__unit_vector()
      # 3D vector with unit length
      self.assertEqual(u.shape, (3,))
      self.assertAlmostEqual(np.linalg.norm(u), 1.0)

  def test_generate_random_rotation_matrix(self):
    # very basic test, we check if rotations actually work in test_rotate_molecules
    for _ in range(100):
      m = rgf.generate_random_rotation_matrix()
      self.assertEqual(m.shape, (3, 3))

  def test_rotate_molecules(self):
    # check if distances do not change
    vectors = np.random.rand(4, 2, 3)
    norms = np.linalg.norm(vectors[:, 1] - vectors[:, 0], axis=1)
    vectors_rot = np.array(rgf.rotate_molecules(vectors))
    norms_rot = np.linalg.norm(vectors_rot[:, 1] - vectors_rot[:, 0], axis=1)
    self.assertTrue(np.allclose(norms, norms_rot))

    # check if it works for molecules with different numbers of atoms
    coords = [np.random.rand(n, 3) for n in (10, 20, 40, 100)]
    coords_rot = rgf.rotate_molecules(coords)
    self.assertEqual(len(coords), len(coords_rot))

  def test_compute_pairwise_distances(self):
    n1 = 10
    n2 = 50
    coords1 = np.random.rand(n1, 3)
    coords2 = np.random.rand(n2, 3)

    distance = rgf.compute_pairwise_distances(coords1, coords2)
    self.assertEqual(distance.shape, (n1, n2))
    self.assertTrue((distance >= 0).all())
    # random coords between 0 and 1, so the max possible distance in sqrt(2)
    self.assertTrue((distance <= 2.0**0.5).all())

  def test_unit_vector(self):
    for _ in range(10):
      vector = np.random.rand(3)
      norm_vector = rgf.unit_vector(vector)
      self.assertAlmostEqual(np.linalg.norm(norm_vector), 1.0)

  def test_angle_between(self):
    for _ in range(10):
      v1 = np.random.rand(3,)
      v2 = np.random.rand(3,)
      angle = rgf.angle_between(v1, v2)
      self.assertLessEqual(angle, np.pi)
      self.assertGreaterEqual(angle, 0.0)

  def test_hash_ecfp(self):
    for power in (2, 16, 64):
      for _ in range(10):
        string = random_string(10)
        string_hash = rgf.hash_ecfp(string, power)
        self.assertIsInstance(string_hash, integer_types)
        self.assertLess(string_hash, 2**power)
        self.assertGreaterEqual(string_hash, 0)

  def test_hash_ecfp_pair(self):
    for power in (2, 16, 64):
      for _ in range(10):
        string1 = random_string(10)
        string2 = random_string(10)
        pair_hash = rgf.hash_ecfp_pair((string1, string2), power)
        self.assertIsInstance(pair_hash, integer_types)
        self.assertLess(pair_hash, 2**power)
        self.assertGreaterEqual(pair_hash, 0)

  def test_compute_all_ecfp(self):
    mol = MolFromMolFile(self.ligand_file)
    num_atoms = mol.GetNumAtoms()
    for degree in range(1, 4):
      # TODO test if dict contains smiles

      ecfp_all = rgf.compute_all_ecfp(mol, degree=degree)
      self.assertIsInstance(ecfp_all, dict)
      self.assertEqual(len(ecfp_all), num_atoms)
      self.assertEqual(list(ecfp_all.keys()), list(range(num_atoms)))

      num_ind = np.random.choice(range(1, num_atoms))
      indices = list(np.random.choice(num_atoms, num_ind, replace=False))

      ecfp_selected = rgf.compute_all_ecfp(mol, indices=indices, degree=degree)
      self.assertIsInstance(ecfp_selected, dict)
      self.assertEqual(len(ecfp_selected), num_ind)
      self.assertEqual(sorted(ecfp_selected.keys()), sorted(indices))

  def test_featurize_binding_pocket_ecfp(self):
    prot_xyz, prot_rdk = rgf.load_molecule(self.protein_file)
    lig_xyz, lig_rdk = rgf.load_molecule(self.ligand_file)
    distance = rgf.compute_pairwise_distances(
        protein_xyz=prot_xyz, ligand_xyz=lig_xyz)

    # check if results are the same if we provide precomputed distances
    prot_dict, lig_dict = rgf.featurize_binding_pocket_ecfp(
        prot_xyz,
        prot_rdk,
        lig_xyz,
        lig_rdk,)
    prot_dict_dist, lig_dict_dist = rgf.featurize_binding_pocket_ecfp(
        prot_xyz, prot_rdk, lig_xyz, lig_rdk, pairwise_distances=distance)
    # ...but first check if we actually got two dicts
    self.assertIsInstance(prot_dict, dict)
    self.assertIsInstance(lig_dict, dict)

    self.assertEqual(prot_dict, prot_dict_dist)
    self.assertEqual(lig_dict, lig_dict_dist)

    # check if we get less features with smaller distance cutoff
    prot_dict_d2, lig_dict_d2 = rgf.featurize_binding_pocket_ecfp(
        prot_xyz,
        prot_rdk,
        lig_xyz,
        lig_rdk,
        cutoff=2.0,)
    prot_dict_d6, lig_dict_d6 = rgf.featurize_binding_pocket_ecfp(
        prot_xyz,
        prot_rdk,
        lig_xyz,
        lig_rdk,
        cutoff=6.0,)
    self.assertLess(len(prot_dict_d2), len(prot_dict))
    # ligands are typically small so all atoms might be present
    self.assertLessEqual(len(lig_dict_d2), len(lig_dict))
    self.assertGreater(len(prot_dict_d6), len(prot_dict))
    self.assertGreaterEqual(len(lig_dict_d6), len(lig_dict))

    # check if using different ecfp_degree changes anything
    prot_dict_e3, lig_dict_e3 = rgf.featurize_binding_pocket_ecfp(
        prot_xyz,
        prot_rdk,
        lig_xyz,
        lig_rdk,
        ecfp_degree=3,)
    self.assertNotEqual(prot_dict_e3, prot_dict)
    self.assertNotEqual(lig_dict_e3, lig_dict)

  def test_compute_splif_features_in_range(self):
    prot_xyz, prot_rdk = rgf.load_molecule(self.protein_file)
    lig_xyz, lig_rdk = rgf.load_molecule(self.ligand_file)
    prot_num_atoms = prot_rdk.GetNumAtoms()
    lig_num_atoms = lig_rdk.GetNumAtoms()
    distance = rgf.compute_pairwise_distances(
        protein_xyz=prot_xyz, ligand_xyz=lig_xyz)

    for bins in ((0, 2), (2, 3)):
      splif_dict = rgf.compute_splif_features_in_range(
          prot_rdk,
          lig_rdk,
          distance,
          bins,)

      self.assertIsInstance(splif_dict, dict)
      for (prot_idx, lig_idx), ecfp_pair in splif_dict.items():

        for idx in (prot_idx, lig_idx):
          self.assertIsInstance(idx, (int, np.int64))
        self.assertGreaterEqual(prot_idx, 0)
        self.assertLess(prot_idx, prot_num_atoms)
        self.assertGreaterEqual(lig_idx, 0)
        self.assertLess(lig_idx, lig_num_atoms)

        for ecfp in ecfp_pair:
          ecfp_idx, ecfp_frag = ecfp.split(',')
          ecfp_idx = int(ecfp_idx)
          self.assertGreaterEqual(ecfp_idx, 0)
          # TODO upperbound?

  def test_featurize_splif(self):
    prot_xyz, prot_rdk = rgf.load_molecule(self.protein_file)
    lig_xyz, lig_rdk = rgf.load_molecule(self.ligand_file)
    distance = rgf.compute_pairwise_distances(
        protein_xyz=prot_xyz, ligand_xyz=lig_xyz)

    bins = [(1, 2), (2, 3)]

    dicts = rgf.featurize_splif(
        prot_xyz,
        prot_rdk,
        lig_xyz,
        lig_rdk,
        contact_bins=bins,
        pairwise_distances=distance,
        ecfp_degree=2)
    expected_dicts = [
        rgf.compute_splif_features_in_range(
            prot_rdk, lig_rdk, distance, c_bin, ecfp_degree=2) for c_bin in bins
    ]
    self.assertIsInstance(dicts, list)
    self.assertEqual(dicts, expected_dicts)

  def test_convert_atom_to_voxel(self):
    # 20 points with coords between -5 and 5, centered at 0
    coords_range = 10
    xyz = (np.random.rand(20, 3) - 0.5) * coords_range
    for idx in np.random.choice(20, 6):
      for box_width in (10, 20, 40):
        for voxel_width in (0.5, 1, 2):
          voxel = rgf.convert_atom_to_voxel(xyz, idx, box_width, voxel_width)
          self.assertIsInstance(voxel, list)
          self.assertEqual(len(voxel), 1)
          self.assertIsInstance(voxel[0], np.ndarray)
          self.assertEqual(voxel[0].shape, (3,))
          self.assertIs(voxel[0].dtype, np.dtype('int'))
          # indices are positive
          self.assertTrue((voxel[0] >= 0).all())
          # coordinates were properly translated and scaled
          self.assertTrue(
              (voxel[0] < (box_width + coords_range) / 2.0 / voxel_width).all())
          self.assertTrue(
              np.allclose(voxel[0],
                          np.floor((xyz[idx] + box_width / 2.0) / voxel_width)))

    # for coordinates outside of the box function should properly transform them
    # to indices and warn the user
    for args in ((np.array([[0, 1, 6]]), 0, 10, 1.0), (np.array([[0, 4, -6]]),
                                                       0, 10, 1.0)):
      # TODO check if function warns. There is assertWarns method in unittest,
      # but it is not implemented in 2.7 and buggy in 3.5 (issue 29620)
      voxel = rgf.convert_atom_to_voxel(*args)
      self.assertTrue(
          np.allclose(voxel[0], np.floor((args[0] + args[2] / 2.0) / args[3])))

  def test_convert_atom_pair_to_voxel(self):
    # 20 points with coords between -5 and 5, centered at 0
    coords_range = 10
    xyz1 = (np.random.rand(20, 3) - 0.5) * coords_range
    xyz2 = (np.random.rand(20, 3) - 0.5) * coords_range
    # 3 pairs of indices
    for idx1, idx2 in np.random.choice(20, (3, 2)):
      for box_width in (10, 20, 40):
        for voxel_width in (0.5, 1, 2):
          v1 = rgf.convert_atom_to_voxel(xyz1, idx1, box_width, voxel_width)
          v2 = rgf.convert_atom_to_voxel(xyz2, idx2, box_width, voxel_width)
          v_pair = rgf.convert_atom_pair_to_voxel((xyz1, xyz2), (idx1, idx2),
                                                  box_width, voxel_width)
          self.assertEqual(len(v_pair), 2)
          self.assertTrue((v1 == v_pair[0]).all())
          self.assertTrue((v2 == v_pair[1]).all())

  def test_compute_charge_dictionary(self):
    for fname in (self.ligand_file, self.protein_file):
      _, mol = rgf.load_molecule(fname)
      ComputeGasteigerCharges(mol)
      charge_dict = rgf.compute_charge_dictionary(mol)
      self.assertEqual(len(charge_dict), mol.GetNumAtoms())
      for i in range(mol.GetNumAtoms()):
        self.assertIn(i, charge_dict)
        self.assertIsInstance(charge_dict[i], (float, int))


class TestRdkitGridFeaturizer(unittest.TestCase):
  """
  Test RdkitGridFeaturizer class defined in rdkit_grid_featurizer module.
  """

  def setUp(self):
    current_dir = os.path.dirname(os.path.realpath(__file__))
    package_dir = os.path.dirname(os.path.dirname(current_dir))
    self.protein_file = os.path.join(package_dir, 'dock', 'tests',
                                     '1jld_protein.pdb')
    self.ligand_file = os.path.join(package_dir, 'dock', 'tests',
                                    '1jld_ligand.sdf')

  def test_voxelize(self):
    prot_xyz, prot_rdk = rgf.load_molecule(self.protein_file)
    lig_xyz, lig_rdk = rgf.load_molecule(self.ligand_file)

    centroid = rgf.compute_centroid(lig_xyz)
    prot_xyz = rgf.subtract_centroid(prot_xyz, centroid)
    lig_xyz = rgf.subtract_centroid(lig_xyz, centroid)

    prot_ecfp_dict, lig_ecfp_dict = (rgf.featurize_binding_pocket_ecfp(
        prot_xyz, prot_rdk, lig_xyz, lig_rdk))

    box_w = 20
    f_power = 5

    rgf_featurizer = rgf.RdkitGridFeaturizer(
        box_width=box_w, ecfp_power=f_power)

    prot_tensor = rgf_featurizer._voxelize(
        rgf.convert_atom_to_voxel,
        rgf.hash_ecfp,
        prot_xyz,
        feature_dict=prot_ecfp_dict,
        channel_power=f_power)
    self.assertEqual(prot_tensor.shape, tuple([box_w] * 3 + [2**f_power]))
    all_features = prot_tensor.sum()
    # protein is too big for the box, some features should be missing
    self.assertGreater(all_features, 0)
    self.assertLess(all_features, prot_rdk.GetNumAtoms())

    lig_tensor = rgf_featurizer._voxelize(
        rgf.convert_atom_to_voxel,
        rgf.hash_ecfp,
        lig_xyz,
        feature_dict=lig_ecfp_dict,
        channel_power=f_power)
    self.assertEqual(lig_tensor.shape, tuple([box_w] * 3 + [2**f_power]))
    all_features = lig_tensor.sum()
    # whole ligand should fit in the box
    self.assertEqual(all_features, lig_rdk.GetNumAtoms())
Loading