Commit 08e6f3ec authored by peastman's avatar peastman
Browse files

Use fixed size for dataset cache

parent b9a1fedf
Loading
Loading
Loading
Loading
+32 −6
Original line number Diff line number Diff line
@@ -15,7 +15,6 @@ import time
import shutil
import json
import warnings
import weakref
from multiprocessing.dummy import Pool
from deepchem.utils.save import save_to_disk, save_metadata
from deepchem.utils.save import load_from_disk
@@ -949,6 +948,8 @@ class DiskDataset(Dataset):
    logger.info("Loading dataset from disk.")
    self.tasks, self.metadata_df = self.load_metadata()
    self._cached_shards = None
    self._memory_cache_size = 20 * (1 << 20)  # 20 MB
    self._cache_used = 0

  @staticmethod
  def create_dataset(shard_generator, data_dir=None, tasks=[]):
@@ -1598,8 +1599,7 @@ class DiskDataset(Dataset):
    if self._cached_shards is None:
      self._cached_shards = [None] * self.get_number_shards()
    if self._cached_shards[i] is not None:
      shard = self._cached_shards[i]()
      if shard is not None:
      shard = self._cached_shards[i]
      return (shard.X, shard.y, shard.w, shard.ids)

    # We don't, so load it from disk.
@@ -1627,9 +1627,22 @@ class DiskDataset(Dataset):
    ids = np.array(
        load_from_disk(os.path.join(self.data_dir, row['ids'])), dtype=object)

    # Cache this shard for later use.
    # Try to cache this shard for later use.  Since the normal usage pattern is
    # a series of passes through the whole dataset, there's no point doing
    # anything fancy.  It never makes sense to evict another shard from the
    # cache to make room for this one, because we'll probably want that other
    # shard again before the next time we want this one.  So just cache as many
    # as we can and then stop.

    shard = Shard(X, y, w, ids)
    self._cached_shards[i] = weakref.ref(shard)
    shard_size = X.nbytes + ids.nbytes
    if y is not None:
      shard_size += y.nbytes
    if w is not None:
      shard_size += w.nbytes
    if self._cache_used + shard_size < self._memory_cache_size:
      self._cached_shards[i] = shard
      self._cache_used += shard_size
    return (shard.X, shard.y, shard.w, shard.ids)

  def add_shard(self, X, y, w, ids):
@@ -1760,6 +1773,19 @@ class DiskDataset(Dataset):
    else:
      return np.concatenate(ws)

  @property
  def memory_cache_size(self):
    """Get the size of the memory cache for this dataset, measured in bytes."""
    return self._memory_cache_size

  @memory_cache_size.setter
  def memory_cache_size(self, size):
    """Get the size of the memory cache for this dataset, measured in bytes."""
    self._memory_cache_size = size
    if self._cache_used > size:
      self._cached_shards = None
      self._cache_used = 0

  def __len__(self):
    """
    Finds number of elements in dataset.