Unverified Commit 5000a7a1 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1372 from peastman/parallelrdkit

Parallelize RdkitGridFeaturizer
parents 905e5509 cf6475f7
Loading
Loading
Loading
Loading
+221 −220
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ from warnings import warn
import time
import tempfile
import hashlib
import multiprocessing
from collections import Counter
from rdkit import Chem
from rdkit.Chem import AllChem
@@ -706,7 +707,8 @@ def get_partial_charge(atom):


def get_formal_charge(atom):
  warn('get_formal_charge function is deprecated and will be removed'
  warn(
      'get_formal_charge function is deprecated and will be removed'
      ' in version 1.4, use get_partial_charge instead', DeprecationWarning)
  return get_partial_charge(atom)

@@ -966,7 +968,8 @@ class RdkitGridFeaturizer(ComplexFeaturizer):

    for arg in deprecated_args:
      if arg in kwargs and verbose:
        warn('%s argument was removed and it is ignored,'
        warn(
            '%s argument was removed and it is ignored,'
            ' using it will result in error in version 1.4' % arg,
            DeprecationWarning)

@@ -1011,205 +1014,13 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
        "SO", "P3", "P", "P3+", "F", "Cl", "Br", "I"
    ]

    # define methods to calculate available flat features
    # all methods (flat and voxel) must have the same API:
    # f(prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances) -> list of np.ndarrays
    self.FLAT_FEATURES = {
        'ecfp_ligand': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [compute_ecfp_features(
                lig_rdk,
                self.ecfp_degree,
                self.ecfp_power)],

        'ecfp_hashed': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [self._vectorize(
                hash_ecfp,
                feature_dict=ecfp_dict,
                channel_power=self.ecfp_power
            ) for ecfp_dict in featurize_binding_pocket_ecfp(
                prot_xyz,
                prot_rdk,
                lig_xyz,
                lig_rdk,
                distances,
                cutoff=self.cutoffs['ecfp_cutoff'],
                ecfp_degree=self.ecfp_degree)],

        'splif_hashed': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [self._vectorize(
                hash_ecfp_pair,
                feature_dict=splif_dict,
                channel_power=self.splif_power
            ) for splif_dict in featurize_splif(
                prot_xyz,
                prot_rdk,
                lig_xyz,
                lig_rdk,
                self.cutoffs['splif_contact_bins'],
                distances,
                self.ecfp_degree)],

        'hbond_count': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [self._vectorize(
                hash_ecfp_pair,
                feature_list=hbond_list,
                channel_power=0
            ) for hbond_list in compute_hydrogen_bonds(
                prot_xyz,
                prot_rdk,
                lig_xyz,
                lig_rdk,
                distances,
                self.cutoffs['hbond_dist_bins'],
                self.cutoffs['hbond_angle_cutoffs'])]
    }

    def voxelize_pi_stack(prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances):
      protein_pi_t, protein_pi_parallel, ligand_pi_t, ligand_pi_parallel = (
          compute_pi_stack(
              prot_rdk,
              lig_rdk,
              distances,
              dist_cutoff=self.cutoffs['pi_stack_dist_cutoff'],
              angle_cutoff=self.cutoffs['pi_stack_angle_cutoff']))
      pi_parallel_tensor = self._voxelize(
          convert_atom_to_voxel,
          None,
          prot_xyz,
          feature_dict=protein_pi_parallel,
          nb_channel=1)
      pi_parallel_tensor += self._voxelize(
          convert_atom_to_voxel,
          None,
          lig_xyz,
          feature_dict=ligand_pi_parallel,
          nb_channel=1)

      pi_t_tensor = self._voxelize(
          convert_atom_to_voxel,
          None,
          prot_xyz,
          feature_dict=protein_pi_t,
          nb_channel=1)
      pi_t_tensor += self._voxelize(
          convert_atom_to_voxel,
          None,
          lig_xyz,
          feature_dict=ligand_pi_t,
          nb_channel=1)
      return [pi_parallel_tensor, pi_t_tensor]

    # define methods to calculate available voxel features
    self.VOXEL_FEATURES = {
        'ecfp': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [sum([self._voxelize(
                convert_atom_to_voxel,
                hash_ecfp,
                xyz,
                feature_dict=ecfp_dict,
                channel_power=self.ecfp_power
            ) for xyz, ecfp_dict in zip(
                (prot_xyz, lig_xyz), featurize_binding_pocket_ecfp(
                    prot_xyz,
                    prot_rdk,
                    lig_xyz,
                    lig_rdk,
                    distances,
                    cutoff=self.cutoffs['ecfp_cutoff'],
                    ecfp_degree=self.ecfp_degree
                ))])],

        'splif': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [self._voxelize(
                convert_atom_pair_to_voxel,
                hash_ecfp_pair,
                (prot_xyz, lig_xyz),
                feature_dict=splif_dict,
                channel_power=self.splif_power
            ) for splif_dict in featurize_splif(
                prot_xyz,
                prot_rdk,
                lig_xyz,
                lig_rdk,
                self.cutoffs['splif_contact_bins'],
                distances,
                self.ecfp_degree)],

        'sybyl': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [self._voxelize(
                convert_atom_to_voxel,
                lambda x: hash_sybyl(x, sybyl_types=self.sybyl_types),
                xyz,
                feature_dict=sybyl_dict,
                nb_channel=len(self.sybyl_types)
            ) for xyz, sybyl_dict in zip(
                (prot_xyz, lig_xyz), featurize_binding_pocket_sybyl(
                    prot_xyz,
                    prot_rdk,
                    lig_xyz,
                    lig_rdk,
                    distances,
                    cutoff=self.cutoffs['sybyl_cutoff']
                ))],

        'salt_bridge': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [self._voxelize(
                convert_atom_pair_to_voxel,
                None,
                (prot_xyz, lig_xyz),
                feature_list=compute_salt_bridges(
                    prot_xyz,
                    prot_rdk,
                    lig_xyz,
                    lig_rdk,
                    distances,
                    cutoff=self.cutoffs['salt_bridges_cutoff']),
                nb_channel=1
            )],

        'charge': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [sum([self._voxelize(
                convert_atom_to_voxel,
                None,
                xyz,
                feature_dict=compute_charge_dictionary(mol),
                nb_channel=1,
                dtype="np.float16"
            ) for xyz, mol in ((prot_xyz, prot_rdk), (lig_xyz, lig_rdk))])],

        'hbond': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [self._voxelize(
                convert_atom_pair_to_voxel,
                None,
                (prot_xyz, lig_xyz),
                feature_list=hbond_list,
                channel_power=0
            ) for hbond_list in compute_hydrogen_bonds(
                prot_xyz,
                prot_rdk,
                lig_xyz,
                lig_rdk,
                distances,
                self.cutoffs['hbond_dist_bins'],
                self.cutoffs['hbond_angle_cutoffs'])
            ],
        'pi_stack': voxelize_pi_stack,

        'cation_pi': lambda prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances:
            [sum([self._voxelize(
                convert_atom_to_voxel,
                None,
                xyz,
                feature_dict=cation_pi_dict,
                nb_channel=1
            ) for xyz, cation_pi_dict in zip(
                (prot_xyz, lig_xyz), compute_binding_pocket_cation_pi(
                    prot_rdk,
                    lig_rdk,
                    dist_cutoff=self.cutoffs['cation_pi_dist_cutoff'],
                    angle_cutoff=self.cutoffs['cation_pi_angle_cutoff'],
                ))])],
    }
    self.FLAT_FEATURES = [
        'ecfp_ligand', 'ecfp_hashed', 'splif_hashed', 'hbond_count'
    ]
    self.VOXEL_FEATURES = [
        'ecfp', 'splif', 'sybyl', 'salt_bridge', 'charge', 'hbond', 'pi_stack',
        'cation_pi'
    ]

    if feature_types is None:
      feature_types = ['ecfp']
@@ -1249,7 +1060,7 @@ class RdkitGridFeaturizer(ComplexFeaturizer):

      elif feature_type == 'flat_combined':
        self.feature_types += [(True, ftype)
                               for ftype in sorted(self.FLAT_FEATURES.keys())
                               for ftype in sorted(self.FLAT_FEATURES)
                               if ftype not in ignored_features]
        if self.flatten is False:
          if self.verbose:
@@ -1258,14 +1069,14 @@ class RdkitGridFeaturizer(ComplexFeaturizer):

      elif feature_type == 'voxel_combined':
        self.feature_types += [(False, ftype)
                               for ftype in sorted(self.VOXEL_FEATURES.keys())
                               for ftype in sorted(self.VOXEL_FEATURES)
                               if ftype not in ignored_features]
      elif feature_type == 'all_combined':
        self.feature_types += [(True, ftype)
                               for ftype in sorted(self.FLAT_FEATURES.keys())
                               for ftype in sorted(self.FLAT_FEATURES)
                               if ftype not in ignored_features]
        self.feature_types += [(False, ftype)
                               for ftype in sorted(self.VOXEL_FEATURES.keys())
                               for ftype in sorted(self.VOXEL_FEATURES)
                               if ftype not in ignored_features]
        if self.flatten is False:
          if self.verbose:
@@ -1274,8 +1085,152 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
      elif self.verbose:
        warn('Ignoring unknown feature %s' % feature_type)

  def _featurize_complex(self, ligand_ext, ligand_lines, protein_pdb_lines):
  def _compute_feature(self, feature_name, prot_xyz, prot_rdk, lig_xyz, lig_rdk,
                       distances):
    if feature_name == 'ecfp_ligand':
      return [compute_ecfp_features(lig_rdk, self.ecfp_degree, self.ecfp_power)]
    if feature_name == 'ecfp_hashed':
      return [
          self._vectorize(
              hash_ecfp, feature_dict=ecfp_dict, channel_power=self.ecfp_power)
          for ecfp_dict in featurize_binding_pocket_ecfp(
              prot_xyz,
              prot_rdk,
              lig_xyz,
              lig_rdk,
              distances,
              cutoff=self.cutoffs['ecfp_cutoff'],
              ecfp_degree=self.ecfp_degree)
      ]
    if feature_name == 'splif_hashed':
      return [
          self._vectorize(
              hash_ecfp_pair,
              feature_dict=splif_dict,
              channel_power=self.splif_power) for splif_dict in featurize_splif(
                  prot_xyz, prot_rdk, lig_xyz, lig_rdk, self.cutoffs[
                      'splif_contact_bins'], distances, self.ecfp_degree)
      ]
    if feature_name == 'hbond_count':
      return [
          self._vectorize(
              hash_ecfp_pair, feature_list=hbond_list, channel_power=0)
          for hbond_list in compute_hydrogen_bonds(
              prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances, self.cutoffs[
                  'hbond_dist_bins'], self.cutoffs['hbond_angle_cutoffs'])
      ]
    if feature_name == 'ecfp':
      return [
          sum([
              self._voxelize(
                  convert_atom_to_voxel,
                  hash_ecfp,
                  xyz,
                  feature_dict=ecfp_dict,
                  channel_power=self.ecfp_power)
              for xyz, ecfp_dict in zip((prot_xyz, lig_xyz),
                                        featurize_binding_pocket_ecfp(
                                            prot_xyz,
                                            prot_rdk,
                                            lig_xyz,
                                            lig_rdk,
                                            distances,
                                            cutoff=self.cutoffs['ecfp_cutoff'],
                                            ecfp_degree=self.ecfp_degree))
          ])
      ]
    if feature_name == 'splif':
      return [
          self._voxelize(
              convert_atom_pair_to_voxel,
              hash_ecfp_pair, (prot_xyz, lig_xyz),
              feature_dict=splif_dict,
              channel_power=self.splif_power) for splif_dict in featurize_splif(
                  prot_xyz, prot_rdk, lig_xyz, lig_rdk, self.cutoffs[
                      'splif_contact_bins'], distances, self.ecfp_degree)
      ]
    if feature_name == 'sybyl':
      return [
          self._voxelize(
              convert_atom_to_voxel,
              lambda x: hash_sybyl(x, sybyl_types=self.sybyl_types),
              xyz,
              feature_dict=sybyl_dict,
              nb_channel=len(self.sybyl_types))
          for xyz, sybyl_dict in zip((prot_xyz, lig_xyz),
                                     featurize_binding_pocket_sybyl(
                                         prot_xyz,
                                         prot_rdk,
                                         lig_xyz,
                                         lig_rdk,
                                         distances,
                                         cutoff=self.cutoffs['sybyl_cutoff']))
      ]
    if feature_name == 'salt_bridge':
      return [
          self._voxelize(
              convert_atom_pair_to_voxel,
              None, (prot_xyz, lig_xyz),
              feature_list=compute_salt_bridges(
                  prot_xyz,
                  prot_rdk,
                  lig_xyz,
                  lig_rdk,
                  distances,
                  cutoff=self.cutoffs['salt_bridges_cutoff']),
              nb_channel=1)
      ]
    if feature_name == 'charge':
      return [
          sum([
              self._voxelize(
                  convert_atom_to_voxel,
                  None,
                  xyz,
                  feature_dict=compute_charge_dictionary(mol),
                  nb_channel=1,
                  dtype="np.float16")
              for xyz, mol in ((prot_xyz, prot_rdk), (lig_xyz, lig_rdk))
          ])
      ]
    if feature_name == 'hbond':
      return [
          self._voxelize(
              convert_atom_pair_to_voxel,
              None, (prot_xyz, lig_xyz),
              feature_list=hbond_list,
              channel_power=0) for hbond_list in compute_hydrogen_bonds(
                  prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances, self.cutoffs[
                      'hbond_dist_bins'], self.cutoffs['hbond_angle_cutoffs'])
      ]
    if feature_name == 'pi_stack':
      return self._voxelize_pi_stack(prot_xyz, prot_rdk, lig_xyz, lig_rdk,
                                     distances)
    if feature_name == 'cation_pi':
      return [
          sum([
              self._voxelize(
                  convert_atom_to_voxel,
                  None,
                  xyz,
                  feature_dict=cation_pi_dict,
                  nb_channel=1) for xyz, cation_pi_dict in zip(
                      (prot_xyz, lig_xyz),
                      compute_binding_pocket_cation_pi(
                          prot_rdk,
                          lig_rdk,
                          dist_cutoff=self.cutoffs['cation_pi_dist_cutoff'],
                          angle_cutoff=self.cutoffs['cation_pi_angle_cutoff'],
                      ))
          ])
      ]
    raise ValueError('Unknown feature type "%s"' % feature_name)

  def _featurize_complex(self, ligand_ext, ligand_lines, protein_pdb_lines,
                         log_message):
    tempdir = tempfile.mkdtemp()
    if log_message is not None:
      log(log_message)

    ############################################################## TIMING
    time1 = time.time()
@@ -1301,7 +1256,7 @@ class RdkitGridFeaturizer(ComplexFeaturizer):

    features_dict = self._transform(protein_pdb_file, ligand_file)
    shutil.rmtree(tempdir)
    return features_dict.values()
    return list(features_dict.values())

  def featurize_complexes(self, mol_files, protein_pdbs, log_every_n=1000):
    """
@@ -1314,17 +1269,24 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
    protein_pdbs: list
      List of PDB filenames for proteins.
    """
    features = []
    pool = multiprocessing.Pool()
    results = []
    for i, (mol_file, protein_pdb) in enumerate(zip(mol_files, protein_pdbs)):
      if i % log_every_n == 0:
        log("Featurizing %d / %d" % (i, len(mol_files)))
      log_message = "Featurizing %d / %d" % (
          i, len(mol_files)) if i % log_every_n == 0 else None
      ligand_ext = get_ligand_filetype(mol_file)
      with open(mol_file) as mol_f:
        mol_lines = mol_f.readlines()
      with open(protein_pdb) as protein_file:
        protein_pdb_lines = protein_file.readlines()
      features += self._featurize_complex(ligand_ext, mol_lines,
                                          protein_pdb_lines)
      results.append(
          pool.apply_async(
              _featurize_complex,
              (self, ligand_ext, mol_lines, protein_pdb_lines, log_message)))
    pool.close()
    features = []
    for result in results:
      features += result.get()
    features = np.asarray(features)
    return features

@@ -1388,18 +1350,16 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
    for system_id, (protein_xyz, ligand_xyz) in transformed_systems.items():
      feature_arrays = []
      for is_flat, function_name in self.feature_types:
        if is_flat:
          function = self.FLAT_FEATURES[function_name]
        else:
          function = self.VOXEL_FEATURES[function_name]

        feature_arrays += function(
        result = self._compute_feature(
            function_name,
            protein_xyz,
            protein_rdk,
            ligand_xyz,
            ligand_rdk,
            pairwise_distances,
        )
        feature_arrays += result

        if self.flatten:
          features[system_id] = np.concatenate(
@@ -1453,6 +1413,41 @@ class RdkitGridFeaturizer(ComplexFeaturizer):

    return feature_tensor

  def _voxelize_pi_stack(self, prot_xyz, prot_rdk, lig_xyz, lig_rdk, distances):
    protein_pi_t, protein_pi_parallel, ligand_pi_t, ligand_pi_parallel = (
        compute_pi_stack(
            prot_rdk,
            lig_rdk,
            distances,
            dist_cutoff=self.cutoffs['pi_stack_dist_cutoff'],
            angle_cutoff=self.cutoffs['pi_stack_angle_cutoff']))
    pi_parallel_tensor = self._voxelize(
        convert_atom_to_voxel,
        None,
        prot_xyz,
        feature_dict=protein_pi_parallel,
        nb_channel=1)
    pi_parallel_tensor += self._voxelize(
        convert_atom_to_voxel,
        None,
        lig_xyz,
        feature_dict=ligand_pi_parallel,
        nb_channel=1)

    pi_t_tensor = self._voxelize(
        convert_atom_to_voxel,
        None,
        prot_xyz,
        feature_dict=protein_pi_t,
        nb_channel=1)
    pi_t_tensor += self._voxelize(
        convert_atom_to_voxel,
        None,
        lig_xyz,
        feature_dict=ligand_pi_t,
        nb_channel=1)
    return [pi_parallel_tensor, pi_t_tensor]

  def _vectorize(self,
                 hash_function,
                 feature_dict=None,
@@ -1469,3 +1464,9 @@ class RdkitGridFeaturizer(ComplexFeaturizer):
      feature_vector[0] += len(feature_list)

    return feature_vector


def _featurize_complex(featurizer, ligand_ext, ligand_lines, protein_pdb_lines,
                       log_message):
  return featurizer._featurize_complex(ligand_ext, ligand_lines,
                                       protein_pdb_lines, log_message)