Commit 48f23dbd authored by flo's avatar flo
Browse files

yapf

parent 994aa14a
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -927,7 +927,8 @@ class DiskDataset(Dataset):
        pass
    if tasks:
      if len(tasks) < len(datasets) or len(set(map(tuple, tasks))) > 1:
        raise ValueError('Cannot merge datasets with different task specifications')
        raise ValueError(
            'Cannot merge datasets with different task specifications')
      tasks = tasks[0]

    def generator():
@@ -935,7 +936,8 @@ class DiskDataset(Dataset):
        X, y, w, ids = (dataset.X, dataset.y, dataset.w, dataset.ids)
        yield (X, y, w, ids)

    return DiskDataset.create_dataset(generator(), data_dir=merge_dir, tasks=tasks)
    return DiskDataset.create_dataset(
        generator(), data_dir=merge_dir, tasks=tasks)

  def subset(self, shard_nums, subset_dir=None):
    """Creates a subset of the original dataset on disk."""
+4 −4
Original line number Diff line number Diff line
@@ -596,8 +596,8 @@ class TestDatasets(unittest.TestCase):
        test_ids.append(d)

      if batch_size is None:
        for idx, (tx, ty, tw,
                  tids) in enumerate(zip(test_Xs, test_ys, test_ws, test_ids)):
        for idx, (tx, ty, tw, tids) in enumerate(
            zip(test_Xs, test_ys, test_ws, test_ids)):
          assert len(tx) == shard_sizes[idx]
          assert len(ty) == shard_sizes[idx]
          assert len(tw) == shard_sizes[idx]