Unverified Commit 9e553fca authored by Karl Leswing's avatar Karl Leswing Committed by GitHub
Browse files

Merge pull request #1031 from fmonta/fix1030

1030 - Tasks in dc.data.DiskDataset.merge
parents 3ba55f47 48f23dbd
Loading
Loading
Loading
Loading
+18 −1
Original line number Diff line number Diff line
@@ -939,12 +939,29 @@ class DiskDataset(Dataset):
    else:
      merge_dir = tempfile.mkdtemp()

    # Protect against generator exhaustion
    datasets = list(datasets)

    # This ensures tasks are consistent for all datasets
    tasks = []
    for dataset in datasets:
      try:
        tasks.append(dataset.tasks)
      except AttributeError:
        pass
    if tasks:
      if len(tasks) < len(datasets) or len(set(map(tuple, tasks))) > 1:
        raise ValueError(
            'Cannot merge datasets with different task specifications')
      tasks = tasks[0]

    def generator():
      for ind, dataset in enumerate(datasets):
        X, y, w, ids = (dataset.X, dataset.y, dataset.w, dataset.ids)
        yield (X, y, w, ids)

    return DiskDataset.create_dataset(generator(), data_dir=merge_dir)
    return DiskDataset.create_dataset(
        generator(), data_dir=merge_dir, tasks=tasks)

  def subset(self, shard_nums, subset_dir=None):
    """Creates a subset of the original dataset on disk."""
+24 −2
Original line number Diff line number Diff line
@@ -596,8 +596,8 @@ class TestDatasets(unittest.TestCase):
        test_ids.append(d)

      if batch_size is None:
        for idx, (tx, ty, tw,
                  tids) in enumerate(zip(test_Xs, test_ys, test_ws, test_ids)):
        for idx, (tx, ty, tw, tids) in enumerate(
            zip(test_Xs, test_ys, test_ws, test_ids)):
          assert len(tx) == shard_sizes[idx]
          assert len(ty) == shard_sizes[idx]
          assert len(tw) == shard_sizes[idx]
@@ -662,6 +662,28 @@ class TestDatasets(unittest.TestCase):
      batch_sizes.append(len(X))
    self.assertEqual([3, 3, 3, 1], batch_sizes)

  def test_merge(self):
    """Test that dataset merge works."""
    num_datapoints = 10
    num_features = 10
    num_tasks = 1
    num_datasets = 4
    datasets = []
    for i in range(num_datasets):
      Xi = np.random.rand(num_datapoints, num_features)
      yi = np.random.randint(2, size=(num_datapoints, num_tasks))
      wi = np.ones((num_datapoints, num_tasks))
      idsi = np.array(["id"] * num_datapoints)
      dataseti = dc.data.DiskDataset.from_numpy(Xi, yi, wi, idsi)
      datasets.append(dataseti)

    new_data = dc.data.datasets.DiskDataset.merge(datasets)

    # Check that we have all the data in
    assert new_data.X.shape == (num_datapoints * num_datasets, num_features)
    assert new_data.y.shape == (num_datapoints * num_datasets, num_tasks)
    assert len(new_data.tasks) == len(datasets[0].tasks)


if __name__ == "__main__":
  unittest.main()