Unverified Commit 9f9a2ed6 authored by Suzukazole's avatar Suzukazole
Browse files

fix url typo

parent e99b6492
Loading
Loading
Loading
Loading
+39 −16
Original line number Diff line number Diff line
@@ -4,11 +4,10 @@ Loads synthetic reaction datasets from USPTO.
This file contains loaders for synthetic reaction datasets from the US Patenent Office. http://nextmovesoftware.com/blog/2014/02/27/unleashing-over-a-million-reactions-into-the-wild/.
"""
import os
import csv
import logging
import deepchem
import numpy as np
from deepchem.data import DiskDataset
from deepchem.data import Dataset
from deepchem.molnet.load_function.molnet_loader import _MolnetLoader
from typing import List, Optional, Tuple, Union
import deepchem as dc
@@ -34,9 +33,11 @@ class _USPTOLoader(_MolnetLoader):
    self.sep_reagent = sep_reagent
    self.name = 'USPTO_' + subset

  def create_dataset(self) -> DiskDataset:
  def create_dataset(self) -> Tuple[Dataset, ...]:
    #####INCOMPLETE/INCORRECT: I don'd think this is the right way to bypass the splitter!
    if self.subset not in ['MIT', 'STEREO']:
      raise ValueError("Valid Subset names are MIT and STEREO.")

    if self.subset == 'MIT':
      train_file = os.path.join(self.data_dir, USPTO_MIT_TRAIN)
      test_file = os.path.join(self.data_dir, USPTO_MIT_TEST)
@@ -51,12 +52,33 @@ class _USPTOLoader(_MolnetLoader):

        logger.info("Downloading test file...")
        dc.utils.data_utils.download_url(
            url=USPTO_MIT_TRAIN, dest_dir=self.data_dir)
            url=USPTO_MIT_TEST, dest_dir=self.data_dir)
        logger.info("Test file download complete.")

        logger.info("Downloading validation file...")
        dc.utils.data_utils.download_url(
            url=USPTO_MIT_TRAIN, dest_dir=self.data_dir)
            url=USPTO_MIT_VALID, dest_dir=self.data_dir)
        logger.info("Validation file download complete.")
      if self.subset == 'STEREO':
        train_file = os.path.join(self.data_dir, USPTO_STEREO_TRAIN)
        test_file = os.path.join(self.data_dir, USPTO_STEREO_TEST)
        valid_file = os.path.join(self.data_dir, USPTO_STEREO_VALID)

        if not os.path.exists(train_file):

          logger.info("Downloading training file...")
          dc.utils.data_utils.download_url(
              url=USPTO_STEREO_TRAIN, dest_dir=self.data_dir)
          logger.info("Training file download complete.")

          logger.info("Downloading test file...")
          dc.utils.data_utils.download_url(
              url=USPTO_STEREO_TEST, dest_dir=self.data_dir)
          logger.info("Test file download complete.")

          logger.info("Downloading validation file...")
          dc.utils.data_utils.download_url(
              url=USPTO_STEREO_VALID, dest_dir=self.data_dir)
          logger.info("Validation file download complete.")

    loader = dc.data.CSVLoader(
@@ -70,7 +92,8 @@ class _USPTOLoader(_MolnetLoader):
    valid_file = loader.create_dataset(valid_file, shard_size=8192)
    logger.info("Loading successful!")

    return train_file, test_file, valid_file
    #need to figure out how to return the train, test and valid files!
    return (train_file, test_file, valid_file)  


def load_uspto(
@@ -83,7 +106,7 @@ def load_uspto(
    subset: str = "MIT",
    sep_reagent: bool = True,
    **kwargs
) -> Tuple[List[str], Tuple[DiskDataset, ...], List[dc.trans.Transformer]]:
) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]:

  loader = _USPTOLoader(
      featurizer,