Commit 94d6db05 authored by Nathan Frey's avatar Nathan Frey
Browse files

Dynamic defaults and kwargs

parent 9e9f6b2a
Loading
Loading
Loading
Loading

deepchem/molnet/defaults.json

deleted100644 → 0
+0 −1
Original line number Original line Diff line number Diff line
{"featurizer": ["AdjacencyFingerprint", "AtomicCoordinates", "BPSymmetryFunctionInput", "BindingPocketFeaturizer", "CircularFingerprint", "ComplexFeaturizer", "ConvMolFeaturizer", "CoulombMatrix", "CoulombMatrixEig", "Featurizer", "NeighborListComplexAtomicCoordinates", "OneHotFeaturizer", "RDKitDescriptors", "RawFeaturizer", "RdkitGridFeaturizer", "SmilesToImage", "SmilesToSeq", "UserDefinedFeaturizer", "WeaveFeaturizer"], "transformer": ["ANITransformer", "BalancingTransformer", "CDFTransformer", "ClippingTransformer", "CoulombFitTransformer", "DAGTransformer", "IRVTransformer", "LogTransformer", "MinMaxTransformer", "NormalizationTransformer", "PowerTransformer"], "splitter": ["ButinaSplitter", "DiskDataset", "FingerprintSplitter", "IndexSplitter", "IndiceSplitter", "MaxMinSplitter", "MolecularWeightSplitter", "NumpyDataset", "RandomGroupSplitter", "RandomSplitter", "RandomStratifiedSplitter", "ScaffoldGenerator", "ScaffoldSplitter", "SingletaskStratifiedSplitter", "SpecifiedIndexSplitter", "SpecifiedSplitter", "Splitter", "TaskSplitter", "TimeSplitterPDBbind"]}
 No newline at end of file
+34 −23
Original line number Original line Diff line number Diff line
@@ -9,47 +9,58 @@ import logging
import json
import json
from typing import Dict, List
from typing import Dict, List


from deepchem.feat.base_classes import Featurizer
from deepchem.trans.transformers import Transformer
from deepchem.splits.splitters import Splitter

logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)




def get_defaults(inspect_modules: bool = False) -> Dict[str, List[str]]:
def get_defaults(module_name: str = None) -> Dict[str, object]:
  """Get featurizers, transformers, and splitters.
  """Get featurizers, transformers, and splitters.


  This function returns a dictionary with keys 'featurizer', 'transformer',
  This function returns a dictionary with class names as keys and classes
  and 'splitter'. Each value is a list of names of classes in that
  as values. All MolNet ``load_x`` functions should specify which
  category. All MolNet ``load_x`` functions should specify which
  featurizers, transformers, and splitters the dataset supports and
  featurizers, transformers, and splitters the dataset supports and
  provide sensible defaults.
  provide sensible defaults.


  Parameters
  Parameters
  ----------
  ----------
  inspect_modules : bool (default False)
  module_name : {"feat", "trans", "splits"}
    Inspect dc.feat, dc.trans, and dc.splits modules to get class names.
    Default classes from deepchem.`module_name` will be returned.


  Returns
  Returns
  -------
  -------
  defaults : dict
  defaults : Dict[str, object]
    Contains names of all available featurizers, transformers, and splitters.
    Keys are class names and values are class constructors. 

  Examples
  --------
  >> splitter = get_defaults('splits')['RandomSplitter']() 
  >> transformer = get_defaults('trans')['BalancingTransformer'](dataset, {"transform_X": True})
  >> featurizer = get_defaults('feat')["CoulombMatrix"](max_atoms=5)  


  """
  """


  if not inspect_modules:
  if module_name not in ["feat", "trans", "splits"]:
    path = os.path.dirname(os.path.abspath(__file__))
    raise ValueError(
    defaults = json.load(open(os.path.join(path, "defaults.json")))
        "Input argument must be either 'feat', 'trans', or 'splits'.")
  else:

    module = importlib.import_module("deepchem.feat", package="deepchem")
  if module_name == "feat":
    featurizers = [x[0] for x in inspect.getmembers(module, inspect.isclass)]
    sc = Featurizer
  elif module_name == "trans":
    sc = Transformer
  elif module_name == "splits":
    sc = Splitter


    module = importlib.import_module("deepchem.trans", package="deepchem")
  module_name = "deepchem." + module_name
    transformers = [x[0] for x in inspect.getmembers(module, inspect.isclass)]


    module = importlib.import_module("deepchem.splits", package="deepchem")
  module = importlib.import_module(module_name, package="deepchem")
    splitters = [x[0] for x in inspect.getmembers(module, inspect.isclass)]


  defaults = {
  defaults = {
        'featurizer': featurizers,
      x[0]: x[1]
        'transformer': transformers,
      for x in inspect.getmembers(module, inspect.isclass)
        'splitter': splitters
      if issubclass(x[1], sc)
  }
  }


  return defaults
  return defaults
