Commit b5fbca7c authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #410 from lilleswing/378-doctests-cr

378: Add Doctests To Travis
parents c7a722b2 f1b47030
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -20,6 +20,7 @@ install:
- python setup.py install
script:
- nosetests -v deepchem --nologcapture
- find ./deepchem | grep .py$ |xargs python -m doctest -v
- bash devtools/travis-ci/test_format_code.sh
after_success:
- echo $TRAVIS_SECURE_ENV_VARS
+43 −22
Original line number Diff line number Diff line
@@ -20,14 +20,15 @@ from deepchem.utils.save import load_sdf_files
from deepchem.feat import UserDefinedFeaturizer
from deepchem.data import DiskDataset


def convert_df_to_numpy(df, tasks, verbose=False):
  """Transforms a dataframe containing deepchem input into numpy arrays"""
  n_samples = df.shape[0]
  n_tasks = len(tasks)

  time1 = time.time()
  y = np.hstack([
      np.reshape(np.array(df[task].values), (n_samples, 1)) for task in tasks])
  y = np.hstack(
      [np.reshape(np.array(df[task].values), (n_samples, 1)) for task in tasks])
  time2 = time.time()

  w = np.ones((n_samples, n_tasks))
@@ -49,6 +50,7 @@ def convert_df_to_numpy(df, tasks, verbose=False):

  return y.astype(float), w.astype(float)


def featurize_smiles_df(df, featurizer, field, log_every_N=1000, verbose=True):
  """Featurize individual compounds in dataframe.

@@ -64,11 +66,12 @@ def featurize_smiles_df(df, featurizer, field, log_every_N=1000, verbose=True):
    if ind % log_every_N == 0:
      log("Featurizing sample %d" % ind, verbose)
    features.append(featurizer.featurize([mol]))
  valid_inds = np.array([1 if elt.size > 0 else 0 for elt in features],
                        dtype=bool)
  valid_inds = np.array(
      [1 if elt.size > 0 else 0 for elt in features], dtype=bool)
  features = [elt for (is_valid, elt) in zip(valid_inds, features) if is_valid]
  return np.squeeze(np.array(features)), valid_inds


def get_user_specified_features(df, featurizer, verbose=True):
  """Extract and merge user specified features. 

@@ -86,12 +89,15 @@ def get_user_specified_features(df, featurizer, verbose=True):

  """
  time1 = time.time()
  df[featurizer.feature_fields] = df[featurizer.feature_fields].apply(pd.to_numeric)
  df[featurizer.feature_fields] = df[featurizer.feature_fields].apply(
      pd.to_numeric)
  X_shard = df.as_matrix(columns=featurizer.feature_fields)
  time2 = time.time()
  log("TIMING: user specified processing took %0.3f s" % (time2-time1), verbose)
  log("TIMING: user specified processing took %0.3f s" % (time2 - time1),
      verbose)
  return X_shard


def featurize_mol_df(df, featurizer, field, verbose=True, log_every_N=1000):
  """Featurize individual compounds in dataframe.

@@ -108,11 +114,12 @@ def featurize_mol_df(df, featurizer, field, verbose=True, log_every_N=1000):
    if ind % log_every_N == 0:
      log("Featurizing sample %d" % ind, verbose)
    features.append(featurizer.featurize([mol]))
  valid_inds = np.array([1 if elt.size > 0 else 0 for elt in features],
                        dtype=bool)
  valid_inds = np.array(
      [1 if elt.size > 0 else 0 for elt in features], dtype=bool)
  features = [elt for (is_valid, elt) in zip(valid_inds, features) if is_valid]
  return np.squeeze(np.array(features)), valid_inds


class DataLoader(object):
  """
  Handles loading/featurizing of chemical samples (datapoints).
@@ -121,9 +128,14 @@ class DataLoader(object):
  dataframe object to disk as output.
  """

  def __init__(self, tasks, smiles_field=None,
               id_field=None, mol_field=None, featurizer=None,
               verbose=True, log_every_n=1000):
  def __init__(self,
               tasks,
               smiles_field=None,
               id_field=None,
               mol_field=None,
               featurizer=None,
               verbose=True,
               log_every_n=1000):
    """Extracts data from input as Pandas data frame"""
    if not isinstance(tasks, list):
      raise ValueError("tasks must be a list.")
