Commit 1cb76ba4 authored by peastman's avatar peastman
Browse files

Created ImageDataset

parent 787bffa4
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ from deepchem.data.datasets import pad_batch
from deepchem.data.datasets import Dataset
from deepchem.data.datasets import NumpyDataset
from deepchem.data.datasets import DiskDataset
from deepchem.data.datasets import ImageDataset
from deepchem.data.datasets import sparsify_features
from deepchem.data.datasets import densify_features
from deepchem.data.supports import *
+195 −0
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@ from __future__ import unicode_literals
import json
import os
import math
import deepchem as dc
import numpy as np
import pandas as pd
import random
@@ -1339,6 +1340,200 @@ class DiskDataset(Dataset):
    return self.metadata_df["y_stds"]


class ImageDataset(Dataset):
  """A Dataset that loads data from image files on disk."""

  def __init__(self, X, y, w=None, ids=None):
    """Create a dataset whose X and/or y array is defined by image files on disk.

    Parameters
    ----------
    X: ndarray or list of strings
      The dataset's input data.  This may be either a single NumPy array directly
      containing the data, or a list containing the paths to the image files
    y: ndarray or list of strings
      The dataset's labels.  This may be either a single NumPy array directly
      containing the data, or a list containing the paths to the image files
    w: ndarray
      a 1D or 2D array containing the weights for each sample or sample/task pair
    ids: ndarray
      the sample IDs
    """
    self._X_shape = self._find_array_shape(X)
    self._y_shape = self._find_array_shape(y)
    n_samples = len(X)
    if w is None:
      w = np.ones(self._y_shape[:2])
    if ids is None:
      if not isinstance(X, np.ndarray):
        ids = X
      elif not isinstance(y, np.ndarray):
        ids = y
      else:
        ids = np.arange(n_samples)
    self._X = X
    self._y = y
    self._w = w
    self._ids = np.array(ids, dtype=object)

  def _find_array_shape(self, array):
    if isinstance(array, np.ndarray):
      return array.shape
    image_shape = dc.data.ImageLoader.load_img([array[0]]).shape[1:]
    return np.concatenate([[len(array)], image_shape])

  def __len__(self):
    """
    Get the number of elements in the dataset.
    """
    return self._X_shape[0]

  def get_shape(self):
    """Get the shape of the dataset.

    Returns four tuples, giving the shape of the X, y, w, and ids arrays.
    """
    return self._X_shape, self._y_shape, self._w.shape, self._ids.shape

  def get_task_names(self):
    """Get the names of the tasks associated with this dataset."""
    if len(self._y_shape) < 2:
      return np.array([0])
    return np.arange(self._y_shape[1])

  @property
  def X(self):
    """Get the X vector for this dataset as a single numpy array."""
    if isinstance(self._X, np.ndarray):
      return self._X
    return dc.data.ImageLoader.load_img(self._X)

  @property
  def y(self):
    """Get the y vector for this dataset as a single numpy array."""
    if isinstance(self._y, np.ndarray):
      return self._y
    return dc.data.ImageLoader.load_img(self._y)

  @property
  def ids(self):
    """Get the ids vector for this dataset as a single numpy array."""
    return self._ids

  @property
  def w(self):
    """Get the weight vector for this dataset as a single numpy array."""
    return self._w

  def iterbatches(self,
                  batch_size=None,
                  epoch=0,
                  deterministic=False,
                  pad_batches=False):
    """Get an object that iterates over minibatches from the dataset.

    Each minibatch is returned as a tuple of four numpy arrays: (X, y, w, ids).
    """

    def iterate(dataset, batch_size, deterministic, pad_batches):
      n_samples = dataset._X_shape[0]
      if not deterministic:
        sample_perm = np.random.permutation(n_samples)
      else:
        sample_perm = np.arange(n_samples)
      if batch_size is None:
        batch_size = n_samples
      batch_idx = 0
      num_batches = np.math.ceil(n_samples / batch_size)
      while batch_idx < num_batches:
        start = batch_idx * batch_size
        end = min(n_samples, (batch_idx + 1) * batch_size)
        indices = range(start, end)
        perm_indices = sample_perm[indices]
        if isinstance(dataset._X, np.ndarray):
          X_batch = dataset._X[perm_indices]
        else:
          X_batch = dc.data.ImageLoader.load_img(
              [dataset._X[i] for i in perm_indices])
        if isinstance(dataset._y, np.ndarray):
          y_batch = dataset._y[perm_indices]
        else:
          y_batch = dc.data.ImageLoader.load_img(
              [dataset._y[i] for i in perm_indices])
        w_batch = dataset._w[perm_indices]
        ids_batch = dataset._ids[perm_indices]
        if pad_batches:
          (X_batch, y_batch, w_batch, ids_batch) = pad_batch(
              batch_size, X_batch, y_batch, w_batch, ids_batch)
        batch_idx += 1
        yield (X_batch, y_batch, w_batch, ids_batch)

    return iterate(self, batch_size, deterministic, pad_batches)

  def itersamples(self):
    """Get an object that iterates over the samples in the dataset.

    Example:

    >>> dataset = NumpyDataset(np.ones((2,2)))
    >>> for x, y, w, id in dataset.itersamples():
    ...   print(x.tolist(), y.tolist(), w.tolist(), id)
    [1.0, 1.0] [0.0] [0.0] 0
    [1.0, 1.0] [0.0] [0.0] 1
    """

    def get_image(array, index):
      if isinstance(array, np.ndarray):
        return array[index]
      return dc.data.ImageLoader.load_img([array[index]])[0]

    n_samples = self._X_shape[0]
    return ((get_image(self._X, i), get_image(self._y, i), self._w[i],
             self._ids[i]) for i in range(n_samples))

  def transform(self, fn, **args):
    """Construct a new dataset by applying a transformation to every sample in this dataset.

    The argument is a function that can be called as follows:

    >> newx, newy, neww = fn(x, y, w)

    It might be called only once with the whole dataset, or multiple times with
    different subsets of the data.  Each time it is called, it should transform
    the samples and return the transformed data.

    Parameters
    ----------
    fn: function
      A function to apply to each sample in the dataset

    Returns
    -------
    a newly constructed Dataset object
    """
    newx, newy, neww = fn(self.X, self.y, self.w)
    return NumpyDataset(newx, newy, neww, self.ids[:])

  def select(self, indices, select_dir=None):
    """Creates a new dataset from a selection of indices from self.

    TODO(rbharath): select_dir is here due to dc.splits always passing in
    splits.

    Parameters
    ----------
    indices: list
      List of indices to select.
    select_dir: string
      Ignored.
    """
    X = self.X[indices]
    y = self.y[indices]
    w = self.w[indices]
    ids = self.ids[indices]
    return NumpyDataset(X, y, w, ids)


