Commit 5019ad5d authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #848 from peastman/prefetch

Prefetch shards for DiskDataset
parents 586871a1 da7cbcd5
Loading
Loading
Loading
Loading
+7 −1
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@ from deepchem.utils.save import log
import tempfile
import time
import shutil
from multiprocessing.dummy import Pool

__author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
@@ -643,8 +644,12 @@ class DiskDataset(Dataset):
        shard_perm = np.random.permutation(num_shards)
      else:
        shard_perm = np.arange(num_shards)
      pool = Pool(1)
      next_shard = pool.apply_async(dataset.get_shard, (shard_perm[0],))
      for i in range(num_shards):
        X, y, w, ids = dataset.get_shard(shard_perm[i])
        X, y, w, ids = next_shard.get()
        if i < num_shards - 1:
          next_shard = pool.apply_async(dataset.get_shard, (shard_perm[i + 1],))
        n_samples = X.shape[0]
        # TODO(rbharath): This happens in tests sometimes, but don't understand why?
        # Handle edge case.
@@ -683,6 +688,7 @@ class DiskDataset(Dataset):
            (X_batch, y_batch, w_batch, ids_batch) = pad_batch(
                shard_batch_size, X_batch, y_batch, w_batch, ids_batch)
          yield (X_batch, y_batch, w_batch, ids_batch)
      pool.close()

    return iterate(self)