@@ -148,8 +160,10 @@ class DataLoader(object):

    if not isinstance(input_files, list):
      input_files = [input_files]

    def shard_generator():
      for shard_num, shard in enumerate(self.get_shards(input_files, shard_size)):
      for shard_num, shard in enumerate(
          self.get_shards(input_files, shard_size)):
        time1 = time.time()
        X, valid_inds = self.featurize_shard(shard)
        ids = shard[self.id_field].values
@@ -167,10 +181,12 @@ class DataLoader(object):
          assert len(X) == len(ids)

        time2 = time.time()
        log("TIMING: featurizing shard %d took %0.3f s" % (shard_num, time2-time1),
            self.verbose)
        log("TIMING: featurizing shard %d took %0.3f s" %
            (shard_num, time2 - time1), self.verbose)
        yield X, y, w, ids
    return DiskDataset.create_dataset(shard_generator(), data_dir, self.tasks)

    return DiskDataset.create_dataset(
        shard_generator(), data_dir, self.tasks, verbose=self.verbose)

  def get_shards(self, input_files, shard_size):
    """Stub for children classes."""
@@ -180,23 +196,26 @@ class DataLoader(object):
    """Featurizes a shard of an input dataframe."""
    raise NotImplementedError


class CSVLoader(DataLoader):
  """
  Handles loading of CSV files.
  """

  def get_shards(self, input_files, shard_size, verbose=True):
    """Defines a generator which returns data for each shard"""
    return load_csv_files(input_files, shard_size, verbose=verbose)

  def featurize_shard(self, shard):
    """Featurizes a shard of an input dataframe."""
    return featurize_smiles_df(shard, self.featurizer,
                               field=self.smiles_field)
    return featurize_smiles_df(shard, self.featurizer, field=self.smiles_field)


class UserCSVLoader(DataLoader):
  """
  Handles loading of CSV files with user-defined featurizers.
  """

  def get_shards(self, input_files, shard_size):
    """Defines a generator which returns data for each shard"""
    return load_csv_files(input_files, shard_size)
@@ -207,16 +226,18 @@ class UserCSVLoader(DataLoader):
    X = get_user_specified_features(shard, self.featurizer)
    return (X, np.ones(len(X), dtype=bool))


class SDFLoader(DataLoader):
  """
  Handles loading of SDF files.
  """

  def get_shards(self, input_files, shard_size):
    """Defines a generator which returns data for each shard"""
    return load_sdf_files(input_files)

  def featurize_shard(self, shard):
    """Featurizes a shard of an input dataframe."""
    log("Currently featurizing feature_type: %s"
        % self.featurizer.__class__.__name__, self.verbose)
    log("Currently featurizing feature_type: %s" %
        self.featurizer.__class__.__name__, self.verbose)
    return featurize_mol_df(shard, self.featurizer, field=self.mol_field)
+167 −119
Original line number Diff line number Diff line
@@ -20,6 +20,7 @@ __author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "GPL"


def sparsify_features(X):
  """Extracts a sparse feature representation from dense feature array."""
  n_samples = len(X)
@@ -31,6 +32,7 @@ def sparsify_features(X):
  X_sparse = np.array(X_sparse, dtype=object)
  return X_sparse


def densify_features(X_sparse, num_features):
  """Expands sparse feature representation to dense feature array."""
  n_samples = len(X_sparse)
@@ -40,6 +42,7 @@ def densify_features(X_sparse, num_features):
    X[i][nonzero_inds.astype(int)] = nonzero_vals
  return X


def pad_features(batch_size, X_b):
  """Pads a batch of features to have precisely batch_size elements.
  
@@ -69,6 +72,7 @@ def pad_features(batch_size, X_b):
      start += increment
    return X_out


def pad_batch(batch_size, X_b, y_b, w_b, ids_b):
  """Pads batch to have size precisely batch_size elements.

