Unverified Commit 4970501a authored by Suzukazole's avatar Suzukazole
Browse files

change y to ids and checks

parent 5464cd46
Loading
Loading
Loading
Loading
+15 −8
Original line number Diff line number Diff line
@@ -1500,10 +1500,10 @@ class DiskDataset(Dataset):
    """Gets size of shards on disk."""
    if not len(self.metadata_df):
      raise ValueError("No data in dataset.")
    sample_y = load_from_disk(
    sample_ids = load_from_disk(
        os.path.join(self.data_dir,
                     next(self.metadata_df.iterrows())[1]['y']))
    return len(sample_y)
                     next(self.metadata_df.iterrows())[1]['ids']))
    return len(sample_ids)

  def _get_metadata_filename(self) -> Tuple[str, str]:
    """Get standard location for metadata file."""
@@ -2369,11 +2369,11 @@ class DiskDataset(Dataset):
          if y is not None:
            y_sel = y[shard_inds]
          else:
            y_sel = None
            y_sel = np.array([])
          if w is not None:
            w_sel = w[shard_inds]
          else:
            w_sel = None
            w_sel = np.array([])
          ids_sel = ids[shard_inds]
          Xs.append(X_sel)
          ys.append(y_sel)
@@ -2399,9 +2399,16 @@ class DiskDataset(Dataset):
                np.where(sorted_indices == orig_index)[0][0]
                for orig_index in select_shard_indices
            ])
        X, y, w, ids = X[reverted_indices], y[reverted_indices], w[
            reverted_indices], ids[reverted_indices]
        yield (X, y, w, ids)
        if y.size == 0:
          tup_y = y
        else:
          tup_y = y[reverted_indices]
        if w.size == 0:
          tup_w = w
        else:
          tup_w = w[reverted_indices]
        X, ids = X[reverted_indices], ids[reverted_indices]
        yield (X, tup_y, tup_w, ids)
        start = end
        select_shard_num += 1