Unverified Commit 6bdba2e0 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2107 from peastman/pytorch

make_pytorch_dataset() can return batches
parents e0537202 630aca5c
Loading
Loading
Loading
Loading
+54 −21
Original line number Diff line number Diff line
@@ -497,11 +497,14 @@ class Dataset(object):

    return tf.data.Dataset.from_generator(gen_data, dtypes, shapes)

  def make_pytorch_dataset(self, epochs: int = 1, deterministic: bool = False):
  def make_pytorch_dataset(self,
                           epochs: int = 1,
                           deterministic: bool = False,
                           batch_size: int = None):
    """Create a torch.utils.data.IterableDataset that iterates over the data in this Dataset.

    Each value returned by the Dataset's iterator is a tuple of (X, y,
    w, id) for one sample.
    Each value returned by the Dataset's iterator is a tuple of (X, y, w, id)
    containing the data for one batch, or for a single sample if batch_size is None.

    Parameters
    ----------
@@ -510,6 +513,9 @@ class Dataset(object):
    deterministic: bool
      if True, the data is produced in order.  If False, a different
      random permutation of the data is used for each epoch.
    batch_size: int
      the number of samples to return in each batch.  If None, each returned
      value is a single sample.

    Returns
    -------
@@ -855,19 +861,25 @@ class NumpyDataset(Dataset):
    ids = self.ids[indices]
    return NumpyDataset(X, y, w, ids)

  def make_pytorch_dataset(self, epochs: int = 1, deterministic: bool = False):
  def make_pytorch_dataset(self,
                           epochs: int = 1,
                           deterministic: bool = False,
                           batch_size: int = None):
    """Create a torch.utils.data.IterableDataset that iterates over the data in this Dataset.

    Each value returned by the Dataset's iterator is a tuple of (X, y, w, id) for
    one sample.
    Each value returned by the Dataset's iterator is a tuple of (X, y, w, id)
    containing the data for one batch, or for a single sample if batch_size is None.

    Parameters
    ----------
    epochs: int
      the number of times to iterate over the Dataset
    deterministic: bool
      if True, the data is produced in order.  If False, a different random
      permutation of the data is used for each epoch.
      if True, the data is produced in order.  If False, a different
      random permutation of the data is used for each epoch.
    batch_size: int
      the number of samples to return in each batch.  If None, each returned
      value is a single sample.

    Returns
    -------
@@ -881,7 +893,10 @@ class NumpyDataset(Dataset):
      raise ValueError("This method requires PyTorch to be installed.")

    pytorch_ds = _TorchNumpyDataset(
        numpy_dataset=self, epochs=epochs, deterministic=deterministic)
        numpy_dataset=self,
        epochs=epochs,
        deterministic=deterministic,
        batch_size=batch_size)
    return pytorch_ds

  @staticmethod
@@ -1684,19 +1699,25 @@ class DiskDataset(Dataset):
    return DiskDataset.write_data_to_disk(out_dir, basename, tasks, X, y, w,
                                          ids)

  def make_pytorch_dataset(self, epochs: int = 1, deterministic: bool = False):
  def make_pytorch_dataset(self,
                           epochs: int = 1,
                           deterministic: bool = False,
                           batch_size: int = None):
    """Create a torch.utils.data.IterableDataset that iterates over the data in this Dataset.

    Each value returned by the Dataset's iterator is a tuple of (X, y, w, id) for
    one sample.
    Each value returned by the Dataset's iterator is a tuple of (X, y, w, id)
    containing the data for one batch, or for a single sample if batch_size is None.

    Parameters
    ----------
    epochs: int
      the number of times to iterate over the Dataset
    deterministic: bool
      if True, the data is produced in order.  If False, a different random
      permutation of the data is used for each epoch.
      if True, the data is produced in order.  If False, a different
      random permutation of the data is used for each epoch.
    batch_size: int
      the number of samples to return in each batch.  If None, each returned
      value is a single sample.

    Returns
    -------
@@ -1710,7 +1731,10 @@ class DiskDataset(Dataset):
      raise ValueError("This method requires PyTorch to be installed.")

    pytorch_ds = _TorchDiskDataset(
        disk_dataset=self, epochs=epochs, deterministic=deterministic)
        disk_dataset=self,
        epochs=epochs,
        deterministic=deterministic,
        batch_size=batch_size)
    return pytorch_ds

  @staticmethod
@@ -2589,11 +2613,14 @@ class ImageDataset(Dataset):
    ids = self._ids[indices]
    return ImageDataset(X, y, w, ids)

  def make_pytorch_dataset(self, epochs: int = 1, deterministic: bool = False):
  def make_pytorch_dataset(self,
                           epochs: int = 1,
                           deterministic: bool = False,
                           batch_size: int = None):
    """Create a torch.utils.data.IterableDataset that iterates over the data in this Dataset.

    Each value returned by the Dataset's iterator is a tuple of (X, y,
    w, id) for one sample.
    Each value returned by the Dataset's iterator is a tuple of (X, y, w, id)
    containing the data for one batch, or for a single sample if batch_size is None.

    Parameters
    ----------
