Commit fa09de26 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changes

parent d7088698
Loading
Loading
Loading
Loading
+77 −40
Original line number Diff line number Diff line
@@ -218,8 +218,8 @@ class Dataset(object):
  The `Dataset` class attempts to provide for strong interoperability
  with other machine learning representations for datasets.
  Interconversion methods allow for `Dataset` objects to be converted
  to and from pandas dataframes, tensorflow datasets, and pytorch
  datasets (only to and not from for pytorch at present).
  to and from numpy arrays, pandas dataframes, tensorflow datasets,
  and pytorch datasets (only to and not from for pytorch at present).

  Note that you can never instantiate a `Dataset` object directly.
  Instead you will need to instantiate one of the concrete subclasses.
@@ -253,6 +253,13 @@ class Dataset(object):
    Returns
    -------
    Numpy array of features `X`.

    Note
    ----
    If data is stored on disk, accesing this field may involve loading
    data from disk and could potentially be slow. Using
    `iterbatches()` or `itersamples()` may be more efficient for
    larger datasets.
    """
    raise NotImplementedError()

@@ -263,6 +270,13 @@ class Dataset(object):
    Returns
    -------
    Numpy array of labels `y`.

    Note
    ----
    If data is stored on disk, accesing this field may involve loading
    data from disk and could potentially be slow. Using
    `iterbatches()` or `itersamples()` may be more efficient for
    larger datasets.
    """
    raise NotImplementedError()

@@ -273,6 +287,13 @@ class Dataset(object):
    Returns
    -------
    Numpy array of identifiers `ids`.

    Note
    ----
    If data is stored on disk, accesing this field may involve loading
    data from disk and could potentially be slow. Using
    `iterbatches()` or `itersamples()` may be more efficient for
    larger datasets.
    """

    raise NotImplementedError()
@@ -284,18 +305,30 @@ class Dataset(object):
    Returns
    -------
    Numpy array of weights `w`.

    Note
    ----
    If data is stored on disk, accesing this field may involve loading
    data from disk and could potentially be slow. Using
    `iterbatches()` or `itersamples()` may be more efficient for
    larger datasets.
    """
    raise NotImplementedError()

  def __repr__(self):
    """Convert self to REPL print representation."""
    threshold = dc.utils.get_print_threshold()
    id_str = np.array2string(self.ids, threshold=threshold)
    task_str = np.array2string(
        np.array(self.get_task_names()), threshold=threshold)
    if self.__len__() < dc.utils.get_max_print_size():
      id_str = np.array2string(self.ids, threshold=threshold)
      return "<%s X.shape: %s, y.shape: %s, w.shape: %s, ids: %s, task_names: %s>" % (
          self.__class__.__name__, str(self.X.shape), str(self.y.shape),
          str(self.w.shape), id_str, task_str)
    else:
      return "<%s X.shape: %s, y.shape: %s, w.shape: %s, task_names: %s>" % (
          self.__class__.__name__, str(self.X.shape), str(self.y.shape),
          str(self.w.shape), task_str)

  def __str__(self):
    """Convert self to str representation."""
@@ -485,9 +518,13 @@ class Dataset(object):

    Returns
    -------
    pandas datafarme. Will have column "X1,X2,..." for features,
    "y1,y2,..." for labels, "w1,w2,..." for weights, and column "ids"
    for identifiers.
    pandas dataframe. If there is only a single feature per datapoint,
    will have column "X" else will have columns "X1,X2,..." for
    features.  If there is only a single label per datapoint, will
    have column "y" else will have columns "y1,y2,..." for labels. If
    there is only a single weight per datapoint will have column "w"
    else will have columns "w1,w2,...". Will have column "ids" for
    identifiers.
    """
    X = self.X
    y = self.y
@@ -603,7 +640,7 @@ class NumpyDataset(Dataset):
  objects. For example

  >>> import numpy as np
  >>> NumpyDataset(X=np.random.rand(5, 3), y=np.random.rand(5,), ids=np.arange(5))
  >>> dataset = NumpyDataset(X=np.random.rand(5, 3), y=np.random.rand(5,), ids=np.arange(5))
  """

  def __init__(self, X, y=None, w=None, ids=None, n_tasks=1):
@@ -614,9 +651,11 @@ class NumpyDataset(Dataset):
    X: np.ndarray
      Input features. Of shape `(n_samples,...)`
    y: np.ndarray, optional
      Labels. Of shape `(n_samples, n_tasks)` typically.
      Labels. Of shape `(n_samples, ...)`. Note that each label can
      have an arbitrary shape.
    w: np.ndarray, optional
      Weights. Of same shape as `y`.
      Weights. Should either be 1D of shape `(n_samples,)` or if
      there's more than one task, of shape `(n_samples, n_tasks)`.
    ids: np.ndarray, optional
      Identifiers. Of shape `(n_samples,)`
    n_tasks: int, optional
@@ -780,15 +819,13 @@ class NumpyDataset(Dataset):
  def select(self, indices, select_dir=None):
    """Creates a new dataset from a selection of indices from self.

    TODO(rbharath): select_dir is here due to dc.splits always passing in
    splits.

    Parameters
    ----------
    indices: list
      List of indices to select.
    select_dir: string
      Ignored.
      Used to provide same API as `DiskDataset`. Ignored since
      `NumpyDataset` is purely in-memory.
    """
    X = self.X[indices]
    y = self.y[indices]
