Commit 3eb87363 authored by Nathan Frey's avatar Nathan Frey
Browse files

Add tests for get_defaults

parent 94d6db05
Loading
Loading
Loading
Loading
+2 −2
Original line number Original line Diff line number Diff line
@@ -7,7 +7,7 @@ import importlib
import inspect
import inspect
import logging
import logging
import json
import json
from typing import Dict, List
from typing import Dict, List, Any


from deepchem.feat.base_classes import Featurizer
from deepchem.feat.base_classes import Featurizer
from deepchem.trans.transformers import Transformer
from deepchem.trans.transformers import Transformer
@@ -16,7 +16,7 @@ from deepchem.splits.splitters import Splitter
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)




def get_defaults(module_name: str = None) -> Dict[str, object]:
def get_defaults(module_name: str = None) -> Dict[str, Any]:
  """Get featurizers, transformers, and splitters.
  """Get featurizers, transformers, and splitters.


  This function returns a dictionary with class names as keys and classes
  This function returns a dictionary with class names as keys and classes
+11 −3
Original line number Original line Diff line number Diff line
@@ -18,15 +18,23 @@ MYDATASET_URL = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/
MYDATASET_CSV_URL = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/mydataset.csv'
MYDATASET_CSV_URL = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/mydataset.csv'


# dict of accepted featurizers for this dataset
# dict of accepted featurizers for this dataset
# modify the returned dicts your dataset
# modify the returned dicts for your dataset
DEFAULT_FEATURIZERS = get_defaults("feat")
DEFAULT_FEATURIZERS = get_defaults("feat")


# Names of supported featurizers
mydataset_featurizers = ['Featurizer1', 'Featurizer2', 'Featurizer3']
DEFAULT_FEATURIZERS = {k: DEFAULT_FEATURIZERS[k] for k in mydataset_featurizers}

# dict of accepted transformers
# dict of accepted transformers
DEFAULT_TRANSFORMERS = get_defaults("trans")
DEFAULT_TRANSFORMERS = get_defaults("trans")


# dict of accepted splitters
# dict of accepted splitters
DEFAULT_SPLITTERS = get_defaults("split")
DEFAULT_SPLITTERS = get_defaults("split")


# names of supported splitters
mydataset_splitters = ['Splitter1', 'Splitter2', 'Splitter3']
DEFAULT_SPLITTERS = {k: DEFAULT_SPLITTERS[k] for k in mydataset_splitters}



def load_mydataset(
def load_mydataset(
    featurizer: Featurizer = DEFAULT_FEATURIZERS['RawFeaturizer'],
    featurizer: Featurizer = DEFAULT_FEATURIZERS['RawFeaturizer'],
@@ -203,9 +211,9 @@ def load_mydataset(


  # Initialize transformers
  # Initialize transformers
  transformers = [
  transformers = [
      DEFAULT_TRANSFORMERS[t](dataset, **transformer_kwargs[t])
      DEFAULT_TRANSFORMERS[t](dataset=dataset, **transformer_kwargs[t])
      if isinstance(t, str) else t(
      if isinstance(t, str) else t(
          dataset, **transformer_kwargs[str(t.__class__.__name__)])
          dataset=dataset, **transformer_kwargs[str(t.__class__.__name__)])
      for t in transformers
      for t in transformers
  ]
  ]


+40 −0
Original line number Original line Diff line number Diff line
"""
Tests for getting featurizer, transformer, and splitter classes.
"""
import csv
import tempfile
import unittest

import numpy as np
import os
import pytest

import deepchem as dc
from deepchem.feat.base_classes import Featurizer
from deepchem.trans.transformers import Transformer
from deepchem.splits.splitters import Splitter
from deepchem.molnet.defaults import get_defaults


class TestDefaults(unittest.TestCase):
  """
  Tests for getting featurizer, transformer, and splitter classes.
  """

  def test_defaults(self):
    """Test getting defaults for MolNet loaders."""
    feats = get_defaults("feat")
    trans = get_defaults("trans")
    splits = get_defaults("splits")

    fkey = next(iter(feats))
    assert type(fkey) == str
    assert issubclass(feats[fkey], Featurizer)

    tkey = next(iter(trans))
    assert type(tkey) == str
    assert issubclass(trans[tkey], Transformer)

    skey = next(iter(splits))
    assert type(skey) == str
    assert issubclass(splits[skey], Splitter)