Unverified Commit f33bc92d authored by Suzukazole's avatar Suzukazole
Browse files

add skip transform

parent 03bde69a
Loading
Loading
Loading
Loading
+19 −8
Original line number Diff line number Diff line
@@ -25,11 +25,12 @@ USPTO_TASK: List[str] = []

class _USPTOLoader(_MolnetLoader):

  def __init__(self, *args, subset: str, sep_reagent: bool, **kwargs):
  def __init__(self, *args, subset: str, sep_reagent: bool, skip_transform: bool, **kwargs):
    super(_USPTOLoader, self).__init__(*args, **kwargs)
    self.subset = subset
    self.sep_reagent = sep_reagent
    self.name = 'USPTO_' + subset
    self.skip_transform = skip_transform

  def create_dataset(self) -> Dataset:
    if self.subset not in ['MIT', 'STEREO', '50K', 'FULL']:
@@ -72,7 +73,8 @@ def load_uspto(
    data_dir: Optional[str] = None,
    save_dir: Optional[str] = None,
    subset: str = "MIT",
    sep_reagent: bool = False,
    sep_reagent: bool = True,
    skip_transform: bool = True,
    **kwargs
) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]:
  """Load USPTO Datasets.
@@ -121,7 +123,10 @@ def load_uspto(
  subset: str (default 'MIT')
    Subset of dataset to download. 'FULL', 'MIT', 'STEREO', and '50K' are supported.
  sep_reagent: bool (default True)
    Toggle to load dataset with reactants and reado I call it
    Toggle to load dataset with reactants and reagents either separated or mixed.
  skip_transform: bool (default True)
    Toggle to skip the source/target transformation.

  Returns
  -------
  tasks, datasets, transformers : tuple
@@ -148,6 +153,11 @@ def load_uspto(
         arXiv preprint arXiv:2001.01408.
  """
  
  if skip_transform:
    if not sep_reagent:
      raise ValueError("To enable mixed training you must not skip the transformation.")
    transformers = []
  else:
    if sep_reagent:
      transformers = [TransformerGenerator(dc.trans.RxnSplitTransformer, sep_reagent=True)]
    else:
@@ -162,5 +172,6 @@ def load_uspto(
      save_dir,
      subset=subset,
      sep_reagent=sep_reagent,
      skip_transform=skip_transform,
      **kwargs)
  return loader.load_dataset(loader.name, reload)