Commit ed543342 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Debugging

parent a570015b
Loading
Loading
Loading
Loading
+20 −16
Original line number Diff line number Diff line
@@ -362,7 +362,7 @@ class DiskDataset(Dataset):


  @staticmethod
  def create_dataset(shard_generator, data_dir=None, tasks=[]):
  def create_dataset(shard_generator, data_dir=None, tasks=[], verbose=True):
    """Creates a new DiskDataset

    Parameters
@@ -376,8 +376,9 @@ class DiskDataset(Dataset):
      List of tasks for this dataset.
    """
    if data_dir is None:
      self.data_dir = tempfile.mkdtemp()
    self.data_dir = data_dir
      data_dir = tempfile.mkdtemp()
    elif not os.path.exists(data_dir):
      os.makedirs(data_dir)

    metadata_rows = []
    time1 = time.time()
@@ -385,13 +386,12 @@ class DiskDataset(Dataset):
      basename = "shard-%d" % shard_num 
      metadata_rows.append(
          DiskDataset.write_data_to_disk(
              self.data_dir, basename, tasks, X, y, w, ids))
    self.tasks = tasks
    self.metadata_df = DiskDataset._construct_metadata(metadata_rows)
    self.save_to_disk()
              data_dir, basename, tasks, X, y, w, ids))
    metadata_df = DiskDataset._construct_metadata(metadata_rows)
    metadata_filename = os.path.join(data_dir, "metadata.joblib")
    save_to_disk((tasks, metadata_df), metadata_filename)
    time2 = time.time()
    print("TIMING: dataset construction took %0.3f s" % (time2-time1),
          self.verbose)
    print("TIMING: dataset construction took %0.3f s" % (time2-time1), verbose)
    return DiskDataset(data_dir)

  @staticmethod
@@ -463,7 +463,8 @@ class DiskDataset(Dataset):
          yield (X_batch, y_batch, w_batch, ids_batch)
      # Handle spillover from last shard
      yield (X_next, y_next, w_next, ids_next)
    resharded_dataset = DiskDataset(generator(), data_dir=reshard_dir)
    resharded_dataset = DiskDataset.create_dataset(generator(), data_dir=reshard_dir,
                                                   tasks=self.tasks)
    shutil.rmtree(self.data_dir)
    shutil.move(reshard_dir, self.data_dir)
    self.metadata_df = resharded_dataset.metadata_df
@@ -617,7 +618,7 @@ class DiskDataset(Dataset):
        X, y, w, ids = self.get_shard(shard_num)
        newx, newy, neww = fn(X, y, w)
        yield (newx, newy, neww, ids)
    return DiskDataset(generator(), data_dir=out_dir)
    return DiskDataset.create_dataset(generator(), data_dir=out_dir, tasks=tasks)

  @staticmethod
  def from_numpy(X, y, w=None, ids=None, tasks=None, data_dir=None):
@@ -638,7 +639,8 @@ class DiskDataset(Dataset):
    if tasks is None:
      tasks = np.arange(n_tasks)
    #raw_data = (X, y, w, ids)
    return DiskDataset([(X, y, w, ids)], data_dir=data_dir, tasks=tasks)
    return DiskDataset.create_dataset([(X, y, w, ids)], data_dir=data_dir,
                                      tasks=tasks)

  @staticmethod
  def merge(datasets, merge_dir=None):
@@ -652,7 +654,7 @@ class DiskDataset(Dataset):
      for ind, dataset in enumerate(datasets):
        X, y, w, ids = (dataset.X, dataset.y, dataset.w, dataset.ids)
        yield (X, y, w, ids)
    return DiskDataset(generator(), data_dir=merge_dir)
    return DiskDataset.create_dataset(generator(), data_dir=merge_dir)

  def subset(self, shard_nums, subset_dir=None):
    """Creates a subset of the original dataset on disk."""
@@ -668,7 +670,8 @@ class DiskDataset(Dataset):
          continue
        X, y, w, ids = self.get_shard(shard_num)
        yield (X, y, w, ids)
    return DiskDataset(generator(), data_dir=subset_dir)
    return DiskDataset.create_dataset(generator(), data_dir=subset_dir,
                                      tasks=tasks)

  def sparse_shuffle(self):
    """Shuffling that exploits data sparsity to shuffle large datasets.
@@ -783,7 +786,7 @@ class DiskDataset(Dataset):
      select_dir = tempfile.mkdtemp()
    # Handle edge case with empty indices
    if not len(indices):
      return DiskDataset([], data_dir=select_dir)
      return DiskDataset.create_dataset([], data_dir=select_dir)
    indices = np.array(sorted(indices)).astype(int)
    tasks = self.get_task_names()
    def generator():
@@ -809,7 +812,8 @@ class DiskDataset(Dataset):
        # Break when all indices have been used up already
        if indices_count >= len(indices):
          return 
    return DiskDataset(generator(), data_dir=select_dir, tasks=tasks)
    return DiskDataset.create_dataset(generator(), data_dir=select_dir,
                                      tasks=tasks)

  @property
  def ids(self):