Unverified Commit 376930f3 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1803 from peastman/pytorch

Can create a PyTorch Dataset from a DeepChem Dataset
parents 92919d2b c003b8c3
Loading
Loading
Loading
Loading
+162 −4
Original line number Diff line number Diff line
@@ -305,6 +305,22 @@ class Dataset(object):

    return tf.data.Dataset.from_generator(gen_data, dtypes, shapes)

  def make_pytorch_dataset(self, epochs=1, deterministic=False):
    """Create a torch.utils.data.IterableDataset that iterates over the data in this Dataset.

    Each value returned by the Dataset's iterator is a tuple of (X, y, w, id) for
    one sample.

    Parameters
    ----------
    epochs: int
      the number of times to iterate over the Dataset
    deterministic: bool
      if True, the data is produced in order.  If False, a different random
      permutation of the data is used for each epoch.
    """
    raise NotImplementedError()

  def to_dataframe(self):
    """Construct a pandas DataFrame containing the data from this Dataset."""
    X = self.X
@@ -570,6 +586,48 @@ class NumpyDataset(Dataset):
    ids = self.ids[indices]
    return NumpyDataset(X, y, w, ids)

  def make_pytorch_dataset(self, epochs=1, deterministic=False):
    """Create a torch.utils.data.IterableDataset that iterates over the data in this Dataset.

    Each value returned by the Dataset's iterator is a tuple of (X, y, w, id) for
    one sample.

    Parameters
    ----------
    epochs: int
      the number of times to iterate over the Dataset
    deterministic: bool
      if True, the data is produced in order.  If False, a different random
      permutation of the data is used for each epoch.
    """
    import torch

    def iterate():
      n_samples = self._X.shape[0]
      worker_info = torch.utils.data.get_worker_info()
      if worker_info is None:
        first_sample = 0
        last_sample = n_samples
      else:
        first_sample = worker_info.id * n_samples // worker_info.num_workers
        last_sample = (
            worker_info.id + 1) * n_samples // worker_info.num_workers
      for epoch in range(epochs):
        if deterministic:
          order = first_sample + np.arange(last_sample - first_sample)
        else:
          order = first_sample + np.random.permutation(last_sample -
                                                       first_sample)
        for i in order:
          yield (self._X[i], self._y[i], self._w[i], self._ids[i])

    class TorchDataset(torch.utils.data.IterableDataset):

      def __iter__(self):
        return iterate()

    return TorchDataset()

  @staticmethod
  def from_DiskDataset(ds):
    """
@@ -897,9 +955,19 @@ class DiskDataset(Dataset):


    """
    shard_indices = list(range(self.get_number_shards()))
    return self._iterbatches_from_shards(shard_indices, batch_size,
                                         deterministic, pad_batches)

  def _iterbatches_from_shards(self,
                               shard_indices,
                               batch_size=None,
                               deterministic=False,
                               pad_batches=False):
    """Get an object that iterates over batches from a restricted set of shards."""

    def iterate(dataset, batch_size):
      num_shards = dataset.get_number_shards()
      num_shards = len(shard_indices)
      if not deterministic:
        shard_perm = np.random.permutation(num_shards)
      else:
@@ -910,7 +978,8 @@ class DiskDataset(Dataset):
      # objects as an extra overhead. Also, as hideously as un-thread safe this looks,
      # we're actually protected by the GIL.
      pool = Pool(1)  # mp.dummy aliases ThreadPool to Pool
      next_shard = pool.apply_async(dataset.get_shard, (shard_perm[0],))
      next_shard = pool.apply_async(dataset.get_shard,
                                    (shard_indices[shard_perm[0]],))

      total_yield = 0

@@ -927,8 +996,8 @@ class DiskDataset(Dataset):

        X, y, w, ids = next_shard.get()
        if cur_shard < num_shards - 1:
          next_shard = pool.apply_async(dataset.get_shard,
                                        (shard_perm[cur_shard + 1],))
          next_shard = pool.apply_async(
              dataset.get_shard, (shard_indices[shard_perm[cur_shard + 1]],))
        else:
          pool.close()

