Unverified Commit a4173be9 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1091 from peastman/iterator

Added Dataset.make_iterator()
parents 6f86377d ae324f1f
Loading
Loading
Loading
Loading
+43 −0
Original line number Diff line number Diff line
@@ -260,6 +260,49 @@ class Dataset(object):
    else:
      return None

  def make_iterator(self,
                    batch_size=100,
                    epochs=1,
                    deterministic=False,
                    pad_batches=False):
    """Create a tf.data.Iterator that iterates over the data in this Dataset.

    The iterator's get_next() method returns a tuple of three tensors (X, y, w)
    which can be used to retrieve the features, labels, and weights respectively.

    Parameters
    ----------
    batch_size: int
      the number of samples to include in each batch
    epochs: int
      the number of times to iterate over the Dataset
    deterministic: bool
      if True, the data is produced in order.  If False, a different random
      permutation of the data is used for each epoch.
    pad_batches: bool
      if True, batches are padded as necessary to make the size of each batch
      exactly equal batch_size.
    """
    # Retrieve the first sample so we can determine the dtypes.

    import tensorflow as tf
    X, y, w, ids = next(self.itersamples())
    dtypes = (tf.as_dtype(X.dtype), tf.as_dtype(y.dtype), tf.as_dtype(w.dtype))
    shapes = (tf.TensorShape([None] + list(X.shape)),
              tf.TensorShape([None] + list(y.shape)),
              tf.TensorShape([None] + list(w.shape)))

    # Create a Tensorflow Dataset and have it create an Iterator.

    def gen_data():
      for epoch in range(epochs):
        for X, y, w, ids in self.iterbatches(batch_size, epoch, deterministic,
                                             pad_batches):
          yield (X, y, w)

    dataset = tf.data.Dataset.from_generator(gen_data, dtypes, shapes)
    return dataset.make_one_shot_iterator()


class NumpyDataset(Dataset):
  """A Dataset defined by in-memory numpy arrays."""
+25 −1
Original line number Diff line number Diff line
@@ -17,9 +17,11 @@ import os
import shutil
import numpy as np
import deepchem as dc
import tensorflow as tf
from tensorflow.python.framework import test_util


class TestDatasets(unittest.TestCase):
class TestDatasets(test_util.TensorFlowTestCase):
  """
  Test basic top-level API for dataset objects.
  """
@@ -684,6 +686,28 @@ class TestDatasets(unittest.TestCase):
    assert new_data.y.shape == (num_datapoints * num_datasets, num_tasks)
    assert len(new_data.tasks) == len(datasets[0].tasks)

  def test_make_iterator(self):
    """Test creating a Tensorflow Iterator from a Dataset."""
    X = np.random.random((100, 5))
    y = np.random.random((100, 1))
    dataset = dc.data.NumpyDataset(X, y)
    iterator = dataset.make_iterator(
        batch_size=10, epochs=2, deterministic=True)
    next_element = iterator.get_next()
    with self.test_session() as sess:
      for i in range(20):
        batch_X, batch_y, batch_w = sess.run(next_element)
        offset = (i % 10) * 10
        np.testing.assert_array_equal(X[offset:offset + 10, :], batch_X)
        np.testing.assert_array_equal(y[offset:offset + 10, :], batch_y)
        np.testing.assert_array_equal(np.ones((10, 1)), batch_w)
      finished = False
      try:
        sess.run(next_element)
      except tf.errors.OutOfRangeError:
        finished = True
    assert finished


if __name__ == "__main__":
  unittest.main()