Commit 36404369 authored by Yutong Zhao's avatar Yutong Zhao
Browse files

Fix pad_batches on None y or w

parent 9ce96c35
Loading
Loading
Loading
Loading
+22 −8
Original line number Diff line number Diff line
@@ -90,15 +90,24 @@ def pad_batch(batch_size, X_b, y_b, w_b, ids_b):
  else:
    X_out = np.zeros((batch_size,), dtype=X_b.dtype)

  num_tasks = y_b.shape[1]
  y_out = np.zeros((batch_size, num_tasks), dtype=y_b.dtype)
  w_out = np.zeros((batch_size, num_tasks), dtype=w_b.dtype)
  if y_b is None:
    y_out = None
  else:
    y_out = np.zeros((batch_size, y_b.shape[1]), dtype=y_b.dtype)

  if w_b is None:
    w_out = None
  else:
    w_out = np.zeros((batch_size, w_b.shape[1]), dtype=w_b.dtype)

  ids_out = np.zeros((batch_size,), dtype=ids_b.dtype)

  # Fill in batch arrays
  start = 0
  # Only the first set of copy will be counted in training loss
  if w_out is not None:
    w_out[start:start + num_samples] = w_b[:]

  while start < batch_size:
    num_left = batch_size - start
    if num_left < num_samples:
@@ -106,7 +115,10 @@ def pad_batch(batch_size, X_b, y_b, w_b, ids_b):
    else:
      increment = num_samples
    X_out[start:start + increment] = X_b[:increment]

    if y_out is not None:
      y_out[start:start + increment] = y_b[:increment]

    ids_out[start:start + increment] = ids_b[:increment]
    start += increment

@@ -683,7 +695,7 @@ class DiskDataset(Dataset):
      if batch_size is None:
        num_global_batches = num_shards
      else:
        num_global_batches = math.ceil(len(dataset) / batch_size)
        num_global_batches = math.ceil(dataset.get_shape()[0][0] / batch_size)

      cur_global_batch = 0
      cur_shard = 0
@@ -700,7 +712,9 @@ class DiskDataset(Dataset):

        if carry is not None:
          X = np.concatenate([carry[0], X], axis=0)
          if y is not None:
            y = np.concatenate([carry[1], y], axis=0)
          if w is not None:
            w = np.concatenate([carry[2], w], axis=0)
          ids = np.concatenate([carry[3], ids], axis=0)
          carry = None
+44 −0
Original line number Diff line number Diff line
@@ -455,6 +455,50 @@ class TestDatasets(unittest.TestCase):
    np.testing.assert_array_equal(all_ws, test_ws[:total_size, :])
    np.testing.assert_array_equal(all_ids, test_ids[:total_size])

  def test_disk_iterate_y_w_None(self):
    shard_sizes = [21, 11, 41, 21, 51]
    batch_size = 10

    all_Xs, all_ys, all_ws, all_ids = [], [], [], []

    def shard_generator():
      for sz in shard_sizes:
        X_b = np.random.rand(sz, 1)
        ids_b = np.random.rand(sz)

        all_Xs.append(X_b)
        all_ids.append(ids_b)

        yield X_b, None, None, ids_b

    dataset = dc.data.DiskDataset.create_dataset(shard_generator())

    all_Xs = np.concatenate(all_Xs, axis=0)
    all_ids = np.concatenate(all_ids, axis=0)

    test_Xs, test_ids = [], []
    for bidx, (a, _, _, d) in enumerate(
        dataset.iterbatches(
            batch_size=batch_size, pad_batches=True, deterministic=True)):

      test_Xs.append(a)
      test_ids.append(d)

    test_Xs = np.concatenate(test_Xs, axis=0)
    test_ids = np.concatenate(test_ids, axis=0)

    total_size = sum(shard_sizes)

    assert bidx == math.ceil(total_size / batch_size) - 1

    expected_batches = math.ceil(total_size / batch_size) * batch_size

    assert len(test_Xs) == expected_batches
    assert len(test_ids) == expected_batches

    np.testing.assert_array_equal(all_Xs, test_Xs[:total_size, :])
    np.testing.assert_array_equal(all_ids, test_ids[:total_size])

  def test_disk_iterate_batch(self):

    all_batch_sizes = [None, 32, 17, 11]