Unverified Commit 255af47f authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1949 from peastman/ids

Optimized retrieving IDs for DiskDataset
parents 364c16f8 cb95eeff
Loading
Loading
Loading
Loading
+11 −2
Original line number Diff line number Diff line
@@ -1646,6 +1646,15 @@ class DiskDataset(Dataset):
      self._cache_used += shard_size
    return (shard.X, shard.y, shard.w, shard.ids)

  def get_shard_ids(self, i):
    """Retrieves the list of IDs for the i-th shard from disk."""

    if self._cached_shards is not None and self._cached_shards[i] is not None:
      return self._cached_shards[i].ids
    row = self.metadata_df.iloc[i]
    return np.array(
        load_from_disk(os.path.join(self.data_dir, row['ids'])), dtype=object)

  def add_shard(self, X, y, w, ids):
    """Adds a data shard."""
    metadata_rows = self.metadata_df.values.tolist()
@@ -1728,8 +1737,8 @@ class DiskDataset(Dataset):
    if len(self) == 0:
      return np.array([])
    ids = []
    for (_, _, _, ids_b) in self.itershards():
      ids.append(np.atleast_1d(np.squeeze(ids_b)))
    for i in range(self.get_number_shards()):
      ids.append(np.atleast_1d(np.squeeze(self.get_shard_ids(i))))
    return np.concatenate(ids)

  @property