Commit 02262a27 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Removing loading of unnecessary shards

parent 69fac4b4
Loading
Loading
Loading
Loading
+73 −24
Original line number Diff line number Diff line
@@ -2131,11 +2131,18 @@ class DiskDataset(Dataset):

        Xs, ys, ws, ids_s = [], [], [], []
        count, indices_count = 0, 0
        for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
        for shard_num in range(self.get_number_shards()):
          logger.info(
              "Selecting from input shard %d/%d for selection output shard %d" %
              (shard_num + 1, n_shards, select_shard_num + 1))
          if self.legacy_metadata:
            shard_len = len(X)
          else:
            shard_X_shape, _, _, _ = self._get_shard_shape(shard_num)
            if len(shard_X_shape) > 0:
              shard_len = shard_X_shape[0]
            else:
              shard_len = 0
          # Find indices which rest in this shard
          num_shard_elts = 0
          while sorted_indices[indices_count +
@@ -2143,9 +2150,15 @@ class DiskDataset(Dataset):
            num_shard_elts += 1
            if (indices_count + num_shard_elts) >= len(sorted_indices):
              break
          if num_shard_elts == 0:
            count += shard_len
            continue
          else:
            X, y, w, ids = self.get_shard(shard_num)
          # Need to offset indices to fit within shard_size
          shard_inds = sorted_indices[indices_count:indices_count +
                                      num_shard_elts] - count
          # Handle empty case where no data from this shard needed
          X_sel = X[shard_inds]
          # Handle the case of datasets with y/w missing
          if y is not None:
@@ -2163,7 +2176,7 @@ class DiskDataset(Dataset):
          ids_s.append(ids_sel)
          indices_count += num_shard_elts
          count += shard_len
          # Break when all indices have been used up already
          # Break if all indices have been used up already
          if indices_count >= len(sorted_indices):
            break
        # Note these will be in the sorted order
@@ -2272,14 +2285,14 @@ class DiskDataset(Dataset):
      total += len(y)
    return total

  def get_shape(self) -> Tuple[Shape, Shape, Shape, Shape]:
    """Finds shape of dataset."""
  def _get_shard_shape(self,
                       shard_num: int) -> Tuple[Shape, Shape, Shape, Shape]:
    """Finds the shape of the specified shard."""
    if self.legacy_metadata:
      raise ValueError(
          "This function requires the new metadata format to be called. Please reshard this dataset by calling the reshard() method."
      )
    n_tasks = len(self.get_task_names())
    n_rows = len(self.metadata_df.index)
    # If shape metadata is available use it to directly compute shape from
    # metadata
    if not self.legacy_metadata:
      for shard_num in range(n_rows):
    row = self.metadata_df.iloc[shard_num]
    if row['X_shape'] is not None:
      shard_X_shape = make_tuple(str(row['X_shape']))
@@ -2301,6 +2314,42 @@ class DiskDataset(Dataset):
      shard_ids_shape = make_tuple(str(row['ids_shape']))
    else:
      shard_ids_shape = tuple()
    X_shape, y_shape, w_shape, ids_shape = tuple(
        np.array(shard_X_shape)), tuple(np.array(shard_y_shape)), tuple(
            np.array(shard_w_shape)), tuple(np.array(shard_ids_shape))
    return X_shape, y_shape, w_shape, ids_shape

  def get_shape(self) -> Tuple[Shape, Shape, Shape, Shape]:
    """Finds shape of dataset."""
    n_tasks = len(self.get_task_names())
    n_rows = len(self.metadata_df.index)
    # If shape metadata is available use it to directly compute shape from
    # metadata
    if not self.legacy_metadata:
      for shard_num in range(n_rows):
        #row = self.metadata_df.iloc[shard_num]
        #if row['X_shape'] is not None:
        #  shard_X_shape = make_tuple(str(row['X_shape']))
        #else:
        #  shard_X_shape = tuple()
        #if n_tasks > 0:
        #  if row['y_shape'] is not None:
        #    shard_y_shape = make_tuple(str(row['y_shape']))
        #  else:
        #    shard_y_shape = tuple()
        #  if row['w_shape'] is not None:
        #    shard_w_shape = make_tuple(str(row['w_shape']))
        #  else:
        #    shard_w_shape = tuple()
        #else:
        #  shard_y_shape = tuple()
        #  shard_w_shape = tuple()
        #if row['ids_shape'] is not None:
        #  shard_ids_shape = make_tuple(str(row['ids_shape']))
        #else:
        #  shard_ids_shape = tuple()
        shard_X_shape, shard_y_shape, shard_w_shape, shard_ids_shape = self._get_shard_shape(
            shard_num)
        if shard_num == 0:
          X_shape, y_shape, w_shape, ids_shape = np.array(
              shard_X_shape), np.array(shard_y_shape), np.array(