Commit 028387cd authored by joegomes's avatar joegomes
Browse files

Add SingletaskStratifiedSplitter test

parent f783fe82
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -354,9 +354,9 @@ class SingletaskStratifiedSplitter(Splitter):
    while sortidx.shape[0] >= split_cd:
      sortidx_split, sortidx = np.split(sortidx, [split_cd])
      shuffled = np.random.permutation(range(split_cd))
      train_idx = np.hstack([train_idx, sortidx_split[shuffled(:train_cutoff)]])
      valid_idx = np.hstack([valid_idx, sortidx_split[shuffled(train_cutoff:valid_cutoff)]])
      test_idx = np.hstack([test_idx, sortidx_split[shuffled(valid_cutoff:)]])
      train_idx = np.hstack([train_idx, sortidx_split[shuffled[:train_cutoff]]])
      valid_idx = np.hstack([valid_idx, sortidx_split[shuffled[train_cutoff:valid_cutoff]]])
      test_idx = np.hstack([test_idx, sortidx_split[shuffled[valid_cutoff:]]])

    # Append remaining examples to train
    if sortidx.shape[0] > 0: np.hstack([train_idx, sortidx]) 
+18 −0
Original line number Diff line number Diff line
@@ -72,6 +72,24 @@ class TestSplitters(unittest.TestCase):
    assert len(valid_data) == 1
    assert len(test_data) == 1

  def test_singletask_stratified_split(self):
    """
    Test singletask SingletaskStratifiedSplitter class.
    """
    solubility_dataset = dc.data.tests.load_solubility_data()
    stratified_splitter = dc.splits.ScaffoldSplitter()
    train_data, valid_data, test_data = \
        stratified_splitter.train_valid_test_split(
            solubility_dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1)
    assert len(train_data) == 8
    assert len(valid_data) == 1
    assert len(test_data) == 1  

    merged_dataset = dc.data.DiskDataset.merge(
        [train_data, valid_data, test_data])
    assert sorted(merged_dataset.ids) == (
           sorted(solubility_dataset.ids))

  def test_singletask_random_k_fold_split(self):
    """
    Test singletask RandomSplitter class.