Commit b9ad5c8a authored by peastman's avatar peastman
Browse files

Fixed failing test cases

parent e0e6b3dd
Loading
Loading
Loading
Loading
+7 −2
Original line number Diff line number Diff line
@@ -114,6 +114,11 @@ class TaskSplitter(Splitter):
        fold_tasks = range(fold * n_per_fold, (fold + 1) * n_per_fold)
      else:
        fold_tasks = range(fold * n_per_fold, n_tasks)
      fold_datasets.append(
          NumpyDataset(X, y[:, fold_tasks], w[:, fold_tasks], ids))
      if len(w.shape) == 1:
        w_tasks = w
      elif w.shape[1] == 1:
        w_tasks = w[:, 0]
      else:
        w_tasks = w[:, fold_tasks]
      fold_datasets.append(NumpyDataset(X, y[:, fold_tasks], w_tasks, ids))
    return fold_datasets