Commit 6169dfd1 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

bugfix

parent ad9dddfa
Loading
Loading
Loading
Loading
+6 −6
Original line number Diff line number Diff line
@@ -1108,8 +1108,8 @@ class DiskDataset(Dataset):
      y_next = np.zeros((0,) + (len(tasks),))
      w_next = np.zeros((0,) + (len(tasks),))
      ids_next = np.zeros((0,), dtype=object)
      for (X, y, w, ids) in self.itershards():
        loger.info("Resharding shard %d/%d" % (shard_num, n_shards))
      for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
        logger.info("Resharding shard %d/%d" % (shard_num, n_shards))
        X_next = np.concatenate([X_next, X], axis=0)
        y_next = np.concatenate([y_next, y], axis=0)
        w_next = np.concatenate([w_next, w], axis=0)
@@ -1369,11 +1369,11 @@ class DiskDataset(Dataset):
      out_dir = tempfile.mkdtemp()
    tasks = self.get_task_names()

    n_shards = self.get_number_shard()
    n_shards = self.get_number_shards()

    def generator():
      for shard_num, row in self.metadata_df.iterrows():
        loger.info("Transforming shard %d/%d" % (shard_num, n_shards))
        logger.info("Transforming shard %d/%d" % (shard_num, n_shards))
        X, y, w, ids = self.get_shard(shard_num)
        newx, newy, neww = fn(X, y, w)
        yield (newx, newy, neww, ids)
@@ -1491,7 +1491,7 @@ class DiskDataset(Dataset):

    def generator():
      for ind, dataset in enumerate(datasets):
        loger.info("Merging in dataset %d/%d" % (ind, len(datasets)))
        logger.info("Merging in dataset %d/%d" % (ind, len(datasets)))
        X, y, w, ids = (dataset.X, dataset.y, dataset.w, dataset.ids)
        yield (X, y, w, ids)

@@ -1773,7 +1773,7 @@ class DiskDataset(Dataset):
    def generator():
      count, indices_count = 0, 0
      for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
        loger.info("Selecting from shard %d/%d" % (shard_num, n_shards))
        logger.info("Selecting from shard %d/%d" % (shard_num, n_shards))
        shard_len = len(X)
        # Find indices which rest in this shard
        num_shard_elts = 0