@@ -151,7 +155,11 @@ class Dataset(object):
    """Get the weight vector for this dataset as a single numpy array."""
    raise NotImplementedError()

  def iterbatches(self, batch_size=None, epoch=0, deterministic=False, pad_batches=False):
  def iterbatches(self,
                  batch_size=None,
                  epoch=0,
                  deterministic=False,
                  pad_batches=False):
    """Get an object that iterates over minibatches from the dataset.

    Each minibatch is returned as a tuple of four numpy arrays: (X, y, w, ids).
@@ -163,9 +171,11 @@ class Dataset(object):

    Example:

    >>> dataset = NumpyDataset(np.ones((2,2)))
    >>> for x, y, w, id in dataset.itersamples():
    >>>   print(x, y, w, id)

    ...   print(x, y, w, id)
    [ 1.  1.] [ 0.] [ 0.] 0
    [ 1.  1.] [ 0.] [ 0.] 1
    """
    raise NotImplementedError()

@@ -174,7 +184,7 @@ class Dataset(object):

    The argument is a function that can be called as follows:

    >>> newx, newy, neww = fn(x, y, w)
    >> newx, newy, neww = fn(x, y, w)

    It might be called only once with the whole dataset, or multiple times with
    different subsets of the data.  Each time it is called, it should transform
