Commit b982095f authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Adding in indexing capabilities

parent e06055e7
Loading
Loading
Loading
Loading
+112 −10
Original line number Diff line number Diff line
@@ -753,6 +753,21 @@ class NumpyDataset(Dataset):
    """Get the number of elements in the dataset."""
    return len(self._y)

  def __getitem__(self,
                  index: Union[int, slice]) -> Union[Tuple, Iterator[Batch]]:
    """Implements indexing for this object."""
    if isinstance(index, int):
      if index < 0 or index >= len(self):
        raise IndexError("Index %d is out of bounds" % index)
      return (self._X[index], self._y[index], self._w[index], self._ids[index])
    else:
      if index.step is not None:
        indices = range(index.start, index.stop, index.step)
      else:
        indices = range(index.start, index.stop)
      return (
          (self._X[i], self._y[i], self._w[i], self._ids[i]) for i in indices)

  def get_shape(self) -> Tuple[Shape, Shape, Shape, Shape]:
    """Get the shape of the dataset.

@@ -1050,7 +1065,7 @@ class DiskDataset(Dataset):

  On disk, a `DiskDataset` has a simple structure. All files for a given
  `DiskDataset` are stored in a `data_dir`. The contents of `data_dir` should
  be laid out as follows:
  be laid out as follows::

  data_dir/
    |
@@ -1676,6 +1691,55 @@ class DiskDataset(Dataset):

    return iterate(self, batch_size, epochs)

  def _get_shard_for_index(self, index: int) -> Tuple[int, int]:
    """Get the shard and shard index for a particular datapoint."""
    if self.legacy_metadata:
      raise ValueError(
          "Indexing not supported for legacy metadata. Please convert metadata by calling self.reshard()"
      )
    if index < 0 or index >= len(self):
      raise IndexError("Index %d is out of bounds" % index)
    start = 0
    for ind, row in self.metadata_df.iterrows():
      if row['ids_shape'] is not None:
        shard_ids_shape = make_tuple(str(row['ids_shape']))
      else:
        continue
      if len(shard_ids_shape) == 0:
        end = start
      else:
        end = start + shard_ids_shape[0]
      if index >= start and index < end:
        return (ind, index - start)
      start = end
    # We shouldn't get here
    raise ValueError("Malformed metadata detected.")

  def __getitem__(self,
                  index: Union[int, slice]) -> Union[Tuple, Iterator[Batch]]:
    """Implements indexing for this object."""
    if isinstance(index, int):
      if index < 0 or index >= len(self):
        raise IndexError("Index %d is out of bounds" % index)
      shard_num, index_in_shard = self._get_shard_for_index(index)
      (X_shard, y_shard, w_shard, ids_shard) = self.get_shard(shard_num)
      return (X_shard[index_in_shard], y_shard[index_in_shard],
              w_shard[index_in_shard], ids_shard[index_in_shard])
    else:
      if index.step is not None:
        indices = range(index.start, index.stop, index.step)
      else:
        indices = range(index.start, index.stop)
      # Caching makes this efficient
      def iterate():
        for index in indices:
          shard_num, index_in_shard = self._get_shard_for_index(index)
          (X_shard, y_shard, w_shard, ids_shard) = self.get_shard(shard_num)
          yield (X_shard[index_in_shard], y_shard[index_in_shard],
                 w_shard[index_in_shard], ids_shard[index_in_shard])

      return iterate()

  def itersamples(self) -> Iterator[Batch]:
    """Get an object that iterates over the samples in the dataset.

@@ -2490,12 +2554,29 @@ class DiskDataset(Dataset):
      self._cached_shards = None

  def __len__(self) -> int:
    """Finds number of elements in dataset."""
    """
    Finds number of elements in dataset.
    """
    if self.legacy_metadata:
      total = 0
      for _, row in self.metadata_df.iterrows():
        y = load_from_disk(os.path.join(self.data_dir, row['ids']))
        total += len(y)
      return total
    else:
      total = 0
      n_rows = len(self.metadata_df.index)
      for shard_num in range(n_rows):
        row = self.metadata_df.iloc[shard_num]
        if row['ids_shape'] is not None:
          shard_ids_shape = make_tuple(str(row['ids_shape']))
        else:
          continue
        if len(shard_ids_shape) == 0:
          total += 0
        else:
          total += shard_ids_shape[0]
      return total

  def _get_shard_shape(self,
                       shard_num: int) -> Tuple[Shape, Shape, Shape, Shape]:
