Commit 483dbac8 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #647 from lilleswing/singletaskstratified

Splitter k_fold interface
parents 73160191 650843e5
Loading
Loading
Loading
Loading
+9 −3
Original line number Diff line number Diff line
@@ -53,7 +53,7 @@ class Splitter(object):
    """Creates splitter object."""
    self.verbose = verbose

  def k_fold_split(self, dataset, k, directories=None):
  def k_fold_split(self, dataset, k, directories=None, **kwargs):
    """Does K-fold split of dataset."""
    log("Computing K-fold split", self.verbose)
    if directories is None:
@@ -278,7 +278,7 @@ class RandomStratifiedSplitter(Splitter):

    return train_dataset, valid_dataset, test_dataset

  def k_fold_split(self, dataset, k, directories=None):
  def k_fold_split(self, dataset, k, directories=None, **kwargs):
    """Needs custom implementation due to ragged splits for stratification."""
    log("Computing K-fold split", self.verbose)
    if directories is None:
@@ -332,7 +332,13 @@ class SingletaskStratifiedSplitter(Splitter):
    self.task_number = task_number
    self.verbose = verbose

  def k_fold_split(self, dataset, k, seed=None, log_every_n=None):
  def k_fold_split(self,
                   dataset,
                   k,
                   directories=None,
                   seed=None,
                   log_every_n=None,
                   **kwargs):
    """
        Splits compounds into k-folds using stratified sampling.
        Overriding base class k_fold_split.