@@ -286,12 +296,16 @@ class NumpyDataset(Dataset):
    """Get the weight vector for this dataset as a single numpy array."""
    return self._w

  def iterbatches(self, batch_size=None, epoch=0, deterministic=False,
  def iterbatches(self,
                  batch_size=None,
                  epoch=0,
                  deterministic=False,
                  pad_batches=False):
    """Get an object that iterates over minibatches from the dataset.

    Each minibatch is returned as a tuple of four numpy arrays: (X, y, w, ids).
    """

    def iterate(dataset, batch_size, deterministic, pad_batches):
      n_samples = dataset._X.shape[0]
      if not deterministic:
@@ -313,6 +327,7 @@ class NumpyDataset(Dataset):
          (X_batch, y_batch, w_batch, ids_batch) = pad_batch(
              batch_size, X_batch, y_batch, w_batch, ids_batch)
        yield (X_batch, y_batch, w_batch, ids_batch)

    return iterate(self, batch_size, deterministic, pad_batches)

  def itersamples(self):
@@ -320,9 +335,11 @@ class NumpyDataset(Dataset):

    Example:

    >>> dataset = NumpyDataset(np.ones((2,2)))
    >>> for x, y, w, id in dataset.itersamples():
    >>>   print(x, y, w, id)

    ...   print(x, y, w, id)
    [ 1.  1.] [ 0.] [ 0.] 0
    [ 1.  1.] [ 0.] [ 0.] 1
    """
    n_samples = self._X.shape[0]
    return ((self._X[i], self._y[i], self._w[i], self._ids[i])
@@ -333,7 +350,7 @@ class NumpyDataset(Dataset):

    The argument is a function that can be called as follows:

    >>> newx, newy, neww = fn(x, y, w)
    >> newx, newy, neww = fn(x, y, w)

    It might be called only once with the whole dataset, or multiple times with
    different subsets of the data.  Each time it is called, it should transform
@@ -351,10 +368,12 @@ class NumpyDataset(Dataset):
    newx, newy, neww = fn(self._X, self._y, self._w)
    return NumpyDataset(newx, newy, neww, self._ids[:])


class DiskDataset(Dataset):
  """
  A Dataset that is stored as a set of files on disk.
  """

  def __init__(self, data_dir, verbose=True):
    """
    Turns featurized dataframes into numpy files, writes them & metadata to disk.
@@ -364,12 +383,11 @@ class DiskDataset(Dataset):

    log("Loading dataset from disk.", self.verbose)
    if os.path.exists(self._get_metadata_filename()):
      (self.tasks, self.metadata_df) = load_from_disk(
          self._get_metadata_filename())
      (self.tasks,
       self.metadata_df) = load_from_disk(self._get_metadata_filename())
    else:
      raise ValueError("No metadata found on disk.")


  @staticmethod
  def create_dataset(shard_generator, data_dir=None, tasks=[], verbose=True):
    """Creates a new DiskDataset
@@ -394,14 +412,14 @@ class DiskDataset(Dataset):
    for shard_num, (X, y, w, ids) in enumerate(shard_generator):
      basename = "shard-%d" % shard_num
      metadata_rows.append(
          DiskDataset.write_data_to_disk(
              data_dir, basename, tasks, X, y, w, ids))
          DiskDataset.write_data_to_disk(data_dir, basename, tasks, X, y, w,
                                         ids))
    metadata_df = DiskDataset._construct_metadata(metadata_rows)
    metadata_filename = os.path.join(data_dir, "metadata.joblib")
    save_to_disk((tasks, metadata_df), metadata_filename)
    time2 = time.time()
    print("TIMING: dataset construction took %0.3f s" % (time2-time1), verbose)
    return DiskDataset(data_dir)
    log("TIMING: dataset construction took %0.3f s" % (time2 - time1), verbose)
    return DiskDataset(data_dir, verbose=verbose)

  @staticmethod
  def _construct_metadata(metadata_entries):
@@ -411,13 +429,16 @@ class DiskDataset(Dataset):
    above.
    """
    columns = ('basename', 'task_names', 'ids', 'X', 'y', 'w')
    metadata_df = pd.DataFrame(
        metadata_entries,
        columns=columns)
    metadata_df = pd.DataFrame(metadata_entries, columns=columns)
    return metadata_df

  @staticmethod
  def write_data_to_disk(data_dir, basename, tasks, X=None, y=None, w=None,
  def write_data_to_disk(data_dir,
                         basename,
                         tasks,
                         X=None,
                         y=None,
                         w=None,
                         ids=None):
    if X is not None:
      out_X = "%s-X.joblib" % basename
@@ -448,8 +469,7 @@ class DiskDataset(Dataset):

  def save_to_disk(self):
    """Save dataset to disk."""
    save_to_disk(
        (self.tasks, self.metadata_df), self._get_metadata_filename())
    save_to_disk((self.tasks, self.metadata_df), self._get_metadata_filename())

  def move(self, new_data_dir):
    """Moves dataset to new directory."""
@@ -470,6 +490,7 @@ class DiskDataset(Dataset):
    # Create temp directory to store resharded version
    reshard_dir = tempfile.mkdtemp()
    new_metadata = []

    # Write data in new shards
    def generator():
      tasks = self.get_task_names()
@@ -490,8 +511,9 @@ class DiskDataset(Dataset):
          yield (X_batch, y_batch, w_batch, ids_batch)
      # Handle spillover from last shard
      yield (X_next, y_next, w_next, ids_next)
    resharded_dataset = DiskDataset.create_dataset(generator(), data_dir=reshard_dir,
                                                   tasks=self.tasks)

    resharded_dataset = DiskDataset.create_dataset(
        generator(), data_dir=reshard_dir, tasks=self.tasks)
    shutil.rmtree(self.data_dir)
    shutil.move(reshard_dir, self.data_dir)
    self.metadata_df = resharded_dataset.metadata_df
@@ -504,9 +526,8 @@ class DiskDataset(Dataset):
    if not len(self.metadata_df):
      raise ValueError("No data in dataset.")
    sample_X = load_from_disk(
        os.path.join(
            self.data_dir,
            next(self.metadata_df.iterrows())[1]['X']))[0]
        os.path.join(self.data_dir, next(self.metadata_df.iterrows())[1]['X']))[
            0]
    return np.shape(sample_X)

  def get_shard_size(self):
@@ -514,9 +535,7 @@ class DiskDataset(Dataset):
    if not len(self.metadata_df):
      raise ValueError("No data in dataset.")
    sample_y = load_from_disk(
        os.path.join(
            self.data_dir,
            next(self.metadata_df.iterrows())[1]['y']))
        os.path.join(self.data_dir, next(self.metadata_df.iterrows())[1]['y']))
    return len(sample_y)

  def _get_metadata_filename(self):
