Unverified Commit f49416be authored by Nathan Frey's avatar Nathan Frey Committed by GitHub
Browse files

Merge pull request #2318 from ncfrey/complex-featurizer-fix

Complex featurizer fix
parents e20aa6d1 3df372ba
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
@@ -131,8 +131,7 @@ class Docker(object):
        # check whether self.featurizer is instance of ComplexFeaturizer or not
        assert isinstance(self.featurizer, ComplexFeaturizer)
        # TODO: How to handle the failure here?
        (protein_file, ligand_file) = molecular_complex
        features, _ = self.featurizer.featurize([protein_file], [ligand_file])
        features = self.featurizer.featurize([molecular_complex])
        dataset = NumpyDataset(X=features)
        score = self.scoring_model.predict(dataset)
        yield (posed_complex, score)
+1 −1
Original line number Diff line number Diff line
@@ -104,7 +104,7 @@ class TestDocking(unittest.TestCase):
    class DummyFeaturizer(ComplexFeaturizer):

      def featurize(self, complexes, *args, **kwargs):
        return np.zeros((len(complexes), 5)), None
        return np.zeros((len(complexes), 5))

    class DummyModel(Model):

+25 −39
Original line number Diff line number Diff line
@@ -4,8 +4,7 @@ Feature calculations.
import inspect
import logging
import numpy as np
import multiprocessing
from typing import Any, Dict, List, Iterable, Sequence, Tuple, Union
from typing import Any, Dict, Iterable, Tuple, Union, cast

from deepchem.utils import get_print_threshold
from deepchem.utils.typing import PymatgenStructure
@@ -150,70 +149,57 @@ class Featurizer(object):
    return self.__class__.__name__ + override_args_info


class ComplexFeaturizer(object):
class ComplexFeaturizer(Featurizer):
  """"
  Abstract class for calculating features for mol/protein complexes.
  """

  def featurize(self, mol_files: Sequence[str],
                protein_pdbs: Sequence[str]) -> Tuple[np.ndarray, List]:
  def featurize(self,
                complexes: Iterable[Tuple[str, str]],
                log_every_n: int = 100) -> np.ndarray:
    """
    Calculate features for mol/protein complexes.

    Parameters
    ----------
    mols: List[str]
      List of PDB filenames for molecules.
    protein_pdbs: List[str]
      List of PDB filenames for proteins.
    complexes: Iterable[Tuple[str, str]]
      List of filenames (PDB, SDF, etc.) for ligand molecules and proteins.
      Each element should be a tuple of the form (ligand_filename,
      protein_filename).

    Returns
    -------
    features: np.ndarray
      Array of features
    failures: List
      Indices of complexes that failed to featurize.
    """

    pool = multiprocessing.Pool()
    results = []
    for i, (mol_file, protein_pdb) in enumerate(zip(mol_files, protein_pdbs)):
      log_message = "Featurizing %d / %d" % (i, len(mol_files))
      results.append(
          pool.apply_async(ComplexFeaturizer._featurize_callback,
                           (self, mol_file, protein_pdb, log_message)))
    pool.close()
    if not isinstance(complexes, Iterable):
      complexes = [cast(Tuple[str, str], complexes)]
    features = []
    failures = []
    for ind, result in enumerate(results):
      new_features = result.get()
      # Handle loading failures which return None
      if new_features is not None:
        features.append(new_features)
      else:
        failures.append(ind)
    for i, point in enumerate(complexes):
      if i % log_every_n == 0:
        logger.info("Featurizing datapoint %i" % i)
      try:
        features.append(self._featurize(point))
      except:
        logger.warning(
            "Failed to featurize datapoint %i. Appending empty array." % i)
        features.append(np.array([]))

    features = np.asarray(features)
    return features, failures
    return features

  def _featurize(self, mol_pdb: str, complex_pdb: str):
  def _featurize(self, complex: Tuple[str, str]):
    """
    Calculate features for single mol/protein complex.

    Parameters
    ----------
    mol_pdb : str
      The PDB filename.
    complex_pdb : str
      The PDB filename.
    complex: Tuple[str, str]
      Filenames for molecule and protein.
    """
    raise NotImplementedError('Featurizer is not defined.')

  @staticmethod
  def _featurize_callback(featurizer, mol_pdb_file, protein_pdb_file,
                          log_message):
    logging.info(log_message)
    return featurizer._featurize(mol_pdb_file, protein_pdb_file)


