Commit 81521621 authored by seyonechithrananda's avatar seyonechithrananda
Browse files

mypy change all to use datapoint

parent a8463d06
Loading
Loading
Loading
Loading
+44 −19
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ Feature calculations.
import inspect
import logging
import numpy as np
from typing import Any, Dict, Iterable, Tuple, Union, cast
from typing import Any, Dict, Iterable, Optional, Tuple, Union, cast

from deepchem.utils import get_print_threshold
from deepchem.utils.typing import PymatgenStructure
@@ -44,7 +44,7 @@ class Featurizer(object):
      A numpy array containing a featurized representation of `datapoints`.
    """
    datapoints = list(datapoints)
    features: list[any] = []
    features = []
    for i, point in enumerate(datapoints):
      if i % log_every_n == 0:
        logger.info("Featurizing datapoint %i" % i)
@@ -158,7 +158,7 @@ class ComplexFeaturizer(Featurizer):
  """

  def featurize(self,
                datapoints: Iterable[Tuple[str, str]] = None,
                datapoints: Optional[Iterable[Tuple[str, str]]] = None,
                log_every_n: int = 100,
                **kwargs) -> np.ndarray:
    """
@@ -212,7 +212,7 @@ class ComplexFeaturizer(Featurizer):

    return np.asarray(features)

  def _featurize(self, datapoints: Iterable[Tuple[str, str]] = None, **kwargs):
  def _featurize(self, datapoint: Optional[Tuple[str, str]] = None, **kwargs):
    """
    Calculate features for single mol/protein complex.

@@ -242,12 +242,12 @@ class MolecularFeaturizer(Featurizer):
  The subclasses of this class require RDKit to be installed.
  """

  def featurize(self, molecules, log_every_n=1000, **kwargs) -> np.ndarray:
  def featurize(self, datapoints, log_every_n=1000, **kwargs) -> np.ndarray:
    """Calculate features for molecules.

    Parameters
    ----------
    molecules: rdkit.Chem.rdchem.Mol / SMILES string / iterable
    datapoints: rdkit.Chem.rdchem.Mol / SMILES string / iterable
      RDKit Mol, or SMILES string or iterable sequence of RDKit mols/SMILES
      strings.
    log_every_n: int, default 1000
@@ -266,15 +266,21 @@ class MolecularFeaturizer(Featurizer):
    except ModuleNotFoundError:
      raise ImportError("This class requires RDKit to be installed.")

    if 'molecules' in kwargs:
      datapoints = kwargs.get("molecules")
      raise DeprecationWarning(
          'Molecules is being phased out as a parameter, please pass "datapoints" instead.'
      )

    # Special case handling of single molecule
    if isinstance(molecules, str) or isinstance(molecules, Mol):
      molecules = [molecules]
    if isinstance(datapoints, str) or isinstance(datapoints, Mol):
      datapoints = [datapoints]
    else:
      # Convert iterables to list
      molecules = list(molecules)
      datapoints = list(datapoints)

    features: list = []
    for i, mol in enumerate(molecules):
    for i, mol in enumerate(datapoints):
      if i % log_every_n == 0:
        logger.info("Featurizing datapoint %i" % i)

@@ -323,14 +329,15 @@ class MaterialStructureFeaturizer(Featurizer):
  """

  def featurize(self,
                structures: Iterable[Union[Dict[str, Any], PymatgenStructure]],
                datapoints: Optional[Iterable[Union[Dict[str, Any],
                                                    PymatgenStructure]]] = None,
                log_every_n: int = 1000,
                **kwargs) -> np.ndarray:
    """Calculate features for crystal structures.

    Parameters
    ----------
    structures: Iterable[Union[Dict, pymatgen.core.Structure]]
    datapoints: Iterable[Union[Dict, pymatgen.core.Structure]]
      Iterable sequence of pymatgen structure dictionaries
      or pymatgen.core.Structure. Please confirm the dictionary representations
      of pymatgen.core.Structure from https://pymatgen.org/pymatgen.core.structure.html.
