Unverified Commit 458b2911 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2402 from peastman/dtype

Do not force dtypes to be 'object'
parents b9dc469e 048a4848
Loading
Loading
Loading
Loading
+2 −4
Original line number Diff line number Diff line
@@ -2180,8 +2180,7 @@ class DiskDataset(Dataset):
    if self._cached_shards is not None and self._cached_shards[i] is not None:
      return self._cached_shards[i].y
    row = self.metadata_df.iloc[i]
    return np.array(
        load_from_disk(os.path.join(self.data_dir, row['y'])), dtype=object)
    return np.array(load_from_disk(os.path.join(self.data_dir, row['y'])))

  def get_shard_w(self, i: int) -> np.ndarray:
    """Retrieves the weights for the i-th shard from disk.
@@ -2200,8 +2199,7 @@ class DiskDataset(Dataset):
    if self._cached_shards is not None and self._cached_shards[i] is not None:
      return self._cached_shards[i].w
    row = self.metadata_df.iloc[i]
    return np.array(
        load_from_disk(os.path.join(self.data_dir, row['w'])), dtype=object)
    return np.array(load_from_disk(os.path.join(self.data_dir, row['w'])))

  def add_shard(self,
                X: np.ndarray,