Commit b9a1fedf authored by peastman's avatar peastman
Browse files

Cache shards for DiskDataset

parent 51f426b3
Loading
Loading
Loading
Loading
+35 −34
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@ import time
import shutil
import json
import warnings
import weakref
from multiprocessing.dummy import Pool
from deepchem.utils.save import save_to_disk, save_metadata
from deepchem.utils.save import load_from_disk
@@ -947,6 +948,7 @@ class DiskDataset(Dataset):

    logger.info("Loading dataset from disk.")
    self.tasks, self.metadata_df = self.load_metadata()
    self._cached_shards = None

  @staticmethod
  def create_dataset(shard_generator, data_dir=None, tasks=[]):
@@ -1050,6 +1052,7 @@ class DiskDataset(Dataset):
  def save_to_disk(self):
    """Save dataset to disk."""
    save_metadata(self.tasks, self.metadata_df, self.data_dir)
    self._cached_shards = None

  def move(self, new_data_dir):
    """Moves dataset to new directory."""
@@ -1143,32 +1146,7 @@ class DiskDataset(Dataset):
    generator defined by this function returns the data from a particular shard.
    The order of shards returned is guaranteed to remain fixed.
    """

    def iterate(dataset):
      for _, row in dataset.metadata_df.iterrows():
        X = np.array(load_from_disk(os.path.join(dataset.data_dir, row['X'])))
        ids = np.array(
            load_from_disk(os.path.join(dataset.data_dir, row['ids'])),
            dtype=object)
        # These columns may be missing is the dataset is unlabelled.
        if row['y'] is not None:
          y = np.array(load_from_disk(os.path.join(dataset.data_dir, row['y'])))
        else:
          y = None
        if row['w'] is not None:
          w_filename = os.path.join(dataset.data_dir, row['w'])
          if os.path.exists(w_filename):
            w = np.array(load_from_disk(w_filename))
          else:
            if len(y.shape) == 1:
              w = np.ones(y.shape[0], np.float32)
            else:
              w = np.ones((y.shape[0], 1), np.float32)
        else:
          w = None
        yield (X, y, w, ids)

    return iterate(self)
    return (self.get_shard(i) for i in range(self.get_number_shards()))

  def iterbatches(self,
                  batch_size=None,
@@ -1607,6 +1585,24 @@ class DiskDataset(Dataset):

  def get_shard(self, i):
    """Retrieves data for the i-th shard from disk."""

    class Shard(object):

      def __init__(self, X, y, w, ids):
        self.X = X
        self.y = y
        self.w = w
        self.ids = ids

    # See if we have a cached copy of this shard.
    if self._cached_shards is None:
      self._cached_shards = [None] * self.get_number_shards()
    if self._cached_shards[i] is not None:
      shard = self._cached_shards[i]()
      if shard is not None:
        return (shard.X, shard.y, shard.w, shard.ids)

    # We don't, so load it from disk.
    row = self.metadata_df.iloc[i]
    X = np.array(load_from_disk(os.path.join(self.data_dir, row['X'])))

@@ -1630,7 +1626,11 @@ class DiskDataset(Dataset):

    ids = np.array(
        load_from_disk(os.path.join(self.data_dir, row['ids'])), dtype=object)
    return (X, y, w, ids)

    # Cache this shard for later use.
    shard = Shard(X, y, w, ids)
    self._cached_shards[i] = weakref.ref(shard)
    return (shard.X, shard.y, shard.w, shard.ids)

  def add_shard(self, X, y, w, ids):
    """Adds a data shard."""
@@ -1649,6 +1649,7 @@ class DiskDataset(Dataset):
    basename = "shard-%d" % shard_num
    tasks = self.get_task_names()
    DiskDataset.write_data_to_disk(self.data_dir, basename, tasks, X, y, w, ids)
    self._cached_shards = None

  def select(self, indices, select_dir=None):
    """Creates a new dataset from a selection of indices from self.