@@ -2746,9 +2827,30 @@ class ImageDataset(Dataset):

    return iterate(self, batch_size, epochs, deterministic, pad_batches)

  def _get_image(self, array: Union[np.ndarray, List[str]],
                 indices: Union[int, np.ndarray]) -> np.ndarray:
    """Method for loading an image
  def __getitem__(self,
                  index: Union[int, slice]) -> Union[Tuple, Iterator[Batch]]:
    """Implements indexing for this object."""

    def get_image(array, index):
      if isinstance(array, np.ndarray):
        return array[index]
      return dc.data.ImageLoader.load_img([array[index]])[0]

    if isinstance(index, int):
      if index < 0 or index >= len(self):
        raise IndexError("Index %d is out of bounds" % index)
      return (get_image(self._X, index), get_image(self._y, index),
              self._w[index], self._ids[index])
    else:
      if index.step is not None:
        indices = range(index.start, index.stop, index.step)
      else:
        indices = range(index.start, index.stop)
      return ((get_image(self._X, i), get_image(self._y, i), self._w[i],
               self._ids[i]) for i in indices)

  def itersamples(self) -> Iterator[Batch]:
    """Get an object that iterates over the samples in the dataset.

    Parameters
    ----------
+85 −0
Original line number Diff line number Diff line
import deepchem as dc
import numpy as np
import os


def test_numpy_getindex_int():
  """Test __getitem__ on int for NumpyDataset"""
  X = np.random.rand(10, 10)
  y = np.random.rand(10,)
  w = np.random.rand(10,)
  dataset = dc.data.NumpyDataset(X, y, w)
  xi, yi, wi, idsi = dataset[5]
  assert np.all(xi == X[5])
  assert np.all(yi == y[5])
  assert np.all(wi == w[5])
  assert idsi == 5


def test_numpy_getindex_slice():
  """Test __getitem__ on int for NumpyDataset"""
  X = np.random.rand(10, 10)
  y = np.random.rand(10,)
  w = np.random.rand(10,)
  dataset = dc.data.NumpyDataset(X, y, w)
  start = 3
  for (xi, yi, wi, idsi) in dataset[3:5]:
    assert np.all(xi == X[start])
    assert np.all(yi == y[start])
    assert np.all(wi == w[start])
    assert idsi == start
    start += 1


def test_disk_getindex_int():
  """Test __getitem__ on int for DiskDataset"""
  X = np.random.rand(10, 10)
  y = np.random.rand(10,)
  w = np.random.rand(10,)
  dataset = dc.data.DiskDataset.from_numpy(X, y, w)
  xi, yi, wi, idsi = dataset[5]
  assert np.all(xi == X[5])
  assert np.all(yi == y[5])
  assert np.all(wi == w[5])
  assert idsi == 5


def test_disk_getindex_slice():
  """Test __getitem__ on slice for DiskDataset"""
  X = np.random.rand(10, 10)
  y = np.random.rand(10,)
  w = np.random.rand(10,)
  dataset = dc.data.DiskDataset.from_numpy(X, y, w)
  start = 3
  for (xi, yi, wi, idsi) in dataset[3:5]:
    assert np.all(xi == X[start])
    assert np.all(yi == y[start])
    assert np.all(wi == w[start])
    assert idsi == start
    start += 1


def test_image_getindex_int():
  """Test __getitem__ on int for ImageDataset"""
  path = os.path.join(os.path.dirname(__file__), 'images')
  files = [os.path.join(path, f) for f in os.listdir(path)]
  ds = dc.data.ImageDataset(files, np.random.random(10))
  xi, yi, wi, idsi = ds[5]
  assert np.all(xi == ds.X[5])
  assert np.all(yi == ds.y[5])
  assert np.all(wi == ds.w[5])
  assert np.all(idsi == ds.ids[5])


def test_image_getindex_slice():
  """Test __getitem__ on slice for ImageDataset"""
  path = os.path.join(os.path.dirname(__file__), 'images')
  files = [os.path.join(path, f) for f in os.listdir(path)]
  ds = dc.data.ImageDataset(files, np.random.random(10))
  start = 3
  for (xi, yi, wi, idsi) in ds[3:5]:
    assert np.all(xi == ds.X[start])
    assert np.all(yi == ds.y[start])
    assert np.all(wi == ds.w[start])
    assert idsi == ds.ids[start]
    start += 1