Commit 4b9d9aff authored by Peter Eastman's avatar Peter Eastman
Browse files

Prefetch shards for DiskDataset

parent 586871a1
Loading
Loading
Loading
Loading
+6 −1
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@ from deepchem.utils.save import log
import tempfile
import time
import shutil
from multiprocessing.dummy import Pool

__author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
@@ -643,8 +644,12 @@ class DiskDataset(Dataset):
        shard_perm = np.random.permutation(num_shards)
      else:
        shard_perm = np.arange(num_shards)
      pool = Pool(1)
      next_shard = pool.apply_async(dataset.get_shard, (shard_perm[0],))
      for i in range(num_shards):
        X, y, w, ids = dataset.get_shard(shard_perm[i])
        X, y, w, ids = next_shard.get()
        if i < num_shards - 1:
          next_shard = pool.apply_async(dataset.get_shard, (shard_perm[i + 1],))
        n_samples = X.shape[0]
        # TODO(rbharath): This happens in tests sometimes, but don't understand why?
        # Handle edge case.