Commit b8b7ad31 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

First cut at implementing ordered select

parent d3bb4ee8
Loading
Loading
Loading
Loading
+115 −54
Original line number Diff line number Diff line
@@ -1899,14 +1899,11 @@ class DiskDataset(Dataset):
      shard_num = 0
      while start < N:
        logger.info("Constructing shard %d" % shard_num)
        if start + shard_size < N:
          end = start + shard_size
        else:
          end = N
        end = min(start + shard_size, N)
        shard_indices = perm[start:end]
        # Note that this is in sorted order which doesn't respect the random
        # permutation.
        shard_dataset = self.select(shard_indices)
        shard_dataset = self.select(shard_indices, output_numpy_dataset=True)
        # One bit of trickiness here is that select() will return in sorted
        # order. For example, suppose we'd like these elements in our permuted
        # shard:
@@ -1917,9 +1914,6 @@ class DiskDataset(Dataset):
        #
        # [1, 4, 12, 234]
        #
        # We need to recover the original ordering. We can do this by using
        # np.where to find the locatios of the original indices in the sorted
        # indices.
        sorted_indices = np.array(sorted(shard_indices))
        reverted_indices = np.array(
            # We know there's only one match for np.where since this is a
@@ -2097,15 +2091,21 @@ class DiskDataset(Dataset):
    DiskDataset.write_data_to_disk(self.data_dir, basename, tasks, X, y, w, ids)
    self._cached_shards = None

  def select(self, indices: Sequence[int],
             select_dir: Optional[str] = None) -> "DiskDataset":
  def select(self,
             indices: Sequence[int],
             select_dir: Optional[str] = None,
             select_shard_size: Optional[int] = None,
             output_numpy_dataset: Optional[bool] = False) -> "DiskDataset":
    """Creates a new dataset from a selection of indices from self.

    Note
    ----
    The specified indices will be returned in sorted order. That is, if you
    request that indices `[3, 1, 2]` are returned, you will get a
    `DiskDataset` which contains elements in order `[1, 2, 3]`.
    Examples
    --------
    >>> import numpy as np
    >>> X = np.random.rand(10, 10)
    >>> dataset = dc.data.DiskDataset.from_numpy(X)
    >>> selected = dataset.select([1, 3, 4])
    >>> len(selected)
    3

    Parameters
    ----------
@@ -2114,38 +2114,74 @@ class DiskDataset(Dataset):
    select_dir: Optional[str], (default None)
      Path to new directory that the selected indices will be copied
      to.
    select_shard_size: Optional[int], (default None)
      If specified, the shard-size to use for output selected `DiskDataset`.
      If not output_numpy_dataset, then this is set to this current dataset's
      shard size if not manually specified. 
    output_numpy_dataset: Optional[bool], (default False)
      If True, output an in-memory `NumpyDataset` instead of a `DiskDataset`.
      Note that `select_dir` and `select_shard_size` must be `None` if this
      is `True`

    Returns
    -------
    DiskDataset
      Contains selected indices.
    """
    if output_numpy_dataset and (select_dir is not None or
                                 select_shard_size is not None):
      raise ValueError(
          "If output_numpy_dataset is set, then select_dir and select_shard_size must both be None"
      )
    if output_numpy_dataset:
      # When outputting a NumpyDataset, we have 1 in-memory shard
      select_shard_size = len(indices)
    else:
      if select_dir is not None:
        if not os.path.exists(select_dir):
          os.makedirs(select_dir)
      else:
        select_dir = tempfile.mkdtemp()
      if select_shard_size is None:
        select_shard_size = self.get_shard_size()
    # Handle edge case with empty indices
    if not len(indices):
      if not output_numpy_dataset:
        return DiskDataset.create_dataset([], data_dir=select_dir)
    indices = np.array(sorted(indices)).astype(int)
    tasks = self.get_task_names()
      else:
        return NumpyDataset(
            np.array([]), np.array([]), np.array([]), np.array([]))

    N = len(indices)
    indices = np.array(indices).astype(int)
    tasks = self.get_task_names()
    n_shards = self.get_number_shards()

    # We use two loops here. The outer while loop walks over selection shards
    # (the chunks of the indices to select that should go into separate
    # output shards), while the inner for loop walks over the shards in the
    # source datasets to select out the shard indices from that  source shard
    def generator():
      start = 0
      while start < N:
        end = min(start + select_shard_size, N)
        select_shard_indices = indices[start:end]
        sorted_indices = np.array(sorted(select_shard_indices)).astype(int)

        Xs, ys, ws, ids_s = [], [], [], []
        count, indices_count = 0, 0
        for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
          logger.info("Selecting from shard %d/%d" % (shard_num, n_shards))
          shard_len = len(X)
          # Find indices which rest in this shard
          num_shard_elts = 0
        while indices[indices_count + num_shard_elts] < count + shard_len:
          while sorted_indices[indices_count +
                               num_shard_elts] < count + shard_len:
            num_shard_elts += 1
          if indices_count + num_shard_elts >= len(indices):
            if indices_count + num_shard_elts >= len(sorted_indices):
              break
          # Need to offset indices to fit within shard_size
        shard_inds = indices[indices_count:indices_count +
          shard_inds = sorted_indices[indices_count:indices_count +
                                      num_shard_elts] - count
          X_sel = X[shard_inds]
          # Handle the case of datasets with y/w missing
@@ -2158,16 +2194,41 @@ class DiskDataset(Dataset):
          else:
            w_sel = None
          ids_sel = ids[shard_inds]
        yield (X_sel, y_sel, w_sel, ids_sel)
        # Updating counts
          Xs.append(X_sel)
          ys.append(y_sel)
          ws.append(w_sel)
          ids_s.append(ids_sel)
          indices_count += num_shard_elts
          count += shard_len
          # Break when all indices have been used up already
          if indices_count >= len(indices):
          return
            break
        # Note these will be in the sorted order
        X = np.concatenate(Xs, axis=0)
        y = np.concatenate(ys, axis=0)
        w = np.concatenate(ws, axis=0)
        ids = np.concatenate(ids_s, axis=0)
        # We need to recover the original ordering. We can do this by using
        # np.where to find the locatios of the original indices in the sorted
        # indices.
        reverted_indices = np.array(
            # We know there's only one match for np.where since this is a
            # permutation, so the [0][0] pulls out the exact match location.
            [
                np.where(sorted_indices == orig_index)[0][0]
                for orig_index in select_shard_indices
            ])
        X, y, w, ids = X[reverted_indices], y[reverted_indices], w[
            reverted_indices], ids[reverted_indices]
        yield (X, y, w, ids)
        start = end

    if not output_numpy_dataset:
      return DiskDataset.create_dataset(
          generator(), data_dir=select_dir, tasks=tasks)
    else:
      X, y, w, ids = next(generator())
      return NumpyDataset(X, y, w, ids)

  @property
  def ids(self) -> np.ndarray:
+0 −21
Original line number Diff line number Diff line
@@ -272,27 +272,6 @@ def test_reshard():
  np.testing.assert_array_equal(ids, ids_rr)


def test_select():
  """Test that dataset select works."""
  num_datapoints = 10
  num_features = 10
  num_tasks = 1
  X = np.random.rand(num_datapoints, num_features)
  y = np.random.randint(2, size=(num_datapoints, num_tasks))
  w = np.ones((num_datapoints, num_tasks))
  ids = np.array(["id"] * num_datapoints)
  dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids)

  indices = [0, 4, 5, 8]
  select_dataset = dataset.select(indices)
  X_sel, y_sel, w_sel, ids_sel = (select_dataset.X, select_dataset.y,
                                  select_dataset.w, select_dataset.ids)
  np.testing.assert_array_equal(X[indices], X_sel)
  np.testing.assert_array_equal(y[indices], y_sel)
  np.testing.assert_array_equal(w[indices], w_sel)
  np.testing.assert_array_equal(ids[indices], ids_sel)


def test_complete_shuffle():
  shard_sizes = [1, 2, 3, 4, 5]

+0 −4
Original line number Diff line number Diff line
"""
Tests for ImageDataset class
"""
__author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "MIT"

import unittest
import numpy as np
import deepchem as dc
+130 −0
Original line number Diff line number Diff line
import deepchem as dc
import numpy as np
import os


def test_select():
  """Test that dataset select works."""
  num_datapoints = 10
  num_features = 10
  num_tasks = 1
  X = np.random.rand(num_datapoints, num_features)
  y = np.random.randint(2, size=(num_datapoints, num_tasks))
  w = np.ones((num_datapoints, num_tasks))
  ids = np.array(["id"] * num_datapoints)
  dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids)

  indices = [0, 4, 5, 8]
  select_dataset = dataset.select(indices)
  assert isinstance(select_dataset, dc.data.DiskDataset)
  X_sel, y_sel, w_sel, ids_sel = (select_dataset.X, select_dataset.y,
                                  select_dataset.w, select_dataset.ids)
  np.testing.assert_array_equal(X[indices], X_sel)
  np.testing.assert_array_equal(y[indices], y_sel)
  np.testing.assert_array_equal(w[indices], w_sel)
  np.testing.assert_array_equal(ids[indices], ids_sel)


