Commit f783fe82 authored by joegomes's avatar joegomes
Browse files

Implemented SingletaskStratifiedSplitter

parent 561862cb
Loading
Loading
Loading
Loading
+107 −0
Original line number Diff line number Diff line
@@ -255,6 +255,113 @@ class RandomStratifiedSplitter(Splitter):
      fold_datasets.append(fold_dataset)
    return fold_datasets

class SingletaskStratifiedSplitter(Splitter):
  """ 
  Class for doing data splits by stratification on a single task.

  Example:

  >>> n_samples = 100
  >>> n_features = 10
  >>> n_tasks = 10
  >>> X = np.random.rand(n_samples, n_features)
  >>> y = np.random.rang(n_samples, n_tasks)
  >>> w = np.ones_like(y)
  >>> dataset = dc.data.NumpyDataset(X, y, w, ids=None)
  >>> splitter = SingletaskStratifiedSplitter(task_number=5)
  >>> train_dataset, test_dataset = splitter.train_valid_split()

  """

  def __init__(self, task_number=0, verbose=False):
    """
    Creates splitter object.

    Parameters:
    task_number: int (Optional, Default 0)
      Task number for stratification.
    verbose: bool (Optional, Default False)
      Controls logging frequency.
    """
    self.task_number = task_number
    self.verbose = verbose

  def k_fold_split(self, dataset, k, seed=None, log_every_n=None):
    """
    Splits compounds into k-folds using stratified sampling.
    Overriding base class k_fold_split.

    Parameters
    ----------
    dataset: dc.data.Dataset object
      Dataset.
    k: int
      Number of folds.
    seed: int (Optional, Default None)
      Random seed.
    log_every_n: int (Optional, Default None)
      Log every n examples (not currently used)

    Returns
    -------
    fold_datasets: List
      List containing dc.data.Dataset objects
    """
    log("Computing K-fold split", self.verbose)
    if directories is None:
      directories = [tempfile.mkdtemp() for _ in range(k)]
    else:
      assert len(directories) == k

    y_s = dataset.y[:, self.task_number]
    sortidx = np.argsort(y_s)
    sortidx_list = np.array_split(sortidx, k)

    fold_datasets = []
    for fold in range(k):
      fold_dir = directories[fold]
      fold_ind = sortidx_list[fold]
      fold_dataset = dataset.select(fold_ind, fold_dir)
      fold_datasets.append(fold_dataset)
    return fold_datasets

  def split(self, dataset, seed=None, frac_train=.8, frac_valid=.1,
            frac_test=.1, log_every_n=None):
    """
    Splits compounds into train/validation/test using stratified sampling.
    """
    # JSG Assert that split fractions can be written as proper fractions over 10.
    # This can be generalized in the future with some common demoninator determination.
    # This will work for 80/20 train/test or 80/10/10 train/valid/test (most use cases).
    np.testing.assert_equal(frac_train + frac_valid + frac_test, 1.)
    np.testing.assert_equal(10*frac_train + 10*frac_valid + 10*frac_test, 10.)
    
    if not seed is None:
      np.random.seed(seed)

    y_s = dataset.y[self.task_number]
    sortidx = np.argsort(y_s)

    split_cd = 10
    train_cutoff = int(frac_train * split_cd)
    valid_cutoff = int(frac_valid * split_cd)
    test_cutoff = int(frac_test * split_cd)

    train_idx = np.array([])
    valid_idx = np.array([])
    test_idx = np.array([])

    while sortidx.shape[0] >= split_cd:
      sortidx_split, sortidx = np.split(sortidx, [split_cd])
      shuffled = np.random.permutation(range(split_cd))
      train_idx = np.hstack([train_idx, sortidx_split[shuffled(:train_cutoff)]])
      valid_idx = np.hstack([valid_idx, sortidx_split[shuffled(train_cutoff:valid_cutoff)]])
      test_idx = np.hstack([test_idx, sortidx_split[shuffled(valid_cutoff:)]])

    # Append remaining examples to train
    if sortidx.shape[0] > 0: np.hstack([train_idx, sortidx]) 

    return (train_idx, valid_idx, test_idx)     

class MolecularWeightSplitter(Splitter):
  """