Commit 80e8677b authored by nd-02110114's avatar nd-02110114
Browse files

♻️ refactor codes

parent 734740c4
Loading
Loading
Loading
Loading
+43 −63
Original line number Diff line number Diff line
@@ -2,10 +2,9 @@
Feature calculations.
"""
import logging
import types
import numpy as np
import multiprocessing
from typing import Any, Dict, List, Iterable, Sequence, Tuple, Union
from typing import Any, Dict, List, Iterable, Sequence, Tuple

logger = logging.getLogger(__name__)

@@ -74,24 +73,6 @@ class Featurizer(object):
    raise NotImplementedError('Featurizer is not defined.')


def _featurize_callback(
    featurizer,
    mol_pdb_file,
    protein_pdb_file,
    log_message,
):
  """Callback function for apply_async in ComplexFeaturizer.

  This callback function must be defined globally
  because `apply_async` doesn't execute a nested function.

  See the details from the following link.
  https://stackoverflow.com/questions/56533827/pool-apply-async-nested-function-is-not-executed
  """
  logging.info(log_message)
  return featurizer._featurize(mol_pdb_file, protein_pdb_file)


class ComplexFeaturizer(object):
  """"
  Abstract class for calculating features for mol/protein complexes.
@@ -122,7 +103,7 @@ class ComplexFeaturizer(object):
    for i, (mol_file, protein_pdb) in enumerate(zip(mol_files, protein_pdbs)):
      log_message = "Featurizing %d / %d" % (i, len(mol_files))
      results.append(
          pool.apply_async(_featurize_callback,
          pool.apply_async(ComplexFeaturizer._featurize_callback,
                           (self, mol_file, protein_pdb, log_message)))
    pool.close()
    features = []
@@ -150,6 +131,12 @@ class ComplexFeaturizer(object):
    """
    raise NotImplementedError('Featurizer is not defined.')

  @staticmethod
  def _featurize_callback(featurizer, mol_pdb_file, protein_pdb_file,
                          log_message):
    logging.info(log_message)
    return featurizer._featurize(mol_pdb_file, protein_pdb_file)


class MolecularFeaturizer(Featurizer):
  """Abstract class for calculating a set of features for a
@@ -164,12 +151,12 @@ class MolecularFeaturizer(Featurizer):
  Child classes need to implement the _featurize method for
  calculating features for a single molecule.

  Note
  ----
  In general, subclasses of this class will require RDKit to be installed.
  Notes
  -----
  The subclasses of this class require RDKit to be installed.
  """

  def featurize(self, molecules, log_every_n=1000):
  def featurize(self, molecules, log_every_n=1000, canonical=False):
    """Calculate features for molecules.

    Parameters
@@ -177,41 +164,42 @@ class MolecularFeaturizer(Featurizer):
    molecules: RDKit Mol / SMILES string / iterable
      RDKit Mol, or SMILES string or iterable sequence of RDKit mols/SMILES
      strings.
    log_every_n: int, default 1000
      Logging messages reported every `log_every_n` samples.
    canonical: bool, default False
      Whether to use a canonical order of atoms returned by RDKit

    Returns
    -------
    A numpy array containing a featurized representation of
    `datapoints`.
    features: np.ndarray
      A numpy array containing a featurized representation of `datapoints`.
    """
    try:
      from rdkit import Chem
      from rdkit.Chem import rdmolfiles
      from rdkit.Chem import rdmolops
      from rdkit.Chem.rdchem import Mol
    except ModuleNotFoundError:
      raise ValueError("This class requires RDKit to be installed.")

    # Special case handling of single molecule
    if isinstance(molecules, str) or isinstance(molecules, Mol):
      molecules = [molecules]
    else:
      # Convert iterables to list
      molecules = list(molecules)

    features = []
    for i, mol in enumerate(molecules):
      if i % log_every_n == 0:
        logger.info("Featurizing datapoint %i" % i)
      try:
        # Process only case of SMILES strings.
        if isinstance(mol, str):
          # mol must be a SMILES string so parse
          mol = Chem.MolFromSmiles(mol)
          # TODO (ytz) this is a bandage solution to reorder the atoms
          # so that they're always in the same canonical order.
          # Presumably this should be correctly implemented in the
          # future for graph mols.
          if mol:
            new_order = rdmolfiles.CanonicalRankAtoms(mol)
            mol = rdmolops.RenumberAtoms(mol, new_order)
        # canonicalize
        if canonical:
          canonical_smiles = Chem.MolToSmiles(mol)
          mol = Chem.MolFromSmiles(canonical_smiles)

        features.append(self._featurize(mol))
      except:
        logger.warning(
@@ -243,7 +231,6 @@ class MaterialStructureFeaturizer(Featurizer):
  -----
  Some subclasses of this class will require pymatgen and matminer to be
  installed.

  """

  def featurize(self,
@@ -265,16 +252,13 @@ class MaterialStructureFeaturizer(Featurizer):
    features: np.ndarray
      A numpy array containing a featurized representation of
      `structures`.

    """

    structures = list(structures)

    try:
      from pymatgen import Structure
    except ModuleNotFoundError:
      raise ValueError("This class requires pymatgen to be installed.")

    structures = list(structures)
    features = []
    for idx, structure in enumerate(structures):
      if idx % log_every_n == 0:
@@ -312,7 +296,6 @@ class MaterialCompositionFeaturizer(Featurizer):
  -----
  Some subclasses of this class will require pymatgen and matminer to be
  installed.

  """

  def featurize(self, compositions: Iterable[str],
@@ -331,16 +314,13 @@ class MaterialCompositionFeaturizer(Featurizer):
    features: np.ndarray
      A numpy array containing a featurized representation of
      `compositions`.

    """

    compositions = list(compositions)

    try:
      from pymatgen import Composition
    except ModuleNotFoundError:
      raise ValueError("This class requires pymatgen to be installed.")

    compositions = list(compositions)
    features = []
    for idx, composition in enumerate(compositions):
      if idx % log_every_n == 0:
+8 −8
Original line number Diff line number Diff line
@@ -109,8 +109,8 @@ class GraphData:
    return Data(
        x=torch.from_numpy(self.node_features),
        edge_index=torch.from_numpy(self.edge_index).long(),
      edge_attr=None if self.edge_features is None \
        else torch.from_numpy(self.edge_features),
        edge_attr=None
        if self.edge_features is None else torch.from_numpy(self.edge_features),
    )

  def to_dgl_graph(self):
@@ -193,10 +193,10 @@ class BatchGraphData(GraphData):

    # create new edge index
    num_nodes_list = [graph.num_nodes for graph in graph_list]
    batch_edge_index = np.hstack(
      [graph.edge_index + prev_num_node for prev_num_node, graph \
        in zip([0] + num_nodes_list[:-1], graph_list)]
    )
    batch_edge_index = np.hstack([
        graph.edge_index + prev_num_node
        for prev_num_node, graph in zip([0] + num_nodes_list[:-1], graph_list)
    ])

    # graph_index indicates which nodes belong to which graph
    graph_index = []
+1 −0
Original line number Diff line number Diff line
"""
Featurizers for inorganic crystals.
"""
# flake8: noqa
from deepchem.feat.material_featurizers.element_property_fingerprint import ElementPropertyFingerprint
from deepchem.feat.material_featurizers.sine_coulomb_matrix import SineCoulombMatrix
from deepchem.feat.material_featurizers.cgcnn_featurizer import CGCNNFeaturizer
+1 −0
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@ ignore =
    E129,  # Visually indented line with same indent as next logical line
    W503,  # Line break before binary operator
    W504,  # Line break after binary operator
    E722   # do not use bare 'except'
max-line-length = 300

[yapf]