Commit 637f10b1 authored by marta-sd's avatar marta-sd
Browse files

new implementation of _transform

parent 11df1a7d
Loading
Loading
Loading
Loading
+213 −278
Original line number Diff line number Diff line
@@ -19,7 +19,6 @@ from deepchem.utils.rdkit_util import load_molecule
import numpy as np
from scipy.spatial.distance import cdist
from copy import deepcopy
from functools import partial
from deepchem.feat import ComplexFeaturizer
from deepchem.utils.save import log
"""
@@ -650,14 +649,13 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
  def __init__(self,
               nb_rotations=0,
               nb_reflections=0,
               feature_types="ecfp",
               feature_types=None,
               ecfp_degree=2,
               ecfp_power=3,
               splif_power=3,
               ligand_only=False,
               box_width=16.0,
               voxel_width=1.0,
               voxel_feature_types=[],
               flatten=False,
               verbose=True,
               **kwargs):
@@ -665,7 +663,7 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
    # check if user tries to set removed arguments
    deprecated_args = [
        'box_x', 'box_y', 'box_z', 'save_intermediates', 'voxelize_features',
        'parallel'
        'parallel', 'voxel_feature_types'
    ]
    for arg in deprecated_args:
      if arg in kwargs:
@@ -681,7 +679,6 @@ class RdkitGridFeaturizer(ComplexFeaturizer):

    self.nb_rotations = nb_rotations
    self.nb_reflections = nb_reflections
    self.feature_types = feature_types

    self.ligand_only = ligand_only

@@ -692,7 +689,6 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
    self.box_width = float(box_width)
    self.voxel_width = float(voxel_width)
    self.voxels_per_edge = self.box_width / self.voxel_width
    self.voxel_feature_types = voxel_feature_types

    self.sybyl_types = [
        "C3", "C2", "C1", "Cac", "Car", "N3", "N3+", "Npl", "N2", "N1", "Ng+",
@@ -702,6 +698,190 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
        "SO", "P3", "P", "P3+", "F", "Cl", "Br", "I"
    ]

    self.FLAT_FEATURES = {
        'ecfp_ligand': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [compute_ecfp_features(
                lig_rdk,
                self.ecfp_degree,
                self.ecfp_power)],

        'ecfp_hashed': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [self._vectorize(
                hash_ecfp,
                feature_dict=ecfp_dict,
                channel_power=self.ecfp_power
            ) for ecfp_dict in featurize_binding_pocket_ecfp(
                prot_xyz,
                prot_rdk,
                lig_xyz,
                lig_rdk,
                distances,
                cutoff=4.5,
                ecfp_degree=self.ecfp_degree)],

        'splif_hashed': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [self._vectorize(
                hash_ecfp_pair,
                feature_dict=splif_dict,
                channel_power=self.splif_power
            ) for splif_dict in featurize_splif(
                prot_xyz,
                prot_rdk,
                lig_xyz,
                lig_rdk,
                self.contact_bins,
                distances,
                self.ecfp_degree)],

        'hbond_count': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [self._vectorize(
                hash_ecfp_pair,
                feature_list=hbond_list,
                channel_power=0
            ) for hbond_list in compute_hydrogen_bonds(
                prot_xyz,
                prot_rdk,
                lig_xyz,
                lig_rdk,
                distances,
                self.hbond_dist_bins,
                self.hbond_angle_cutoffs)]
    }

    self.VOXEL_FEATURES = {
        'ecfp': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [sum([self._voxelize(
                convert_atom_to_voxel,
                hash_ecfp,
                xyz,
                feature_dict=ecfp_dict,
                channel_power=self.ecfp_power
            ) for xyz, ecfp_dict in zip(
                (prot_xyz, lig_xyz), featurize_binding_pocket_ecfp(
                    prot_xyz,
                    prot_rdk,
                    lig_xyz,
                    lig_rdk,
                    distances,
                    cutoff=4.5,
                    ecfp_degree=self.ecfp_degree
                ))])],

        'splif': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [self._voxelize(
                convert_atom_pair_to_voxel,
                hash_ecfp_pair,
                (prot_xyz, lig_xyz),
                feature_dict=splif_dict,
                channel_power=self.splif_power
            ) for splif_dict in featurize_splif(
                prot_xyz,
                prot_rdk,
                lig_xyz,
                lig_rdk,
                self.contact_bins,
                distances,
                self.ecfp_degree)],

        'sybyl': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [self._voxelize(
                convert_atom_to_voxel,
                lambda x: hash_sybyl(x, sybyl_types=self.sybyl_types),
                xyz,
                feature_dict=sybyl_dict,
                nb_channel=len(self.sybyl_types)
            ) for xyz, sybyl_dict in zip(
                (prot_xyz, lig_xyz), featurize_binding_pocket_sybyl(
                    prot_xyz,
                    prot_rdk,
                    lig_xyz,
                    lig_rdk,
                    distances,
                    cutoff=7.0
                ))],

        'salt_bridge': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [self._voxelize(
                convert_atom_pair_to_voxel,
                None,
                (prot_xyz, lig_xyz),
                feature_list=compute_salt_bridges(
                    prot_xyz,
                    prot_rdk,
                    lig_xyz,
                    lig_rdk,
                    distances),
                nb_channel=1
            )],

        'charge': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [sum([self._voxelize(
                convert_atom_to_voxel,
                None,
                xyz,
                feature_dict=compute_charge_dictionary(mol),
                nb_channel=1,
                dtype="np.float16"
            ) for xyz, mol in ((prot_xyz, prot_rdk), (lig_xyz, lig_rdk))])],

        'hbond': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [self._voxelize(
                convert_atom_pair_to_voxel,
                None,
                (prot_xyz, lig_xyz),
                feature_list=hbond_list,
                channel_power=0
            ) for hbond_list in compute_hydrogen_bonds(
                prot_xyz,
                prot_rdk,
                lig_xyz,
                lig_rdk,
                distances,
                self.hbond_dist_bins,
                self.hbond_angle_cutoffs)
            ]
    }

    if feature_types is None:
      feature_types = ['ecfp_ligand']

    self.feature_types = []

    for feature_type in feature_types:
      if feature_type in self.FLAT_FEATURES:
        self.feature_types.append((True, feature_type))
        if self.flatten is False:
          warn('%s feature is used, output will be flatten' % feature_type)
          self.flatten = True

      elif feature_type in self.VOXEL_FEATURES:
        self.feature_types.append((False, feature_type))

      elif feature_type == 'flat_combined':
        self.feature_types += list(
            zip([True] * len(self.FLAT_FEATURES),
                sorted(self.FLAT_FEATURES.keys())))
        if self.flatten is False:
          warn('flat features are used, output will be flatten')
          self.flatten = True

      elif feature_type == 'voxel_combined':
        self.feature_types += list(
            zip([False] * len(self.VOXEL_FEATURES),
                sorted(self.VOXEL_FEATURES.keys())))
      elif feature_type == 'all_combined':
        self.feature_types += list(
            zip([True] * len(self.FLAT_FEATURES),
                sorted(self.FLAT_FEATURES.keys())))
        self.feature_types += list(
            zip([False] * len(self.VOXEL_FEATURES),
                sorted(self.VOXEL_FEATURES.keys())))
        if self.flatten is False:
          warn('flat feature are used, output will be flatten')
          self.flatten = True
      else:
        warn('Ignoring unknown feature %s' % feature_type)

  def _featurize_complex(self, ligand_ext, ligand_lines, protein_pdb_lines):
    tempdir = tempfile.mkdtemp()

@@ -774,7 +954,7 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
    ############################################################## TIMING

    if not self.ligand_only:
      protein_xyz, protein_ob = load_molecule(protein_pdb, calc_charges=True)
      protein_xyz, protein_rdk = load_molecule(protein_pdb, calc_charges=True)
    ############################################################## TIMING
    time2 = time.time()
    log("TIMING: Loading protein coordinates took %0.3f s" % (time2 - time1),
@@ -783,18 +963,13 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
    ############################################################## TIMING
    time1 = time.time()
    ############################################################## TIMING
    ligand_xyz, ligand_ob = load_molecule(ligand_file, calc_charges=True)
    ligand_xyz, ligand_rdk = load_molecule(ligand_file, calc_charges=True)
    ############################################################## TIMING
    time2 = time.time()
    log("TIMING: Loading ligand coordinates took %0.3f s" % (time2 - time1),
        self.verbose)
    ############################################################## TIMING

    if "ecfp" in self.feature_types:
      ecfp_array = compute_ecfp_features(ligand_ob, self.ecfp_degree,
                                         self.ecfp_power)
      return ({(0, 0): ecfp_array})

    ############################################################## TIMING
    time1 = time.time()
    ############################################################## TIMING
@@ -808,282 +983,42 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
        self.verbose)
    ############################################################## TIMING

    if "splif" in self.feature_types:
      splif_array = self._featurize_splif(protein_xyz, protein_ob, ligand_xyz,
                                          ligand_ob)
      return ({(0, 0): splif_array})

    if "flat_combined" in self.feature_types:
      return (self._compute_flat_features(protein_xyz, protein_ob, ligand_xyz,
                                          ligand_ob))

    pairwise_distances = compute_pairwise_distances(protein_xyz, ligand_xyz)
    if "ecfp" in self.voxel_feature_types:
      ############################################################## TIMING
      time1 = time.time()
      ############################################################## TIMING
      protein_ecfp_dict, ligand_ecfp_dict = (featurize_binding_pocket_ecfp(
          protein_xyz,
          protein_ob,
          ligand_xyz,
          ligand_ob,
          pairwise_distances,
          cutoff=4.5,
          ecfp_degree=self.ecfp_degree))
      ############################################################## TIMING
      time2 = time.time()
      log("TIMING: ecfp voxel computataion took %0.3f s" % (time2 - time1),
          self.verbose)
      ############################################################## TIMING
    if "splif" in self.voxel_feature_types:
      ############################################################## TIMING
      time1 = time.time()
      ############################################################## TIMING
      splif_dicts = featurize_splif(protein_xyz, protein_ob, ligand_xyz,
                                    ligand_ob, self.contact_bins,
                                    pairwise_distances, self.ecfp_degree)
      ############################################################## TIMING
      time2 = time.time()
      log("TIMING: splif voxel computataion took %0.3f s" % (time2 - time1),
          self.verbose)
      ############################################################## TIMING

    if "hbond" in self.voxel_feature_types:
      ############################################################## TIMING
      time1 = time.time()
      ############################################################## TIMING
      hbond_list = compute_hydrogen_bonds(
          protein_xyz, protein_ob, ligand_xyz, ligand_ob, pairwise_distances,
          self.hbond_dist_bins, self.hbond_angle_cutoffs)
      ############################################################## TIMING
      time2 = time.time()
      log("TIMING: hbond voxel computataion took %0.3f s" % (time2 - time1),
          self.verbose)
      ############################################################## TIMING

    if "sybyl" in self.voxel_feature_types:
      ############################################################## TIMING
      time1 = time.time()
      ############################################################## TIMING
      protein_sybyl_dict, ligand_sybyl_dict = featurize_binding_pocket_sybyl(
          protein_xyz,
          protein_ob,
          ligand_xyz,
          ligand_ob,
          pairwise_distances,
          cutoff=7.0)
      ############################################################## TIMING
      time2 = time.time()
      log("TIMING: sybyl voxel computataion took %0.3f s" % (time2 - time1),
          self.verbose)
      ############################################################## TIMING

    if "pi_stack" in self.voxel_feature_types:
      ############################################################## TIMING
      time1 = time.time()
      ############################################################## TIMING
      protein_pi_t, protein_pi_parallel, ligand_pi_t, ligand_pi_parallel = (
          compute_pi_stack(protein_xyz, protein_ob, ligand_xyz, ligand_ob,
                           pairwise_distances))
      ############################################################## TIMING
      time2 = time.time()
      log("TIMING: pi_stack voxel computataion took %0.3f s" % (time2 - time1),
          self.verbose)
      ############################################################## TIMING

    if "cation_pi" in self.voxel_feature_types:
      ############################################################## TIMING
      time1 = time.time()
      ############################################################## TIMING
      protein_cation_pi, ligand_cation_pi = (compute_binding_pocket_cation_pi(
          protein_xyz, protein_ob, ligand_xyz, ligand_ob))
      ############################################################## TIMING
      time2 = time.time()
      log("TIMING: cation_pi voxel computataion took %0.3f s" % (time2 - time1),
          self.verbose)
      ############################################################## TIMING

    if "salt_bridge" in self.voxel_feature_types:
      ############################################################## TIMING
      time1 = time.time()
      ############################################################## TIMING
      salt_bridge_list = compute_salt_bridges(
          protein_xyz, protein_ob, ligand_xyz, ligand_ob, pairwise_distances)
      ############################################################## TIMING
      time2 = time.time()
      log("TIMING: salt_bridge voxel computataion took %0.3f s" %
          (time2 - time1), self.verbose)
      ############################################################## TIMING

    if "charge" in self.voxel_feature_types:
      ############################################################## TIMING
      time1 = time.time()
      ############################################################## TIMING
      protein_charge_dictionary = compute_charge_dictionary(protein_ob)
      ligand_charge_dictionary = compute_charge_dictionary(ligand_ob)
      ############################################################## TIMING
      time2 = time.time()
      log("TIMING: charge voxel computataion took %0.3f s" % (time2 - time1),
          self.verbose)
      ############################################################## TIMING

    transformed_systems = {}
    transformed_systems[(0, 0)] = [protein_xyz, ligand_xyz]

    for i in range(0, int(self.nb_rotations)):
    for i in range(self.nb_rotations):
      rotated_system = rotate_molecules([protein_xyz, ligand_xyz])
      transformed_systems[(i + 1, 0)] = rotated_system
      for j in range(0, int(self.nb_reflections)):
        reflected_system = self._reflect_molecule(rotated_system)
        transformed_systems[(i + 1, j + 1)] = reflected_system
    # FIXME: _reflect_molecule is not implemented
    #   for j in range(self.nb_reflections):
    #     reflected_system = self._reflect_molecule(rotated_system)
    #     transformed_systems[(i + 1, j + 1)] = reflected_system

    if "voxel_combined" in self.feature_types:
    features = {}
      for system_id, system in transformed_systems.items():
        protein_xyz = system[0]
        ligand_xyz = system[1]
        feature_tensors = []
        if "ecfp" in self.voxel_feature_types:
          ecfp_tensor = self._voxelize(
              convert_atom_to_voxel,
              hash_ecfp,
              protein_xyz,
              feature_dict=protein_ecfp_dict,
              channel_power=self.ecfp_power)
          ecfp_tensor += self._voxelize(
              convert_atom_to_voxel,
              hash_ecfp,
              ligand_xyz,
              feature_dict=ligand_ecfp_dict,
              channel_power=self.ecfp_power)
          feature_tensors.append(ecfp_tensor)
          print("Completed ecfp tensor")

        if "splif" in self.voxel_feature_types:
          feature_tensors += [
              self._voxelize(
                  convert_atom_pair_to_voxel,
                  hash_ecfp_pair, (protein_xyz, ligand_xyz),
                  feature_dict=splif_dict,
                  channel_power=self.splif_power) for splif_dict in splif_dicts
          ]
          print("Completed splif tensor")

        if "hbond" in self.voxel_feature_types:
          feature_tensors += [
              self._voxelize(
                  convert_atom_pair_to_voxel,
                  None, (protein_xyz, ligand_xyz),
                  feature_list=hbond,
                  channel_power=0) for hbond in hbond_list
          ]
          print("Completed hbond tensor")

        if "sybyl" in self.voxel_feature_types:
          sybyl_partial = partial(hash_sybyl, sybyl_types=self.sybyl_types)
          sybyl_tensor = self._voxelize(
              convert_atom_to_voxel,
              hash_sybyl,
              protein_xyz,
              feature_dict=protein_sybyl_dict,
              nb_channel=len(self.sybyl_types))
          sybyl_tensor += self._voxelize(
              convert_atom_to_voxel,
              hash_sybyl,
              ligand_xyz,
              feature_dict=ligand_sybyl_dict,
              nb_channel=len(self.sybyl_types))
          feature_tensors.append(sybyl_tensor)
          print("Completed sybyl tensor")

        if "pi_stack" in self.voxel_feature_types:
          pi_parallel_tensor = self._voxelize(
              convert_atom_to_voxel,
              None,
              protein_xyz,
              feature_dict=protein_pi_parallel,
              nb_channel=1)
          pi_parallel_tensor += self._voxelize(
              convert_atom_to_voxel,
              None,
              ligand_xyz,
              feature_dict=ligand_pi_parallel,
              nb_channel=1)
          feature_tensors.append(pi_parallel_tensor)

          pi_t_tensor = self._voxelize(
              convert_atom_to_voxel,
              None,
              protein_xyz,
              feature_dict=protein_pi_t,
              nb_channel=1)
          pi_t_tensor += self._voxelize(
              convert_atom_to_voxel,
              None,
              ligand_xyz,
              feature_dict=ligand_pi_t,
              nb_channel=1)
          feature_tensors.append(pi_t_tensor)
          print("Completed pi_stack tensor")

        if "cation_pi" in self.voxel_feature_types:
          cation_pi_tensor = self._voxelize(
              convert_atom_to_voxel,
              None,
              protein_xyz,
              feature_dict=protein_cation_pi,
              nb_channel=1)
          cation_pi_tensor += self._voxelize(
              convert_atom_to_voxel,
              None,
              ligand_xyz,
              feature_dict=ligand_cation_pi,
              nb_channel=1)
          feature_tensors.append(cation_pi_tensor)
          print("Completed cation_pi tensor.")

        if "salt_bridge" in self.voxel_feature_types:
          salt_bridge_tensor = self._voxelize(
              convert_atom_pair_to_voxel,
              None, (protein_xyz, ligand_xyz),
              feature_list=salt_bridge_list,
              nb_channel=1)
          feature_tensors.append(salt_bridge_tensor)

          print("Completed salt_bridge tensor.")
    for system_id, (protein_xyz, ligand_xyz) in transformed_systems.items():
      feature_arrays = []
      for is_flat, function_name in self.feature_types:
        if is_flat:
          function = self.FLAT_FEATURES[function_name]
        else:
          function = self.VOXEL_FEATURES[function_name]

        if "charge" in self.voxel_feature_types:
          charge_tensor = self._voxelize(
              convert_atom_to_voxel,
              None,
        feature_arrays += function(
            protein_xyz,
              feature_dict=protein_charge_dictionary,
              nb_channel=1,
              dtype="np.float16")
          charge_tensor += self._voxelize(
              convert_atom_to_voxel,
              None,
            protein_rdk,
            ligand_xyz,
              feature_dict=ligand_charge_dictionary,
              nb_channel=1,
              dtype="np.float16")
          feature_tensors.append(charge_tensor)

          print("Completed salt_bridge tensor.")

        if "charge" in self.voxel_feature_types:
          feature_tensor = np.concatenate(
              feature_tensors, axis=3).astype(np.float16)
        else:
          feature_tensor = np.concatenate(
              feature_tensors, axis=3).astype(np.int8)
            ligand_rdk,
            pairwise_distances,)

        if self.flatten:
          feature_tensor = feature_tensor.flatten()

        features[system_id] = feature_tensor
          features[system_id] = np.concatenate(
              [feature_array.flatten() for feature_array in feature_arrays])
        else:
          features[system_id] = np.concatenate(feature_arrays, axis=-1)

      return (features)
    return features

  def _voxelize(self,
                get_voxels,
+13 −9
Original line number Diff line number Diff line
@@ -333,20 +333,24 @@ class TestRdkitGridFeaturizer(unittest.TestCase):
    self.ligand_file = os.path.join(package_dir, 'dock', 'tests',
                                    '1jld_ligand.sdf')

  def test_init(self):
  def test_featurizer(self):
    ecfp_power = 5
    splif_power = 5

    # just check if it doesn't throw any error for the use-case from examples
    # just check if it works for the use-case from examples
    featurizer = rgf.RdkitGridFeaturizer(
        voxel_width=16.0,
        feature_types="voxel_combined",
        voxel_feature_types=[
            "ecfp", "splif", "hbond", "pi_stack", "cation_pi", "salt_bridge"
        ],
        ecfp_power=5,
        splif_power=5,
        parallel=True,
        feature_types=["ecfp", "splif", "hbond", "salt_bridge"],
        ecfp_power=ecfp_power,
        splif_power=splif_power,
        flatten=True)
    self.assertIsInstance(featurizer, rgf.RdkitGridFeaturizer)
    feature_tensor = featurizer.featurize_complexes([self.ligand_file],
                                                    [self.protein_file])
    self.assertIsInstance(feature_tensor, np.ndarray)
    total_len = (2**ecfp_power + len(featurizer.contact_bins) * 2**splif_power +
                 len(featurizer.hbond_dist_bins) + 1)
    self.assertEqual(feature_tensor.shape, (1, total_len))

  def test_voxelize(self):
    prot_xyz, prot_rdk = rgf.load_molecule(self.protein_file)