def test_image_dataset_select():
  """Test that select works on image datasets."""
  path = os.path.join(os.path.dirname(__file__), 'images')
  files = [os.path.join(path, f) for f in os.listdir(path)]
  dataset = dc.data.ImageDataset(files, np.random.random(10))
  indices = [0, 4, 5, 8, 2]
  select_dataset = dataset.select(indices)
  assert isinstance(select_dataset, dc.data.ImageDataset)
  X_sel, y_sel, w_sel, ids_sel = (select_dataset.X, select_dataset.y,
                                  select_dataset.w, select_dataset.ids)
  np.testing.assert_array_equal(dataset.X[indices], X_sel)
  np.testing.assert_array_equal(dataset.y[indices], y_sel)
  np.testing.assert_array_equal(dataset.w[indices], w_sel)
  np.testing.assert_array_equal(dataset.ids[indices], ids_sel)


def test_numpy_dataset_select():
  """Test that dataset select works with numpy dataset."""
  num_datapoints = 10
  num_features = 10
  num_tasks = 1
  X = np.random.rand(num_datapoints, num_features)
  y = np.random.randint(2, size=(num_datapoints, num_tasks))
  w = np.ones((num_datapoints, num_tasks))
  ids = np.array(["id"] * num_datapoints)
  dataset = dc.data.NumpyDataset(X, y, w, ids)

  indices = [0, 4, 5, 8, 2]
  select_dataset = dataset.select(indices)
  assert isinstance(select_dataset, dc.data.NumpyDataset)
  X_sel, y_sel, w_sel, ids_sel = (select_dataset.X, select_dataset.y,
                                  select_dataset.w, select_dataset.ids)
  np.testing.assert_array_equal(X[indices], X_sel)
  np.testing.assert_array_equal(y[indices], y_sel)
  np.testing.assert_array_equal(w[indices], w_sel)
  np.testing.assert_array_equal(ids[indices], ids_sel)