class Databag(object):
  """
  A utility class to iterate through multiple datasets together.
+92 −0
Original line number Diff line number Diff line
"""
Tests for ImageDataset class
"""
from __future__ import division
from __future__ import unicode_literals

__author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "MIT"

import unittest
import numpy as np
import deepchem as dc
import os
from tensorflow.python.framework import test_util


class TestImageDataset(test_util.TensorFlowTestCase):
  """
  Test ImageDataset class.
  """

  def test_load_images(self):
    """Test that ImageDataset loads images."""

    files = [os.path.join('images', f) for f in os.listdir('images')]

    # First try using images for X.

    ds1 = dc.data.ImageDataset(files, np.random.random(10))
    x_shape, y_shape, w_shape, ids_shape = ds1.get_shape()
    np.testing.assert_array_equal([10, 28, 28], x_shape)
    np.testing.assert_array_equal([10], y_shape)
    np.testing.assert_array_equal([10], w_shape)
    np.testing.assert_array_equal([10], ids_shape)
    np.testing.assert_array_equal(ds1.X.shape, x_shape)
    np.testing.assert_array_equal(ds1.y.shape, y_shape)
    np.testing.assert_array_equal(ds1.w.shape, w_shape)
    np.testing.assert_array_equal(ds1.ids.shape, ids_shape)

    # Now try using images for y.

    ds2 = dc.data.ImageDataset(np.random.random(10), files)
    x_shape, y_shape, w_shape, ids_shape = ds2.get_shape()
    np.testing.assert_array_equal([10], x_shape)
    np.testing.assert_array_equal([10, 28, 28], y_shape)
    np.testing.assert_array_equal([10, 28], w_shape)
    np.testing.assert_array_equal([10], ids_shape)
    np.testing.assert_array_equal(ds2.X.shape, x_shape)
    np.testing.assert_array_equal(ds2.y.shape, y_shape)
    np.testing.assert_array_equal(ds2.w.shape, w_shape)
    np.testing.assert_array_equal(ds2.ids.shape, ids_shape)
    np.testing.assert_array_equal(ds1.X, ds2.y)

  def test_itersamples(self):
    """Test iterating samples of an ImageDataset."""

    files = [os.path.join('images', f) for f in os.listdir('images')]
    ds = dc.data.ImageDataset(files, np.random.random(10))
    X = ds.X
    i = 0
    for x, y, w, id in ds.itersamples():
      np.testing.assert_array_equal(x, X[i])
      assert y == ds.y[i]
      assert w == ds.w[i]
      assert id == ds.ids[i]
      i += 1
    assert i == 10

  def test_iterbatches(self):
    """Test iterating batches of an ImageDataset."""

    files = [os.path.join('images', f) for f in os.listdir('images')]
    ds = dc.data.ImageDataset(files, np.random.random(10))
    X = ds.X
    iterated_ids = set()
    for x, y, w, ids in ds.iterbatches(2):
      np.testing.assert_array_equal([2, 28, 28], x.shape)
      np.testing.assert_array_equal([2], y.shape)
      np.testing.assert_array_equal([2], w.shape)
      np.testing.assert_array_equal([2], ids.shape)
      for i in (0, 1):
        assert ids[i] in files
        assert ids[i] not in iterated_ids
        iterated_ids.add(ids[i])
        index = files.index(ids[i])
        np.testing.assert_array_equal(x[i], X[index])
    assert len(iterated_ids) == 10


if __name__ == "__main__":
  unittest.main()