Commit b75985db authored by peastman's avatar peastman
Browse files

Fixed more assumptions about array shapes

parent b5a4d7c3
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -338,6 +338,10 @@ class RandomStratifiedSplitter(Splitter):
      return dataset_1, dataset_2
    X, y, w, ids = randomize_arrays((dataset.X, dataset.y, dataset.w,
                                     dataset.ids))
    if len(y.shape) == 1:
      y = np.expand_dims(y, 1)
    if len(w.shape) == 1:
      w = np.expand_dims(w, 1)
    split_indices = self.get_task_split_indices(y, w, frac_split)

    # Create weight matrices fpor two haves.