Commit 69fac4b4 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Added in more tests for shuffle and cleaned up types

parent b8b7ad31
Loading
Loading
Loading
Loading
+14 −50
Original line number Diff line number Diff line
@@ -1863,7 +1863,7 @@ class DiskDataset(Dataset):
    time2 = time.time()
    logger.info("TIMING: sparse_shuffle took %0.3f s" % (time2 - time1))

  def complete_shuffle(self, data_dir: Optional[str] = None) -> "DiskDataset":
  def complete_shuffle(self, data_dir: Optional[str] = None) -> Dataset:
    """Completely shuffle across all data, across all shards.

    Note
@@ -1893,49 +1893,7 @@ class DiskDataset(Dataset):
    N = len(self)
    perm = np.random.permutation(N)
    shard_size = self.get_shard_size()

    def generator():
      start = 0
      shard_num = 0
      while start < N:
        logger.info("Constructing shard %d" % shard_num)
        end = min(start + shard_size, N)
        shard_indices = perm[start:end]
        # Note that this is in sorted order which doesn't respect the random
        # permutation.
        shard_dataset = self.select(shard_indices, output_numpy_dataset=True)
        # One bit of trickiness here is that select() will return in sorted
        # order. For example, suppose we'd like these elements in our permuted
        # shard:
        #
        # [12, 234, 1, 4]
        #
        # Then select would return elements in order
        #
        # [1, 4, 12, 234]
        #
        sorted_indices = np.array(sorted(shard_indices))
        reverted_indices = np.array(
            # We know there's only one match for np.where since this is a
            # permutation, so the [0][0] pulls out the exact match location.
            [
                np.where(sorted_indices == orig_index)[0][0]
                for orig_index in shard_indices
            ])
        # Let's pull out shard elements
        shard_X, shard_y, shard_w, shard_ids = (shard_dataset.X,
                                                shard_dataset.y,
                                                shard_dataset.w,
                                                shard_dataset.ids)

        yield (shard_X[reverted_indices], shard_y[reverted_indices],
               shard_w[reverted_indices], shard_ids[reverted_indices])

        start = end
        shard_num += 1

    return DiskDataset.create_dataset(
        generator(), data_dir=data_dir, tasks=self.get_task_names())
    return self.select(perm, data_dir, self.get_shard_size())

  def shuffle_each_shard(self,
                         shard_basenames: Optional[List[str]] = None) -> None:
@@ -2095,7 +2053,7 @@ class DiskDataset(Dataset):
             indices: Sequence[int],
             select_dir: Optional[str] = None,
             select_shard_size: Optional[int] = None,
             output_numpy_dataset: Optional[bool] = False) -> "DiskDataset":
             output_numpy_dataset: Optional[bool] = False) -> Dataset:
    """Creates a new dataset from a selection of indices from self.

    Examples
@@ -2112,7 +2070,7 @@ class DiskDataset(Dataset):
    indices: list
      List of indices to select.
    select_dir: Optional[str], (default None)
      Path to new directory that the selected indices will be copied
      Path to new directory that the selected samples will be copied
      to.
    select_shard_size: Optional[int], (default None)
      If specified, the shard-size to use for output selected `DiskDataset`.
