Commit 619b5d05 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Shuffle each shard added.

parent 3cc05330
Loading
Loading
Loading
Loading
+18 −0
Original line number Diff line number Diff line
@@ -340,6 +340,7 @@ class Dataset(object):
    log("Resharding to original shard-size %d." % orig_shard_size,
        self.verbosity)
    self.reshard(shard_size=orig_shard_size)
    self.shuffle_each_shard()
    #########################################################  TIMING
    time2 = time.time()
    log("TIMING: reshard_shuffle took %0.3f s" % (time2-time1),
@@ -391,6 +392,23 @@ class Dataset(object):
      self.metadata_df = Dataset.construct_metadata(metadata_rows)
      self.save_to_disk()

  def shuffle_each_shard(self):
    """Shuffles elements within each shard of the datset."""
    tasks = self.get_task_names()
    # Shuffle the arrays corresponding to each row in metadata_df
    n_rows = len(self.metadata_df.index)
    n_rows = len(self.metadata_df.index)
    for i in range(n_rows):
      row = self.metadata_df.iloc[i]
      basename = row["basename"]
      X, y, w, ids = self.get_shard(i)
      n = X.shape[0]
      permutation = np.random.permutation(n)
      X, y, w, ids = (X[permutation], y[permutation],
                      w[permutation], ids[permutation])
      Dataset.write_data_to_disk(
          self.data_dir, basename, tasks, X, y, w, ids)

  def shuffle_shards(self):
    """Shuffles the order of the shards for this dataset."""
    metadata_rows = self.metadata_df.values.tolist()
+28 −3
Original line number Diff line number Diff line
@@ -91,6 +91,34 @@ class TestShuffle(TestAPI):
    assert y_orig.shape == y_new.shape
    assert w_orig.shape == w_new.shape

  def test_shuffle_each_shard(self):
    """Test that shuffle_each_shard works."""
    n_samples = 100
    n_tasks = 10
    n_features = 10

    X = np.random.rand(n_samples, n_features)
    y = np.random.randint(2, size=(n_samples, n_tasks))
    w = np.random.randint(2, size=(n_samples, n_tasks))
    ids = np.arange(n_samples)
    dataset = Dataset.from_numpy(self.data_dir, X, y, w, ids)
    dataset.reshard(shard_size=10)

    dataset.shuffle_each_shard()
    X_s, y_s, w_s, ids_s = dataset.to_numpy()
    assert X_s.shape == X.shape
    assert y_s.shape == y.shape
    assert ids_s.shape == ids.shape
    assert w_s.shape == w.shape

    # The ids should now store the performed permutation. Check that the
    # original dataset is recoverable.
    for i in range(n_samples):
      np.testing.assert_array_equal(X_s[i], X[ids_s[i]])
      np.testing.assert_array_equal(y_s[i], y[ids_s[i]])
      np.testing.assert_array_equal(w_s[i], w[ids_s[i]])
      np.testing.assert_array_equal(ids_s[i], ids[ids_s[i]])

  def test_shuffle_shards(self):
    """Test that shuffle_shards works."""
    n_samples = 100
@@ -119,6 +147,3 @@ class TestShuffle(TestAPI):
      np.testing.assert_array_equal(y_s[i], y[ids_s[i]])
      np.testing.assert_array_equal(w_s[i], w[ids_s[i]])
      np.testing.assert_array_equal(ids_s[i], ids[ids_s[i]])