@@ -2602,6 +2629,9 @@ class ImageDataset(Dataset):
    deterministic: bool
      if True, the data is produced in order.  If False, a different
      random permutation of the data is used for each epoch.
    batch_size: int
      the number of samples to return in each batch.  If None, each returned
      value is a single sample.

    Returns
    -------
@@ -2615,7 +2645,10 @@ class ImageDataset(Dataset):
      raise ValueError("This method requires PyTorch to be installed.")

    pytorch_ds = _TorchImageDataset(
        image_dataset=self, epochs=epochs, deterministic=deterministic)
        image_dataset=self,
        epochs=epochs,
        deterministic=deterministic,
        batch_size=batch_size)
    return pytorch_ds


+68 −21
Original line number Diff line number Diff line
@@ -8,8 +8,11 @@ from deepchem.data.datasets import NumpyDataset, DiskDataset, ImageDataset

class _TorchNumpyDataset(torch.utils.data.IterableDataset):  # type: ignore

  def __init__(self, numpy_dataset: NumpyDataset, epochs: int,
               deterministic: bool):
  def __init__(self,
               numpy_dataset: NumpyDataset,
               epochs: int,
               deterministic: bool,
               batch_size: int = None):
    """
    Parameters
    ----------
@@ -20,10 +23,14 @@ class _TorchNumpyDataset(torch.utils.data.IterableDataset): # type: ignore
    deterministic: bool
      if True, the data is produced in order.  If False, a different random
      permutation of the data is used for each epoch.
    batch_size: int
      the number of samples to return in each batch.  If None, each returned
      value is a single sample.
    """
    self.numpy_dataset = numpy_dataset
    self.epochs = epochs
    self.deterministic = deterministic
    self.batch_size = batch_size

  def __iter__(self):
    n_samples = self.numpy_dataset._X.shape[0]
@@ -38,16 +45,28 @@ class _TorchNumpyDataset(torch.utils.data.IterableDataset): # type: ignore
      if self.deterministic:
        order = first_sample + np.arange(last_sample - first_sample)
      else:
        order = first_sample + np.random.permutation(last_sample - first_sample)
        # Ensure that every worker will pick the same random order for each epoch.
        random = np.random.RandomState(epoch)
        order = random.permutation(n_samples)[first_sample:last_sample]
      if self.batch_size is None:
        for i in order:
          yield (self.numpy_dataset._X[i], self.numpy_dataset._y[i],
                 self.numpy_dataset._w[i], self.numpy_dataset._ids[i])
      else:
        for i in range(0, len(order), self.batch_size):
          indices = order[i:i + self.batch_size]
          yield (self.numpy_dataset._X[indices], self.numpy_dataset._y[indices],
                 self.numpy_dataset._w[indices],
                 self.numpy_dataset._ids[indices])


class _TorchDiskDataset(torch.utils.data.IterableDataset):  # type: ignore

  def __init__(self, disk_dataset: DiskDataset, epochs: int,
               deterministic: bool):
  def __init__(self,
               disk_dataset: DiskDataset,
               epochs: int,
               deterministic: bool,
               batch_size: int = None):
    """
    Parameters
    ----------
@@ -58,10 +77,14 @@ class _TorchDiskDataset(torch.utils.data.IterableDataset): # type: ignore
    deterministic: bool
      if True, the data is produced in order.  If False, a different random
      permutation of the data is used for each epoch.
    batch_size: int
      the number of samples to return in each batch.  If None, each returned
      value is a single sample.
    """
    self.disk_dataset = disk_dataset
    self.epochs = epochs
    self.deterministic = deterministic
    self.batch_size = batch_size

  def __iter__(self):
    worker_info = torch.utils.data.get_worker_info()
