Commit 994aa14a authored by flo's avatar flo
Browse files

ensure all datasets have the same tasks

parent 09db59f3
Loading
Loading
Loading
Loading
+16 −1
Original line number Diff line number Diff line
@@ -915,12 +915,27 @@ class DiskDataset(Dataset):
    else:
      merge_dir = tempfile.mkdtemp()

    # Protect against generator exhaustion
    datasets = list(datasets)

    # This ensures tasks are consistent for all datasets
    tasks = []
    for dataset in datasets:
      try:
        tasks.append(dataset.tasks)
      except AttributeError:
        pass
    if tasks:
      if len(tasks) < len(datasets) or len(set(map(tuple, tasks))) > 1:
        raise ValueError('Cannot merge datasets with different task specifications')
      tasks = tasks[0]

    def generator():
      for ind, dataset in enumerate(datasets):
        X, y, w, ids = (dataset.X, dataset.y, dataset.w, dataset.ids)
        yield (X, y, w, ids)

    return DiskDataset.create_dataset(generator(), data_dir=merge_dir, tasks=next(iter(datasets)).tasks)
    return DiskDataset.create_dataset(generator(), data_dir=merge_dir, tasks=tasks)

  def subset(self, shard_nums, subset_dir=None):
    """Creates a subset of the original dataset on disk."""