@@ -1138,31 +1175,26 @@ class DiskDataset(Dataset):
                  epoch=0,
                  deterministic=False,
                  pad_batches=False):
    """ Get an object that iterates over minibatches from the dataset. It is guaranteed
    that the number of batches returned is math.ceil(len(dataset)/batch_size).

    Each minibatch is returned as a tuple of four numpy arrays: (X, y, w, ids).
    """ Get an object that iterates over minibatches from the dataset.

    It is guaranteed that the number of batches returned is
    `math.ceil(len(dataset)/batch_size)`. Each minibatch is returned as
    a tuple of four numpy arrays: `(X, y, w, ids)`.

    Parameters:
    -----------
    batch_size: int
      Number of elements in a batch. If None, then it yields batches with size equal to the size
      of each individual shard.

      Number of elements in a batch. If None, then it yields batches
      with size equal to the size of each individual shard.
    epoch: int
      Not used

    deterministic: bool
      Whether or not we should should shuffle each shard before generating the batches.
      Note that this is only local in the sense that it does not ever mix between different
      shards.

      Whether or not we should should shuffle each shard before
      generating the batches.  Note that this is only local in the
      sense that it does not ever mix between different shards.
    pad_batches: bool
      Whether or not we should pad the last batch, globally, such that it has exactly batch_size
      elements.


      Whether or not we should pad the last batch, globally, such that
      it has exactly batch_size elements.
    """
    shard_indices = list(range(self.get_number_shards()))
    return self._iterbatches_from_shards(shard_indices, batch_size,
@@ -1776,13 +1808,16 @@ class ImageDataset(Dataset):
    Parameters
    ----------
    X: ndarray or list of strings
      The dataset's input data.  This may be either a single NumPy array directly
      containing the data, or a list containing the paths to the image files
      The dataset's input data.  This may be either a single NumPy
      array directly containing the data, or a list containing the
      paths to the image files
    y: ndarray or list of strings
      The dataset's labels.  This may be either a single NumPy array directly
      containing the data, or a list containing the paths to the image files
      The dataset's labels.  This may be either a single NumPy array
      directly containing the data, or a list containing the paths to
      the image files
    w: ndarray
      a 1D or 2D array containing the weights for each sample or sample/task pair
      a 1D or 2D array containing the weights for each sample or
      sample/task pair
    ids: ndarray
      the sample IDs
    """
@@ -1823,7 +1858,8 @@ class ImageDataset(Dataset):
  def get_shape(self):
    """Get the shape of the dataset.

    Returns four tuples, giving the shape of the X, y, w, and ids arrays.
    Returns four tuples, giving the shape of the X, y, w, and ids
    arrays.
    """
    return self._X_shape, self._y_shape, self._w.shape, self._ids.shape

@@ -1864,7 +1900,8 @@ class ImageDataset(Dataset):
                  pad_batches=False):
    """Get an object that iterates over minibatches from the dataset.

    Each minibatch is returned as a tuple of four numpy arrays: (X, y, w, ids).
    Each minibatch is returned as a tuple of four numpy arrays: (X, y,
    w, ids).
    """

    def iterate(dataset, batch_size, deterministic, pad_batches):
@@ -1955,7 +1992,7 @@ class ImageDataset(Dataset):
      List of indices to select.
    select_dir: string
      Used to provide same API as `DiskDataset`. Ignored since
      `NumpYDataset` is purely in-memory.
      `ImageDataset` is purely in-memory.
    """
    if isinstance(self._X, np.ndarray):
      X = self._X[indices]
+7 −0
Original line number Diff line number Diff line
@@ -813,3 +813,10 @@ class TestDatasets(test_util.TensorFlowTestCase):
        X=np.random.rand(50, 3), y=np.random.rand(50, 20), ids=np.arange(50))
    ref_str = '<NumpyDataset X.shape: (50, 3), y.shape: (50, 20), w.shape: (50, 1), ids: [0 1 2 ... 47 48 49], task_names: [ 0  1  2 ... 17 18 19]>'
    assert str(dataset) == ref_str

    # Test max print size
    dc.utils.set_max_print_size(25)
    dataset = dc.data.NumpyDataset(
        X=np.random.rand(50, 3), y=np.random.rand(50,), ids=np.arange(50))
    ref_str = '<NumpyDataset X.shape: (50, 3), y.shape: (50,), w.shape: (50,), task_names: [0]>'
    assert str(dataset) == ref_str
+46 −8
Original line number Diff line number Diff line
@@ -67,7 +67,7 @@ _print_threshold = 10


def get_print_threshold():
  """Return the printing threshold for array.
  """Return the printing threshold for datasets.

  The print threshold is the number of elements from ids/tasks to
  print when printing representations of `Dataset` objects.
@@ -95,6 +95,44 @@ def set_print_threshold(threshold):
  _print_threshold = threshold


# If a dataset contains more than this number of elements, it won't
# print any dataset ids
_max_print_size = 1000


def get_max_print_size():
  """Return the max print size for a datset.

  If a dataset is large, printing `self.ids` as part of a string
  representation can be very slow. This field controls the maximum
  size for a dataset before ids are no longer printed.

  Returns
  -------
  max_print_size: int
    Maximum length of a dataset for ids to be printed in string
    representation.
  """
  return _max_print_size


def set_max_print_size(max_print_size):
  """Set max_print_size

  If a dataset is large, printing `self.ids` as part of a string
  representation can be very slow. This field controls the maximum
  size for a dataset before ids are no longer printed.

  Parameters
  ----------
  max_print_size: int
    Maximum length of a dataset for ids to be printed in string
    representation.
  """
  global _max_print_size
  _max_print_size = max_print_size


def download_url(url, dest_dir=get_data_dir(), name=None):
  """Download a file to disk.