+36 −46
Original line number Original line Diff line number Diff line
@@ -9,7 +9,7 @@ from deepchem.trans import Transformer
from deepchem.split.splitters import Splitter
from deepchem.split.splitters import Splitter
from deepchem.molnet.defaults import get_defaults
from deepchem.molnet.defaults import get_defaults


from typing import List, Tuple, Optional
from typing import List, Tuple, Dict, Optional


logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)


@@ -17,54 +17,36 @@ DEFAULT_DIR = deepchem.utils.get_data_dir()
MYDATASET_URL = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/mydataset.tar.gz'
MYDATASET_URL = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/mydataset.tar.gz'
MYDATASET_CSV_URL = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/mydataset.csv'
MYDATASET_CSV_URL = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/mydataset.csv'


# Get dictionary of all featurizers, transformers, and splitters
# Check this dict or `defaults.json` to see list of available classes
DEFAULTS = get_defaults()

# dict of accepted featurizers for this dataset
# dict of accepted featurizers for this dataset
# update for your dataset
# modify the returned dicts your dataset
DEFAULT_FEATURIZERS = {
DEFAULT_FEATURIZERS = get_defaults("feat")
    'AdjacencyFingerprint': deepchem.feat.AdjacencyFingerprint(),
    'AtomicCoordinates': deepchem.feat.AtomicCoordinates(),
    'ConvMolFeaturizer': deepchem.feat.ConvMolFeaturizer(),
    'CoulombMatrix': deepchem.feat.CoulombMatrix(max_atoms=5),
    'RDKitDescriptors': deepchem.feat.RDKitDescriptors(),
    'RawFeaturizer': deepchem.feat.RawFeaturizer(),
    'CircularFingerprint': deepchem.feat.CircularFingerprint(size=1024),
}


# dict of accepted transformers
# dict of accepted transformers
DEFAULT_TRANSFORMERS = {
DEFAULT_TRANSFORMERS = get_defaults("trans")
    'Power': deepchem.trans.PowerTransformer(),
    'Balancing': deepchem.trans.BalancingTransformer(),
    'Log': deepchem.trans.LogTransformer(),
    'MinMax': deepchem.trans.MinMaxTransformer()
}


# dict of accepted splitters
# dict of accepted splitters
DEFAULT_SPLITTERS = {
DEFAULT_SPLITTERS = get_defaults("split")
    'Index': deepchem.splits.IndexSplitter(),
    'Random': deepchem.splits.RandomSplitter(),
    'Scaffold': deepchem.splits.ScaffoldSplitter(),
}




def load_mydataset(
def load_mydataset(
    featurizer: Featurizer = DEFAULT_FEATURIZERS['RawFeaturizer'],
    featurizer: Featurizer = DEFAULT_FEATURIZERS['RawFeaturizer'],
    transformers: Tuple[Transformer] = (DEFAULT_TRANSFORMERS['Power']),
    transformers: Tuple[Transformer] = (
    splitter: Splitter = DEFAULT_SPLITTERS['Random'],
        DEFAULT_TRANSFORMERS['PowerTransformer']),
    splitter: Splitter = DEFAULT_SPLITTERS['RandomSplitter'],
    reload: bool = True,
    reload: bool = True,
    data_dir: Optional[str] = None,
    data_dir: Optional[str] = None,
    save_dir: Optional[str] = None,
    save_dir: Optional[str] = None,
    featurizer_kwargs: Optional[Dict[str, object]] = None,
    splitter_kwargs: Optional[Dict[str, object]] = None,
    transformer_kwargs: Optional[Dict[str, Dict[str, object]]] = None,
    **kwargs) -> Tuple[List, Tuple, List]:
    **kwargs) -> Tuple[List, Tuple, List]:
  """Load mydataset.
  """Load mydataset.


  This is a template for adding a function to load a dataset from
  This is a template for adding a function to load a dataset from
  MoleculeNet. Adjust the global variable URL strings, default parameters,
  MoleculeNet. Adjust the global variable URL strings, default parameters,
  default featurizers, transformers, and splitters, and variable names as
  default featurizers, transformers, and splitters, and variable names as
  needed. A dictionary of all available featurizers, transformers, and
  needed. All available featurizers, transformers, and
  splitters is available in the global variable `DEFAULTS` and also
  splitters are in the `DEFAULTS_X` global variables.
  in `deepchem/molnet/defaults.json`.


  If `reload = True` and `data_dir` (`save_dir`) is specified, the loader
  If `reload = True` and `data_dir` (`save_dir`) is specified, the loader
  will attempt to load the raw dataset (featurized dataset) from disk.
  will attempt to load the raw dataset (featurized dataset) from disk.
