Commit 3cc05330 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Variants of shuffling.

parent 2e66e122
Loading
Loading
Loading
Loading
+9 −1
Original line number Diff line number Diff line
@@ -178,7 +178,7 @@ class Dataset(object):
    sample_y = load_from_disk(
        os.path.join(
            self.data_dir,
            self.metadata_df.iterrows().next()[1]['y-transformed']))[0]
            self.metadata_df.iterrows().next()[1]['y-transformed']))
    return len(sample_y)

  def _get_metadata_filename(self):
@@ -329,6 +329,9 @@ class Dataset(object):

  def reshard_shuffle(self, reshard_size=10):
    """Shuffles by resharding, shuffling shards, undoing resharding."""
    #########################################################  TIMING
    time1 = time.time()
    #########################################################  TIMING
    orig_shard_size = self.get_shard_size()
    log("Resharding to shard-size %d." % reshard_size, self.verbosity)
    self.reshard(shard_size=reshard_size)
@@ -337,6 +340,11 @@ class Dataset(object):
    log("Resharding to original shard-size %d." % orig_shard_size,
        self.verbosity)
    self.reshard(shard_size=orig_shard_size)
    #########################################################  TIMING
    time2 = time.time()
    log("TIMING: reshard_shuffle took %0.3f s" % (time2-time1),
        self.verbosity)
    #########################################################  TIMING

  def shuffle(self, iterations=1):
    """Shuffles this dataset on disk to have random order."""
+2 −0
Original line number Diff line number Diff line
@@ -54,9 +54,11 @@ class TestBasicDatasetAPI(TestDatasetAPI):
    X, y, w, ids = solubility_dataset.to_numpy()
    assert solubility_dataset.get_number_shards() == 1
    solubility_dataset.reshard(shard_size=1)
    assert solubility_dataset.get_shard_size() == 1
    X_r, y_r, w_r, ids_r = solubility_dataset.to_numpy()
    assert solubility_dataset.get_number_shards() == 10
    solubility_dataset.reshard(shard_size=10)
    assert solubility_dataset.get_shard_size() == 10
    X_rr, y_rr, w_rr, ids_rr = solubility_dataset.to_numpy()

    # Test first resharding worked