Commit 09db59f3 authored by flo's avatar flo
Browse files

added test to the merge method checking the size of the tasks, and the shape...

added test to the merge method checking the size of the tasks, and the shape of the new merged dataset
parent a54f39ba
Loading
Loading
Loading
Loading
+22 −0
Original line number Diff line number Diff line
@@ -662,6 +662,28 @@ class TestDatasets(unittest.TestCase):
      batch_sizes.append(len(X))
    self.assertEqual([3, 3, 3, 1], batch_sizes)

  def test_merge(self):
    """Test that dataset merge works."""
    num_datapoints = 10
    num_features = 10
    num_tasks = 1
    num_datasets = 4
    datasets = []
    for i in range(num_datasets):
      Xi = np.random.rand(num_datapoints, num_features)
      yi = np.random.randint(2, size=(num_datapoints, num_tasks))
      wi = np.ones((num_datapoints, num_tasks))
      idsi = np.array(["id"] * num_datapoints)
      dataseti = dc.data.DiskDataset.from_numpy(Xi, yi, wi, idsi)
      datasets.append(dataseti)

    new_data = dc.data.datasets.DiskDataset.merge(datasets)

    # Check that we have all the data in
    assert new_data.X.shape == (num_datapoints*num_datasets, num_features)
    assert new_data.y.shape == (num_datapoints*num_datasets, num_tasks)
    assert len(new_data.tasks) == len(datasets[0].tasks)


if __name__ == "__main__":
  unittest.main()