Commit 40265f8e authored by Yutong Zhao's avatar Yutong Zhao
Browse files

Add additional tests for pad_batches

parent 7fdfceee
Loading
Loading
Loading
Loading
+9 −5
Original line number Diff line number Diff line
@@ -646,7 +646,8 @@ class DiskDataset(Dataset):
    Parameters:
    -----------
    batch_size: int
      Number of elements in a batch
      Number of elements in a batch. If batch_size is not None then the entire dataset
      across all shards will be returned as a single batch.

    epoch: int
      Not used
@@ -663,7 +664,7 @@ class DiskDataset(Dataset):

    """

    def iterate(dataset):
    def iterate(dataset, batch_size):
      num_shards = dataset.get_number_shards()
      if not deterministic:
        shard_perm = np.random.permutation(num_shards)
@@ -679,7 +680,10 @@ class DiskDataset(Dataset):

      total_yield = 0

      num_global_batches = math.ceil(dataset.get_shape()[0][0]/batch_size)
      if batch_size is None:
        batch_size = len(dataset)

      num_global_batches = math.ceil(len(dataset)/batch_size)
      cur_global_batch = 0
      cur_shard = 0
      carry = None
@@ -738,14 +742,14 @@ class DiskDataset(Dataset):

            # (ytz): this skips everything except possibly the last shard
            if pad_batches:
              (X_b, y_b, w_b, ids_b) = pad_batches(shard_batch_size, X_b, y_b, w_b, ids_b)
              (X_b, y_b, w_b, ids_b) = pad_batch(batch_size, X_b, y_b, w_b, ids_b)

            yield X_b, y_b, w_b, ids_b
            cur_global_batch += 1
          cur_local_batch += 1
        cur_shard += 1

    return iterate(self)
    return iterate(self, batch_size)

  def itersamples(self):
    """Get an object that iterates over the samples in the dataset.
+125 −45
Original line number Diff line number Diff line
@@ -397,6 +397,67 @@ class TestDatasets(unittest.TestCase):
      batch_sizes.append(len(X))
    self.assertEqual([3, 3, 3, 1], batch_sizes)

  def test_disk_pad_batches(self):
    shard_sizes = [21, 11, 41, 21, 51]
    batch_size = 10

    all_Xs, all_ys, all_ws, all_ids = [], [], [], []

    def shard_generator():
      for sz in shard_sizes:
        X_b = np.random.rand(sz, 1)
        y_b = np.random.rand(sz, 1)
        w_b = np.random.rand(sz, 1)
        ids_b = np.random.rand(sz)

        all_Xs.append(X_b)
        all_ys.append(y_b)
        all_ws.append(w_b)
        all_ids.append(ids_b)

        yield X_b, y_b, w_b, ids_b

    dataset = dc.data.DiskDataset.create_dataset(
      shard_generator())

    all_Xs = np.concatenate(all_Xs, axis=0)
    all_ys = np.concatenate(all_ys, axis=0)
    all_ws = np.concatenate(all_ws, axis=0)
    all_ids = np.concatenate(all_ids, axis=0)

    test_Xs, test_ys, test_ws, test_ids = [], [], [], []
    for bidx, (a, b, c, d) in enumerate(dataset.iterbatches(
      batch_size=batch_size,
      pad_batches=True,
      deterministic=True)):

      test_Xs.append(a)
      test_ys.append(b)
      test_ws.append(c)
      test_ids.append(d)

    test_Xs = np.concatenate(test_Xs, axis=0)
    test_ys = np.concatenate(test_ys, axis=0)
    test_ws = np.concatenate(test_ws, axis=0)
    test_ids = np.concatenate(test_ids, axis=0)

    total_size = sum(shard_sizes)

    assert bidx == math.ceil(total_size/batch_size) - 1

    expected_batches = math.ceil(total_size/batch_size)*batch_size

    assert len(test_Xs) == expected_batches
    assert len(test_ys) == expected_batches
    assert len(test_ws) == expected_batches
    assert len(test_ids) == expected_batches

    np.testing.assert_array_equal(all_Xs, test_Xs[:total_size, :])
    np.testing.assert_array_equal(all_ys, test_ys[:total_size, :])
    np.testing.assert_array_equal(all_ws, test_ws[:total_size, :])
    np.testing.assert_array_equal(all_ids, test_ids[:total_size])


  def test_disk_iterate_batch(self):

    all_batch_sizes = [
@@ -410,19 +471,21 @@ class TestDatasets(unittest.TestCase):
      [21, 11, 41, 21, 51]
    ]

    for _ in range(50):
    for idx in range(25):
      shard_length = random.randint(1, 32)
      shard_sizes = []
      for _ in range(shard_length):
        shard_sizes.append(random.randint(1, 128))
      all_shard_sizes.append(shard_sizes)
      if idx == 0:
        # special case to test 
        all_batch_sizes.append(None)
      else:
        all_batch_sizes.append(random.randint(1, 256))
 
      batch_size = random.randint(1, 256)
      all_batch_sizes.append(batch_size)
    for shard_sizes, batch_size in zip(all_shard_sizes, all_batch_sizes):

    for shard_sizes in all_shard_sizes:

      All_Xs, All_ys, All_ws, All_ids = [], [], [], []
      all_Xs, all_ys, all_ws, all_ids = [], [], [], []

      def shard_generator():
        for sz in shard_sizes:
