Unverified Commit e99b6492 authored by Suzukazole's avatar Suzukazole
Browse files

MIT loader and init

parent 8e8bc721
Loading
Loading
Loading
Loading
+57 −12
Original line number Diff line number Diff line
@@ -28,25 +28,70 @@ USPTO_STEREO_VALID = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/U

class _USPTOLoader(_MolnetLoader):

  def __init__(self, *args, subset: str, sep_reagent: bool, **kwargs):
    super(_USPTOLoader, self).__init__(*args, **kwargs)
    self.subset = subset
    self.sep_reagent = sep_reagent
    self.name = 'USPTO_' + subset

  def create_dataset(self) -> DiskDataset:
    dataset_file = os.path.join(self.data_dir, "USPTO_MIT_test.csv")
    if not os.path.exists(dataset_file):
      dc.utils.data_utils.download_url(url=USPTO_MIT_TEST,
                                       dest_dir=self.data_dir)
    loader = dc.data.CSVLoader(tasks=self.tasks,
                               feature_field="smiles",
                               featurizer=self.featurizer)
    return loader.create_dataset(dataset_file, shard_size=8192)
    if self.subset not in ['MIT', 'STEREO']:
      raise ValueError("Valid Subset names are MIT and STEREO.")
    if self.subset == 'MIT':
      train_file = os.path.join(self.data_dir, USPTO_MIT_TRAIN)
      test_file = os.path.join(self.data_dir, USPTO_MIT_TEST)
      valid_file = os.path.join(self.data_dir, USPTO_MIT_VALID)

      if not os.path.exists(train_file):

        logger.info("Downloading training file...")
        dc.utils.data_utils.download_url(
            url=USPTO_MIT_TRAIN, dest_dir=self.data_dir)
        logger.info("Training file download complete.")

        logger.info("Downloading test file...")
        dc.utils.data_utils.download_url(
            url=USPTO_MIT_TRAIN, dest_dir=self.data_dir)
        logger.info("Test file download complete.")

        logger.info("Downloading validation file...")
        dc.utils.data_utils.download_url(
            url=USPTO_MIT_TRAIN, dest_dir=self.data_dir)
        logger.info("Validation file download complete.")

      loader = dc.data.CSVLoader(
          tasks=self.tasks, feature_field="smiles", featurizer=self.featurizer)

      logger.info("Loading train dataset..")
      train_file = loader.create_dataset(train_file, shard_size=8192)
      logger.info("Loading test dataset..")
      test_file = loader.create_dataset(test_file, shard_size=8192)
      logger.info("Loading validation dataset..")
      valid_file = loader.create_dataset(valid_file, shard_size=8192)
      logger.info("Loading successful!")

    return train_file, test_file, valid_file


def load_uspto(
    featurizer: Union[dc.feat.Featurizer, str] = None,
    splitter: Union[dc.splits.Splitter, str, None] = None,
    transformers: List[Union[TransformerGenerator, str]] = None,
    featurizer=None,
    splitter=None,
    transformers=None,
    reload: bool = True,
    data_dir: Optional[str] = None,
    save_dir: Optional[str] = None,
    subset: str = "MIT",
    sep_reagent: bool = True,
    **kwargs
) -> Tuple[List[str], Tuple[DiskDataset, ...], List[dc.trans.Transformer]]:

  pass
  loader = _USPTOLoader(
      featurizer,
      splitter,
      transformers,
      data_dir,
      save_dir,
      subset=subset,
      sep_reagent=sep_reagent,
      **kwargs)
  return loader.load_dataset(loader.name, reload)