Commit ab558d41 authored by Nathan Frey's avatar Nathan Frey
Browse files

Fix mypy errors

parent 1e564f22
Loading
Loading
Loading
Loading
+10 −14
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ from deepchem.trans import Transformer
from deepchem.splits.splitters import Splitter
from deepchem.molnet.defaults import get_defaults

from typing import List, Tuple, Dict, Optional, Union
from typing import List, Tuple, Dict, Optional, Union, Any, Type

logger = logging.getLogger(__name__)

@@ -39,22 +39,19 @@ DEFAULT_SPLITTERS = {k: DEFAULT_SPLITTERS[k] for k in splitters}


def load_bandgap(
    featurizer: MaterialCompositionFeaturizer = DEFAULT_FEATURIZERS[
        'ElementPropertyFingerprint'],
    transformers: List[Transformer] = [
        DEFAULT_TRANSFORMERS['NormalizationTransformer']
    ],
    splitter: Splitter = DEFAULT_SPLITTERS['RandomSplitter'],
    featurizer=DEFAULT_FEATURIZERS['ElementPropertyFingerprint'],
    transformers: List = [DEFAULT_TRANSFORMERS['NormalizationTransformer']],
    splitter=DEFAULT_SPLITTERS['RandomSplitter'],
    reload: bool = True,
    data_dir: Optional[str] = None,
    save_dir: Optional[str] = None,
    featurizer_kwargs: Dict[str, object] = {'data_source': 'matminer'},
    splitter_kwargs: Dict[str, object] = {
    featurizer_kwargs: Dict[str, Any] = {'data_source': 'matminer'},
    splitter_kwargs: Dict[str, Any] = {
        'frac_train': 0.8,
        'frac_valid': 0.1,
        'frac_test': 0.1
    },
    transformer_kwargs: Dict[str, Dict[str, object]] = {
    transformer_kwargs: Dict[str, Dict[str, Any]] = {
        'NormalizationTransformer': {
            'transform_X': True
        }
@@ -86,9 +83,9 @@ def load_bandgap(
    Path to datasets.
  save_dir : str, optional
    Path to featurized datasets.
  featurizer_kwargs : dict
  featurizer_kwargs : Optional[Dict[str, Any]]
    Specify parameters to featurizer, e.g. {"size": 1024}
  splitter_kwargs : dict
  splitter_kwargs : Optional[Dict[str, Any]]
    Specify parameters to splitter, e.g. {"seed": 42}
  transformer_kwargs : dict
    Maps transformer names to constructor arguments, e.g.
@@ -158,8 +155,7 @@ def load_bandgap(
      return my_tasks, all_dataset, transformers

  # First type of supported featurizers
  supported_featurizers = ['ElementPropertyFingerprint'
                          ]  # type: List[Featurizer]
  supported_featurizers = ['ElementPropertyFingerprint']  # type: List[str]

  # Load .tar.gz file
  if featurizer.__class__.__name__ in supported_featurizers:
+10 −13
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ from deepchem.trans import Transformer
from deepchem.splits.splitters import Splitter
from deepchem.molnet.defaults import get_defaults

from typing import List, Tuple, Dict, Optional, Union
from typing import List, Tuple, Dict, Optional, Union, Any, Type, Callable

logger = logging.getLogger(__name__)

@@ -37,22 +37,19 @@ DEFAULT_SPLITTERS = {k: DEFAULT_SPLITTERS[k] for k in splitters}


def load_perovskite(
    featurizer: MaterialStructureFeaturizer = DEFAULT_FEATURIZERS[
        'SineCoulombMatrix'],
    transformers: List[Transformer] = [
        DEFAULT_TRANSFORMERS['NormalizationTransformer']
    ],
    splitter: Splitter = DEFAULT_SPLITTERS['RandomSplitter'],
    featurizer=DEFAULT_FEATURIZERS['SineCoulombMatrix'],
    transformers: List = [DEFAULT_TRANSFORMERS['NormalizationTransformer']],
    splitter=DEFAULT_SPLITTERS['RandomSplitter'],
    reload: bool = True,
    data_dir: Optional[str] = None,
    save_dir: Optional[str] = None,
    featurizer_kwargs: Dict[str, object] = None,
    splitter_kwargs: Dict[str, object] = {
    featurizer_kwargs: Dict[str, Any] = {},
    splitter_kwargs: Dict[str, Any] = {
        'frac_train': 0.8,
        'frac_valid': 0.1,
        'frac_test': 0.1
    },
    transformer_kwargs: Dict[str, Dict[str, object]] = {
    transformer_kwargs: Dict[str, Dict[str, Any]] = {
        'NormalizationTransformer': {
            'transform_X': True
        }
@@ -84,9 +81,9 @@ def load_perovskite(
    Path to datasets.
  save_dir : str, optional
    Path to featurized datasets.
  featurizer_kwargs : dict
  featurizer_kwargs : Optional[Dict[str, Any]]
    Specify parameters to featurizer, e.g. {"size": 1024}
  splitter_kwargs : dict
  splitter_kwargs : Optional[Dict[str, Any]]
    Specify parameters to splitter, e.g. {"seed": 42}
  transformer_kwargs : dict
    Maps transformer names to constructor arguments, e.g.
@@ -157,7 +154,7 @@ def load_perovskite(

  # First type of supported featurizers
  supported_featurizers = ['StructureGraphFeaturizer',
                           'SineCoulombMatrix']  # type: List[Featurizer]
                           'SineCoulombMatrix']  # type: List[str]

  # Load .tar.gz file
  if featurizer.__class__.__name__ in supported_featurizers: