Unverified Commit f114ba67 authored by Suzukazole's avatar Suzukazole
Browse files

add test

parent 7ae5640f
Loading
Loading
Loading
Loading
+18 −0
Original line number Diff line number Diff line
@@ -106,3 +106,21 @@ def test_disk_dataset_get_legacy_shape_multishard():
  assert y_shape == (num_datapoints, num_tasks)
  assert w_shape == (num_datapoints, num_tasks)
  assert ids_shape == (num_datapoints,)


def test_get_shard_size():
  """Test that using ids for getting the shard size does not break the method."""
  num_datapoints = 100
  num_features = 10
  num_tasks = 10
  # Generate data
  X = np.random.rand(num_datapoints, num_features)
  y = np.random.randint(2, size=(num_datapoints, num_tasks))
  w = np.random.randint(2, size=(num_datapoints, num_tasks))
  ids = np.array(["id"] * num_datapoints)

  dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids)
  assert dataset.get_shard_size() == 100

  dataset.reshard(shard_size=15)
  assert dataset.get_shard_size() == 15