Commit c6bb2f0a authored by Yutong Zhao's avatar Yutong Zhao
Browse files

yapf

parent 40265f8e
Loading
Loading
Loading
Loading
+9 −7
Original line number Diff line number Diff line
@@ -638,7 +638,7 @@ class DiskDataset(Dataset):
                  deterministic=False,
                  pad_batches=False):
    """ Get an object that iterates over minibatches from the dataset. It is guaranteed
    that the number of batches returned is math.ceil(dataset.get_shape()[0][0]/batch_size).
    that the number of batches returned is math.ceil(len(dataset)/batch_size).
    
    Each minibatch is returned as a tuple of four numpy arrays: (X, y, w, ids).

@@ -692,7 +692,8 @@ class DiskDataset(Dataset):

        X, y, w, ids = next_shard.get()
        if cur_shard < num_shards - 1:
          next_shard = pool.apply_async(dataset.get_shard, (shard_perm[cur_shard + 1],))
          next_shard = pool.apply_async(dataset.get_shard,
                                        (shard_perm[cur_shard + 1],))
        else:
          pool.close()

@@ -742,7 +743,8 @@ class DiskDataset(Dataset):

            # (ytz): this skips everything except possibly the last shard
            if pad_batches:
              (X_b, y_b, w_b, ids_b) = pad_batch(batch_size, X_b, y_b, w_b, ids_b)
              (X_b, y_b, w_b, ids_b) = pad_batch(batch_size, X_b, y_b, w_b,
                                                 ids_b)

            yield X_b, y_b, w_b, ids_b
            cur_global_batch += 1
+30 −39
Original line number Diff line number Diff line
@@ -417,8 +417,7 @@ class TestDatasets(unittest.TestCase):

        yield X_b, y_b, w_b, ids_b

    dataset = dc.data.DiskDataset.create_dataset(
      shard_generator())
    dataset = dc.data.DiskDataset.create_dataset(shard_generator())

    all_Xs = np.concatenate(all_Xs, axis=0)
    all_ys = np.concatenate(all_ys, axis=0)
@@ -426,10 +425,9 @@ class TestDatasets(unittest.TestCase):
    all_ids = np.concatenate(all_ids, axis=0)

    test_Xs, test_ys, test_ws, test_ids = [], [], [], []
    for bidx, (a, b, c, d) in enumerate(dataset.iterbatches(
      batch_size=batch_size,
      pad_batches=True,
      deterministic=True)):
    for bidx, (a, b, c, d) in enumerate(
        dataset.iterbatches(
            batch_size=batch_size, pad_batches=True, deterministic=True)):

      test_Xs.append(a)
      test_ys.append(b)
@@ -457,19 +455,11 @@ class TestDatasets(unittest.TestCase):
    np.testing.assert_array_equal(all_ws, test_ws[:total_size, :])
    np.testing.assert_array_equal(all_ids, test_ids[:total_size])


  def test_disk_iterate_batch(self):

    all_batch_sizes = [
      32,
      17,
      11
    ]
    all_shard_sizes = [
      [1, 1, 1, 1, 1],
      [31, 31, 31, 31, 31],
      [21, 11, 41, 21, 51]
    ]
    all_batch_sizes = [32, 17, 11]
    all_shard_sizes = [[1, 1, 1, 1, 1], [31, 31, 31, 31, 31],
                       [21, 11, 41, 21, 51]]

    for idx in range(25):
      shard_length = random.randint(1, 32)
@@ -501,9 +491,7 @@ class TestDatasets(unittest.TestCase):

          yield X_b, y_b, w_b, ids_b


      dataset = dc.data.DiskDataset.create_dataset(
        shard_generator())
      dataset = dc.data.DiskDataset.create_dataset(shard_generator())

      all_Xs = np.concatenate(all_Xs, axis=0)
      all_ys = np.concatenate(all_ys, axis=0)
@@ -516,10 +504,9 @@ class TestDatasets(unittest.TestCase):

      # deterministic
      test_Xs, test_ys, test_ws, test_ids = [], [], [], []
      for bidx, (a, b, c, d) in enumerate(dataset.iterbatches(
        batch_size=batch_size,
        pad_batches=False,
        deterministic=True)):
      for bidx, (a, b, c, d) in enumerate(
          dataset.iterbatches(
              batch_size=batch_size, pad_batches=False, deterministic=True)):

        test_Xs.append(a)
        test_ys.append(b)
@@ -549,10 +536,9 @@ class TestDatasets(unittest.TestCase):

      # non-deterministic
      test_Xs, test_ys, test_ws, test_ids = [], [], [], []
      for bidx, (a, b, c, d) in enumerate(dataset.iterbatches(
        batch_size=batch_size,
        pad_batches=False,
        deterministic=False)):
      for bidx, (a, b, c, d) in enumerate(
          dataset.iterbatches(
              batch_size=batch_size, pad_batches=False, deterministic=False)):

        test_Xs.append(a)
        test_ys.append(b)
@@ -575,10 +561,14 @@ class TestDatasets(unittest.TestCase):
      else:
        assert bidx == math.ceil(total_size / batch_size) - 1

      np.testing.assert_array_equal(np.sort(all_Xs, axis=0), np.sort(test_Xs, axis=0))
      np.testing.assert_array_equal(np.sort(all_ys, axis=0), np.sort(test_ys, axis=0))
      np.testing.assert_array_equal(np.sort(all_ws, axis=0), np.sort(test_ws, axis=0))
      np.testing.assert_array_equal(np.sort(all_ids, axis=0), np.sort(test_ids, axis=0))
      np.testing.assert_array_equal(
          np.sort(all_Xs, axis=0), np.sort(test_Xs, axis=0))
      np.testing.assert_array_equal(
          np.sort(all_ys, axis=0), np.sort(test_ys, axis=0))
      np.testing.assert_array_equal(
          np.sort(all_ws, axis=0), np.sort(test_ws, axis=0))
      np.testing.assert_array_equal(
          np.sort(all_ids, axis=0), np.sort(test_ids, axis=0))

  def test_numpy_iterate_batch_size(self):
    solubility_dataset = dc.data.tests.load_solubility_data()
@@ -592,5 +582,6 @@ class TestDatasets(unittest.TestCase):
      batch_sizes.append(len(X))
    self.assertEqual([3, 3, 3, 1], batch_sizes)


if __name__ == "__main__":
  unittest.main()