Unverified Commit 4eefa98d authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2017 from nd-02110114/fix-feat

Refactor featurizer
parents 5e025e46 2ddfd44e
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -109,7 +109,7 @@ class Docker(object):
    if self.scoring_model is not None:
      for posed_complex in complexes:
        # TODO: How to handle the failure here?
        features, _ = self.featurizer.featurize_complexes([molecular_complex])
        features, _ = self.featurizer.featurize([molecular_complex])
        dataset = NumpyDataset(X=features)
        score = self.scoring_model.predict(dataset)
        yield (posed_complex, score)
+1 −1
Original line number Diff line number Diff line
@@ -105,7 +105,7 @@ class TestDocking(unittest.TestCase):

    class DummyFeaturizer(ComplexFeaturizer):

      def featurize_complexes(self, complexes, *args, **kwargs):
      def featurize(self, complexes, *args, **kwargs):
        return np.zeros((len(complexes), 5)), None

    class DummyModel(Model):
+1 −1
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ from deepchem.feat.base_classes import UserDefinedFeaturizer
from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.graph_features import WeaveFeaturizer
from deepchem.feat.fingerprints import CircularFingerprint
from deepchem.feat.basic import RDKitDescriptors
from deepchem.feat.rdkit_descriptors import RDKitDescriptors
from deepchem.feat.coulomb_matrices import CoulombMatrix
from deepchem.feat.coulomb_matrices import CoulombMatrixEig
from deepchem.feat.coulomb_matrices import BPSymmetryFunctionInput
+2 −3
Original line number Diff line number Diff line
@@ -3,7 +3,6 @@ Atomic coordinate featurizer.
"""
import logging
import numpy as np
from deepchem.utils.save import log
from deepchem.feat import Featurizer
from deepchem.feat import ComplexFeaturizer
from deepchem.utils import rdkit_util, pad_array
@@ -162,7 +161,7 @@ class NeighborListComplexAtomicCoordinates(ComplexFeaturizer):
    self.dtype = object
    self.coordinates_featurizer = AtomicCoordinates()

  def _featurize_complex(self, mol_pdb_file, protein_pdb_file):
  def _featurize(self, mol_pdb_file, protein_pdb_file):
    """
    Compute neighbor list for complex.

@@ -218,7 +217,7 @@ class ComplexNeighborListFragmentAtomicCoordinates(ComplexFeaturizer):
    self.neighborlist_featurizer = NeighborListComplexAtomicCoordinates(
        self.max_num_neighbors, self.neighbor_cutoff)

  def _featurize_complex(self, mol_pdb_file, protein_pdb_file):
  def _featurize(self, mol_pdb_file, protein_pdb_file):
    try:
      frag1_coords, frag1_mol = rdkit_util.load_molecule(
          mol_pdb_file, is_protein=False, sanitize=True, add_hydrogens=False)
+38 −85
Original line number Diff line number Diff line
@@ -12,11 +12,6 @@ logger = logging.getLogger(__name__)
JSON = Dict[str, Any]


def _featurize_complex(featurizer, mol_pdb_file, protein_pdb_file, log_message):
  logging.info(log_message)
  return featurizer._featurize_complex(mol_pdb_file, protein_pdb_file)


class Featurizer(object):
  """Abstract class for calculating a set of features for a datapoint.

@@ -63,18 +58,45 @@ class Featurizer(object):
    Parameters
    ----------
    datapoints: object
       Any blob of data you like. Subclasss should instantiate
       this. 
      Any blob of data you like. Subclasss should instantiate this.
    """
    return self.featurize(datapoints)

  def _featurize(self, datapoint):
    """Calculate features for a single datapoint.

    Parameters
    ----------
    datapoint: object
      a single datapoint in a sequence of objects
    """
    raise NotImplementedError('Featurizer is not defined.')


def _featurize_callback(
    featurizer,
    mol_pdb_file,
    protein_pdb_file,
    log_message,
):
  """Callback function for apply_async in ComplexFeaturizer.

  This callback function must be defined globally
  because `apply_async` doesn't execute a nested function.

  See the details from the following link.
  https://stackoverflow.com/questions/56533827/pool-apply-async-nested-function-is-not-executed
  """
  logging.info(log_message)
  return featurizer._featurize(mol_pdb_file, protein_pdb_file)

class ComplexFeaturizer(object):

class ComplexFeaturizer(Featurizer):
  """"
  Abstract class for calculating features for mol/protein complexes.
  """

  def featurize_complexes(self, mol_files, protein_pdbs):
  def featurize(self, mol_files, protein_pdbs):
    """
    Calculate features for mol/protein complexes.

@@ -92,12 +114,13 @@ class ComplexFeaturizer(object):
    failures: list
      Indices of complexes that failed to featurize.
    """

    pool = multiprocessing.Pool()
    results = []
    for i, (mol_file, protein_pdb) in enumerate(zip(mol_files, protein_pdbs)):
      log_message = "Featurizing %d / %d" % (i, len(mol_files))
      results.append(
          pool.apply_async(_featurize_complex,
          pool.apply_async(_featurize_callback,
                           (self, mol_file, protein_pdb, log_message)))
    pool.close()
    features = []
@@ -112,7 +135,7 @@ class ComplexFeaturizer(object):
    features = np.asarray(features)
    return features, failures

  def _featurize_complex(self, mol_pdb, complex_pdb):
  def _featurize(self, mol_pdb, complex_pdb):
    """
    Calculate features for single mol/protein complex.

@@ -187,28 +210,6 @@ class MolecularFeaturizer(Featurizer):
    features = np.asarray(features)
    return features

  def _featurize(self, mol):
    """
    Calculate features for a single molecule.

    Parameters
    ----------
    mol : RDKit Mol
        Molecule.
    """
    raise NotImplementedError('Featurizer is not defined.')

  def __call__(self, molecules):
    """
    Calculate features for molecules.

    Parameters
    ----------
    molecules: iterable
        An iterable yielding RDKit Mol objects or SMILES strings.
    """
    return self.featurize(molecules)


class StructureFeaturizer(Featurizer):
  """
@@ -282,30 +283,6 @@ class StructureFeaturizer(Featurizer):
    features = np.asarray(features)
    return features

  def _featurize(self, structure: "pymatgen.Structure"):
    """Calculate features for a single crystal structure.

    Parameters
    ----------
    structure: pymatgen.Structure object
      Structure object with 3D coordinates and periodic lattice.

    """

    raise NotImplementedError('Featurizer is not defined.')

  def __call__(self, structures: Iterable[dict]):
    """Calculate features for crystal structures.

    Parameters
    ----------
    structures: Iterable[dict]
      An iterable of crystal structure dictionaries.

    """

    return self.featurize(structures)


class CompositionFeaturizer(Featurizer):
  """
@@ -377,30 +354,6 @@ class CompositionFeaturizer(Featurizer):
    features = np.asarray(features)
    return features

  def _featurize(self, composition: "pymatgen.Composition"):
    """Calculate features for a single crystal composition.

    Parameters
    ----------
    composition: pymatgen.Composition object
      Composition object for 3D inorganic crystal.

    """

    raise NotImplementedError('Featurizer is not defined.')

  def __call__(self, compositions: Iterable[str]):
    """Calculate features for crystal compositions.

    Parameters
    ----------
    compositions: Iterable[str]
      An iterable of crystal compositions.

    """

    return self.featurize(compositions)


class UserDefinedFeaturizer(Featurizer):
  """Directs usage of user-computed featurizations."""
Loading