Commit c110c8fb authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changes

parent b8749616
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -2018,6 +2018,9 @@ class DiskDataset(Dataset):
  def get_shape(self) -> Tuple[Shape, Shape, Shape, Shape]:
    """Finds shape of dataset."""
    n_tasks = len(self.get_task_names())
    n_rows = len(self.metadata_df.index)
    for i in range(n_rows):
      row = self.metadata_df.iloc[i]
    #for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
    #  if shard_num == 0:
    #    X_shape = np.array(X.shape)