Commit 9e9f6b2a authored by Nathan Frey's avatar Nathan Frey
Browse files

Add defaults list and generalize loaders

parent 2e265c1b
Loading
Loading
Loading
Loading
+1 −0
Original line number Original line Diff line number Diff line
{"featurizer": ["AdjacencyFingerprint", "AtomicCoordinates", "BPSymmetryFunctionInput", "BindingPocketFeaturizer", "CircularFingerprint", "ComplexFeaturizer", "ConvMolFeaturizer", "CoulombMatrix", "CoulombMatrixEig", "Featurizer", "NeighborListComplexAtomicCoordinates", "OneHotFeaturizer", "RDKitDescriptors", "RawFeaturizer", "RdkitGridFeaturizer", "SmilesToImage", "SmilesToSeq", "UserDefinedFeaturizer", "WeaveFeaturizer"], "transformer": ["ANITransformer", "BalancingTransformer", "CDFTransformer", "ClippingTransformer", "CoulombFitTransformer", "DAGTransformer", "IRVTransformer", "LogTransformer", "MinMaxTransformer", "NormalizationTransformer", "PowerTransformer"], "splitter": ["ButinaSplitter", "DiskDataset", "FingerprintSplitter", "IndexSplitter", "IndiceSplitter", "MaxMinSplitter", "MolecularWeightSplitter", "NumpyDataset", "RandomGroupSplitter", "RandomSplitter", "RandomStratifiedSplitter", "ScaffoldGenerator", "ScaffoldSplitter", "SingletaskStratifiedSplitter", "SpecifiedIndexSplitter", "SpecifiedSplitter", "Splitter", "TaskSplitter", "TimeSplitterPDBbind"]}
 No newline at end of file
+55 −0
Original line number Original line Diff line number Diff line
"""
Featurizers, transformers, and splitters for MolNet.
"""

import os
import importlib
import inspect
import logging
import json
from typing import Dict, List

logger = logging.getLogger(__name__)


def get_defaults(inspect_modules: bool = False) -> Dict[str, List[str]]:
  """Get featurizers, transformers, and splitters.

  This function returns a dictionary with keys 'featurizer', 'transformer',
  and 'splitter'. Each value is a list of names of classes in that
  category. All MolNet ``load_x`` functions should specify which
  featurizers, transformers, and splitters the dataset supports and
  provide sensible defaults.

  Parameters
  ----------
  inspect_modules : bool (default False)
    Inspect dc.feat, dc.trans, and dc.splits modules to get class names.

  Returns
  -------
  defaults : dict
    Contains names of all available featurizers, transformers, and splitters.

  """

  if not inspect_modules:
    path = os.path.dirname(os.path.abspath(__file__))
    defaults = json.load(open(os.path.join(path, "defaults.json")))
  else:
    module = importlib.import_module("deepchem.feat", package="deepchem")
    featurizers = [x[0] for x in inspect.getmembers(module, inspect.isclass)]

    module = importlib.import_module("deepchem.trans", package="deepchem")
    transformers = [x[0] for x in inspect.getmembers(module, inspect.isclass)]

    module = importlib.import_module("deepchem.splits", package="deepchem")
    splitters = [x[0] for x in inspect.getmembers(module, inspect.isclass)]

    defaults = {
        'featurizer': featurizers,
        'transformer': transformers,
        'splitter': splitters
    }

  return defaults
+51 −22
Original line number Original line Diff line number Diff line
@@ -7,6 +7,7 @@ import deepchem
from deepchem.feat import Featurizer
from deepchem.feat import Featurizer
from deepchem.trans import Transformer
from deepchem.trans import Transformer
from deepchem.split.splitters import Splitter
from deepchem.split.splitters import Splitter
from deepchem.molnet.defaults import get_defaults


from typing import List, Tuple, Optional
from typing import List, Tuple, Optional


@@ -16,26 +17,40 @@ DEFAULT_DIR = deepchem.utils.get_data_dir()
MYDATASET_URL = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/mydataset.tar.gz'
MYDATASET_URL = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/mydataset.tar.gz'
MYDATASET_CSV_URL = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/mydataset.csv'
MYDATASET_CSV_URL = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/mydataset.csv'


# Get dictionary of all featurizers, transformers, and splitters
# Check this dict or `defaults.json` to see list of available classes
DEFAULTS = get_defaults()