@@ -341,16 +348,25 @@ class MaterialStructureFeaturizer(Featurizer):
    -------
    features: np.ndarray
      A numpy array containing a featurized representation of
      `structures`.
      `datapoints`.
    """
    try:
      from pymatgen.core import Structure
    except ModuleNotFoundError:
      raise ImportError("This class requires pymatgen to be installed.")

    structures = list(structures)
    if 'structures' in kwargs:
      datapoints = kwargs.get("structures")
      raise DeprecationWarning(
          'Structures is being phased out as a parameter, please pass "datapoints" instead.'
      )

    if not isinstance(datapoints, Iterable):
      datapoints = [cast(Union[Dict[str, Any], PymatgenStructure], datapoints)]

    datapoints = list(datapoints)
    features = []
    for idx, structure in enumerate(structures):
    for idx, structure in enumerate(datapoints):
      if idx % log_every_n == 0:
        logger.info("Featurizing datapoint %i" % idx)
      try:
@@ -389,14 +405,14 @@ class MaterialCompositionFeaturizer(Featurizer):
  """

  def featurize(self,
                compositions: Iterable[str],
                datapoints: Optional[Iterable[str]] = None,
                log_every_n: int = 1000,
                **kwargs) -> np.ndarray:
    """Calculate features for crystal compositions.

    Parameters
    ----------
    compositions: Iterable[str]
    datapoints: Iterable[str]
      Iterable sequence of composition strings, e.g. "MoS2".
    log_every_n: int, default 1000
      Logging messages reported every `log_every_n` samples.
@@ -412,9 +428,18 @@ class MaterialCompositionFeaturizer(Featurizer):
    except ModuleNotFoundError:
      raise ImportError("This class requires pymatgen to be installed.")

    compositions = list(compositions)
    if 'compositions' in kwargs and datapoints is None:
      datapoints = kwargs.get("compositions")
      raise DeprecationWarning(
          'Compositions is being phased out as a parameter, please pass "datapoints" instead.'
      )

    if not isinstance(datapoints, Iterable):
      datapoints = [cast(str, datapoints)]

    datapoints = list(datapoints)
    features = []
    for idx, composition in enumerate(compositions):
    for idx, composition in enumerate(datapoints):
      if idx % log_every_n == 0:
        logger.info("Featurizing datapoint %i" % idx)
      try:
+10 −4
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ from deepchem.utils.data_utils import pad_array
from deepchem.utils.rdkit_utils import MoleculeLoadException, get_xyz_from_mol, \
  load_molecule, merge_molecules_xyz, merge_molecules

from typing import Tuple
from typing import Tuple, Optional, Iterable, cast


def compute_neighbor_list(coords, neighbor_cutoff, max_num_neighbors,
@@ -118,16 +118,22 @@ class NeighborListComplexAtomicCoordinates(ComplexFeaturizer):
    # Type of data created by this featurizer
    self.dtype = object

  def _featurize(self, complex: Tuple[str, str], **kwargs):
  def _featurize(self, datapoint, **kwargs):
    """
    Compute neighbor list for complex.

    Parameters
    ----------
    complex: Tuple[str, str]
    datapoint: Tuple[str, str]
      Filenames for molecule and protein.
    """
    mol_pdb_file, protein_pdb_file = complex
    if 'complex' in kwargs:
      datapoint = kwargs.get("complex")
      raise DeprecationWarning(
          'Complex is being phased out as a parameter, please pass "datapoint" instead.'
      )

    mol_pdb_file, protein_pdb_file = datapoint
    mol_coords, ob_mol = load_molecule(mol_pdb_file)
    protein_coords, protein_mol = load_molecule(protein_pdb_file)
    system_coords = merge_molecules_xyz([mol_coords, protein_coords])
+9 −3
Original line number Diff line number Diff line
@@ -93,17 +93,23 @@ class ContactCircularFingerprint(ComplexFeaturizer):
    self.radius = radius
    self.size = size

  def _featurize(self, complex: Tuple[str, str], **kwargs):
  def _featurize(self, datapoint, **kwargs):
    """
    Compute featurization for a molecular complex

    Parameters
    ----------
    complex: Tuple[str, str]
    datapoint: Tuple[str, str]
      Filenames for molecule and protein.
    """
    if 'complex' in kwargs:
      datapoint = kwargs.get("complex")
      raise DeprecationWarning(
          'Complex is being phased out as a parameter, please pass "datapoint" instead.'
      )

    try:
      fragments = load_complex(complex, add_hydrogens=False)
      fragments = load_complex(datapoint, add_hydrogens=False)

    except MoleculeLoadException:
      logger.warning("This molecule cannot be loaded by Rdkit. Returning None")
+52 −23
Original line number Diff line number Diff line
@@ -81,18 +81,23 @@ class ChargeVoxelizer(ComplexFeaturizer):
    self.voxel_width = voxel_width
    self.reduce_to_contacts = reduce_to_contacts

  def _featurize(self, complex: Tuple[str, str],
                 **kwargs) -> Optional[np.ndarray]:
  def _featurize(self, datapoint, **kwargs) -> Optional[np.ndarray]:
    """
    Compute featurization for a single mol/protein complex

    Parameters
    ----------
    complex: Tuple[str, str]
    datapoint: Tuple[str, str]
      Filenames for molecule and protein.
    """
    if 'complex' in kwargs:
      datapoint = kwargs.get("complex")
      raise DeprecationWarning(
          'Complex is being phased out as a parameter, please pass "datapoint" instead.'
      )

    try:
      fragments = rdkit_utils.load_complex(complex, add_hydrogens=False)
      fragments = rdkit_utils.load_complex(datapoint, add_hydrogens=False)

    except MoleculeLoadException:
      logger.warning("This molecule cannot be loaded by Rdkit. Returning None")
@@ -168,18 +173,23 @@ class SaltBridgeVoxelizer(ComplexFeaturizer):
    self.voxel_width = voxel_width
    self.reduce_to_contacts = reduce_to_contacts

  def _featurize(self, complex: Tuple[str, str],
                 **kwargs) -> Optional[np.ndarray]:
  def _featurize(self, datapoint, **kwargs) -> Optional[np.ndarray]:
    """
    Compute featurization for a single mol/protein complex

    Parameters
    ----------
    complex: Tuple[str, str]
    datapoint: Tuple[str, str]
      Filenames for molecule and protein.
    """
    if 'complex' in kwargs:
      datapoint = kwargs.get("complex")
      raise DeprecationWarning(
          'Complex is being phased out as a parameter, please pass "datapoint" instead.'
      )

    try:
      fragments = rdkit_utils.load_complex(complex, add_hydrogens=False)
      fragments = rdkit_utils.load_complex(datapoint, add_hydrogens=False)

    except MoleculeLoadException:
      logger.warning("This molecule cannot be loaded by Rdkit. Returning None")
@@ -254,18 +264,23 @@ class CationPiVoxelizer(ComplexFeaturizer):
    self.box_width = box_width
    self.voxel_width = voxel_width

  def _featurize(self, complex: Tuple[str, str],
                 **kwargs) -> Optional[np.ndarray]:
  def _featurize(self, datapoint, **kwargs) -> Optional[np.ndarray]:
    """
    Compute featurization for a single mol/protein complex

    Parameters
    ----------
    complex: Tuple[str, str]
    datapoint: Tuple[str, str]
      Filenames for molecule and protein.
    """
    if 'complex' in kwargs:
      datapoint = kwargs.get("complex")
      raise DeprecationWarning(
          'Complex is being phased out as a parameter, please pass "datapoint" instead.'
      )

    try:
      fragments = rdkit_utils.load_complex(complex, add_hydrogens=False)
      fragments = rdkit_utils.load_complex(datapoint, add_hydrogens=False)

    except MoleculeLoadException:
      logger.warning("This molecule cannot be loaded by Rdkit. Returning None")
@@ -346,17 +361,23 @@ class PiStackVoxelizer(ComplexFeaturizer):
    self.box_width = box_width
    self.voxel_width = voxel_width

  def _featurize(self, complex, **kwargs) -> Optional[np.ndarray]:
  def _featurize(self, datapoint, **kwargs) -> Optional[np.ndarray]:
    """
    Compute featurization for a single mol/protein complex

    Parameters
    ----------
    complex: Tuple[str, str]
    datapoint: Tuple[str, str]
      Filenames for molecule and protein.
    """
    if 'complex' in kwargs:
      datapoint = kwargs.get("complex")
      raise DeprecationWarning(
          'Complex is being phased out as a parameter, please pass "datapoint" instead.'
      )

    try:
      fragments = rdkit_utils.load_complex(complex, add_hydrogens=False)
      fragments = rdkit_utils.load_complex(datapoint, add_hydrogens=False)

    except MoleculeLoadException:
      logger.warning("This molecule cannot be loaded by Rdkit. Returning None")
@@ -458,18 +479,22 @@ class HydrogenBondCounter(ComplexFeaturizer):
      self.angle_cutoffs = angle_cutoffs
    self.reduce_to_contacts = reduce_to_contacts

  def _featurize(self, complex: Tuple[str, str],
                 **kwargs) -> Optional[np.ndarray]:
  def _featurize(self, datapoint, **kwargs) -> Optional[np.ndarray]:
    """
    Compute featurization for a single mol/protein complex

    Parameters
    ----------
    complex: Tuple[str, str]
    datapoint: Tuple[str, str]
      Filenames for molecule and protein.
    """
    if 'complex' in kwargs:
      datapoint = kwargs.get("complex")
      raise DeprecationWarning(
          'Complex is being phased out as a parameter, please pass "datapoint" instead.'
      )
    try:
      fragments = rdkit_utils.load_complex(complex, add_hydrogens=False)
      fragments = rdkit_utils.load_complex(datapoint, add_hydrogens=False)

    except MoleculeLoadException:
      logger.warning("This molecule cannot be loaded by Rdkit. Returning None")
@@ -564,18 +589,22 @@ class HydrogenBondVoxelizer(ComplexFeaturizer):
    self.voxel_width = voxel_width
    self.reduce_to_contacts = reduce_to_contacts

  def _featurize(self, complex: Tuple[str, str],
                 **kwargs) -> Optional[np.ndarray]:
  def _featurize(self, datapoint, **kwargs) -> Optional[np.ndarray]:
    """
    Compute featurization for a single mol/protein complex

    Parameters
    ----------
    complex: Tuple[str, str]
    datapoint: Tuple[str, str]
      Filenames for molecule and protein.
    """
    if 'complex' in kwargs:
      datapoint = kwargs.get("complex")
      raise DeprecationWarning(
          'Complex is being phased out as a parameter, please pass "datapoint" instead.'
      )
    try:
      fragments = rdkit_utils.load_complex(complex, add_hydrogens=False)
      fragments = rdkit_utils.load_complex(datapoint, add_hydrogens=False)

    except MoleculeLoadException:
      logger.warning("This molecule cannot be loaded by Rdkit. Returning None")
+18 −6
Original line number Diff line number Diff line
@@ -146,17 +146,23 @@ class SplifFingerprint(ComplexFeaturizer):
    self.size = size
    self.radius = radius

  def _featurize(self, complex: Tuple[str, str], **kwargs):
  def _featurize(self, datapoint, **kwargs):
    """
    Compute featurization for a molecular complex

    Parameters
    ----------
    complex: Tuple[str, str]
    datapoint: Tuple[str, str]
      Filenames for molecule and protein.
    """
    if 'complex' in kwargs:
      datapoint = kwargs.get("complex")
      raise DeprecationWarning(
          'Complex is being phased out as a parameter, please pass "datapoint" instead.'
      )

    try:
      fragments = load_complex(complex, add_hydrogens=False)
      fragments = load_complex(datapoint, add_hydrogens=False)

    except MoleculeLoadException:
      logger.warning("This molecule cannot be loaded by Rdkit. Returning None")
@@ -236,17 +242,23 @@ class SplifVoxelizer(ComplexFeaturizer):
    self.voxel_width = voxel_width
    self.voxels_per_edge = int(self.box_width / self.voxel_width)

  def _featurize(self, complex: Tuple[str, str], **kwargs):
  def _featurize(self, datapoint, **kwargs):
    """
    Compute featurization for a molecular complex

    Parameters
    ----------
    complex: Tuple[str, str]
    datapoint: Tuple[str, str]
      Filenames for molecule and protein.
    """
    if 'complex' in kwargs:
      datapoint = kwargs.get("complex")
      raise DeprecationWarning(
          'Complex is being phased out as a parameter, please pass "datapoint" instead.'
      )

    try:
      fragments = load_complex(complex, add_hydrogens=False)
      fragments = load_complex(datapoint, add_hydrogens=False)

    except MoleculeLoadException:
      logger.warning("This molecule cannot be loaded by Rdkit. Returning None")
Loading