@@ -97,8 +79,14 @@ def load_mydataset(
    Path to datasets.
    Path to datasets.
  save_dir : str, optional
  save_dir : str, optional
    Path to featurized datasets.
    Path to featurized datasets.
  **kwargs : optional arguments to methods of featurizers, transformers, and
  featurizer_kwargs : dict
    splitters.
    Specify parameters to featurizer, e.g. {"size": 1024}
  splitter_kwargs : dict
    Specify parameters to splitter, e.g. {"seed": 42}
  transformer_kwargs : dict
    Maps transformer names to constructor arguments, e.g.
    {"BalancingTransformer": {"transform_x":True, "transform_y":False}}
  **kwargs : additional optional arguments.


  Returns
  Returns
  -------
  -------
@@ -150,14 +138,16 @@ def load_mydataset(
  if save_dir is None:
  if save_dir is None:
    save_dir = DEFAULT_DIR
    save_dir = DEFAULT_DIR


  # Check for str args to featurizer, splitter, and transformers
  # Check for str args to featurizer and splitter
  if isinstance(featurizer, str):
  if isinstance(featurizer, str):
    featurizer = DEFAULT_FEATURIZERS[featurizer]
    featurizer = DEFAULT_FEATURIZERS[featurizer](**featurizer_kwargs)
  elif issubclass(featurizer, Featurizer):
    featurizer = featurizer(**featurizer_kwargs)

  if isinstance(splitter, str):
  if isinstance(splitter, str):
    splitter = DEFAULT_SPLITTERS[splitter]
    splitter = DEFAULT_SPLITTERS[splitter](**splitter_kwargs)
  transformers = [
  elif issubclass(splitter, Splitter):
      DEFAULT_TRANSFORMERS[t] if isinstance(t, str) else t for t in transformers
    splitter = splitter(**splitter_kwargs)
  ]


  # Reload from disk
  # Reload from disk
  if reload:
  if reload:
@@ -211,13 +201,13 @@ def load_mydataset(
      frac_valid=frac_valid,
      frac_valid=frac_valid,
      frac_test=frac_test)
      frac_test=frac_test)


  # Check for transformers that require a dataset
  # Initialize transformers
  normalize = kwargs.get("normalize", True)  # Normalization transform
  transformers = [
  move_mean = kwargs.get("move_mean", True)  # Zero out mean of dataset
      DEFAULT_TRANSFORMERS[t](dataset, **transformer_kwargs[t])
  if normalize:
      if isinstance(t, str) else t(
    transformers.append(
          dataset, **transformer_kwargs[str(t.__class__.__name__)])
        deepchem.trans.NormalizationTransformer(
      for t in transformers
            transform_y=True, dataset=train_dataset, move_mean=move_mean))
  ]


  for transformer in transformers:
  for transformer in transformers:
    train_dataset = transformer.transform(train_dataset)
    train_dataset = transformer.transform(train_dataset)
+1 −1
Original line number Original line Diff line number Diff line
@@ -12,7 +12,7 @@ please follow the instructions below. Please review the `datasets already availa


1. Open an `issue <https://github.com/deepchem/deepchem/issues>`_ to discuss the dataset you want to add to MolNet.
1. Open an `issue <https://github.com/deepchem/deepchem/issues>`_ to discuss the dataset you want to add to MolNet.


2. Implement a function in the `deepchem.molnet.load_function <https://github.com/deepchem/deepchem/tree/master/deepchem/molnet/load_function>`_ module following the template function `deepchem.molnet.load_function.load_mydataset <https://github.com/deepchem/deepchem/blob/master/deepchem/molnet/load_function/load_mydataset.py>`_. Specify which featurizers, transformers, and splitters (listed in `deepchem/molnet/defaults <https://github.com/deepchem/deepchem/tree/master/deepchem/molnet/defaults.json>`_) are supported for your dataset. 
2. Implement a function in the `deepchem.molnet.load_function <https://github.com/deepchem/deepchem/tree/master/deepchem/molnet/load_function>`_ module following the template function `deepchem.molnet.load_function.load_mydataset <https://github.com/deepchem/deepchem/blob/master/deepchem/molnet/load_function/load_mydataset.py>`_. Specify which featurizers, transformers, and splitters (available from `deepchem.molnet.defaults <https://github.com/deepchem/deepchem/tree/master/deepchem/molnet/defaults.py>`_) are supported for your dataset. 


3. Add your load function to `deepchem.molnet.__init__.py <https://github.com/deepchem/deepchem/blob/master/deepchem/molnet/__init__.py>`_ for easy importing.
3. Add your load function to `deepchem.molnet.__init__.py <https://github.com/deepchem/deepchem/blob/master/deepchem/molnet/__init__.py>`_ for easy importing.