Unverified Commit 95f43b86 authored by Suzukazole's avatar Suzukazole
Browse files

Removed toggle, added support for 50K

parent 7e936532
Loading
Loading
Loading
Loading
+1 −14
Original line number Diff line number Diff line
@@ -88,8 +88,7 @@ class _MolnetLoader(object):
               splitter: Union[dc.splits.Splitter, str, None],
               transformer_generators: List[Union[TransformerGenerator, str]],
               tasks: List[str], data_dir: Optional[str],
               save_dir: Optional[str],
               precomputed_splits: Optional[bool], **kwargs):
               save_dir: Optional[str], **kwargs):
    """Construct an object for loading a dataset.

    Parameters
@@ -133,7 +132,6 @@ class _MolnetLoader(object):
    self.tasks = list(tasks)
    self.data_dir = data_dir
    self.save_dir = save_dir
    self.precomputed_splits = precomputed_splits
    self.args = kwargs

  def load_dataset(
@@ -172,13 +170,6 @@ class _MolnetLoader(object):
            save_folder)
        if all_dataset is not None:
          return self.tasks, all_dataset, transformers
    
    # Try to load precomputed splits.

    if self.precomputed_splits is True:
      train, valid, test = self.load_precomputed_splits()
      return (train, valid, test)

    # Create the dataset

    logger.info("About to featurize %s dataset." % name)
@@ -218,7 +209,3 @@ class _MolnetLoader(object):
  def create_dataset(self) -> Dataset:
    """Subclasses must implement this to load the dataset."""
    raise NotImplementedError()

  def load_precomputed_splits(self) -> Tuple[Dataset, ...]:
    """Subclasses must implement this to load precomputed train/test/eval splits."""
    raise NotImplementedError()
+18 −61
Original line number Diff line number Diff line
@@ -16,13 +16,9 @@ logger = logging.getLogger(__name__)

DEFAULT_DIR = deepchem.utils.data_utils.get_data_dir()

USPTO_MIT_TRAIN = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/USPTO_MIT_train.csv"
USPTO_MIT_TEST = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/USPTO_MIT_test.csv"
USPTO_MIT_VALID = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/USPTO_MIT_val.csv"

USPTO_STEREO_TRAIN = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/USPTO_STEREO_train.csv"
USPTO_STEREO_TEST = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/USPTO_STEREO_test.csv"
USPTO_STEREO_VALID = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/USPTO_STEREO_val.csv"
USPTO_MIT_URL = "https://deepchemdata.s3.us-west-1.amazonaws.com/datasets/USPTO_MIT.csv"
USPTO_STEREO_URL = "https://deepchemdata.s3.us-west-1.amazonaws.com/datasets/USPTO_STEREO.csv"
USPTO_50K_URL = "https://deepchemdata.s3.us-west-1.amazonaws.com/datasets/USPTO_50K.csv"


class _USPTOLoader(_MolnetLoader):
@@ -34,72 +30,32 @@ class _USPTOLoader(_MolnetLoader):
    self.name = 'USPTO_' + subset

  def create_dataset(self) -> Dataset:
    #####INCOMPLETE/INCORRECT: I don't think this is the right way to bypass the splitter!
    if self.subset not in ['MIT', 'STEREO']:
      raise ValueError("Valid Subset names are MIT and STEREO.")
    if self.subset not in ['MIT', 'STEREO', '50K']:
      raise ValueError("Valid Subset names are MIT, STEREO and 50K.")

    if self.subset == 'MIT':
      train_file = os.path.join(self.data_dir, USPTO_MIT_TRAIN)
      test_file = os.path.join(self.data_dir, USPTO_MIT_TEST)
      valid_file = os.path.join(self.data_dir, USPTO_MIT_VALID)

      if not os.path.exists(train_file):
        logger.info("Downloading training file...")
        dc.utils.data_utils.download_url(
            url=USPTO_MIT_TRAIN, dest_dir=self.data_dir)
        logger.info("Training file download complete.")

        logger.info("Downloading test file...")
        dc.utils.data_utils.download_url(
            url=USPTO_MIT_TEST, dest_dir=self.data_dir)
        logger.info("Test file download complete.")

        logger.info("Downloading validation file...")
        dc.utils.data_utils.download_url(
            url=USPTO_MIT_VALID, dest_dir=self.data_dir)
        logger.info("Validation file download complete.")
      dataset_url = USPTO_MIT_URL

    if self.subset == 'STEREO':
      train_file = os.path.join(self.data_dir, USPTO_STEREO_TRAIN)
      test_file = os.path.join(self.data_dir, USPTO_STEREO_TEST)
      valid_file = os.path.join(self.data_dir, USPTO_STEREO_VALID)

      if not os.path.exists(train_file):
        logger.info("Downloading training file...")
        dc.utils.data_utils.download_url(
            url=USPTO_STEREO_TRAIN, dest_dir=self.data_dir)
        logger.info("Training file download complete.")

        logger.info("Downloading test file...")
        dc.utils.data_utils.download_url(
            url=USPTO_STEREO_TEST, dest_dir=self.data_dir)
        logger.info("Test file download complete.")
      dataset_url = USPTO_STEREO_URL

        logger.info("Downloading validation file...")
        dc.utils.data_utils.download_url(
            url=USPTO_STEREO_VALID, dest_dir=self.data_dir)
        logger.info("Validation file download complete.")
    if self.subset == '50K':
      dataset_url = USPTO_50K_URL

    loader = dc.data.CSVLoader(
        tasks=self.tasks, feature_field="smiles", featurizer=self.featurizer)
    dataset_file = os.path.join(self.data_dir, self.name + '.csv')

    logger.info("Loading train dataset..")
    train_file = loader.create_dataset(train_file, shard_size=8192)
    logger.info("Loading test dataset..")
    test_file = loader.create_dataset(test_file, shard_size=8192)
    logger.info("Loading validation dataset..")
    valid_file = loader.create_dataset(valid_file, shard_size=8192)
    logger.info("Loading successful!")
    if not os.path.exists(dataset_file):
      logger.info("Downloading dataset...")
      dc.utils.data_utils.download_url(url=dataset_url, dest_dir=self.data_dir)
      logger.info("Dataset download complete.")

    #need to figure out how to return the train, test and valid files!
    return (train_file, test_file, valid_file)
    loader = dc.data.CSVLoader(tasks=[], featurizer=self.featurizer)

    def load_precomputed_splits(self) -> List[Dataset, ...]:
      pass
    return loader.create_dataset(dataset_file, shard_size=8192)


def load_uspto(
    featurizer=None,
    featurizer=None,  # should I remove this?
    splitter=None,
    transformers=None,
    reload: bool = True,
@@ -110,6 +66,7 @@ def load_uspto(
    **kwargs
) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]:

  featurizer = dc.feat.UserDefinedFeaturizer([])
  loader = _USPTOLoader(
      featurizer,
      splitter,