Commit 9ce96c35 authored by Yutong Zhao's avatar Yutong Zhao
Browse files

Preserve behavior of batch_size=None to iterate in shard_size

parent c6bb2f0a
Loading
Loading
Loading
Loading
+17 −11
Original line number Diff line number Diff line
@@ -646,8 +646,8 @@ class DiskDataset(Dataset):
    Parameters:
    -----------
    batch_size: int
      Number of elements in a batch. If batch_size is not None then the entire dataset
      across all shards will be returned as a single batch.
      Number of elements in a batch. If None, then it yields batches with size equal to the size
      of each individual shard.

    epoch: int
      Not used
@@ -681,9 +681,10 @@ class DiskDataset(Dataset):
      total_yield = 0

      if batch_size is None:
        batch_size = len(dataset)

        num_global_batches = num_shards
      else:
        num_global_batches = math.ceil(len(dataset) / batch_size)

      cur_global_batch = 0
      cur_shard = 0
      carry = None
@@ -706,7 +707,12 @@ class DiskDataset(Dataset):

        n_shard_samples = X.shape[0]
        cur_local_batch = 0
        num_local_batches = math.ceil(n_shard_samples / batch_size)
        if batch_size is None:
          shard_batch_size = n_shard_samples
        else:
          shard_batch_size = batch_size

        num_local_batches = math.ceil(n_shard_samples / shard_batch_size)

        if n_shard_samples == 0:
          continue
@@ -716,8 +722,8 @@ class DiskDataset(Dataset):
          sample_perm = np.arange(n_shard_samples)

        while cur_local_batch < num_local_batches:
          start = cur_local_batch * batch_size
          end = min(n_shard_samples, (cur_local_batch + 1) * batch_size)
          start = cur_local_batch * shard_batch_size
          end = min(n_shard_samples, (cur_local_batch + 1) * shard_batch_size)

          indices = range(start, end)
          perm_indices = sample_perm[indices]
@@ -735,16 +741,16 @@ class DiskDataset(Dataset):

          ids_b = ids[perm_indices]

          assert len(X_b) <= batch_size
          if len(X_b) < batch_size and cur_shard != num_shards - 1:
          assert len(X_b) <= shard_batch_size
          if len(X_b) < shard_batch_size and cur_shard != num_shards - 1:
            assert carry is None
            carry = [X_b, y_b, w_b, ids_b]
          else:

            # (ytz): this skips everything except possibly the last shard
            if pad_batches:
              (X_b, y_b, w_b, ids_b) = pad_batch(batch_size, X_b, y_b, w_b,
                                                 ids_b)
              (X_b, y_b, w_b, ids_b) = pad_batch(shard_batch_size, X_b, y_b,
                                                 w_b, ids_b)

            yield X_b, y_b, w_b, ids_b
            cur_global_batch += 1
+12 −14
Original line number Diff line number Diff line
@@ -457,8 +457,8 @@ class TestDatasets(unittest.TestCase):

  def test_disk_iterate_batch(self):

    all_batch_sizes = [32, 17, 11]
    all_shard_sizes = [[1, 1, 1, 1, 1], [31, 31, 31, 31, 31],
    all_batch_sizes = [None, 32, 17, 11]
    all_shard_sizes = [[7, 3, 12, 4, 5], [1, 1, 1, 1, 1], [31, 31, 31, 31, 31],
                       [21, 11, 41, 21, 51]]

    for idx in range(25):
@@ -514,10 +514,12 @@ class TestDatasets(unittest.TestCase):
        test_ids.append(d)

      if batch_size is None:
        assert len(test_Xs) == 1
        assert len(test_ys) == 1
        assert len(test_ws) == 1
        assert len(test_ids) == 1
        for idx, (tx, ty, tw,
                  tids) in enumerate(zip(test_Xs, test_ys, test_ws, test_ids)):
          assert len(tx) == shard_sizes[idx]
          assert len(ty) == shard_sizes[idx]
          assert len(tw) == shard_sizes[idx]
          assert len(tids) == shard_sizes[idx]

      test_Xs = np.concatenate(test_Xs, axis=0)
      test_ys = np.concatenate(test_ys, axis=0)
@@ -525,7 +527,7 @@ class TestDatasets(unittest.TestCase):
      test_ids = np.concatenate(test_ids, axis=0)

      if batch_size is None:
        assert bidx == 0
        assert bidx == len(shard_sizes) - 1
      else:
        assert bidx == math.ceil(total_size / batch_size) - 1

@@ -536,6 +538,7 @@ class TestDatasets(unittest.TestCase):

      # non-deterministic
      test_Xs, test_ys, test_ws, test_ids = [], [], [], []

      for bidx, (a, b, c, d) in enumerate(
          dataset.iterbatches(
              batch_size=batch_size, pad_batches=False, deterministic=False)):
@@ -545,19 +548,14 @@ class TestDatasets(unittest.TestCase):
        test_ws.append(c)
        test_ids.append(d)

      if batch_size is None:
        assert len(test_Xs) == 1
        assert len(test_ys) == 1
        assert len(test_ws) == 1
        assert len(test_ids) == 1

      # we don't know the order in which the shards are iterated in.
      test_Xs = np.concatenate(test_Xs, axis=0)
      test_ys = np.concatenate(test_ys, axis=0)
      test_ws = np.concatenate(test_ws, axis=0)
      test_ids = np.concatenate(test_ids, axis=0)

      if batch_size is None:
        assert bidx == 0
        assert bidx == len(shard_sizes) - 1
      else:
        assert bidx == math.ceil(total_size / batch_size) - 1