Commit b2fadcf2 authored by nd-02110114's avatar nd-02110114
Browse files

🐛 fix import error

parent 208050eb
Loading
Loading
Loading
Loading
+36 −30
Original line number Diff line number Diff line
@@ -46,6 +46,42 @@ class TransformerGenerator(object):
    return name


featurizers = {
    'graphconv': dc.feat.ConvMolFeaturizer(),
    'weave': dc.feat.WeaveFeaturizer(),
}

try:
  featurizers['ecfp'] = dc.feat.CircularFingerprint(size=1024)
  featurizers['raw'] = dc.feat.RawFeaturizer()
  featurizers['smiles2img'] = dc.feat.SmilesToImage(img_size=80, img_spec='std')
  featurizers['onehot'] = dc.feat.OneHotFeaturizer()
except:
  pass

splitters = {
    'index': dc.splits.IndexSplitter(),
    'random': dc.splits.RandomSplitter(),
    'scaffold': dc.splits.ScaffoldSplitter(),
    'butina': dc.splits.ButinaSplitter(),
    'task': dc.splits.TaskSplitter(),
    'stratified': dc.splits.RandomStratifiedSplitter()
}

transformers = {
    'balancing':
    TransformerGenerator(dc.trans.BalancingTransformer),
    'normalization':
    TransformerGenerator(dc.trans.NormalizationTransformer, transform_y=True),
    'minmax':
    TransformerGenerator(dc.trans.MinMaxTransformer, transform_y=True),
    'clipping':
    TransformerGenerator(dc.trans.ClippingTransformer, transform_y=True),
    'log':
    TransformerGenerator(dc.trans.LogTransformer, transform_y=True)
}


class _MolnetLoader(object):
  """The class provides common functionality used by many molnet loader functions.
  It is an abstract class.  Subclasses implement loading of particular datasets.
@@ -79,36 +115,6 @@ class _MolnetLoader(object):
    save_dir: str
      a directory to save the dataset in
    """
    featurizers = {
        'ecfp': dc.feat.CircularFingerprint(size=1024),
        'graphconv': dc.feat.ConvMolFeaturizer(),
        'weave': dc.feat.WeaveFeaturizer(),
        'raw': dc.feat.RawFeaturizer(),
        'smiles2img': dc.feat.SmilesToImage(img_size=80, img_spec='std')
    }

    splitters = {
        'index': dc.splits.IndexSplitter(),
        'random': dc.splits.RandomSplitter(),
        'scaffold': dc.splits.ScaffoldSplitter(),
        'butina': dc.splits.ButinaSplitter(),
        'task': dc.splits.TaskSplitter(),
        'stratified': dc.splits.RandomStratifiedSplitter()
    }

    transformers = {
        'balancing':
        TransformerGenerator(dc.trans.BalancingTransformer),
        'normalization':
        TransformerGenerator(dc.trans.NormalizationTransformer, transform_y=True),
        'minmax':
        TransformerGenerator(dc.trans.MinMaxTransformer, transform_y=True),
        'clipping':
        TransformerGenerator(dc.trans.ClippingTransformer, transform_y=True),
        'log':
        TransformerGenerator(dc.trans.LogTransformer, transform_y=True)
    }

    if 'split' in kwargs:
      splitter = kwargs['split']
      logger.warning("'split' is deprecated.  Use 'splitter' instead.")
+1 −1
Original line number Diff line number Diff line
@@ -54,7 +54,7 @@ class _Zinc15Loader(_MolnetLoader):


def load_zinc15(
    featurizer: Union[dc.feat.Featurizer, str] = dc.feat.OneHotFeaturizer(),
    featurizer: Union[dc.feat.Featurizer, str] = 'OneHot',
    splitter: Union[dc.splits.Splitter, str, None] = 'random',
    transformers: List[Union[TransformerGenerator, str]] = ['normalization'],
    reload: bool = True,