Commit ad9dddfa authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

change

parent 9393f7c5
Loading
Loading
Loading
Loading
+3 −0
Original line number Original line Diff line number Diff line
@@ -1099,6 +1099,8 @@ class DiskDataset(Dataset):
    # Create temp directory to store resharded version
    # Create temp directory to store resharded version
    reshard_dir = tempfile.mkdtemp()
    reshard_dir = tempfile.mkdtemp()


    n_shards = self.get_number_shards()

    # Write data in new shards
    # Write data in new shards
    def generator():
    def generator():
      tasks = self.get_task_names()
      tasks = self.get_task_names()
@@ -1107,6 +1109,7 @@ class DiskDataset(Dataset):
      w_next = np.zeros((0,) + (len(tasks),))
      w_next = np.zeros((0,) + (len(tasks),))
      ids_next = np.zeros((0,), dtype=object)
      ids_next = np.zeros((0,), dtype=object)
      for (X, y, w, ids) in self.itershards():
      for (X, y, w, ids) in self.itershards():
        loger.info("Resharding shard %d/%d" % (shard_num, n_shards))
        X_next = np.concatenate([X_next, X], axis=0)
        X_next = np.concatenate([X_next, X], axis=0)
        y_next = np.concatenate([y_next, y], axis=0)
        y_next = np.concatenate([y_next, y], axis=0)
        w_next = np.concatenate([w_next, w], axis=0)
        w_next = np.concatenate([w_next, w], axis=0)