def test_select_multishard():
  """Test that dataset select works with multiple shards."""
  num_datapoints = 100
  num_features = 10
  num_tasks = 1
  X = np.random.rand(num_datapoints, num_features)
  y = np.random.randint(2, size=(num_datapoints, num_tasks))
  w = np.ones((num_datapoints, num_tasks))
  ids = np.array(["id"] * num_datapoints)
  dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids)
  dataset.reshard(shard_size=10)

  indices = [10, 42, 51, 82, 2, 4, 6]
  select_dataset = dataset.select(indices)
  assert isinstance(select_dataset, dc.data.DiskDataset)
  X_sel, y_sel, w_sel, ids_sel = (select_dataset.X, select_dataset.y,
                                  select_dataset.w, select_dataset.ids)
  np.testing.assert_array_equal(X[indices], X_sel)
  np.testing.assert_array_equal(y[indices], y_sel)
  np.testing.assert_array_equal(w[indices], w_sel)
  np.testing.assert_array_equal(ids[indices], ids_sel)


def test_select_not_sorted():
  """Test that dataset select with ids not in sorted order."""
  num_datapoints = 10
  num_features = 10
  num_tasks = 1
  X = np.random.rand(num_datapoints, num_features)
  y = np.random.randint(2, size=(num_datapoints, num_tasks))
  w = np.ones((num_datapoints, num_tasks))
  ids = np.array(["id"] * num_datapoints)
  dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids)

  indices = [4, 2, 8, 5, 0]
  select_dataset = dataset.select(indices)
  assert isinstance(select_dataset, dc.data.DiskDataset)
  X_sel, y_sel, w_sel, ids_sel = (select_dataset.X, select_dataset.y,
                                  select_dataset.w, select_dataset.ids)
  np.testing.assert_array_equal(X[indices], X_sel)
  np.testing.assert_array_equal(y[indices], y_sel)
  np.testing.assert_array_equal(w[indices], w_sel)
  np.testing.assert_array_equal(ids[indices], ids_sel)


def test_select_to_numpy():
  """Test that dataset select works."""
  num_datapoints = 10
  num_features = 10
  num_tasks = 1
  X = np.random.rand(num_datapoints, num_features)
  y = np.random.randint(2, size=(num_datapoints, num_tasks))
  w = np.ones((num_datapoints, num_tasks))
  ids = np.array(["id"] * num_datapoints)
  dataset = dc.data.DiskDataset.from_numpy(X, y, w, ids)

  indices = [0, 4, 5, 8]
  select_dataset = dataset.select(indices, output_numpy_dataset=True)
  assert isinstance(select_dataset, dc.data.NumpyDataset)
  X_sel, y_sel, w_sel, ids_sel = (select_dataset.X, select_dataset.y,
                                  select_dataset.w, select_dataset.ids)
  np.testing.assert_array_equal(X[indices], X_sel)
  np.testing.assert_array_equal(y[indices], y_sel)
  np.testing.assert_array_equal(w[indices], w_sel)
  np.testing.assert_array_equal(ids[indices], ids_sel)