Commit 9393f7c5 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changes

parent 5a17bb15
Loading
Loading
Loading
Loading
+6 −0
Original line number Original line Diff line number Diff line
@@ -1366,8 +1366,11 @@ class DiskDataset(Dataset):
      out_dir = tempfile.mkdtemp()
      out_dir = tempfile.mkdtemp()
    tasks = self.get_task_names()
    tasks = self.get_task_names()


    n_shards = self.get_number_shard()

    def generator():
    def generator():
      for shard_num, row in self.metadata_df.iterrows():
      for shard_num, row in self.metadata_df.iterrows():
        loger.info("Transforming shard %d/%d" % (shard_num, n_shards))
        X, y, w, ids = self.get_shard(shard_num)
        X, y, w, ids = self.get_shard(shard_num)
        newx, newy, neww = fn(X, y, w)
        newx, newy, neww = fn(X, y, w)
        yield (newx, newy, neww, ids)
        yield (newx, newy, neww, ids)
@@ -1762,9 +1765,12 @@ class DiskDataset(Dataset):
    indices = np.array(sorted(indices)).astype(int)
    indices = np.array(sorted(indices)).astype(int)
    tasks = self.get_task_names()
    tasks = self.get_task_names()


    n_shards = self.get_number_shards()

    def generator():
    def generator():
      count, indices_count = 0, 0
      count, indices_count = 0, 0
      for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
      for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
        loger.info("Selecting from shard %d/%d" % (shard_num, n_shards))
        shard_len = len(X)
        shard_len = len(X)
        # Find indices which rest in this shard
        # Find indices which rest in this shard
        num_shard_elts = 0
        num_shard_elts = 0