Commit 0e9b9b73 authored by nd-02110114's avatar nd-02110114
Browse files

Revert "🐛 fix bug"

This reverts commit b23600e1.
parent 2f82ffa6
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -37,7 +37,7 @@ from deepchem.molnet.load_function.material_datasets.load_perovskite import load
from deepchem.molnet.load_function.material_datasets.load_mp_formation_energy import load_mp_formation_energy
from deepchem.molnet.load_function.material_datasets.load_mp_metallicity import load_mp_metallicity

from deepchem.molnet.load_function.molnet_loader import TransformerGenerator, _MolnetLoader
from deepchem.molnet.load_function.molnet_loader import featurizers, splitters, transformers, TransformerGenerator, _MolnetLoader

from deepchem.molnet.dnasim import simulate_motif_density_localization
from deepchem.molnet.dnasim import simulate_motif_counting
+31 −30
Original line number Diff line number Diff line
@@ -46,6 +46,37 @@ class TransformerGenerator(object):
    return name


featurizers = {
    'ecfp': dc.feat.CircularFingerprint(size=1024),
    'graphconv': dc.feat.ConvMolFeaturizer(),
    'weave': dc.feat.WeaveFeaturizer(),
    'raw': dc.feat.RawFeaturizer(),
    'smiles2img': dc.feat.SmilesToImage(img_size=80, img_spec='std')
}

splitters = {
    'index': dc.splits.IndexSplitter(),
    'random': dc.splits.RandomSplitter(),
    'scaffold': dc.splits.ScaffoldSplitter(),
    'butina': dc.splits.ButinaSplitter(),
    'task': dc.splits.TaskSplitter(),
    'stratified': dc.splits.RandomStratifiedSplitter()
}

transformers = {
    'balancing':
    TransformerGenerator(dc.trans.BalancingTransformer),
    'normalization':
    TransformerGenerator(dc.trans.NormalizationTransformer, transform_y=True),
    'minmax':
    TransformerGenerator(dc.trans.MinMaxTransformer, transform_y=True),
    'clipping':
    TransformerGenerator(dc.trans.ClippingTransformer, transform_y=True),
    'log':
    TransformerGenerator(dc.trans.LogTransformer, transform_y=True)
}


class _MolnetLoader(object):
  """The class provides common functionality used by many molnet loader functions.
  It is an abstract class.  Subclasses implement loading of particular datasets.
@@ -79,36 +110,6 @@ class _MolnetLoader(object):
    save_dir: str
      a directory to save the dataset in
    """
    featurizers = {
        'ecfp': dc.feat.CircularFingerprint(size=1024),
        'graphconv': dc.feat.ConvMolFeaturizer(),
        'weave': dc.feat.WeaveFeaturizer(),
        'raw': dc.feat.RawFeaturizer(),
        'smiles2img': dc.feat.SmilesToImage(img_size=80, img_spec='std')
    }

    splitters = {
        'index': dc.splits.IndexSplitter(),
        'random': dc.splits.RandomSplitter(),
        'scaffold': dc.splits.ScaffoldSplitter(),
        'butina': dc.splits.ButinaSplitter(),
        'task': dc.splits.TaskSplitter(),
        'stratified': dc.splits.RandomStratifiedSplitter()
    }

    transformers = {
        'balancing':
        TransformerGenerator(dc.trans.BalancingTransformer),
        'normalization':
        TransformerGenerator(dc.trans.NormalizationTransformer, transform_y=True),
        'minmax':
        TransformerGenerator(dc.trans.MinMaxTransformer, transform_y=True),
        'clipping':
        TransformerGenerator(dc.trans.ClippingTransformer, transform_y=True),
        'log':
        TransformerGenerator(dc.trans.LogTransformer, transform_y=True)
    }

    if 'split' in kwargs:
      splitter = kwargs['split']
      logger.warning("'split' is deprecated.  Use 'splitter' instead.")