Unverified Commit 03bde69a authored by Suzukazole's avatar Suzukazole
Browse files

add mixed training to loader

parent d6368781
Loading
Loading
Loading
Loading
+7 −2
Original line number Diff line number Diff line
@@ -72,7 +72,7 @@ def load_uspto(
    data_dir: Optional[str] = None,
    save_dir: Optional[str] = None,
    subset: str = "MIT",
    sep_reagent: bool = True,  # functionality to be added!
    sep_reagent: bool = False,
    **kwargs
) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]:
  """Load USPTO Datasets.
@@ -148,6 +148,11 @@ def load_uspto(
         arXiv preprint arXiv:2001.01408.
  """
  
  if sep_reagent:
    transformers = [TransformerGenerator(dc.trans.RxnSplitTransformer, sep_reagent=True)]
  else:
    transformers = [TransformerGenerator(dc.trans.RxnSplitTransformer, sep_reagent=False)]
  
  loader = _USPTOLoader(
      featurizer,
      splitter,