Commit a570015b authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changing DiskDataset constructor

parent d65baa3b
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -160,7 +160,7 @@ class DataLoader(object):
        log("TIMING: featurizing shard %d took %0.3f s" % (shard_num, time2-time1),
            self.verbose)
        yield X, y, w, ids
    return DiskDataset(shard_generator(), data_dir, self.tasks)
    return DiskDataset.create_dataset(shard_generator(), data_dir, self.tasks)

  def get_shards(self, input_files, shard_size):
    """Stub for children classes."""
+31 −19
Original line number Diff line number Diff line
@@ -346,27 +346,38 @@ class DiskDataset(Dataset):
  """
  A Dataset that is stored as a set of files on disk.
  """
  def __init__(self, shard_generator=[], data_dir=None, tasks=[],
               reload=False, verbose=True):
  def __init__(self, data_dir, verbose=True):
    """
    Turns featurized dataframes into numpy files, writes them & metadata to disk.
    """
    if data_dir is not None:
      if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    else:
      data_dir = tempfile.mkdtemp()
    self.data_dir = data_dir
    self.verbose = verbose

    if reload:
      log("Loading pre-existing dataset.", self.verbose)
    log("Loading dataset from disk.", self.verbose)
    if os.path.exists(self._get_metadata_filename()):
      (self.tasks, self.metadata_df) = load_from_disk(
          self._get_metadata_filename())
    else:
        raise ValueError("No metadata found.")
      return
      raise ValueError("No metadata found on disk.")


  @staticmethod
  def create_dataset(shard_generator, data_dir=None, tasks=[]):
    """Creates a new DiskDataset

    Parameters
    ----------
    shard_generator: Iterable
      An iterable (either a list or generator) that provides tuples of data
      (X, y, w, ids). Each tuple will be written to a separate shard on disk.
    data_dir: str
      Filename for data directory. Creates a temp directory if none specified.
    tasks: list
      List of tasks for this dataset.
    """
    if data_dir is None:
      self.data_dir = tempfile.mkdtemp()
    self.data_dir = data_dir

    metadata_rows = []
    time1 = time.time()
@@ -376,14 +387,15 @@ class DiskDataset(Dataset):
          DiskDataset.write_data_to_disk(
              self.data_dir, basename, tasks, X, y, w, ids))
    self.tasks = tasks
    self.metadata_df = DiskDataset.construct_metadata(metadata_rows)
    self.metadata_df = DiskDataset._construct_metadata(metadata_rows)
    self.save_to_disk()
    time2 = time.time()
    print("TIMING: dataset construction took %0.3f s" % (time2-time1),
          self.verbose)
    return DiskDataset(data_dir)

  @staticmethod
  def construct_metadata(metadata_entries):
  def _construct_metadata(metadata_entries):
    """Construct a dataframe containing metadata.
  
    metadata_entries should have elements returned by write_data_to_disk
@@ -717,7 +729,7 @@ class DiskDataset(Dataset):
    """Shuffles the order of the shards for this dataset."""
    metadata_rows = self.metadata_df.values.tolist()
    random.shuffle(metadata_rows)
    self.metadata_df = DiskDataset.construct_metadata(metadata_rows)
    self.metadata_df = DiskDataset._construct_metadata(metadata_rows)
    self.save_to_disk()

  def get_shard(self, i):
@@ -745,7 +757,7 @@ class DiskDataset(Dataset):
    metadata_rows.append(
        DiskDataset.write_data_to_disk(
            self.data_dir, basename, tasks, X, y, w, ids))
    self.metadata_df = DiskDataset.construct_metadata(metadata_rows)
    self.metadata_df = DiskDataset._construct_metadata(metadata_rows)
    self.save_to_disk()

  def set_shard(self, shard_num, X, y, w, ids):
+1 −2
Original line number Diff line number Diff line
@@ -161,7 +161,6 @@ class TestDataLoader(unittest.TestCase):
    # Now perform move
    shutil.move(data_dir, moved_data_dir)

    moved_featurized_dataset = dc.data.DiskDataset(
        data_dir=moved_data_dir, reload=True)
    moved_featurized_dataset = dc.data.DiskDataset(moved_data_dir)

    assert len(moved_featurized_dataset) == n_dataset
+2 −6
Original line number Diff line number Diff line
@@ -39,8 +39,7 @@ class TestLoad(unittest.TestCase):
    X, y, w, ids = (dataset.X, dataset.y, dataset.w, dataset.ids)
    shutil.move(data_dir, moved_data_dir)

    moved_dataset = dc.data.DiskDataset(
        data_dir=moved_data_dir, reload=True)
    moved_dataset = dc.data.DiskDataset(moved_data_dir)

    X_moved, y_moved, w_moved, ids_moved = (moved_dataset.X, moved_dataset.y,
                                            moved_dataset.w, moved_dataset.ids)
@@ -56,9 +55,6 @@ class TestLoad(unittest.TestCase):
    # Only for debug!
    np.random.seed(123)

    # Set some global variables up top
    reload = True

    current_dir = os.path.dirname(os.path.realpath(__file__))
    ##Make directories to store the raw and featurized datasets.
    data_dir = tempfile.mkdtemp() 
@@ -85,9 +81,9 @@ class TestLoad(unittest.TestCase):

    ####### Do singletask load
    y_tasks, w_tasks, = [], []
    dataset = dc.data.DiskDataset(data_dir)
    for ind, task in enumerate(all_tasks):
      print("Processing task %s" % task)
      dataset = dc.data.DiskDataset(data_dir=data_dir, reload=reload)

      X_task, y_task, w_task, ids_task = (dataset.X, dataset.y, dataset.w,
                                          dataset.ids)
+1 −11
Original line number Diff line number Diff line
@@ -60,7 +60,7 @@ class SingletaskToMultitask(Model):
    tasks = dataset.get_task_names()
    assert len(tasks) == len(task_dirs)
    log("Splitting multitask dataset into singletask datasets", dataset.verbose)
    task_datasets = [DiskDataset(data_dir=task_dirs[task_num], tasks=[task])
    task_datasets = [DiskDataset.create_dataset([], task_dirs[task_num], [task])
                    for (task_num, task) in enumerate(tasks)]
    #task_metadata_rows = {task: [] for task in tasks}
    for shard_num, (X, y, w, ids) in enumerate(dataset.itershards()):
@@ -81,16 +81,6 @@ class SingletaskToMultitask(Model):
        task_datasets[task_num].add_shard(X_nonzero, y_nonzero, w_nonzero,
                                          ids_nonzero)

        #if X_nonzero.size > 0: 
        #  task_metadata_rows[task].append(
        #    DiskDataset.write_data_to_disk(
        #        task_dirs[task_num], basename, [task],
        #        X_nonzero, y_nonzero, w_nonzero, ids_nonzero))
    
    #task_datasets = [
    #    DiskDataset(data_dir=task_dirs[task_num],
    #            metadata_rows=task_metadata_rows[task])
    #    for (task_num, task) in enumerate(tasks)]
    return task_datasets