Unverified Commit 8e2543c4 authored by Suzukazole's avatar Suzukazole
Browse files

fix tasks, transformers

parent 738479f3
Loading
Loading
Loading
Loading
+12 −7
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ import logging
import deepchem
import numpy as np
from deepchem.data import Dataset
from deepchem.molnet.load_function.molnet_loader import _MolnetLoader
from deepchem.molnet.load_function.molnet_loader import TransformerGenerator, _MolnetLoader
from typing import List, Optional, Tuple, Union
import deepchem as dc

@@ -21,6 +21,7 @@ USPTO_STEREO_URL = "https://deepchemdata.s3.us-west-1.amazonaws.com/datasets/USP
USPTO_50K_URL = "https://deepchemdata.s3.us-west-1.amazonaws.com/datasets/USPTO_50K.csv"
USPTO_FULL_URL = "https://deepchemdata.s3.us-west-1.amazonaws.com/datasets/USPTO_FULL.csv"

USPTO_TASK = []

class _USPTOLoader(_MolnetLoader):

@@ -45,6 +46,8 @@ class _USPTOLoader(_MolnetLoader):

    if self.subset == 'FULL':
      dataset_url = USPTO_FULL_URL
      if self.splitter == 'SpecifiedSplitter':
        raise ValueError("There is no pre computed split for the full dataset, use a custom split instead!")

    dataset_file = os.path.join(self.data_dir, self.name + '.csv')

@@ -53,15 +56,15 @@ class _USPTOLoader(_MolnetLoader):
      dc.utils.data_utils.download_url(url=dataset_url, dest_dir=self.data_dir)
      logger.info("Dataset download complete.")

    loader = dc.data.CSVLoader(tasks=[], featurizer=self.featurizer)
    loader = dc.data.CSVLoader(tasks=self.tasks, featurizer=self.featurizer)

    return loader.create_dataset(dataset_file, shard_size=8192)


def load_uspto(
    featurizer=None,  # should I remove this?
    featurizer: Union[dc.feat.Featurizer, str] = dc.feat.UserDefinedFeaturizer([]),  # This will be changed to dummy featurizer!
    splitter: Union[dc.splits.Splitter, str, None] = 'SpecifiedSplitter',
    transformers=None,
    transformers: List[Union[TransformerGenerator, str]] = [],
    reload: bool = True,
    data_dir: Optional[str] = None,
    save_dir: Optional[str] = None,
@@ -116,8 +119,7 @@ def load_uspto(
  subset : str (default 'MIT')
    Subset of dataset to download. 'FULL', 'MIT', 'STEREO', and '50K' are supported.
  sep_reagent : bool (default True)
    Toggle to load dataset with reactants and reagents separated or mixed.
    
    Toggle to load dataset with reactants and reado I call it 
  Returns
  -------
  tasks, datasets, transformers : tuple
@@ -146,11 +148,14 @@ def load_uspto(
  # get test and valid lists if subset is MIT, 50K, STEREO and splitter = specified.
  # if subset is Full use splitter passed by the user.
  # splitter = dc.splits.SpecifiedSplitter(valid_indices=,test_indices=)
  featurizer = dc.feat.UserDefinedFeaturizer([])


 
  loader = _USPTOLoader(
      featurizer,
      splitter,
      transformers,
      USPTO_TASK,
      data_dir,
      save_dir,
      subset=subset,