@@ -431,10 +494,10 @@ class TestDatasets(unittest.TestCase):
          w_b = np.random.rand(sz, 1)
          ids_b = np.random.rand(sz)

          All_Xs.append(X_b)
          All_ys.append(y_b)
          All_ws.append(w_b)
          All_ids.append(ids_b)
          all_Xs.append(X_b)
          all_ys.append(y_b)
          all_ws.append(w_b)
          all_ids.append(ids_b)

          yield X_b, y_b, w_b, ids_b

@@ -442,63 +505,80 @@ class TestDatasets(unittest.TestCase):
      dataset = dc.data.DiskDataset.create_dataset(
        shard_generator())

      All_Xs = np.concatenate(All_Xs, axis=0)
      All_ys = np.concatenate(All_ys, axis=0)
      All_ws = np.concatenate(All_ws, axis=0)
      All_ids = np.concatenate(All_ids, axis=0)
      all_Xs = np.concatenate(all_Xs, axis=0)
      all_ys = np.concatenate(all_ys, axis=0)
      all_ws = np.concatenate(all_ws, axis=0)
      all_ids = np.concatenate(all_ids, axis=0)

      total_size = sum(shard_sizes)

      assert dataset.X.shape[0] == total_size

      # deterministic
      Test_Xs, Test_ys, Test_ws, Test_ids = [], [], [], []
      test_Xs, test_ys, test_ws, test_ids = [], [], [], []
      for bidx, (a, b, c, d) in enumerate(dataset.iterbatches(
        batch_size=batch_size,
        pad_batches=False,
        deterministic=True)):

        Test_Xs.append(a)
        Test_ys.append(b)
        Test_ws.append(c)
        Test_ids.append(d)

      Test_Xs = np.concatenate(Test_Xs, axis=0)
      Test_ys = np.concatenate(Test_ys, axis=0)
      Test_ws = np.concatenate(Test_ws, axis=0)
      Test_ids = np.concatenate(Test_ids, axis=0)

        test_Xs.append(a)
        test_ys.append(b)
        test_ws.append(c)
        test_ids.append(d)

      if batch_size is None:
        assert len(test_Xs) == 1
        assert len(test_ys) == 1
        assert len(test_ws) == 1
        assert len(test_ids) == 1

      test_Xs = np.concatenate(test_Xs, axis=0)
      test_ys = np.concatenate(test_ys, axis=0)
      test_ws = np.concatenate(test_ws, axis=0)
      test_ids = np.concatenate(test_ids, axis=0)

      if batch_size is None:
        assert bidx == 0
      else:
        assert bidx == math.ceil(total_size/batch_size) - 1

      np.testing.assert_array_equal(All_Xs, Test_Xs)
      np.testing.assert_array_equal(All_ys, Test_ys)
      np.testing.assert_array_equal(All_ws, Test_ws)
      np.testing.assert_array_equal(All_ids, Test_ids)

      np.testing.assert_array_equal(all_Xs, test_Xs)
      np.testing.assert_array_equal(all_ys, test_ys)
      np.testing.assert_array_equal(all_ws, test_ws)
      np.testing.assert_array_equal(all_ids, test_ids)

      # non-deterministic
      Test_Xs, Test_ys, Test_ws, Test_ids = [], [], [], []
      test_Xs, test_ys, test_ws, test_ids = [], [], [], []
      for bidx, (a, b, c, d) in enumerate(dataset.iterbatches(
        batch_size=batch_size,
        pad_batches=False,
        deterministic=False)):

        Test_Xs.append(a)
        Test_ys.append(b)
        Test_ws.append(c)
        Test_ids.append(d)

      Test_Xs = np.concatenate(Test_Xs, axis=0)
      Test_ys = np.concatenate(Test_ys, axis=0)
      Test_ws = np.concatenate(Test_ws, axis=0)
      Test_ids = np.concatenate(Test_ids, axis=0)

        test_Xs.append(a)
        test_ys.append(b)
        test_ws.append(c)
        test_ids.append(d)

      if batch_size is None:
        assert len(test_Xs) == 1
        assert len(test_ys) == 1
        assert len(test_ws) == 1
        assert len(test_ids) == 1

      test_Xs = np.concatenate(test_Xs, axis=0)
      test_ys = np.concatenate(test_ys, axis=0)
      test_ws = np.concatenate(test_ws, axis=0)
      test_ids = np.concatenate(test_ids, axis=0)

      if batch_size is None:
        assert bidx == 0
      else:
        assert bidx == math.ceil(total_size/batch_size) - 1

      np.testing.assert_array_equal(np.sort(All_Xs, axis=0), np.sort(Test_Xs, axis=0))
      np.testing.assert_array_equal(np.sort(All_ys, axis=0), np.sort(Test_ys, axis=0))
      np.testing.assert_array_equal(np.sort(All_ws, axis=0), np.sort(Test_ws, axis=0))
      np.testing.assert_array_equal(np.sort(All_ids, axis=0), np.sort(Test_ids, axis=0))
      np.testing.assert_array_equal(np.sort(all_Xs, axis=0), np.sort(test_Xs, axis=0))
      np.testing.assert_array_equal(np.sort(all_ys, axis=0), np.sort(test_ys, axis=0))
      np.testing.assert_array_equal(np.sort(all_ws, axis=0), np.sort(test_ws, axis=0))
      np.testing.assert_array_equal(np.sort(all_ids, axis=0), np.sort(test_ids, axis=0))

  def test_numpy_iterate_batch_size(self):
    solubility_dataset = dc.data.tests.load_solubility_data()