Commit 09452829 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #513 from rbharath/numpy_select

Make splitters work with NumpyDataset
parents 35956bec f5ff588e
Loading
Loading
Loading
Loading
+28 −2
Original line number Diff line number Diff line
@@ -368,6 +368,25 @@ class NumpyDataset(Dataset):
    newx, newy, neww = fn(self._X, self._y, self._w)
    return NumpyDataset(newx, newy, neww, self._ids[:])

  def select(self, indices, select_dir=None):
    """Creates a new dataset from a selection of indices from self.

    TODO(rbharath): select_dir is here due to dc.splits always passing in
    splits.

    Parameters
    ----------
    indices: list
      List of indices to select.
    select_dir: string
      Ignored.
    """
    X = self.X[indices]
    y = self.y[indices]
    w = self.w[indices]
    ids = self.ids[indices]
    return NumpyDataset(X, y, w, ids)


class DiskDataset(Dataset):
  """
@@ -907,8 +926,15 @@ class DiskDataset(Dataset):
        shard_inds = indices[indices_count:indices_count +
                             num_shard_elts] - count
        X_sel = X[shard_inds]
        # Handle the case of datasets with y/w missing
        if y is not None:
          y_sel = y[shard_inds]
        else:
          y_sel = None
        if w is not None:
          w_sel = w[shard_inds]
        else:
          w_sel = None
        ids_sel = ids[shard_inds]
        yield (X_sel, y_sel, w_sel, ids_sel)
        # Updating counts
+0 −4
Original line number Diff line number Diff line
@@ -93,10 +93,6 @@ class Splitter(object):

    Returns Dataset objects.
    """
    if (isinstance(dataset, NumpyDataset)):
      raise ValueError(
          "Only possible with DiskDataset.  NumpyDataset doesn't support .select"
      )
    log("Computing train/valid/test indices", self.verbose)
    train_inds, valid_inds, test_inds = self.split(
        dataset,