class MolecularFeaturizer(Featurizer):
  """Abstract class for calculating a set of features for a
+9 −7
Original line number Diff line number Diff line
@@ -11,6 +11,8 @@ from deepchem.utils.data_utils import pad_array
from deepchem.utils.rdkit_utils import MoleculeLoadException, get_xyz_from_mol, \
  load_molecule, merge_molecules_xyz, merge_molecules

from typing import Tuple


def compute_neighbor_list(coords, neighbor_cutoff, max_num_neighbors,
                          periodic_box_size):
@@ -98,7 +100,7 @@ class NeighborListComplexAtomicCoordinates(ComplexFeaturizer):
  """
  Adjacency list of neighbors for protein-ligand complexes in 3-space.

  Neighbors dtermined by user-dfined distance cutoff.
  Neighbors determined by user-defined distance cutoff.
  """

  def __init__(self, max_num_neighbors=None, neighbor_cutoff=4):
@@ -112,17 +114,16 @@ class NeighborListComplexAtomicCoordinates(ComplexFeaturizer):
    # Type of data created by this featurizer
    self.dtype = object

  def _featurize(self, mol_pdb_file, protein_pdb_file):
  def _featurize(self, complex: Tuple[str, str]):
    """
    Compute neighbor list for complex.

    Parameters
    ----------
    mol_pdb_file: str
      Filename for ligand pdb file.
    protein_pdb_file: str
      Filename for protein pdb file.
    complex: Tuple[str, str]
      Filenames for molecule and protein.
    """
    mol_pdb_file, protein_pdb_file = complex
    mol_coords, ob_mol = load_molecule(mol_pdb_file)
    protein_coords, protein_mol = load_molecule(protein_pdb_file)
    system_coords = merge_molecules_xyz([mol_coords, protein_coords])
@@ -168,7 +169,8 @@ class ComplexNeighborListFragmentAtomicCoordinates(ComplexFeaturizer):
    self.neighborlist_featurizer = NeighborListComplexAtomicCoordinates(
        self.max_num_neighbors, self.neighbor_cutoff)

  def _featurize(self, mol_pdb_file, protein_pdb_file):
  def _featurize(self, complex):
    mol_pdb_file, protein_pdb_file = complex
    try:
      frag1_coords, frag1_mol = load_molecule(
          mol_pdb_file, is_protein=False, sanitize=True, add_hydrogens=False)
+8 −13
Original line number Diff line number Diff line
@@ -93,19 +93,17 @@ class ContactCircularFingerprint(ComplexFeaturizer):
    self.radius = radius
    self.size = size

  def _featurize(self, mol_pdb: str, protein_pdb: str):
  def _featurize(self, complex: Tuple[str, str]):
    """
    Compute featurization for a molecular complex

    Parameters
    ----------
    mol_pdb: str
      Filename for ligand molecule
    protein_pdb: str
      Filename for protein molecule
    complex: Tuple[str, str]
      Filenames for molecule and protein.
    """
    try:
      fragments = load_complex((mol_pdb, protein_pdb), add_hydrogens=False)
      fragments = load_complex(complex, add_hydrogens=False)

    except MoleculeLoadException:
      logger.warning("This molecule cannot be loaded by Rdkit. Returning None")
@@ -181,20 +179,17 @@ class ContactCircularVoxelizer(ComplexFeaturizer):
    self.voxels_per_edge = int(self.box_width / self.voxel_width)
    self.flatten = flatten

  def _featurize(self, mol_pdb: str, protein_pdb: str):
  def _featurize(self, complex):
    """
    Compute featurization for a molecular complex

    Parameters
    ----------
    mol_pdb: str
      Filename for ligand molecule
    protein_pdb: str
      Filename for protein molecule
    complex: Tuple[str, str]
      Filenames for molecule and protein.
    """
    molecular_complex = (mol_pdb, protein_pdb)
    try:
      fragments = load_complex(molecular_complex, add_hydrogens=False)
      fragments = load_complex(complex, add_hydrogens=False)

    except MoleculeLoadException:
      logger.warning("This molecule cannot be loaded by Rdkit. Returning None")
Loading