@@ -540,16 +559,16 @@ class DiskDataset(Dataset):
    generator defined by this function returns the data from a particular shard.
    The order of shards returned is guaranteed to remain fixed.
    """

    def iterate(dataset):
      for _, row in dataset.metadata_df.iterrows():
        X = np.array(load_from_disk(
            os.path.join(dataset.data_dir, row['X'])))
        ids = np.array(load_from_disk(
            os.path.join(dataset.data_dir, row['ids'])), dtype=object)
        X = np.array(load_from_disk(os.path.join(dataset.data_dir, row['X'])))
        ids = np.array(
            load_from_disk(os.path.join(dataset.data_dir, row['ids'])),
            dtype=object)
        # These columns may be missing is the dataset is unlabelled.
        if row['y'] is not None:
          y = np.array(load_from_disk(
            os.path.join(dataset.data_dir, row['y'])))
          y = np.array(load_from_disk(os.path.join(dataset.data_dir, row['y'])))
        else:
          y = None
        if row['w'] is not None:
@@ -561,14 +580,19 @@ class DiskDataset(Dataset):
        else:
          w = None
        yield (X, y, w, ids)

    return iterate(self)

  def iterbatches(self, batch_size=None, epoch=0, deterministic=False,
  def iterbatches(self,
                  batch_size=None,
                  epoch=0,
                  deterministic=False,
                  pad_batches=False):
    """Get an object that iterates over minibatches from the dataset.

    Each minibatch is returned as a tuple of four numpy arrays: (X, y, w, ids).
    """

    def iterate(dataset):
      num_shards = dataset.get_number_shards()
      if not deterministic:
@@ -591,7 +615,10 @@ class DiskDataset(Dataset):
        else:
          shard_batch_size = batch_size
        interval_points = np.linspace(
            0, n_samples, np.ceil(float(n_samples)/shard_batch_size)+1, dtype=int)
            0,
            n_samples,
            np.ceil(float(n_samples) / shard_batch_size) + 1,
            dtype=int)
        for j in range(len(interval_points) - 1):
          indices = range(interval_points[j], interval_points[j + 1])
          perm_indices = sample_perm[indices]
@@ -612,6 +639,7 @@ class DiskDataset(Dataset):
            (X_batch, y_batch, w_batch, ids_batch) = pad_batch(
                shard_batch_size, X_batch, y_batch, w_batch, ids_batch)
          yield (X_batch, y_batch, w_batch, ids_batch)

    return iterate(self)

  def itersamples(self):
@@ -619,19 +647,26 @@ class DiskDataset(Dataset):

    Example:

    >>> dataset = DiskDataset.from_numpy(np.ones((2,2)), np.ones((2,1)), verbose=False)
    >>> for x, y, w, id in dataset.itersamples():
    >>>   print(x, y, w, id)
    ...   print(x, y, w, id)
    [ 1.  1.] [ 1.] [ 1.] 0
    [ 1.  1.] [ 1.] [ 1.] 1
    """

    def iterate(dataset):
      for (X_shard, y_shard, w_shard, ids_shard) in dataset.itershards():
        n_samples = X_shard.shape[0]
        for i in range(n_samples):

          def sanitize(elem):
            if elem is None:
              return None
            else:
              return elem[i]

          yield map(sanitize, [X_shard, y_shard, w_shard, ids_shard])

    return iterate(self)

  def transform(self, fn, **args):
@@ -639,7 +674,7 @@ class DiskDataset(Dataset):

    The argument is a function that can be called as follows:

    >>> newx, newy, neww = fn(x, y, w)
    >> newx, newy, neww = fn(x, y, w)

    It might be called only once with the whole dataset, or multiple times with different
    subsets of the data.  Each time it is called, it should transform the samples and return
