Commit e4325547 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Add test for datasets with uneven shards

parent b28229f0
Loading
Loading
Loading
Loading
+14 −0
Original line number Diff line number Diff line
@@ -36,6 +36,20 @@ def test_complete_shuffle_multiple_shard():
  assert shuffled.w.shape == dataset.w.shape


def test_complete_shuffle_multiple_shard_uneven():
  """Test that complete shuffle works with multiple shards and some shards not full size."""
  X = np.random.rand(57, 10)
  dataset = dc.data.DiskDataset.from_numpy(X)
  dataset.reshard(shard_size=10)
  shuffled = dataset.complete_shuffle()
  assert len(shuffled) == len(dataset)
  assert not np.array_equal(shuffled.ids, dataset.ids)
  assert sorted(shuffled.ids) == sorted(dataset.ids)
  assert shuffled.X.shape == dataset.X.shape
  assert shuffled.y.shape == dataset.y.shape
  assert shuffled.w.shape == dataset.w.shape


def test_complete_shuffle():
  """Test that complete shuffle."""
  current_dir = os.path.dirname(os.path.realpath(__file__))