Commit 2e66e122 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

First attempts at faster shuffling routine

parent bff615fe
Loading
Loading
Loading
Loading
+71 −0
Original line number Diff line number Diff line
@@ -13,7 +13,9 @@ from functools import partial
from deepchem.utils.save import save_to_disk
from deepchem.utils.save import load_from_disk
from deepchem.utils.save import log
import tempfile
import time
import shutil

__author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
@@ -169,6 +171,16 @@ class Dataset(object):
            self.metadata_df.iterrows().next()[1]['X-transformed']))[0]
    return np.shape(sample_X)

  def get_shard_size(self):
    """Gets size of shards on disk."""
    if not len(self.metadata_df):
      raise ValueError("No data in dataset.")
    sample_y = load_from_disk(
        os.path.join(
            self.data_dir,
            self.metadata_df.iterrows().next()[1]['y-transformed']))[0]
    return len(sample_y)

  def _get_metadata_filename(self):
    """
    Get standard location for metadata file.
@@ -182,6 +194,7 @@ class Dataset(object):
    """
    return self.metadata_df.shape[0]


  def itershards(self):
    """
    Iterates over all shards in dataset.
@@ -221,6 +234,46 @@ class Dataset(object):
        ids_batch = ids[indices]
        yield (X_batch, y_batch, w_batch, ids_batch)

  def reshard(self, shard_size):
    """Reshards data to have specified shard size."""
    # Create temp directory to store resharded version
    reshard_dir = tempfile.mkdtemp()
    new_metadata = []
    # Write data in new shards
    ind = 0
    tasks = self.get_task_names() 
    X_next = np.zeros((0,) + self.get_data_shape())
    y_next = np.zeros((0,) + (len(tasks),))
    w_next = np.zeros((0,) + (len(tasks),))
    ids_next = np.zeros((0,), dtype=object)
    for (X, y, w, ids) in self.itershards():
      X_next = np.vstack([X_next, X])
      y_next = np.vstack([y_next, y])
      w_next = np.vstack([w_next, w])
      ids_next = np.concatenate([ids_next, ids])
      while len(X_next) > shard_size:
        X_batch, X_next = X_next[:shard_size], X_next[shard_size:]
        y_batch, y_next = y_next[:shard_size], y_next[shard_size:]
        w_batch, w_next = w_next[:shard_size], w_next[shard_size:]
        ids_batch, ids_next = ids_next[:shard_size], ids_next[shard_size:]
        new_basename = "reshard-%d" % ind
        new_metadata.append(Dataset.write_data_to_disk(
            reshard_dir, new_basename, tasks, X_batch, y_batch, w_batch, ids_batch))
        ind += 1
    # Handle spillover from last shard
    new_basename = "reshard-%d" % ind
    new_metadata.append(Dataset.write_data_to_disk(
        reshard_dir, new_basename, tasks, X_next, y_next, w_next, ids_next))
    ind += 1
    # Get new metadata rows
    resharded_dataset = Dataset(
        data_dir=reshard_dir, tasks=tasks, metadata_rows=new_metadata,
        verbosity=self.verbosity)
    shutil.rmtree(self.data_dir)
    shutil.move(reshard_dir, self.data_dir)
    self.metadata_df = resharded_dataset.metadata_df
    self.save_to_disk()

  @staticmethod
  def from_numpy(data_dir, X, y, w=None, ids=None, tasks=None, verbosity=None):
    n_samples = len(X)
@@ -274,6 +327,17 @@ class Dataset(object):
                   metadata_rows=metadata_rows,
                   verbosity=self.verbosity)

  def reshard_shuffle(self, reshard_size=10):
    """Shuffles by resharding, shuffling shards, undoing resharding."""
    orig_shard_size = self.get_shard_size()
    log("Resharding to shard-size %d." % reshard_size, self.verbosity)
    self.reshard(shard_size=reshard_size)
    log("Shuffling shard order.", self.verbosity)
    self.shuffle_shards()
    log("Resharding to original shard-size %d." % orig_shard_size,
        self.verbosity)
    self.reshard(shard_size=orig_shard_size)

  def shuffle(self, iterations=1):
    """Shuffles this dataset on disk to have random order."""
    #np.random.seed(9452)
@@ -319,6 +383,13 @@ class Dataset(object):
      self.metadata_df = Dataset.construct_metadata(metadata_rows)
      self.save_to_disk()

  def shuffle_shards(self):
    """Shuffles the order of the shards for this dataset."""
    metadata_rows = self.metadata_df.values.tolist()
    random.shuffle(metadata_rows)
    self.metadata_df = Dataset.construct_metadata(metadata_rows)
    self.save_to_disk()

  def get_shard(self, i):
    """Retrieves data for the i-th shard from disk."""
    row = self.metadata_df.iloc[i]
+23 −0
Original line number Diff line number Diff line
@@ -48,6 +48,29 @@ class TestBasicDatasetAPI(TestDatasetAPI):
    solubility_dataset = self.load_solubility_data()
    assert len(solubility_dataset) == 10

  def test_reshard(self):
    """Test that resharding the dataset works."""
    solubility_dataset = self.load_solubility_data()
    X, y, w, ids = solubility_dataset.to_numpy()
    assert solubility_dataset.get_number_shards() == 1
    solubility_dataset.reshard(shard_size=1)
    X_r, y_r, w_r, ids_r = solubility_dataset.to_numpy()
    assert solubility_dataset.get_number_shards() == 10
    solubility_dataset.reshard(shard_size=10)
    X_rr, y_rr, w_rr, ids_rr = solubility_dataset.to_numpy()

    # Test first resharding worked
    np.testing.assert_array_equal(X, X_r)
    np.testing.assert_array_equal(y, y_r)
    np.testing.assert_array_equal(w, w_r)
    np.testing.assert_array_equal(ids, ids_r)

    # Test second resharding worked
    np.testing.assert_array_equal(X, X_rr)
    np.testing.assert_array_equal(y, y_rr)
    np.testing.assert_array_equal(w, w_rr)
    np.testing.assert_array_equal(ids, ids_rr)

  def test_select(self):
    """Test that dataset select works."""
    num_datapoints = 10
+66 −0
Original line number Diff line number Diff line
@@ -56,3 +56,69 @@ class TestShuffle(TestAPI):
    assert X_orig.shape == X_new.shape
    assert y_orig.shape == y_new.shape
    assert w_orig.shape == w_new.shape

  def test_reshard_shuffle(self):
    """Test that datasets can be merged."""
    verbosity = "high"
    current_dir = os.path.dirname(os.path.realpath(__file__))
    data_dir = os.path.join(self.base_dir, "dataset")

    dataset_file = os.path.join(
        current_dir, "../../models/tests/example.csv")

    featurizer = CircularFingerprint(size=1024)
    tasks = ["log-solubility"]
    loader = DataLoader(tasks=tasks,
                        smiles_field="smiles",
                        featurizer=featurizer,
                        verbosity=verbosity)
    dataset = loader.featurize(
        dataset_file, data_dir, shard_size=2)

    X_orig, y_orig, w_orig, orig_ids = dataset.to_numpy()
    orig_len = len(dataset)

    dataset.reshard_shuffle(reshard_size=1)
    X_new, y_new, w_new, new_ids = dataset.to_numpy()
    
    assert len(dataset) == orig_len
    # The shuffling should have switched up the ordering
    assert not np.array_equal(orig_ids, new_ids)
    # But all the same entries should still be present
    assert sorted(orig_ids) == sorted(new_ids)
    # All the data should have same shape
    assert X_orig.shape == X_new.shape
    assert y_orig.shape == y_new.shape
    assert w_orig.shape == w_new.shape

  def test_shuffle_shards(self):
    """Test that shuffle_shards works."""
    n_samples = 100
    n_tasks = 10
    n_features = 10

    X = np.random.rand(n_samples, n_features)
    y = np.random.randint(2, size=(n_samples, n_tasks))
    w = np.random.randint(2, size=(n_samples, n_tasks))
    ids = np.arange(n_samples)
    dataset = Dataset.from_numpy(self.data_dir, X, y, w, ids)
    dataset.reshard(shard_size=10)
    dataset.shuffle_shards()

    X_s, y_s, w_s, ids_s = dataset.to_numpy()

    assert X_s.shape == X.shape
    assert y_s.shape == y.shape
    assert ids_s.shape == ids.shape
    assert w_s.shape == w.shape

    # The ids should now store the performed permutation. Check that the
    # original dataset is recoverable.
    for i in range(n_samples):
      np.testing.assert_array_equal(X_s[i], X[ids_s[i]])
      np.testing.assert_array_equal(y_s[i], y[ids_s[i]])
      np.testing.assert_array_equal(w_s[i], w[ids_s[i]])
      np.testing.assert_array_equal(ids_s[i], ids[ids_s[i]])