Unverified Commit 510b9bf1 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2104 from deepchem/ordered_select

Make Order Respecting Select for DiskDataset
parents d3bb4ee8 df7a0ca3
Loading
Loading
Loading
Loading
+173 −119
Original line number Diff line number Diff line
@@ -1863,7 +1863,7 @@ class DiskDataset(Dataset):
    time2 = time.time()
    logger.info("TIMING: sparse_shuffle took %0.3f s" % (time2 - time1))

  def complete_shuffle(self, data_dir: Optional[str] = None) -> "DiskDataset":
  def complete_shuffle(self, data_dir: Optional[str] = None) -> Dataset:
    """Completely shuffle across all data, across all shards.

    Note
@@ -1893,55 +1893,7 @@ class DiskDataset(Dataset):
    N = len(self)
    perm = np.random.permutation(N)
    shard_size = self.get_shard_size()

    def generator():
      start = 0
      shard_num = 0
      while start < N:
        logger.info("Constructing shard %d" % shard_num)
        if start + shard_size < N:
          end = start + shard_size
        else:
          end = N
        shard_indices = perm[start:end]
        # Note that this is in sorted order which doesn't respect the random
        # permutation.
        shard_dataset = self.select(shard_indices)
        # One bit of trickiness here is that select() will return in sorted
        # order. For example, suppose we'd like these elements in our permuted
        # shard:
        #
        # [12, 234, 1, 4]
        #
        # Then select would return elements in order
        #
        # [1, 4, 12, 234]
        #
        # We need to recover the original ordering. We can do this by using
        # np.where to find the locatios of the original indices in the sorted
        # indices.
        sorted_indices = np.array(sorted(shard_indices))
        reverted_indices = np.array(
            # We know there's only one match for np.where since this is a
            # permutation, so the [0][0] pulls out the exact match location.
            [
                np.where(sorted_indices == orig_index)[0][0]
                for orig_index in shard_indices
            ])
        # Let's pull out shard elements
        shard_X, shard_y, shard_w, shard_ids = (shard_dataset.X,
                                                shard_dataset.y,
                                                shard_dataset.w,
                                                shard_dataset.ids)

        yield (shard_X[reverted_indices], shard_y[reverted_indices],
               shard_w[reverted_indices], shard_ids[reverted_indices])

        start = end
        shard_num += 1

    return DiskDataset.create_dataset(
        generator(), data_dir=data_dir, tasks=self.get_task_names())
    return self.select(perm, data_dir, self.get_shard_size())

  def shuffle_each_shard(self,
                         shard_basenames: Optional[List[str]] = None) -> None:
@@ -2097,56 +2049,117 @@ class DiskDataset(Dataset):
    DiskDataset.write_data_to_disk(self.data_dir, basename, tasks, X, y, w, ids)
    self._cached_shards = None

  def select(self, indices: Sequence[int],
             select_dir: Optional[str] = None) -> "DiskDataset":
  def select(self,
             indices: Sequence[int],
             select_dir: Optional[str] = None,
             select_shard_size: Optional[int] = None,
             output_numpy_dataset: Optional[bool] = False) -> Dataset:
    """Creates a new dataset from a selection of indices from self.

    Note
    ----
    The specified indices will be returned in sorted order. That is, if you
    request that indices `[3, 1, 2]` are returned, you will get a
    `DiskDataset` which contains elements in order `[1, 2, 3]`.
    Examples
    --------
    >>> import numpy as np
    >>> X = np.random.rand(10, 10)
    >>> dataset = dc.data.DiskDataset.from_numpy(X)
    >>> selected = dataset.select([1, 3, 4])
    >>> len(selected)
    3

    Parameters
    ----------
    indices: list
      List of indices to select.
    select_dir: Optional[str], (default None)
      Path to new directory that the selected indices will be copied
      Path to new directory that the selected samples will be copied
      to.
    select_shard_size: Optional[int], (default None)
      If specified, the shard-size to use for output selected `DiskDataset`.
      If not output_numpy_dataset, then this is set to this current dataset's
      shard size if not manually specified. 
    output_numpy_dataset: Optional[bool], (default False)
      If True, output an in-memory `NumpyDataset` instead of a `DiskDataset`.
      Note that `select_dir` and `select_shard_size` must be `None` if this
      is `True`

    Returns
    -------
    DiskDataset
      Contains selected indices.
      Contains selected samples.
    """
    if output_numpy_dataset and (select_dir is not None or
                                 select_shard_size is not None):
      raise ValueError(
          "If output_numpy_dataset is set, then select_dir and select_shard_size must both be None"
      )
    if output_numpy_dataset:
      # When outputting a NumpyDataset, we have 1 in-memory shard
      select_shard_size = len(indices)
    else:
      if select_dir is not None:
        if not os.path.exists(select_dir):
          os.makedirs(select_dir)
      else:
        select_dir = tempfile.mkdtemp()
      if select_shard_size is None:
        select_shard_size = self.get_shard_size()
    # Handle edge case with empty indices
    if not len(indices):
      if not output_numpy_dataset:
        return DiskDataset.create_dataset([], data_dir=select_dir)
    indices = np.array(sorted(indices)).astype(int)
    tasks = self.get_task_names()
      else:
        return NumpyDataset(
            np.array([]), np.array([]), np.array([]), np.array([]))

    N = len(indices)
    indices = np.array(indices).astype(int)
    tasks = self.get_task_names()
    n_shards = self.get_number_shards()

    # We use two loops here. The outer while loop walks over selection shards
    # (the chunks of the indices to select that should go into separate
    # output shards), while the inner for loop walks over the shards in the
    # source datasets to select out the shard indices from that  source shard
    def generator():
      start = 0
      select_shard_num = 0
      while start < N:
        logger.info(
            "Constructing selection output shard %d" % (select_shard_num + 1))
        end = min(start + select_shard_size, N)
        select_shard_indices = indices[start:end]
        sorted_indices = np.array(sorted(select_shard_indices)).astype(int)

        Xs, ys, ws, ids_s = [], [], [], []
        count, indices_count = 0, 0
      for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
        logger.info("Selecting from shard %d/%d" % (shard_num, n_shards))
        shard_len = len(X)
        for shard_num in range(self.get_number_shards()):
          logger.info(
              "Selecting from input shard %d/%d for selection output shard %d" %
              (shard_num + 1, n_shards, select_shard_num + 1))
          if self.legacy_metadata:
            ids = self.get_shard_ids(shard_num)
            shard_len = len(ids)
          else:
            shard_X_shape, _, _, _ = self._get_shard_shape(shard_num)
            if len(shard_X_shape) > 0:
              shard_len = shard_X_shape[0]
            else:
              shard_len = 0
          # Find indices which rest in this shard
          num_shard_elts = 0
        while indices[indices_count + num_shard_elts] < count + shard_len:
          while sorted_indices[indices_count +
                               num_shard_elts] < count + shard_len:
            num_shard_elts += 1
          if indices_count + num_shard_elts >= len(indices):
            if (indices_count + num_shard_elts) >= len(sorted_indices):
              break
          if num_shard_elts == 0:
            count += shard_len
            continue
          else:
            X, y, w, ids = self.get_shard(shard_num)
          # Need to offset indices to fit within shard_size
        shard_inds = indices[indices_count:indices_count +
          shard_inds = sorted_indices[indices_count:indices_count +
                                      num_shard_elts] - count
          # Handle empty case where no data from this shard needed
          X_sel = X[shard_inds]
          # Handle the case of datasets with y/w missing
          if y is not None:
@@ -2158,16 +2171,42 @@ class DiskDataset(Dataset):
          else:
            w_sel = None
          ids_sel = ids[shard_inds]
        yield (X_sel, y_sel, w_sel, ids_sel)
        # Updating counts
          Xs.append(X_sel)
          ys.append(y_sel)
          ws.append(w_sel)
          ids_s.append(ids_sel)
          indices_count += num_shard_elts
          count += shard_len
        # Break when all indices have been used up already
        if indices_count >= len(indices):
          return
          # Break if all indices have been used up already
          if indices_count >= len(sorted_indices):
            break
        # Note these will be in the sorted order
        X = np.concatenate(Xs, axis=0)
        y = np.concatenate(ys, axis=0)
        w = np.concatenate(ws, axis=0)
        ids = np.concatenate(ids_s, axis=0)
        # We need to recover the original ordering. We can do this by using
        # np.where to find the locatios of the original indices in the sorted
        # indices.
        reverted_indices = np.array(
            # We know there's only one match for np.where since this is a
            # permutation, so the [0][0] pulls out the exact match location.
            [
                np.where(sorted_indices == orig_index)[0][0]
                for orig_index in select_shard_indices
            ])
        X, y, w, ids = X[reverted_indices], y[reverted_indices], w[
            reverted_indices], ids[reverted_indices]
        yield (X, y, w, ids)
        start = end
        select_shard_num += 1

    if not output_numpy_dataset:
      return DiskDataset.create_dataset(
          generator(), data_dir=select_dir, tasks=tasks)
    else:
      X, y, w, ids = next(generator())
      return NumpyDataset(X, y, w, ids)

  @property
  def ids(self) -> np.ndarray:
@@ -2247,14 +2286,14 @@ class DiskDataset(Dataset):
      total += len(y)
    return total

  def get_shape(self) -> Tuple[Shape, Shape, Shape, Shape]:
    """Finds shape of dataset."""
  def _get_shard_shape(self,
                       shard_num: int) -> Tuple[Shape, Shape, Shape, Shape]:
    """Finds the shape of the specified shard."""
    if self.legacy_metadata:
      raise ValueError(
          "This function requires the new metadata format to be called. Please reshard this dataset by calling the reshard() method."
      )
    n_tasks = len(self.get_task_names())
    n_rows = len(self.metadata_df.index)
    # If shape metadata is available use it to directly compute shape from
    # metadata
    if not self.legacy_metadata:
      for shard_num in range(n_rows):
    row = self.metadata_df.iloc[shard_num]
    if row['X_shape'] is not None:
      shard_X_shape = make_tuple(str(row['X_shape']))
@@ -2276,6 +2315,21 @@ class DiskDataset(Dataset):
      shard_ids_shape = make_tuple(str(row['ids_shape']))
    else:
      shard_ids_shape = tuple()
    X_shape, y_shape, w_shape, ids_shape = tuple(
        np.array(shard_X_shape)), tuple(np.array(shard_y_shape)), tuple(
            np.array(shard_w_shape)), tuple(np.array(shard_ids_shape))
    return X_shape, y_shape, w_shape, ids_shape

  def get_shape(self) -> Tuple[Shape, Shape, Shape, Shape]:
    """Finds shape of dataset."""
    n_tasks = len(self.get_task_names())
    n_rows = len(self.metadata_df.index)
    # If shape metadata is available use it to directly compute shape from
    # metadata
    if not self.legacy_metadata:
      for shard_num in range(n_rows):
        shard_X_shape, shard_y_shape, shard_w_shape, shard_ids_shape = self._get_shard_shape(
            shard_num)
        if shard_num == 0:
          X_shape, y_shape, w_shape, ids_shape = np.array(
              shard_X_shape), np.array(shard_y_shape), np.array(
+0 −21
Original line number Diff line number Diff line
@@ -272,27 +272,6 @@ def test_reshard():
  np.testing.assert_array_equal(ids, ids_rr)


def test_select():
  """Test that dataset select works."""
  num_datapoints = 10
  num_features = 10
  num_tasks = 1
  X = np.random.rand(num_datapoints, num_features)
  y = np.random.randint(2, size=(num_datapoints, num_tasks))
  w = np.ones((num_datapoints, num_tasks))
  ids = np.array(["id"] * num_datapoints)
  dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids)

  indices = [0, 4, 5, 8]
  select_dataset = dataset.select(indices)
  X_sel, y_sel, w_sel, ids_sel = (select_dataset.X, select_dataset.y,
                                  select_dataset.w, select_dataset.ids)
  np.testing.assert_array_equal(X[indices], X_sel)
  np.testing.assert_array_equal(y[indices], y_sel)
  np.testing.assert_array_equal(w[indices], w_sel)
  np.testing.assert_array_equal(ids[indices], ids_sel)


def test_complete_shuffle():
  shard_sizes = [1, 2, 3, 4, 5]

+0 −4
Original line number Diff line number Diff line
"""
Tests for ImageDataset class
"""
__author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "MIT"

import unittest
import numpy as np
import deepchem as dc
+130 −0
Original line number Diff line number Diff line
import deepchem as dc
import numpy as np
import os


def test_select():
  """Test that dataset select works."""
  num_datapoints = 10
  num_features = 10
  num_tasks = 1
  X = np.random.rand(num_datapoints, num_features)
  y = np.random.randint(2, size=(num_datapoints, num_tasks))
  w = np.ones((num_datapoints, num_tasks))
  ids = np.array(["id"] * num_datapoints)
  dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids)

  indices = [0, 4, 5, 8]
  select_dataset = dataset.select(indices)
  assert isinstance(select_dataset, dc.data.DiskDataset)
  X_sel, y_sel, w_sel, ids_sel = (select_dataset.X, select_dataset.y,
                                  select_dataset.w, select_dataset.ids)
  np.testing.assert_array_equal(X[indices], X_sel)
  np.testing.assert_array_equal(y[indices], y_sel)
  np.testing.assert_array_equal(w[indices], w_sel)
  np.testing.assert_array_equal(ids[indices], ids_sel)


def test_image_dataset_select():
  """Test that select works on image datasets."""
  path = os.path.join(os.path.dirname(__file__), 'images')
  files = [os.path.join(path, f) for f in os.listdir(path)]
  dataset = dc.data.ImageDataset(files, np.random.random(10))
  indices = [0, 4, 5, 8, 2]
  select_dataset = dataset.select(indices)
  assert isinstance(select_dataset, dc.data.ImageDataset)
  X_sel, y_sel, w_sel, ids_sel = (select_dataset.X, select_dataset.y,
                                  select_dataset.w, select_dataset.ids)
  np.testing.assert_array_equal(dataset.X[indices], X_sel)
  np.testing.assert_array_equal(dataset.y[indices], y_sel)
  np.testing.assert_array_equal(dataset.w[indices], w_sel)
  np.testing.assert_array_equal(dataset.ids[indices], ids_sel)


def test_numpy_dataset_select():
  """Test that dataset select works with numpy dataset."""
  num_datapoints = 10
  num_features = 10
  num_tasks = 1
  X = np.random.rand(num_datapoints, num_features)
  y = np.random.randint(2, size=(num_datapoints, num_tasks))
  w = np.ones((num_datapoints, num_tasks))
  ids = np.array(["id"] * num_datapoints)
  dataset = dc.data.NumpyDataset(X, y, w, ids)

  indices = [0, 4, 5, 8, 2]
  select_dataset = dataset.select(indices)
  assert isinstance(select_dataset, dc.data.NumpyDataset)
  X_sel, y_sel, w_sel, ids_sel = (select_dataset.X, select_dataset.y,
                                  select_dataset.w, select_dataset.ids)
  np.testing.assert_array_equal(X[indices], X_sel)
  np.testing.assert_array_equal(y[indices], y_sel)
  np.testing.assert_array_equal(w[indices], w_sel)
  np.testing.assert_array_equal(ids[indices], ids_sel)


def test_select_multishard():
  """Test that dataset select works with multiple shards."""
  num_datapoints = 100
  num_features = 10
  num_tasks = 1
  X = np.random.rand(num_datapoints, num_features)
  y = np.random.randint(2, size=(num_datapoints, num_tasks))
  w = np.ones((num_datapoints, num_tasks))
  ids = np.array(["id"] * num_datapoints)
  dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids)
  dataset.reshard(shard_size=10)

  indices = [10, 42, 51, 82, 2, 4, 6]
  select_dataset = dataset.select(indices)
  assert isinstance(select_dataset, dc.data.DiskDataset)
  X_sel, y_sel, w_sel, ids_sel = (select_dataset.X, select_dataset.y,
                                  select_dataset.w, select_dataset.ids)
  np.testing.assert_array_equal(X[indices], X_sel)
  np.testing.assert_array_equal(y[indices], y_sel)
  np.testing.assert_array_equal(w[indices], w_sel)
  np.testing.assert_array_equal(ids[indices], ids_sel)


def test_select_not_sorted():
  """Test that dataset select with ids not in sorted order."""
  num_datapoints = 10
  num_features = 10
  num_tasks = 1
  X = np.random.rand(num_datapoints, num_features)
  y = np.random.randint(2, size=(num_datapoints, num_tasks))
  w = np.ones((num_datapoints, num_tasks))
  ids = np.array(["id"] * num_datapoints)
  dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids)

  indices = [4, 2, 8, 5, 0]
  select_dataset = dataset.select(indices)
  assert isinstance(select_dataset, dc.data.DiskDataset)
  X_sel, y_sel, w_sel, ids_sel = (select_dataset.X, select_dataset.y,
                                  select_dataset.w, select_dataset.ids)
  np.testing.assert_array_equal(X[indices], X_sel)
  np.testing.assert_array_equal(y[indices], y_sel)
  np.testing.assert_array_equal(w[indices], w_sel)
  np.testing.assert_array_equal(ids[indices], ids_sel)


def test_select_to_numpy():
  """Test that dataset select works."""
  num_datapoints = 10
  num_features = 10
  num_tasks = 1
  X = np.random.rand(num_datapoints, num_features)
  y = np.random.randint(2, size=(num_datapoints, num_tasks))
  w = np.ones((num_datapoints, num_tasks))
  ids = np.array(["id"] * num_datapoints)
  dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids)

  indices = [0, 4, 5, 8]
  select_dataset = dataset.select(indices, output_numpy_dataset=True)
  assert isinstance(select_dataset, dc.data.NumpyDataset)
  X_sel, y_sel, w_sel, ids_sel = (select_dataset.X, select_dataset.y,
                                  select_dataset.w, select_dataset.ids)
  np.testing.assert_array_equal(X[indices], X_sel)
  np.testing.assert_array_equal(y[indices], y_sel)
  np.testing.assert_array_equal(w[indices], w_sel)
  np.testing.assert_array_equal(ids[indices], ids_sel)
+36 −3
Original line number Diff line number Diff line
@@ -20,6 +20,14 @@ def test_complete_shuffle_one_shard():
  assert shuffled.X.shape == dataset.X.shape
  assert shuffled.y.shape == dataset.y.shape
  assert shuffled.w.shape == dataset.w.shape
  original_indices = dict((id, i) for i, id in enumerate(dataset.ids))
  shuffled_indices = dict((id, i) for i, id in enumerate(shuffled.ids))
  for id in dataset.ids:
    i = original_indices[id]
    j = shuffled_indices[id]
    assert np.array_equal(dataset.X[i], shuffled.X[j])
    assert np.array_equal(dataset.y[i], shuffled.y[j])
    assert np.array_equal(dataset.w[i], shuffled.w[j])


def test_complete_shuffle_multiple_shard():
@@ -34,6 +42,14 @@ def test_complete_shuffle_multiple_shard():
  assert shuffled.X.shape == dataset.X.shape
  assert shuffled.y.shape == dataset.y.shape
  assert shuffled.w.shape == dataset.w.shape
  original_indices = dict((id, i) for i, id in enumerate(dataset.ids))
  shuffled_indices = dict((id, i) for i, id in enumerate(shuffled.ids))
  for id in dataset.ids:
    i = original_indices[id]
    j = shuffled_indices[id]
    assert np.array_equal(dataset.X[i], shuffled.X[j])
    assert np.array_equal(dataset.y[i], shuffled.y[j])
    assert np.array_equal(dataset.w[i], shuffled.w[j])


def test_complete_shuffle_multiple_shard_uneven():
@@ -48,6 +64,14 @@ def test_complete_shuffle_multiple_shard_uneven():
  assert shuffled.X.shape == dataset.X.shape
  assert shuffled.y.shape == dataset.y.shape
  assert shuffled.w.shape == dataset.w.shape
  original_indices = dict((id, i) for i, id in enumerate(dataset.ids))
  shuffled_indices = dict((id, i) for i, id in enumerate(shuffled.ids))
  for id in dataset.ids:
    i = original_indices[id]
    j = shuffled_indices[id]
    assert np.array_equal(dataset.X[i], shuffled.X[j])
    assert np.array_equal(dataset.y[i], shuffled.y[j])
    assert np.array_equal(dataset.w[i], shuffled.w[j])


def test_complete_shuffle():
@@ -66,10 +90,11 @@ def test_complete_shuffle():
                                      dataset.ids)
  orig_len = len(dataset)

  dataset = dataset.complete_shuffle()
  X_new, y_new, w_new, new_ids = (dataset.X, dataset.y, dataset.w, dataset.ids)
  shuffled = dataset.complete_shuffle()
  X_new, y_new, w_new, new_ids = (shuffled.X, shuffled.y, shuffled.w,
                                  shuffled.ids)

  assert len(dataset) == orig_len
  assert len(shuffled) == orig_len
  # The shuffling should have switched up the ordering
  assert not np.array_equal(orig_ids, new_ids)
  # But all the same entries should still be present
@@ -78,6 +103,14 @@ def test_complete_shuffle():
  assert X_orig.shape == X_new.shape
  assert y_orig.shape == y_new.shape
  assert w_orig.shape == w_new.shape
  original_indices = dict((id, i) for i, id in enumerate(dataset.ids))
  shuffled_indices = dict((id, i) for i, id in enumerate(shuffled.ids))
  for id in dataset.ids:
    i = original_indices[id]
    j = shuffled_indices[id]
    assert np.array_equal(dataset.X[i], shuffled.X[j])
    assert np.array_equal(dataset.y[i], shuffled.y[j])
    assert np.array_equal(dataset.w[i], shuffled.w[j])


def test_sparse_shuffle():