# dict of accepted featurizers for this dataset
# dict of accepted featurizers for this dataset
# update for your dataset
DEFAULT_FEATURIZERS = {
DEFAULT_FEATURIZERS = {
    'Raw': deepchem.feat.RawFeaturizer(),
    'AdjacencyFingerprint': deepchem.feat.AdjacencyFingerprint(),
    'ECFP': deepchem.feat.CircularFingerprint(size=1024),
    'AtomicCoordinates': deepchem.feat.AtomicCoordinates(),
    'ConvMolFeaturizer': deepchem.feat.ConvMolFeaturizer(),
    'CoulombMatrix': deepchem.feat.CoulombMatrix(max_atoms=5),
    'RDKitDescriptors': deepchem.feat.RDKitDescriptors(),
    'RawFeaturizer': deepchem.feat.RawFeaturizer(),
    'CircularFingerprint': deepchem.feat.CircularFingerprint(size=1024),
}
}


# dict of accepted transformers
# dict of accepted transformers
DEFAULT_TRANSFORMERS = {
DEFAULT_TRANSFORMERS = {
    'Power': deepchem.trans.PowerTransformer(),
    'Power': deepchem.trans.PowerTransformer(),
    'Balancing': deepchem.trans.BalancingTransformer(),
    'Log': deepchem.trans.LogTransformer(),
    'MinMax': deepchem.trans.MinMaxTransformer()
}
}


# dict of accepted splitters
# dict of accepted splitters
DEFAULT_SPLITTERS = {
DEFAULT_SPLITTERS = {
    'Index': deepchem.splits.IndexSplitter(),
    'Index': deepchem.splits.IndexSplitter(),
    'Random': deepchem.splits.RandomSplitter(),
    'Random': deepchem.splits.RandomSplitter(),
    'Scaffold': deepchem.splits.ScaffoldSplitter(),
}
}




def load_mydataset(
def load_mydataset(
    featurizer: Featurizer = DEFAULT_FEATURIZERS['Raw'],
    featurizer: Featurizer = DEFAULT_FEATURIZERS['RawFeaturizer'],
    transformers: Tuple[Transformer] = (DEFAULT_TRANSFORMERS['Power']),
    transformers: Tuple[Transformer] = (DEFAULT_TRANSFORMERS['Power']),
    splitter: Splitter = DEFAULT_SPLITTERS['Random'],
    splitter: Splitter = DEFAULT_SPLITTERS['Random'],
    reload: bool = True,
    reload: bool = True,
@@ -47,16 +62,23 @@ def load_mydataset(
  This is a template for adding a function to load a dataset from
  This is a template for adding a function to load a dataset from
  MoleculeNet. Adjust the global variable URL strings, default parameters,
  MoleculeNet. Adjust the global variable URL strings, default parameters,
  default featurizers, transformers, and splitters, and variable names as
  default featurizers, transformers, and splitters, and variable names as
  needed.
  needed. A dictionary of all available featurizers, transformers, and
  splitters is available in the global variable `DEFAULTS` and also
  in `deepchem/molnet/defaults.json`.


  If `reload = True` and `data_dir` (`save_dir`) is specified, the loader
  If `reload = True` and `data_dir` (`save_dir`) is specified, the loader
  will attempt to load the raw dataset (featurized dataset) from disk.
  will attempt to load the raw dataset (featurized dataset) from disk.
  Otherwise, the dataset will be downloaded from the DeepChem AWS bucket.
  Otherwise, the dataset will be downloaded from the DeepChem AWS bucket.


  The dataset will be featurized with `featurizer` and separated into
  The dataset will be featurized with `featurizer` and separated into
  train/val/test sets according to `splitter`. Additional kwargs may
  train/val/test sets according to `splitter`. Some transformers (e.g.
  `NormalizationTransformer`) must be initialized with a dataset. 
  Set up kwargs to enable these transformations. Additional kwargs may
  be given for specific featurizers, transformers, and splitters.
  be given for specific featurizers, transformers, and splitters.


  The load function must be modified with the appropriate DataLoaders
  for all supported featurizers for your dataset.

  Please refer to the MoleculeNet documentation for further information
  Please refer to the MoleculeNet documentation for further information
  https://deepchem.readthedocs.io/en/latest/moleculenet.html.
  https://deepchem.readthedocs.io/en/latest/moleculenet.html.
  
  
@@ -149,23 +171,22 @@ def load_mydataset(
    if loaded:
    if loaded:
      return my_tasks, all_dataset, transformers
      return my_tasks, all_dataset, transformers


  # 3D coordinate featurizers, e.g. 'CoulombMatrix' or 'MP'
  # First type of supported featurizers
  # For crystal structures, replace with json_featurizers
  supported_featurizers = []  # type: List[Featurizer]
  sdf_featurizers = []  # type: List[Featurizer]


  # If featurizer requires a non-CSV file format, load .tar.gz file
  # If featurizer requires a non-CSV file format, load .tar.gz file
  if featurizer in sdf_featurizers:
  if featurizer in supported_featurizers:
    dataset_file = os.path.join(data_dir, 'mydataset.sdf')
    dataset_file = os.path.join(data_dir, 'mydataset.filetype')


    if not os.path.exists(dataset_file):
    if not os.path.exists(dataset_file):
      deepchem.utils.download_url(url=MYDATASET_URL, dest_dir=data_dir)
      deepchem.utils.download_url(url=MYDATASET_URL, dest_dir=data_dir)
      deepchem.utils.untargz_file(
      deepchem.utils.untargz_file(
          os.path.join(data_dir, 'mydataset.tar.gz'), data_dir)
          os.path.join(data_dir, 'mydataset.tar.gz'), data_dir)


    loader = deepchem.data.SDFLoader(
    # Changer loader to match featurizer and data file type
    loader = deepchem.data.DataLoader(
        tasks=my_tasks,
        tasks=my_tasks,
        smiles_field="smiles",  # column name holding SMILES strings
        id_field="id",  # column name holding sample identifier
        mol_field="mol",  # field where RKit mol objects are stored
        featurizer=featurizer)
        featurizer=featurizer)


  else:  # only load CSV file
  else:  # only load CSV file
@@ -177,7 +198,7 @@ def load_mydataset(
        tasks=my_tasks, smiles_field="smiles", featurizer=featurizer)
        tasks=my_tasks, smiles_field="smiles", featurizer=featurizer)


  # Featurize dataset
  # Featurize dataset
  dataset = loader.featurize(dataset_file)
  dataset = loader.create_dataset(dataset_file)


  # 80/10/10 train/val/test split is default
  # 80/10/10 train/val/test split is default
  frac_train = kwargs.get("frac_train", 0.8)
  frac_train = kwargs.get("frac_train", 0.8)
@@ -190,6 +211,14 @@ def load_mydataset(
      frac_valid=frac_valid,
      frac_valid=frac_valid,
      frac_test=frac_test)
      frac_test=frac_test)


  # Check for transformers that require a dataset
  normalize = kwargs.get("normalize", True)  # Normalization transform
  move_mean = kwargs.get("move_mean", True)  # Zero out mean of dataset
  if normalize:
    transformers.append(
        deepchem.trans.NormalizationTransformer(
            transform_y=True, dataset=train_dataset, move_mean=move_mean))

  for transformer in transformers:
  for transformer in transformers:
    train_dataset = transformer.transform(train_dataset)
    train_dataset = transformer.transform(train_dataset)
    valid_dataset = transformer.transform(valid_dataset)
    valid_dataset = transformer.transform(valid_dataset)
+1 −1
Original line number Original line Diff line number Diff line
@@ -12,7 +12,7 @@ please follow the instructions below. Please review the `datasets already availa


1. Open an `issue <https://github.com/deepchem/deepchem/issues>`_ to discuss the dataset you want to add to MolNet.
1. Open an `issue <https://github.com/deepchem/deepchem/issues>`_ to discuss the dataset you want to add to MolNet.


2. Implement a function in the `deepchem.molnet.load_function <https://github.com/deepchem/deepchem/tree/master/deepchem/molnet/load_function>`_ module following the template function `deepchem.molnet.load_function.load_mydataset <https://github.com/deepchem/deepchem/blob/master/deepchem/molnet/load_function/load_mydataset.py>`_.
2. Implement a function in the `deepchem.molnet.load_function <https://github.com/deepchem/deepchem/tree/master/deepchem/molnet/load_function>`_ module following the template function `deepchem.molnet.load_function.load_mydataset <https://github.com/deepchem/deepchem/blob/master/deepchem/molnet/load_function/load_mydataset.py>`_. Specify which featurizers, transformers, and splitters (listed in `deepchem/molnet/defaults <https://github.com/deepchem/deepchem/tree/master/deepchem/molnet/defaults.json>`_) are supported for your dataset. 


3. Add your load function to `deepchem.molnet.__init__.py <https://github.com/deepchem/deepchem/blob/master/deepchem/molnet/__init__.py>`_ for easy importing.
3. Add your load function to `deepchem.molnet.__init__.py <https://github.com/deepchem/deepchem/blob/master/deepchem/molnet/__init__.py>`_ for easy importing.