@@ -2126,7 +2084,7 @@ class DiskDataset(Dataset):
    Returns
    -------
    DiskDataset
      Contains selected indices.
      Contains selected samples.
    """
    if output_numpy_dataset and (select_dir is not None or
                                 select_shard_size is not None):
@@ -2163,7 +2121,10 @@ class DiskDataset(Dataset):
    # source datasets to select out the shard indices from that  source shard
    def generator():
      start = 0
      select_shard_num = 0
      while start < N:
        logger.info(
            "Constructing selection output shard %d" % (select_shard_num + 1))
        end = min(start + select_shard_size, N)
        select_shard_indices = indices[start:end]
        sorted_indices = np.array(sorted(select_shard_indices)).astype(int)
@@ -2171,14 +2132,16 @@ class DiskDataset(Dataset):
        Xs, ys, ws, ids_s = [], [], [], []
        count, indices_count = 0, 0
        for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
          logger.info("Selecting from shard %d/%d" % (shard_num, n_shards))
          logger.info(
              "Selecting from input shard %d/%d for selection output shard %d" %
              (shard_num + 1, n_shards, select_shard_num + 1))
          shard_len = len(X)
          # Find indices which rest in this shard
          num_shard_elts = 0
          while sorted_indices[indices_count +
                               num_shard_elts] < count + shard_len:
            num_shard_elts += 1
            if indices_count + num_shard_elts >= len(sorted_indices):
            if (indices_count + num_shard_elts) >= len(sorted_indices):
              break
          # Need to offset indices to fit within shard_size
          shard_inds = sorted_indices[indices_count:indices_count +
@@ -2201,7 +2164,7 @@ class DiskDataset(Dataset):
          indices_count += num_shard_elts
          count += shard_len
          # Break when all indices have been used up already
          if indices_count >= len(indices):
          if indices_count >= len(sorted_indices):
            break
        # Note these will be in the sorted order
        X = np.concatenate(Xs, axis=0)
@@ -2222,6 +2185,7 @@ class DiskDataset(Dataset):
            reverted_indices], ids[reverted_indices]
        yield (X, y, w, ids)
        start = end
        select_shard_num += 1

    if not output_numpy_dataset:
      return DiskDataset.create_dataset(
+36 −3
Original line number Diff line number Diff line
@@ -20,6 +20,14 @@ def test_complete_shuffle_one_shard():
  assert shuffled.X.shape == dataset.X.shape
  assert shuffled.y.shape == dataset.y.shape
  assert shuffled.w.shape == dataset.w.shape
  original_indices = dict((id, i) for i, id in enumerate(dataset.ids))
  shuffled_indices = dict((id, i) for i, id in enumerate(shuffled.ids))
  for id in dataset.ids:
    i = original_indices[id]
    j = shuffled_indices[id]
    assert np.array_equal(dataset.X[i], shuffled.X[j])
    assert np.array_equal(dataset.y[i], shuffled.y[j])
    assert np.array_equal(dataset.w[i], shuffled.w[j])


def test_complete_shuffle_multiple_shard():
@@ -34,6 +42,14 @@ def test_complete_shuffle_multiple_shard():
  assert shuffled.X.shape == dataset.X.shape
  assert shuffled.y.shape == dataset.y.shape
  assert shuffled.w.shape == dataset.w.shape
  original_indices = dict((id, i) for i, id in enumerate(dataset.ids))
  shuffled_indices = dict((id, i) for i, id in enumerate(shuffled.ids))
  for id in dataset.ids:
    i = original_indices[id]
    j = shuffled_indices[id]
    assert np.array_equal(dataset.X[i], shuffled.X[j])
    assert np.array_equal(dataset.y[i], shuffled.y[j])
    assert np.array_equal(dataset.w[i], shuffled.w[j])


def test_complete_shuffle_multiple_shard_uneven():
@@ -48,6 +64,14 @@ def test_complete_shuffle_multiple_shard_uneven():
  assert shuffled.X.shape == dataset.X.shape
  assert shuffled.y.shape == dataset.y.shape
  assert shuffled.w.shape == dataset.w.shape
  original_indices = dict((id, i) for i, id in enumerate(dataset.ids))
  shuffled_indices = dict((id, i) for i, id in enumerate(shuffled.ids))
  for id in dataset.ids:
    i = original_indices[id]
    j = shuffled_indices[id]
    assert np.array_equal(dataset.X[i], shuffled.X[j])
    assert np.array_equal(dataset.y[i], shuffled.y[j])
    assert np.array_equal(dataset.w[i], shuffled.w[j])


def test_complete_shuffle():
@@ -66,10 +90,11 @@ def test_complete_shuffle():
                                      dataset.ids)
  orig_len = len(dataset)

  dataset = dataset.complete_shuffle()
  X_new, y_new, w_new, new_ids = (dataset.X, dataset.y, dataset.w, dataset.ids)
  shuffled = dataset.complete_shuffle()
  X_new, y_new, w_new, new_ids = (shuffled.X, shuffled.y, shuffled.w,
                                  shuffled.ids)

  assert len(dataset) == orig_len
  assert len(shuffled) == orig_len
  # The shuffling should have switched up the ordering
  assert not np.array_equal(orig_ids, new_ids)
  # But all the same entries should still be present
@@ -78,6 +103,14 @@ def test_complete_shuffle():
  assert X_orig.shape == X_new.shape
  assert y_orig.shape == y_new.shape
  assert w_orig.shape == w_new.shape
  original_indices = dict((id, i) for i, id in enumerate(dataset.ids))
  shuffled_indices = dict((id, i) for i, id in enumerate(shuffled.ids))
  for id in dataset.ids:
    i = original_indices[id]
    j = shuffled_indices[id]
    assert np.array_equal(dataset.X[i], shuffled.X[j])
    assert np.array_equal(dataset.y[i], shuffled.y[j])
    assert np.array_equal(dataset.w[i], shuffled.w[j])


def test_sparse_shuffle():