Unverified Commit 88d8bd25 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2015 from deepchem/merge_log

Improve DiskDataset logging
parents 4eefa98d 6169dfd1
Loading
Loading
Loading
Loading
+11 −1
Original line number Diff line number Diff line
@@ -1099,6 +1099,8 @@ class DiskDataset(Dataset):
    # Create temp directory to store resharded version
    reshard_dir = tempfile.mkdtemp()

    n_shards = self.get_number_shards()

    # Write data in new shards
    def generator():
      tasks = self.get_task_names()
@@ -1106,7 +1108,8 @@ class DiskDataset(Dataset):
      y_next = np.zeros((0,) + (len(tasks),))
      w_next = np.zeros((0,) + (len(tasks),))
      ids_next = np.zeros((0,), dtype=object)
      for (X, y, w, ids) in self.itershards():
      for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
        logger.info("Resharding shard %d/%d" % (shard_num, n_shards))
        X_next = np.concatenate([X_next, X], axis=0)
        y_next = np.concatenate([y_next, y], axis=0)
        w_next = np.concatenate([w_next, w], axis=0)
@@ -1366,8 +1369,11 @@ class DiskDataset(Dataset):
      out_dir = tempfile.mkdtemp()
    tasks = self.get_task_names()

    n_shards = self.get_number_shards()

    def generator():
      for shard_num, row in self.metadata_df.iterrows():
        logger.info("Transforming shard %d/%d" % (shard_num, n_shards))
        X, y, w, ids = self.get_shard(shard_num)
        newx, newy, neww = fn(X, y, w)
        yield (newx, newy, neww, ids)
@@ -1485,6 +1491,7 @@ class DiskDataset(Dataset):

    def generator():
      for ind, dataset in enumerate(datasets):
        logger.info("Merging in dataset %d/%d" % (ind, len(datasets)))
        X, y, w, ids = (dataset.X, dataset.y, dataset.w, dataset.ids)
        yield (X, y, w, ids)

@@ -1761,9 +1768,12 @@ class DiskDataset(Dataset):
    indices = np.array(sorted(indices)).astype(int)
    tasks = self.get_task_names()

    n_shards = self.get_number_shards()

    def generator():
      count, indices_count = 0, 0
      for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
        logger.info("Selecting from shard %d/%d" % (shard_num, n_shards))
        shard_len = len(X)
        # Find indices which rest in this shard
        num_shard_elts = 0