Unverified Commit 3467db1e authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1859 from peastman/shard

DiskDataset caches shards in memory
parents c84e7409 19b93c93
Loading
Loading
Loading
Loading
+54 −27
Original line number Diff line number Diff line
@@ -947,6 +947,9 @@ class DiskDataset(Dataset):

    logger.info("Loading dataset from disk.")
    self.tasks, self.metadata_df = self.load_metadata()
    self._cached_shards = None
    self._memory_cache_size = 20 * (1 << 20)  # 20 MB
    self._cache_used = 0

  @staticmethod
  def create_dataset(shard_generator, data_dir=None, tasks=[]):
@@ -1050,6 +1053,7 @@ class DiskDataset(Dataset):
  def save_to_disk(self):
    """Save dataset to disk."""
    save_metadata(self.tasks, self.metadata_df, self.data_dir)
    self._cached_shards = None

  def move(self, new_data_dir):
    """Moves dataset to new directory."""
@@ -1143,32 +1147,7 @@ class DiskDataset(Dataset):
    generator defined by this function returns the data from a particular shard.
    The order of shards returned is guaranteed to remain fixed.
    """

    def iterate(dataset):
      for _, row in dataset.metadata_df.iterrows():
        X = np.array(load_from_disk(os.path.join(dataset.data_dir, row['X'])))
        ids = np.array(
            load_from_disk(os.path.join(dataset.data_dir, row['ids'])),
            dtype=object)
        # These columns may be missing is the dataset is unlabelled.
        if row['y'] is not None:
          y = np.array(load_from_disk(os.path.join(dataset.data_dir, row['y'])))
        else:
          y = None
        if row['w'] is not None:
          w_filename = os.path.join(dataset.data_dir, row['w'])
          if os.path.exists(w_filename):
            w = np.array(load_from_disk(w_filename))
          else:
            if len(y.shape) == 1:
              w = np.ones(y.shape[0], np.float32)
            else:
              w = np.ones((y.shape[0], 1), np.float32)
        else:
          w = None
        yield (X, y, w, ids)

    return iterate(self)
    return (self.get_shard(i) for i in range(self.get_number_shards()))

  def iterbatches(self,
                  batch_size=None,
@@ -1607,6 +1586,24 @@ class DiskDataset(Dataset):

  def get_shard(self, i):
    """Retrieves data for the i-th shard from disk."""

    class Shard(object):

      def __init__(self, X, y, w, ids):
        self.X = X
        self.y = y
        self.w = w
        self.ids = ids

    # See if we have a cached copy of this shard.
    if self._cached_shards is None:
      self._cached_shards = [None] * self.get_number_shards()
      self._cache_used = 0
    if self._cached_shards[i] is not None:
      shard = self._cached_shards[i]
      return (shard.X, shard.y, shard.w, shard.ids)

    # We don't, so load it from disk.
    row = self.metadata_df.iloc[i]
    X = np.array(load_from_disk(os.path.join(self.data_dir, row['X'])))

@@ -1630,7 +1627,24 @@ class DiskDataset(Dataset):

    ids = np.array(
        load_from_disk(os.path.join(self.data_dir, row['ids'])), dtype=object)
    return (X, y, w, ids)

    # Try to cache this shard for later use.  Since the normal usage pattern is
    # a series of passes through the whole dataset, there's no point doing
    # anything fancy.  It never makes sense to evict another shard from the
    # cache to make room for this one, because we'll probably want that other
    # shard again before the next time we want this one.  So just cache as many
    # as we can and then stop.

    shard = Shard(X, y, w, ids)
    shard_size = X.nbytes + ids.nbytes
    if y is not None:
      shard_size += y.nbytes
    if w is not None:
      shard_size += w.nbytes
    if self._cache_used + shard_size < self._memory_cache_size:
      self._cached_shards[i] = shard
      self._cache_used += shard_size
    return (shard.X, shard.y, shard.w, shard.ids)

  def add_shard(self, X, y, w, ids):
    """Adds a data shard."""
@@ -1649,6 +1663,7 @@ class DiskDataset(Dataset):
    basename = "shard-%d" % shard_num
    tasks = self.get_task_names()
    DiskDataset.write_data_to_disk(self.data_dir, basename, tasks, X, y, w, ids)
    self._cached_shards = None

  def select(self, indices, select_dir=None):
    """Creates a new dataset from a selection of indices from self.
@@ -1759,6 +1774,18 @@ class DiskDataset(Dataset):
    else:
      return np.concatenate(ws)

  @property
  def memory_cache_size(self):
    """Get the size of the memory cache for this dataset, measured in bytes."""
    return self._memory_cache_size

  @memory_cache_size.setter
  def memory_cache_size(self, size):
    """Get the size of the memory cache for this dataset, measured in bytes."""
    self._memory_cache_size = size
    if self._cache_used > size:
      self._cached_shards = None

  def __len__(self):
    """
    Finds number of elements in dataset.
+2 −0
Original line number Diff line number Diff line
@@ -153,6 +153,8 @@ class Splitter(object):
    else:
      valid_dataset = None
    test_dataset = dataset.select(test_inds, test_dir)
    if isinstance(train_dataset, DiskDataset):
      train_dataset.memory_cache_size = 40 * (1 << 20)  # 40 MB

    return train_dataset, valid_dataset, test_dataset

+4 −5
Original line number Diff line number Diff line
@@ -62,7 +62,6 @@ class Transformer(object):
               transform_w=False,
               dataset=None):
    """Initializes transformation based on dataset statistics."""
    self.dataset = dataset
    self.transform_X = transform_X
    self.transform_y = transform_y
    self.transform_w = transform_w
@@ -482,12 +481,12 @@ class BalancingTransformer(Transformer):
    assert transform_w

    # Compute weighting factors from dataset.
    y = self.dataset.y
    w = self.dataset.w
    y = dataset.y
    w = dataset.w
    # Ensure dataset is binary
    np.testing.assert_allclose(sorted(np.unique(y)), np.array([0., 1.]))
    weights = []
    for ind, task in enumerate(self.dataset.get_task_names()):
    for ind, task in enumerate(dataset.get_task_names()):
      task_w = w[:, ind]
      task_y = y[:, ind]
      # Remove labels with zero weights
@@ -505,7 +504,7 @@ class BalancingTransformer(Transformer):
  def transform_array(self, X, y, w):
    """Transform the data in a set of (X, y, w) arrays."""
    w_balanced = np.zeros_like(w)
    for ind, task in enumerate(self.dataset.get_task_names()):
    for ind in range(y.shape[1]):
      task_y = y[:, ind]
      task_w = w[:, ind]
      zero_indices = np.logical_and(task_y == 0, task_w != 0)
+1 −0
Original line number Diff line number Diff line
@@ -292,6 +292,7 @@ def load_dataset_from_disk(save_dir):
  train = deepchem.data.DiskDataset(train_dir)
  valid = deepchem.data.DiskDataset(valid_dir)
  test = deepchem.data.DiskDataset(test_dir)
  train.memory_cache_size = 40 * (1 << 20)  # 40 MB
  all_dataset = (train, valid, test)
  with open(os.path.join(save_dir, "transformers.pkl"), 'rb') as f:
    transformers = pickle.load(f)