@@ -1068,6 +1137,47 @@ class DiskDataset(Dataset):
    return DiskDataset.create_dataset(
        generator(), data_dir=out_dir, tasks=tasks, verbose=verbose)

  def make_pytorch_dataset(self, epochs=1, deterministic=False):
    """Create a torch.utils.data.IterableDataset that iterates over the data in this Dataset.

    Each value returned by the Dataset's iterator is a tuple of (X, y, w, id) for
    one sample.

    Parameters
    ----------
    epochs: int
      the number of times to iterate over the Dataset
    deterministic: bool
      if True, the data is produced in order.  If False, a different random
      permutation of the data is used for each epoch.
    """
    import torch

    def iterate():
      worker_info = torch.utils.data.get_worker_info()
      n_shards = self.get_number_shards()
      if worker_info is None:
        first_shard = 0
        last_shard = n_shards
      else:
        first_shard = worker_info.id * n_shards // worker_info.num_workers
        last_shard = (worker_info.id + 1) * n_shards // worker_info.num_workers
      if first_shard == last_shard:
        return
      shard_indices = list(range(first_shard, last_shard))
      for epoch in range(epochs):
        for X, y, w, ids in self._iterbatches_from_shards(
            shard_indices, deterministic=deterministic):
          for i in range(X.shape[0]):
            yield (X[i], y[i], w[i], ids[i])

    class TorchDataset(torch.utils.data.IterableDataset):

      def __iter__(self):
        return iterate()

    return TorchDataset()

  @staticmethod
  def from_numpy(X,
                 y=None,
@@ -1663,6 +1773,54 @@ class ImageDataset(Dataset):
    ids = self._ids[indices]
    return ImageDataset(X, y, w, ids)

  def make_pytorch_dataset(self, epochs=1, deterministic=False):
    """Create a torch.utils.data.IterableDataset that iterates over the data in this Dataset.

    Each value returned by the Dataset's iterator is a tuple of (X, y, w, id) for
    one sample.

    Parameters
    ----------
    epochs: int
      the number of times to iterate over the Dataset
    deterministic: bool
      if True, the data is produced in order.  If False, a different random
      permutation of the data is used for each epoch.
    """
    import torch

    def get_image(array, index):
      if isinstance(array, np.ndarray):
        return array[index]
      return dc.data.ImageLoader.load_img([array[index]])[0]

    def iterate():
      n_samples = self._X_shape[0]
      worker_info = torch.utils.data.get_worker_info()
      if worker_info is None:
        first_sample = 0
        last_sample = n_samples
      else:
        first_sample = worker_info.id * n_samples // worker_info.num_workers
        last_sample = (
            worker_info.id + 1) * n_samples // worker_info.num_workers
      for epoch in range(epochs):
        if deterministic:
          order = first_sample + np.arange(last_sample - first_sample)
        else:
          order = first_sample + np.random.permutation(last_sample -
                                                       first_sample)
        for i in order:
          yield (get_image(self._X, i), get_image(self._y, i), self._w[i],
                 self._ids[i])

    class TorchDataset(torch.utils.data.IterableDataset):

      def __iter__(self):
        return iterate()

    return TorchDataset()


class Databag(object):
  """
+75 −0
Original line number Diff line number Diff line
@@ -17,6 +17,12 @@ import tensorflow as tf
import pandas as pd
from tensorflow.python.framework import test_util

try:
  import torch
  PYTORCH_IMPORT_FAILED = False
except ImportError:
  PYTORCH_IMPORT_FAILED = True


class TestDatasets(test_util.TensorFlowTestCase):
  """
@@ -697,6 +703,75 @@ class TestDatasets(test_util.TensorFlowTestCase):
      np.testing.assert_array_equal(np.ones((10, 1)), batch_w)
    assert i == 19

  def _validate_pytorch_dataset(self, dataset):
    X = dataset.X
    y = dataset.y
    w = dataset.w
    ids = dataset.ids
    n_samples = X.shape[0]

    # Test iterating in order.

    ds = dataset.make_pytorch_dataset(epochs=2, deterministic=True)
    for i, (iter_X, iter_y, iter_w, iter_id) in enumerate(ds):
      j = i % n_samples
      np.testing.assert_array_equal(X[j, :], iter_X)
      np.testing.assert_array_equal(y[j, :], iter_y)
      np.testing.assert_array_equal(w[j, :], iter_w)
      assert ids[j] == iter_id
    assert i == 2 * n_samples - 1

    # Test iterating out of order.

    ds = dataset.make_pytorch_dataset(epochs=2, deterministic=False)
    id_to_index = dict((id, i) for i, id in enumerate(ids))
    id_count = dict((id, 0) for id in ids)
    for iter_X, iter_y, iter_w, iter_id in ds:
      j = id_to_index[iter_id]
      np.testing.assert_array_equal(X[j, :], iter_X)
      np.testing.assert_array_equal(y[j, :], iter_y)
      np.testing.assert_array_equal(w[j, :], iter_w)
      id_count[iter_id] += 1
    assert all(id_count[id] == 2 for id in ids)

    # Test iterating with multiple workers.

    import torch
    loader = torch.utils.data.DataLoader(ds, num_workers=3)
    id_count = dict((id, 0) for id in ids)
    for iter_X, iter_y, iter_w, iter_id in loader:
      j = id_to_index[iter_id[0]]
      np.testing.assert_array_equal(X[j, :], iter_X[0])
      np.testing.assert_array_equal(y[j, :], iter_y[0])
      np.testing.assert_array_equal(w[j, :], iter_w[0])
      id_count[iter_id[0]] += 1
    assert all(id_count[id] == 2 for id in ids)

  @unittest.skipIf(PYTORCH_IMPORT_FAILED, 'PyTorch is not installed')
  def test_make_pytorch_dataset_from_numpy(self):
    """Test creating a PyTorch Dataset from a NumpyDataset."""
    X = np.random.random((100, 5))
    y = np.random.random((100, 1))
    ids = [str(i) for i in range(100)]
    dataset = dc.data.NumpyDataset(X, y, ids=ids)
    self._validate_pytorch_dataset(dataset)

  @unittest.skipIf(PYTORCH_IMPORT_FAILED, 'PyTorch is not installed')
  def test_make_pytorch_dataset_from_images(self):
    """Test creating a PyTorch Dataset from an ImageDataset."""
    path = os.path.join(os.path.dirname(__file__), 'images')
    files = [os.path.join(path, f) for f in os.listdir(path)]
    y = np.random.random((10, 1))
    ids = [str(i) for i in range(len(files))]
    dataset = dc.data.ImageDataset(files, y, ids=ids)
    self._validate_pytorch_dataset(dataset)

  @unittest.skipIf(PYTORCH_IMPORT_FAILED, 'PyTorch is not installed')
  def test_make_pytorch_dataset_from_disk(self):
    """Test creating a PyTorch Dataset from a DiskDataset."""
    dataset = dc.data.tests.load_solubility_data()
    self._validate_pytorch_dataset(dataset)

  def test_dataframe(self):
    """Test converting between Datasets and DataFrames."""
    dataset = dc.data.tests.load_solubility_data()