Unverified Commit 7e936532 authored by Suzukazole's avatar Suzukazole
Browse files

Add precomputed_splits toggle

parent 140761dd
Loading
Loading
Loading
Loading
+13 −1
Original line number Diff line number Diff line
@@ -88,7 +88,8 @@ class _MolnetLoader(object):
               splitter: Union[dc.splits.Splitter, str, None],
               transformer_generators: List[Union[TransformerGenerator, str]],
               tasks: List[str], data_dir: Optional[str],
               save_dir: Optional[str], **kwargs):
               save_dir: Optional[str],
               precomputed_splits: Optional[bool], **kwargs):
    """Construct an object for loading a dataset.

    Parameters
@@ -132,6 +133,7 @@ class _MolnetLoader(object):
    self.tasks = list(tasks)
    self.data_dir = data_dir
    self.save_dir = save_dir
    self.precomputed_splits = precomputed_splits
    self.args = kwargs

  def load_dataset(
@@ -171,6 +173,12 @@ class _MolnetLoader(object):
        if all_dataset is not None:
          return self.tasks, all_dataset, transformers
    
    # Try to load precomputed splits.

    if self.precomputed_splits is True:
      train, valid, test = self.load_precomputed_splits()
      return (train, valid, test)

    # Create the dataset

    logger.info("About to featurize %s dataset." % name)
@@ -210,3 +218,7 @@ class _MolnetLoader(object):
  def create_dataset(self) -> Dataset:
    """Subclasses must implement this to load the dataset."""
    raise NotImplementedError()

  def load_precomputed_splits(self) -> Tuple[Dataset, ...]:
    """Subclasses must implement this to load precomputed train/test/eval splits."""
    raise NotImplementedError()
+4 −1
Original line number Diff line number Diff line
@@ -33,7 +33,7 @@ class _USPTOLoader(_MolnetLoader):
    self.sep_reagent = sep_reagent
    self.name = 'USPTO_' + subset

  def create_dataset(self) -> Tuple[Dataset, ...]:
  def create_dataset(self) -> Dataset:
    #####INCOMPLETE/INCORRECT: I don't think this is the right way to bypass the splitter!
    if self.subset not in ['MIT', 'STEREO']:
      raise ValueError("Valid Subset names are MIT and STEREO.")
@@ -94,6 +94,9 @@ class _USPTOLoader(_MolnetLoader):
    #need to figure out how to return the train, test and valid files!
    return (train_file, test_file, valid_file)

    def load_precomputed_splits(self) -> List[Dataset, ...]:
      pass


def load_uspto(
    featurizer=None,