@@ -662,15 +697,24 @@ class DiskDataset(Dataset):
    else:
      out_dir = tempfile.mkdtemp()
    tasks = self.get_task_names()

    def generator():
      for shard_num, row in self.metadata_df.iterrows():
        X, y, w, ids = self.get_shard(shard_num)
        newx, newy, neww = fn(X, y, w)
        yield (newx, newy, neww, ids)
    return DiskDataset.create_dataset(generator(), data_dir=out_dir, tasks=tasks)

    return DiskDataset.create_dataset(
        generator(), data_dir=out_dir, tasks=tasks)

  @staticmethod
  def from_numpy(X, y, w=None, ids=None, tasks=None, data_dir=None):
  def from_numpy(X,
                 y,
                 w=None,
                 ids=None,
                 tasks=None,
                 data_dir=None,
                 verbose=True):
    """Creates a DiskDataset object from specified Numpy arrays."""
    #if data_dir is None:
    #  data_dir = tempfile.mkdtemp()
@@ -688,8 +732,8 @@ class DiskDataset(Dataset):
    if tasks is None:
      tasks = np.arange(n_tasks)
    #raw_data = (X, y, w, ids)
    return DiskDataset.create_dataset([(X, y, w, ids)], data_dir=data_dir,
                                      tasks=tasks)
    return DiskDataset.create_dataset(
        [(X, y, w, ids)], data_dir=data_dir, tasks=tasks, verbose=verbose)

  @staticmethod
  def merge(datasets, merge_dir=None):
@@ -699,10 +743,12 @@ class DiskDataset(Dataset):
        os.makedirs(merge_dir)
    else:
      merge_dir = tempfile.mkdtemp()

    def generator():
      for ind, dataset in enumerate(datasets):
        X, y, w, ids = (dataset.X, dataset.y, dataset.w, dataset.ids)
        yield (X, y, w, ids)

    return DiskDataset.create_dataset(generator(), data_dir=merge_dir)

  def subset(self, shard_nums, subset_dir=None):
