Commit 2e265c1b authored by Nathan Frey's avatar Nathan Frey
Browse files

Add defaults to loader template

parent 21d34c62
Loading
Loading
Loading
Loading
+63 −66
Original line number Diff line number Diff line
@@ -5,8 +5,10 @@ import os
import logging
import deepchem
from deepchem.feat import Featurizer
from deepchem.trans import Transformer
from deepchem.split.splitters import Splitter

from typing import Iterable, List
from typing import List, Tuple, Optional

logger = logging.getLogger(__name__)

@@ -14,28 +16,46 @@ DEFAULT_DIR = deepchem.utils.get_data_dir()
MYDATASET_URL = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/mydataset.tar.gz'
MYDATASET_CSV_URL = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/mydataset.csv'

# dict of accepted featurizers for this dataset
DEFAULT_FEATURIZERS = {
    'Raw': deepchem.feat.RawFeaturizer(),
    'ECFP': deepchem.feat.CircularFingerprint(size=1024),
}

# dict of accepted transformers
DEFAULT_TRANSFORMERS = {
    'Power': deepchem.trans.PowerTransformer(),
}

def load_mydataset(featurizer: str = None,
                   split: str = 'random',
# dict of accepted splitters
DEFAULT_SPLITTERS = {
    'Index': deepchem.splits.IndexSplitter(),
    'Random': deepchem.splits.RandomSplitter(),
}


def load_mydataset(
    featurizer: Featurizer = DEFAULT_FEATURIZERS['Raw'],
    transformers: Tuple[Transformer] = (DEFAULT_TRANSFORMERS['Power']),
    splitter: Splitter = DEFAULT_SPLITTERS['Random'],
    reload: bool = True,
                   move_mean: bool = True,
                   data_dir: str = None,
                   save_dir: str = None,
                   **kwargs) -> Iterable:
    data_dir: Optional[str] = None,
    save_dir: Optional[str] = None,
    **kwargs) -> Tuple[List, Tuple, List]:
  """Load mydataset.

  This is a template for adding a function to load a dataset from
  MoleculeNet. Adjust the global variable URL strings, default parameters,
  and variable names as needed. The function will need to be modified
  to handle the allowed featurizers for your dataset. 
  default featurizers, transformers, and splitters, and variable names as
  needed.

  If `reload = True` and `data_dir` (`save_dir`) is specified, the loader
  will attempt to load the raw dataset (featurized dataset) from disk.
  Otherwise, the dataset will be downloaded from the DeepChem AWS bucket.

  The dataset will be featurized with `featurizer` and separated into
  train/val/test sets according to `split`. Additional kwargs may
  be given for specific featurizers and splitters.
  train/val/test sets according to `splitter`. Additional kwargs may
  be given for specific featurizers, transformers, and splitters.

  Please refer to the MoleculeNet documentation for further information
  https://deepchem.readthedocs.io/en/latest/moleculenet.html.
@@ -44,22 +64,23 @@ def load_mydataset(featurizer: str = None,
  ----------
  featurizer: {List of allowed featurizers for this dataset}
    A featurizer that inherits from deepchem.feat.Featurizer.
  split: {'random', 'stratified', 'index', 'scaffold'}
  transformers: Tuple{List of allowed transformers for this dataset}
    A transformer that inherits from deepchem.trans.Transformer.
  splitter: {List of allowed splitters for this dataset}
    A splitter that inherits from deepchem.splits.splitters.Splitter.
  reload: bool (default True)
    Try to reload dataset from disk if already downloaded. Save to disk
    after featurizing.
  move_mean: bool (default True)
    Center data to have 0 mean after transform.
  data_dir: str, optional
    Path to datasets.
  save_dir: str, optional
    Path to featurized datasets.
  **kwargs: optional arguments to featurizers and splitters.
  **kwargs: optional arguments to methods of featurizers, transformers, and
  splitters.

  Returns
  -------
  tasks, datasets, transformers : iterable
  tasks, datasets, transformers : tuple
    tasks : list
      Column names corresponding to machine learning target variables.
    datasets : tuple
@@ -81,12 +102,12 @@ def load_mydataset(featurizer: str = None,

  Examples
  --------
  >>> import deepchem as dc
  >>> tasks, datasets, transformers = dc.molnet.load_tox21(reload=False)
  >>> train_dataset, val_dataset, test_dataset = datasets
  >>> n_tasks = len(tasks)
  >>> n_features = train_dataset.get_data_shape()[0]
  >>> model = dc.models.MultitaskClassifier(n_tasks, n_features)
  >> import deepchem as dc
  >> tasks, datasets, transformers = dc.molnet.load_tox21(reload=False)
  >> train_dataset, val_dataset, test_dataset = datasets
  >> n_tasks = len(tasks)
  >> n_features = train_dataset.get_data_shape()[0]
  >> model = dc.models.MultitaskClassifier(n_tasks, n_features)

  """