@@ -76,17 +99,25 @@ class _TorchDiskDataset(torch.utils.data.IterableDataset): # type: ignore
      return

    shard_indices = list(range(first_shard, last_shard))
    for epoch in range(self.epochs):
    for X, y, w, ids in self.disk_dataset._iterbatches_from_shards(
          shard_indices, deterministic=self.deterministic):
        shard_indices,
        batch_size=self.batch_size,
        epochs=self.epochs,
        deterministic=self.deterministic):
      if self.batch_size is None:
        for i in range(X.shape[0]):
          yield (X[i], y[i], w[i], ids[i])
      else:
        yield (X, y, w, ids)


class _TorchImageDataset(torch.utils.data.IterableDataset):  # type: ignore

  def __init__(self, image_dataset: ImageDataset, epochs: int,
               deterministic: bool):
  def __init__(self,
               image_dataset: ImageDataset,
               epochs: int,
               deterministic: bool,
               batch_size: int = None):
    """
    Parameters
    ----------
@@ -97,10 +128,14 @@ class _TorchImageDataset(torch.utils.data.IterableDataset): # type: ignore
    deterministic: bool
      if True, the data is produced in order.  If False, a different random
      permutation of the data is used for each epoch.
    batch_size: int
      the number of samples to return in each batch.  If None, each returned
      value is a single sample.
    """
    self.image_dataset = image_dataset
    self.epochs = epochs
    self.deterministic = deterministic
    self.batch_size = batch_size

  def __iter__(self):
    n_samples = self.image_dataset._X_shape[0]
@@ -115,14 +150,24 @@ class _TorchImageDataset(torch.utils.data.IterableDataset): # type: ignore
      if self.deterministic:
        order = first_sample + np.arange(last_sample - first_sample)
      else:
        order = first_sample + np.random.permutation(last_sample - first_sample)
        # Ensure that every worker will pick the same random order for each epoch.
        random = np.random.RandomState(epoch)
        order = random.permutation(n_samples)[first_sample:last_sample]
      if self.batch_size is None:
        for i in order:
          yield (self._get_image(self.image_dataset._X, i),
                 self._get_image(self.image_dataset._y, i),
                 self.image_dataset._w[i], self.image_dataset._ids[i])
      else:
        for i in range(0, len(order), self.batch_size):
          indices = order[i:i + self.batch_size]
          yield (self._get_image(self.image_dataset._X, indices),
                 self._get_image(self.image_dataset._y,
                                 indices), self.image_dataset._w[indices],
                 self.image_dataset._ids[indices])

  def _get_image(self, array: Union[np.ndarray, List[str]],
                 index: int) -> np.ndarray:
                 indices: int) -> np.ndarray:
    """Method for loading an image

    Parameters
@@ -138,5 +183,7 @@ class _TorchImageDataset(torch.utils.data.IterableDataset): # type: ignore
      Loaded image
    """
    if isinstance(array, np.ndarray):
      return array[index]
    return load_image_files([array[index]])[0]
      return array[indices]
    if isinstance(indices, np.ndarray):
      return load_image_files([array[i] for i in indices])
    return load_image_files([array[indices]])[0]
+17 −0
Original line number Diff line number Diff line
@@ -721,9 +721,26 @@ def _validate_pytorch_dataset(dataset):
    id_count[iter_id] += 1
  assert all(id_count[id] == 2 for id in ids)

  # Test iterating in batches.

  ds = dataset.make_pytorch_dataset(epochs=2, deterministic=False, batch_size=7)
  id_to_index = dict((id, i) for i, id in enumerate(ids))
  id_count = dict((id, 0) for id in ids)
  for iter_X, iter_y, iter_w, iter_id in ds:
    size = len(iter_id)
    assert size <= 7
    for i in range(size):
      j = id_to_index[iter_id[i]]
      np.testing.assert_array_equal(X[j, :], iter_X[i])
      np.testing.assert_array_equal(y[j, :], iter_y[i])
      np.testing.assert_array_equal(w[j, :], iter_w[i])
      id_count[iter_id[i]] += 1
  assert all(id_count[id] == 2 for id in ids)

  # Test iterating with multiple workers.

  import torch
  ds = dataset.make_pytorch_dataset(epochs=2, deterministic=False)
  loader = torch.utils.data.DataLoader(ds, num_workers=3)
  id_count = dict((id, 0) for id in ids)
  for iter_X, iter_y, iter_w, iter_id in loader: