Unverified Commit 6d480814 authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #2190 from peastman/stratified

Reimplemented RandomStratifiedSplitter
parents ea0fe592 dcfac2aa
Loading
Loading
Loading
Loading
+90 −191
Original line number Diff line number Diff line
@@ -448,220 +448,119 @@ class RandomStratifiedSplitter(Splitter):
  """RandomStratified Splitter class.

  For sparse multitask datasets, a standard split offers no guarantees
  that the splits will have any activate compounds. This class guarantees
  that each task will have a proportional split of the activates in a
  split. To do this, a ragged split is performed with different numbers
  of compounds taken from each task. Thus, the length of the split arrays
  may exceed the split of the original array. That said, no datapoint is
  copied to more than one split, so correctness is still ensured.

  TODO(rbharath): This splitter should be refactored to match style of
  other splitter classes.
  that the splits will have any active compounds. This class tries to
  arrange that each split has a proportional number of the actives for each
  task. This is strictly guaranteed only for single-task datasets, but for
  sparse multitask datasets it usually manages to produces a fairly accurate
  division of the actives for each task.

  Notes
  -----
  This splitter is only valid for boolean label data.
  This splitter is primarily designed for boolean labeled data. It considers
  only whether a label is zero or non-zero. When labels can take on multiple
  non-zero values, it does not try to give each split a proportional fraction
  of the samples with each value.
  """

  def get_task_split_indices(self, y: np.ndarray, w: np.ndarray,
                             frac_split: float) -> List[int]:
    """Returns num datapoints needed per task to split properly."""
    w_present = (w != 0)
    y_present = y * w_present

    # Compute number of actives needed per task.
    task_actives = np.sum(y_present, axis=0)
    task_split_actives = (frac_split * task_actives).astype(int)

    # loop through each column and obtain index required to splice out for
    # required fraction of hits
    split_indices = []
    n_tasks = np.shape(y)[1]
    for task in range(n_tasks):
      actives_count = task_split_actives[task]
      cum_task_actives = np.cumsum(y_present[:, task])
      # Find the first index where the cumulative number of actives equals
      # the actives_count
      split_index = np.amin(np.where(cum_task_actives >= actives_count)[0])
      # Note that np.where tells us last index required to exceed
      # actives_count, so we actually want the following location
      split_indices.append(split_index + 1)
    return split_indices

  # TODO(rbharath): Refactor this split method to match API of other
  # splits (or potentially refactor those to match this).
  def split(  # type: ignore [override]
      self,
      dataset: Dataset,
      frac_split: float,
      split_dirs: Optional[List[str]] = None
  ) -> Tuple[Dataset, Optional[Dataset]]:
    """
    Method that does bulk of splitting dataset.
    """
    if split_dirs is not None:
      assert len(split_dirs) == 2
    else:
      split_dirs = [tempfile.mkdtemp(), tempfile.mkdtemp()]

    # Handle edge case where frac_split is 1
    if frac_split == 1:
      dataset_1 = DiskDataset.from_numpy(dataset.X, dataset.y, dataset.w,
                                         dataset.ids)
      dataset_2 = None
      return dataset_1, dataset_2
    X, y, w, ids = randomize_arrays((dataset.X, dataset.y, dataset.w,
                                     dataset.ids))
    if len(y.shape) == 1:
      y = np.expand_dims(y, 1)
    if len(w.shape) == 1:
      w = np.expand_dims(w, 1)
    split_indices = self.get_task_split_indices(y, w, frac_split)

    # Create weight matrices fpor two haves.
    w_1, w_2 = np.zeros_like(w), np.zeros_like(w)
    for task, split_index in enumerate(split_indices):
      # copy over up to required index for weight first_split
      w_1[:split_index, task] = w[:split_index, task]
      w_2[split_index:, task] = w[split_index:, task]

    # check out if any rows in either w_1 or w_2 are just zeros
    rows_1 = w_1.any(axis=1)
    X_1, y_1, w_1, ids_1 = X[rows_1], y[rows_1], w_1[rows_1], ids[rows_1]
    dataset_1 = DiskDataset.from_numpy(X_1, y_1, w_1, ids_1)

    rows_2 = w_2.any(axis=1)
    X_2, y_2, w_2, ids_2 = X[rows_2], y[rows_2], w_2[rows_2], ids[rows_2]
    dataset_2 = DiskDataset.from_numpy(X_2, y_2, w_2, ids_2)

    return dataset_1, dataset_2

  # FIXME: Signature of "train_valid_test_split" incompatible with supertype "Splitter"
  def train_valid_test_split(  # type: ignore [override]
      self,
  def split(self,
            dataset: Dataset,
      train_dir: Optional[str] = None,
      valid_dir: Optional[str] = None,
      test_dir: Optional[str] = None,
            frac_train: float = 0.8,
            frac_valid: float = 0.1,
            frac_test: float = 0.1,
            seed: Optional[int] = None,
      log_every_n: int = 1000,
      **kwargs) -> Union[Tuple[Dataset, None, None], Tuple[Dataset, Dataset,
                                                           Optional[Dataset]]]:
    """ Splits self into train/validation/test sets.

    Most splitters use the superclass implementation
    `Splitter.train_valid_test_split` but this class has to override the
    implementation to deal with potentially ragged splits.
            log_every_n: Optional[int] = None) -> Tuple:
    """Return indices for specified split

    Parameters
    ----------
    dataset: Dataset
    dataset: dc.data.Dataset
      Dataset to be split.
    train_dir: str, optional (default None)
      If specified, the directory in which the generated
      training dataset should be stored. This is only
      considered if `isinstance(dataset, dc.data.DiskDataset)`
    valid_dir: str, optional (default None)
      If specified, the directory in which the generated
      valid dataset should be stored. This is only
      considered if `isinstance(dataset, dc.data.DiskDataset)`
      is True.
    test_dir: str, optional (default None)
      If specified, the directory in which the generated
      test dataset should be stored. This is only
      considered if `isinstance(dataset, dc.data.DiskDataset)`
      is True.
    seed: int, optional (default None)
      Random seed to use.
    frac_train: float, optional (default 0.8)
      The fraction of data to be used for the training split.
    frac_valid: float, optional (default 0.1)
      The fraction of data to be used for the validation split.
    frac_test: float, optional (default 0.1)
      The fraction of data to be used for the test split.
    seed: int, optional (default None)
      Random seed to use.
    log_every_n: int, optional (default 1000)
    log_every_n: int, optional (default None)
      Controls the logger by dictating how often logger outputs
      will be produced.

    Returns
    -------
    Tuple[Dataset, Optional[Dataset], Optional[Dataset]]
      A tuple of train, valid and test datasets as dc.data.Dataset objects.
      In some cases, valid or test dataset is None.
    Tuple
      A tuple `(train_inds, valid_inds, test_inds)` of the indices (integers) for
      the various splits.
    """
    if train_dir is None:
      train_dir = tempfile.mkdtemp()
    if valid_dir is None:
      valid_dir = tempfile.mkdtemp()
    if test_dir is None:
      test_dir = tempfile.mkdtemp()
    rem_dir = tempfile.mkdtemp()
    train_dataset, rem_dataset = self.split(dataset, frac_train,
                                            [train_dir, rem_dir])

    # calculate percent split for valid (out of test and valid)
    if frac_valid + frac_test > 0:
      valid_percentage = frac_valid / (frac_valid + frac_test)
    else:
      return train_dataset, None, None
    # split remaining data into valid and test, treating sub test set also as sparse
    # FIXME: Argument 1 to "split" of "RandomStratifiedSplitter" has incompatible type
    # "Optional[Dataset]"; expected "Dataset"
    valid_dataset, test_dataset = self.split(
        rem_dataset,  # type: ignore
        valid_percentage,
        [valid_dir, test_dir])
    y_present = (dataset.y != 0) * (dataset.w != 0)
    if len(y_present.shape) == 1:
      y_present = np.expand_dims(y_present, 1)
    elif len(y_present.shape) > 2:
      raise ValueError(
          'RandomStratifiedSplitter cannot be applied when y has more than two dimensions'
      )
    if seed is not None:
      np.random.seed(seed)

    return train_dataset, valid_dataset, test_dataset
    # Figure out how many positive samples we want for each task in each dataset.

  # FIXME: Signature of "k_fold_split" incompatible with supertype "Splitter"
  def k_fold_split(  # type: ignore [override]
      self,
      dataset: Dataset,
      k: int,
      directories: Optional[List[str]] = None,
      **kwargs) -> List[Dataset]:
    """Needs custom implementation due to ragged splits for stratification.

    Parameters
    ----------
    dataset: Dataset
      Dataset to be split.
    k: int
      Number of folds to split `dataset` into.
    directories: List[str], optional (default None)
      List of length k filepaths to save the result disk-datasets.

    Returns
    -------
    fold_datasets: List[Dataset]
      List of dc.data.Dataset objects
    """
    logger.info("Computing K-fold split")
    if directories is None:
      directories = [tempfile.mkdtemp() for _ in range(k)]
    else:
      assert len(directories) == k
    fold_datasets = []
    # rem_dataset is remaining portion of dataset
    rem_dataset: Optional[Dataset] = dataset
    for fold in range(k):
      # Note starts as 1/k since fold starts at 0. Ends at 1 since fold goes up
      # to k-1.
      frac_fold = 1. / (k - fold)
      fold_dir = directories[fold]
      rem_dir = tempfile.mkdtemp()
      # FIXME: Argument 1 to "split" of "RandomStratifiedSplitter" has incompatible type
      # "Optional[Dataset]"; expected "Dataset"
      fold_dataset, rem_dataset = self.split(
          rem_dataset,  # type: ignore
          frac_fold,
          [fold_dir, rem_dir])
      fold_datasets.append(fold_dataset)
    return fold_datasets
    n_tasks = y_present.shape[1]
    indices_for_task = [
        np.random.permutation(np.nonzero(y_present[:, i])[0])
        for i in range(n_tasks)
    ]
    count_for_task = np.array([len(x) for x in indices_for_task])
    train_target = np.round(frac_train * count_for_task).astype(np.int)
    valid_target = np.round(frac_valid * count_for_task).astype(np.int)
    test_target = np.round(frac_test * count_for_task).astype(np.int)

    # Assign the positive samples to datasets.  Since a sample may be positive
    # on more than one task, we need to keep track of the effect of each added
    # sample on each task.  To try to keep everything balanced, we cycle through
    # tasks, assigning one positive sample for each one.

    train_counts = np.zeros(n_tasks, np.int)
    valid_counts = np.zeros(n_tasks, np.int)
    test_counts = np.zeros(n_tasks, np.int)
    set_target = [train_target, valid_target, test_target]
    set_counts = [train_counts, valid_counts, test_counts]
    set_inds: List[List[int]] = [[], [], []]
    assigned = set()
    max_count = np.max(count_for_task)
    for i in range(max_count):
      for task in range(n_tasks):
        indices = indices_for_task[task]
        if i < len(indices) and indices[i] not in assigned:
          # We have a sample that hasn't been assigned yet.  Assign it to
          # whichever set currently has the lowest fraction of its target for
          # this task.

          index = indices[i]
          set_frac = [
              1 if set_target[i][task] == 0 else
              set_counts[i][task] / set_target[i][task] for i in range(3)
          ]
          s = np.argmin(set_frac)
          set_inds[s].append(index)
          assigned.add(index)
          set_counts[s] += y_present[index]

    # The remaining samples are negative for all tasks.  Add them to fill out
    # each set to the correct total number.

    n_samples = y_present.shape[0]
    set_size = [
        int(np.round(n_samples * f))
        for f in (frac_train, frac_valid, frac_test)
    ]
    s = 0
    for i in np.random.permutation(range(n_samples)):
      if i not in assigned:
        while s < 2 and len(set_inds[s]) >= set_size[s]:
          s += 1
        set_inds[s].append(i)
    return tuple(sorted(x) for x in set_inds)


class SingletaskStratifiedSplitter(Splitter):
+29 −39
Original line number Diff line number Diff line
@@ -320,16 +320,20 @@ class TestSplitter(unittest.TestCase):
    n_positives = 20
    n_tasks = 1

    X = np.ones(n_samples)
    y = np.zeros((n_samples, n_tasks))
    y[:n_positives] = 1
    w = np.ones((n_samples, n_tasks))
    dataset = dc.data.NumpyDataset(X, y, w)
    stratified_splitter = dc.splits.RandomStratifiedSplitter()
    column_indices = stratified_splitter.get_task_split_indices(
        y, w, frac_split=.5)
    train, valid, test = stratified_splitter.split(dataset, 0.5, 0, 0.5)

    split_index = column_indices[0]
    # The split index should partition dataset in half.
    assert np.count_nonzero(y[:split_index]) == 10
    assert len(train) == 50
    assert len(valid) == 0
    assert len(test) == 50
    assert np.count_nonzero(y[train]) == 10
    assert np.count_nonzero(y[test]) == 10

  def test_singletask_stratified_column_indices_mask(self):
    """
@@ -341,22 +345,22 @@ class TestSplitter(unittest.TestCase):
    n_tasks = 1

    # Test case where some weights are zero (i.e. masked)
    X = np.ones(n_samples)
    y = np.zeros((n_samples, n_tasks))
    y[:n_positives] = 1
    w = np.ones((n_samples, n_tasks))
    # Set half the positives to have zero weight
    w[:n_positives // 2] = 0
    dataset = dc.data.NumpyDataset(X, y, w)

    stratified_splitter = dc.splits.RandomStratifiedSplitter()
    column_indices = stratified_splitter.get_task_split_indices(
        y, w, frac_split=.5)
    train, valid, test = stratified_splitter.split(dataset, 0.5, 0, 0.5)

    split_index = column_indices[0]
    # There are 10 nonzero actives.
    # The split index should partition this into half, so expect 5
    w_present = (w != 0)
    y_present = y * w_present
    assert np.count_nonzero(y_present[:split_index]) == 5
    assert np.count_nonzero(y_present[train]) == 5

  def test_multitask_stratified_column_indices(self):
    """
@@ -365,18 +369,19 @@ class TestSplitter(unittest.TestCase):
    n_samples = 100
    n_tasks = 10
    p = .05  # proportion actives
    X = np.ones(n_samples)
    y = np.random.binomial(1, p, size=(n_samples, n_tasks))
    w = np.ones((n_samples, n_tasks))
    dataset = dc.data.NumpyDataset(X, y, w)

    stratified_splitter = dc.splits.RandomStratifiedSplitter()
    split_indices = stratified_splitter.get_task_split_indices(
        y, w, frac_split=.5)
    train, valid, test = stratified_splitter.split(dataset, 0.5, 0, 0.5)

    for task in range(n_tasks):
      split_index = split_indices[task]
      task_actives = np.count_nonzero(y[:, task])
      # The split index should partition dataset in half.
      assert np.count_nonzero(y[:split_index, task]) == int(task_actives / 2)
      # The split index should partition the positives for each task roughly in half.
      target = task_actives / 2
      assert target - 2 <= np.count_nonzero(y[train, task]) <= target + 2

  def test_multitask_stratified_column_indices_masked(self):
    """
@@ -385,23 +390,24 @@ class TestSplitter(unittest.TestCase):
    n_samples = 200
    n_tasks = 10
    p = .05  # proportion actives
    X = np.ones(n_samples)
    y = np.random.binomial(1, p, size=(n_samples, n_tasks))
    w = np.ones((n_samples, n_tasks))
    # Mask half the examples
    w[:n_samples // 2] = 0
    dataset = dc.data.NumpyDataset(X, y, w)

    stratified_splitter = dc.splits.RandomStratifiedSplitter()
    split_indices = stratified_splitter.get_task_split_indices(
        y, w, frac_split=.5)
    train, valid, test = stratified_splitter.split(dataset, 0.5, 0, 0.5)

    w_present = (w != 0)
    y_present = y * w_present
    for task in range(n_tasks):
      split_index = split_indices[task]
      task_actives = np.count_nonzero(y_present[:, task])
      target = task_actives / 2
      # The split index should partition dataset in half.
      assert np.count_nonzero(y_present[:split_index, task]) == int(
          task_actives / 2)
      assert target - 1 <= np.count_nonzero(
          y_present[train, task]) <= target + 1

  def test_random_stratified_split(self):
    """
@@ -422,7 +428,10 @@ class TestSplitter(unittest.TestCase):
    dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids)

    stratified_splitter = dc.splits.RandomStratifiedSplitter()
    dataset_1, dataset_2 = stratified_splitter.split(dataset, frac_split=.5)
    dataset_1, dataset_2 = stratified_splitter.train_test_split(
        dataset, frac_train=.5)
    print(dataset_1.get_shape())
    print(dataset_2.get_shape())

    # Should have split cleanly in half (picked random seed to ensure this)
    assert len(dataset_1) == 10
@@ -483,6 +492,7 @@ class TestSplitter(unittest.TestCase):

    K = 5
    fold_datasets = stratified_splitter.k_fold_split(dataset, K)
    fold_datasets = [f[1] for f in fold_datasets]

    for fold in range(K):
      fold_dataset = fold_datasets[fold]
@@ -546,26 +556,6 @@ class TestSplitter(unittest.TestCase):
    assert len(valid_data) == 1
    assert len(test_data) == 1

  def test_stratified_multitask_split(self):
    """
    Test multitask RandomStratifiedSplitter class
    """
    # sparsity is determined by number of w weights that are 0 for a given
    # task structure of w np array is such that each row corresponds to a
    # sample. The loaded sparse dataset has many rows with only zeros
    sparse_dataset = load_sparse_multitask_dataset()

    stratified_splitter = dc.splits.RandomStratifiedSplitter()
    datasets = stratified_splitter.train_valid_test_split(
        sparse_dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1)
    train_data, valid_data, test_data = datasets

    for dataset_index, dataset in enumerate(datasets):
      w = dataset.w
      # verify that there are no rows (samples) in weights matrix w
      # that have no hits.
      assert len(np.where(w.any(axis=1) == 0)[0]) == 0

  def test_specified_split(self):

    solubility_dataset = load_solubility_data()