@@ -107,13 +128,21 @@ def load_mydataset(featurizer: str = None,
  if save_dir is None:
    save_dir = DEFAULT_DIR

  # Check for str args to featurizer, splitter, and transformers
  if isinstance(featurizer, str):
    featurizer = DEFAULT_FEATURIZERS[featurizer]
  if isinstance(splitter, str):
    splitter = DEFAULT_SPLITTERS[splitter]
  transformers = [
      DEFAULT_TRANSFORMERS[t] if isinstance(t, str) else t for t in transformers
  ]

  # Reload from disk
  if reload:
    save_folder = os.path.join(save_dir, "mydataset-featurized")
    if not move_mean:
      save_folder = os.path.join(save_folder, str(featurizer) + "_mean_unmoved")
    else:
      save_folder = os.path.join(save_folder, str(featurizer))
    featurizer_name = str(featurizer.__class__.__name__)
    splitter_name = str(splitter.__class__.__name__)
    save_folder = os.path.join(save_dir, "mydataset-featurized",
                               featurizer_name, splitter_name)

    loaded, all_dataset, transformers = deepchem.utils.save.load_dataset_from_disk(
        save_folder)
@@ -132,50 +161,23 @@ def load_mydataset(featurizer: str = None,
      deepchem.utils.download_url(url=MYDATASET_URL, dest_dir=data_dir)
      deepchem.utils.untargz_file(
          os.path.join(data_dir, 'mydataset.tar.gz'), data_dir)
  else:  # only load CSV file
    dataset_file = os.path.join(data_dir, "mydataset.csv")
    if not os.path.exists(dataset_file):
      deepchem.utils.download_url(url=MYDATASET_CSV_URL, dest_dir=data_dir)

  # Handle all allowed SDF featurizers
  if featurizer in sdf_featurizers:
    if featurizer == 'Featurizer1':
      featurizer = deepchem.feat.Featurizer1()
    elif featurizer == 'Featurizer2':
      featurizer = deepchem.feat.Featurizer2()

    loader = deepchem.data.SDFLoader(
        tasks=my_tasks,
        smiles_field="smiles",  # column name holding SMILES strings
        mol_field="mol",  # field where RKit mol objects are stored
        featurizer=featurizer)
  else:  # Handle allowed CSV featurizers
    if featurizer == 'Featurizer3':
      featurizer = deepchem.feat.Featurizer3()
    elif featurizer == 'Featurizer4':
      featurizer = deepchem.feat.Featurizer4()

  else:  # only load CSV file
    dataset_file = os.path.join(data_dir, "mydataset.csv")
    if not os.path.exists(dataset_file):
      deepchem.utils.download_url(url=MYDATASET_CSV_URL, dest_dir=data_dir)

    loader = deepchem.data.CSVLoader(
        tasks=my_tasks, smiles_field="smiles", featurizer=featurizer)

  # Featurize dataset
  dataset = loader.featurize(dataset_file)
  if split is None:  # Must give a recommended split for data
    raise ValueError()

  # Generate Splitter
  splitters = {
      'index':
      deepchem.splits.IndexSplitter(),
      'random':
      deepchem.splits.RandomSplitter(),
      'stratified':
      deepchem.splits.SingletaskStratifiedSplitter(task_number=len(my_tasks)),
      'scaffold':
      deepchem.splits.ScaffoldSplitter()
  }

  splitter = splitters[split]

  # 80/10/10 train/val/test split is default
  frac_train = kwargs.get("frac_train", 0.8)
@@ -188,11 +190,6 @@ def load_mydataset(featurizer: str = None,
      frac_valid=frac_valid,
      frac_test=frac_test)

  transformers = [
      deepchem.trans.NormalizationTransformer(
          transform_y=True, dataset=train_dataset, move_mean=move_mean)
  ]

  for transformer in transformers:
    train_dataset = transformer.transform(train_dataset)
    valid_dataset = transformer.transform(valid_dataset)
+1 −1
Original line number Diff line number Diff line
@@ -20,7 +20,7 @@ please follow the instructions below. Please review the `datasets already availa

5. Ask a member of the technical steering committee to add your .tar.gz or .zip file to the DeepChem AWS bucket. Modify your load function to pull down the dataset from AWS.

6. Submit a [WIR] PR (Work in progress pull request) following the PR `template <https://github.com/deepchem/deepchem/blob/master/docs/molnet_pr_template.md>`_.  
6. Submit a [WIP] PR (Work in progress pull request) following the PR `template <https://github.com/deepchem/deepchem/blob/master/docs/molnet_pr_template.md>`_.  

Load Dataset Template
---------------------