@@ -713,14 +759,16 @@ class DiskDataset(Dataset):
    else:
      subset_dir = tempfile.mkdtemp()
    tasks = self.get_task_names()

    def generator():
      for shard_num, row in self.metadata_df.iterrows():
        if shard_num not in shard_nums:
          continue
        X, y, w, ids = self.get_shard(shard_num)
        yield (X, y, w, ids)
    return DiskDataset.create_dataset(generator(), data_dir=subset_dir,
                                      tasks=tasks)

    return DiskDataset.create_dataset(
        generator(), data_dir=subset_dir, tasks=tasks)

  def sparse_shuffle(self):
    """Shuffling that exploits data sparsity to shuffle large datasets.
@@ -738,12 +786,11 @@ class DiskDataset(Dataset):
      if num_features is None:
        num_features = X_s.shape[1]
      X_sparse = sparsify_features(X_s)
      X_sparses, ys, ws, ids = (
          X_sparses + [X_sparse], ys + [y_s], ws + [w_s],
      X_sparses, ys, ws, ids = (X_sparses + [X_sparse], ys + [y_s], ws + [w_s],
                                ids + [np.atleast_1d(np.squeeze(ids_s))])
    # Get full dataset in memory
    (X_sparse, y, w, ids) = (
        np.vstack(X_sparses), np.vstack(ys), np.vstack(ws), np.concatenate(ids))
    (X_sparse, y, w, ids) = (np.vstack(X_sparses), np.vstack(ys), np.vstack(ws),
                             np.concatenate(ids))
    # Shuffle in memory
    num_samples = len(X_sparse)
    permutation = np.random.permutation(num_samples)
@@ -752,13 +799,12 @@ class DiskDataset(Dataset):
    # Write shuffled shards out to disk
    for i in range(num_shards):
      start, stop = i * shard_size, (i + 1) * shard_size
      (X_sparse_s, y_s, w_s, ids_s) = (
          X_sparse[start:stop], y[start:stop], w[start:stop], ids[start:stop])
      (X_sparse_s, y_s, w_s, ids_s) = (X_sparse[start:stop], y[start:stop],
                                       w[start:stop], ids[start:stop])
      X_s = densify_features(X_sparse_s, num_features)
      self.set_shard(i, X_s, y_s, w_s, ids_s)
    time2 = time.time()
    log("TIMING: sparse_shuffle took %0.3f s" % (time2-time1),
        self.verbose)
    log("TIMING: sparse_shuffle took %0.3f s" % (time2 - time1), self.verbose)

  def shuffle_each_shard(self):
    """Shuffles elements within each shard of the datset."""
@@ -772,10 +818,10 @@ class DiskDataset(Dataset):
      X, y, w, ids = self.get_shard(i)
      n = X.shape[0]
      permutation = np.random.permutation(n)
      X, y, w, ids = (X[permutation], y[permutation],
                      w[permutation], ids[permutation])
      DiskDataset.write_data_to_disk(
          self.data_dir, basename, tasks, X, y, w, ids)
      X, y, w, ids = (X[permutation], y[permutation], w[permutation],
                      ids[permutation])
      DiskDataset.write_data_to_disk(self.data_dir, basename, tasks, X, y, w,
                                     ids)

  def shuffle_shards(self):
    """Shuffles the order of the shards for this dataset."""
@@ -787,12 +833,10 @@ class DiskDataset(Dataset):
  def get_shard(self, i):
    """Retrieves data for the i-th shard from disk."""
    row = self.metadata_df.iloc[i]
    X = np.array(load_from_disk(
        os.path.join(self.data_dir, row['X'])))
    X = np.array(load_from_disk(os.path.join(self.data_dir, row['X'])))

    if row['y'] is not None:
      y = np.array(load_from_disk(
        os.path.join(self.data_dir, row['y'])))
      y = np.array(load_from_disk(os.path.join(self.data_dir, row['y'])))
    else:
      y = None

@@ -806,8 +850,8 @@ class DiskDataset(Dataset):
    else:
      w = None

    ids = np.array(load_from_disk(
        os.path.join(self.data_dir, row['ids'])), dtype=object)
    ids = np.array(
        load_from_disk(os.path.join(self.data_dir, row['ids'])), dtype=object)
    return (X, y, w, ids)

  def add_shard(self, X, y, w, ids):
@@ -817,8 +861,8 @@ class DiskDataset(Dataset):
    basename = "shard-%d" % shard_num
    tasks = self.get_task_names()
    metadata_rows.append(
        DiskDataset.write_data_to_disk(
            self.data_dir, basename, tasks, X, y, w, ids))
        DiskDataset.write_data_to_disk(self.data_dir, basename, tasks, X, y, w,
                                       ids))
    self.metadata_df = DiskDataset._construct_metadata(metadata_rows)
    self.save_to_disk()

@@ -845,9 +889,11 @@ class DiskDataset(Dataset):
      select_dir = tempfile.mkdtemp()
    # Handle edge case with empty indices
    if not len(indices):
      return DiskDataset.create_dataset([], data_dir=select_dir)
      return DiskDataset.create_dataset(
          [], data_dir=select_dir, verbose=self.verbose)
    indices = np.array(sorted(indices)).astype(int)
    tasks = self.get_task_names()

    def generator():
      count, indices_count = 0, 0
      for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
@@ -859,7 +905,8 @@ class DiskDataset(Dataset):
          if indices_count + num_shard_elts >= len(indices):
            break
        # Need to offset indices to fit within shard_size
        shard_inds =  indices[indices_count:indices_count+num_shard_elts] - count
        shard_inds = indices[indices_count:indices_count +
                             num_shard_elts] - count
        X_sel = X[shard_inds]
        y_sel = y[shard_inds]
        w_sel = w[shard_inds]
@@ -871,8 +918,9 @@ class DiskDataset(Dataset):
        # Break when all indices have been used up already
        if indices_count >= len(indices):
          return
    return DiskDataset.create_dataset(generator(), data_dir=select_dir,
                                      tasks=tasks)

    return DiskDataset.create_dataset(
        generator(), data_dir=select_dir, tasks=tasks, verbose=self.verbose)

  @property
  def ids(self):
+31 −15

File changed.

Preview size limit exceeded, changes collapsed.

+410 −308

File changed.

Preview size limit exceeded, changes collapsed.

Loading