Commit 7fdfceee authored by Yutong Zhao's avatar Yutong Zhao
Browse files

Fix iterbatches not iterating in batch_size across shards.

parent f37eb55b
Loading
Loading
Loading
Loading
+84 −33
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals
import os
import math
import numpy as np
import pandas as pd
import random
@@ -636,9 +637,30 @@ class DiskDataset(Dataset):
                  epoch=0,
                  deterministic=False,
                  pad_batches=False):
    """Get an object that iterates over minibatches from the dataset.
    """ Get an object that iterates over minibatches from the dataset. It is guaranteed
    that the number of batches returned is math.ceil(dataset.get_shape()[0][0]/batch_size).
    
    Each minibatch is returned as a tuple of four numpy arrays: (X, y, w, ids).


    Parameters:
    -----------
    batch_size: int
      Number of elements in a batch

    epoch: int
      Not used

    deterministic: bool
      Whether or not we should should shuffle each shard before generating the batches.
      Note that this is only local in the sense that it does not ever mix between different
      shards.

    pad_batches: bool
      Whether or not we should pad the last batch, globally, such that it has exactly batch_size
      elements.


    """

    def iterate(dataset):
@@ -647,52 +669,81 @@ class DiskDataset(Dataset):
        shard_perm = np.random.permutation(num_shards)
      else:
        shard_perm = np.arange(num_shards)
      pool = Pool(1)

      # (ytz): Depending on the application, thread-based pools may be faster
      # than process based pools, since process based pools need to pickle/serialize
      # objects as an extra overhead. Also, as hideously as un-thread safe this looks,
      # we're actually protected by the GIL.
      pool = Pool(1) # mp.dummy aliases ThreadPool to Pool
      next_shard = pool.apply_async(dataset.get_shard, (shard_perm[0],))
      for i in range(num_shards):

      total_yield = 0

      num_global_batches = math.ceil(dataset.get_shape()[0][0]/batch_size)
      cur_global_batch = 0
      cur_shard = 0
      carry = None

      while cur_global_batch < num_global_batches:

        X, y, w, ids = next_shard.get()
        if i < num_shards - 1:
          next_shard = pool.apply_async(dataset.get_shard, (shard_perm[i + 1],))
        n_samples = X.shape[0]
        # TODO(rbharath): This happens in tests sometimes, but don't understand why?
        # Handle edge case.
        if n_samples == 0:
        if cur_shard < num_shards - 1:
          next_shard = pool.apply_async(dataset.get_shard, (shard_perm[cur_shard + 1],))
        else:
          pool.close()

        if carry is not None:
          X = np.concatenate([carry[0], X], axis=0)
          y = np.concatenate([carry[1], y], axis=0)
          w = np.concatenate([carry[2], w], axis=0)
          ids = np.concatenate([carry[3], ids], axis=0)
          carry = None

        n_shard_samples = X.shape[0]
        cur_local_batch = 0
        num_local_batches = math.ceil(n_shard_samples/batch_size)

        if n_shard_samples == 0:
          continue
        if not deterministic:
          sample_perm = np.random.permutation(n_samples)
          sample_perm = np.random.permutation(n_shard_samples)
        else:
          sample_perm = np.arange(n_samples)
        if batch_size is None:
          shard_batch_size = n_samples
        else:
          shard_batch_size = batch_size
          sample_perm = np.arange(n_shard_samples)

        while cur_local_batch < num_local_batches:
          start = cur_local_batch * batch_size
          end = min(n_shard_samples, (cur_local_batch + 1) * batch_size)

        batch_idx = 0
        num_batches = np.math.ceil(n_samples / shard_batch_size)
        while batch_idx < num_batches:
          start = batch_idx * shard_batch_size
          end = min(n_samples, (batch_idx + 1) * shard_batch_size)
          indices = range(start, end)
          perm_indices = sample_perm[indices]
          X_batch = X[perm_indices]
          X_b = X[perm_indices]

          if y is not None:
            y_batch = y[perm_indices]
            y_b = y[perm_indices]
          else:
            y_batch = None
            y_b = None

          if w is not None:
            w_batch = w[perm_indices]
            w_b = w[perm_indices]
          else:
            w_batch = None
            w_b = None

          ids_b = ids[perm_indices]

          ids_batch = ids[perm_indices]
          assert len(X_b) <= batch_size
          if len(X_b) < batch_size and cur_shard != num_shards-1:
            assert carry is None
            carry = [X_b, y_b, w_b, ids_b]
          else:

            # (ytz): this skips everything except possibly the last shard
            if pad_batches:
            (X_batch, y_batch, w_batch, ids_batch) = pad_batch(
                shard_batch_size, X_batch, y_batch, w_batch, ids_batch)
          batch_idx += 1
          yield (X_batch, y_batch, w_batch, ids_batch)
      pool.close()
              (X_b, y_b, w_b, ids_b) = pad_batches(shard_batch_size, X_b, y_b, w_b, ids_b)

            yield X_b, y_b, w_b, ids_b
            cur_global_batch += 1
          cur_local_batch += 1
        cur_shard += 1

    return iterate(self)

+108 −0
Original line number Diff line number Diff line
@@ -9,6 +9,8 @@ __author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "MIT"

import random
import math
import unittest
import tempfile
import os
@@ -395,6 +397,109 @@ class TestDatasets(unittest.TestCase):
      batch_sizes.append(len(X))
    self.assertEqual([3, 3, 3, 1], batch_sizes)

  def test_disk_iterate_batch(self):

    all_batch_sizes = [
      32,
      17,
      11
    ]
    all_shard_sizes = [
      [1, 1, 1, 1, 1],
      [31, 31, 31, 31, 31],
      [21, 11, 41, 21, 51]
    ]

    for _ in range(50):
      shard_length = random.randint(1, 32)
      shard_sizes = []
      for _ in range(shard_length):
        shard_sizes.append(random.randint(1, 128))
      all_shard_sizes.append(shard_sizes)

      batch_size = random.randint(1, 256)
      all_batch_sizes.append(batch_size)
 
    for shard_sizes in all_shard_sizes:

      All_Xs, All_ys, All_ws, All_ids = [], [], [], []

      def shard_generator():
        for sz in shard_sizes:
          X_b = np.random.rand(sz, 1)
          y_b = np.random.rand(sz, 1)
          w_b = np.random.rand(sz, 1)
          ids_b = np.random.rand(sz)

          All_Xs.append(X_b)
          All_ys.append(y_b)
          All_ws.append(w_b)
          All_ids.append(ids_b)

          yield X_b, y_b, w_b, ids_b


      dataset = dc.data.DiskDataset.create_dataset(
        shard_generator())

      All_Xs = np.concatenate(All_Xs, axis=0)
      All_ys = np.concatenate(All_ys, axis=0)
      All_ws = np.concatenate(All_ws, axis=0)
      All_ids = np.concatenate(All_ids, axis=0)

      total_size = sum(shard_sizes)

      assert dataset.X.shape[0] == total_size

      # deterministic
      Test_Xs, Test_ys, Test_ws, Test_ids = [], [], [], []
      for bidx, (a, b, c, d) in enumerate(dataset.iterbatches(
        batch_size=batch_size,
        pad_batches=False,
        deterministic=True)):

        Test_Xs.append(a)
        Test_ys.append(b)
        Test_ws.append(c)
        Test_ids.append(d)

      Test_Xs = np.concatenate(Test_Xs, axis=0)
      Test_ys = np.concatenate(Test_ys, axis=0)
      Test_ws = np.concatenate(Test_ws, axis=0)
      Test_ids = np.concatenate(Test_ids, axis=0)

      assert bidx == math.ceil(total_size/batch_size) - 1

      np.testing.assert_array_equal(All_Xs, Test_Xs)
      np.testing.assert_array_equal(All_ys, Test_ys)
      np.testing.assert_array_equal(All_ws, Test_ws)
      np.testing.assert_array_equal(All_ids, Test_ids)


      # non-deterministic
      Test_Xs, Test_ys, Test_ws, Test_ids = [], [], [], []
      for bidx, (a, b, c, d) in enumerate(dataset.iterbatches(
        batch_size=batch_size,
        pad_batches=False,
        deterministic=False)):

        Test_Xs.append(a)
        Test_ys.append(b)
        Test_ws.append(c)
        Test_ids.append(d)

      Test_Xs = np.concatenate(Test_Xs, axis=0)
      Test_ys = np.concatenate(Test_ys, axis=0)
      Test_ws = np.concatenate(Test_ws, axis=0)
      Test_ids = np.concatenate(Test_ids, axis=0)

      assert bidx == math.ceil(total_size/batch_size) - 1

      np.testing.assert_array_equal(np.sort(All_Xs, axis=0), np.sort(Test_Xs, axis=0))
      np.testing.assert_array_equal(np.sort(All_ys, axis=0), np.sort(Test_ys, axis=0))
      np.testing.assert_array_equal(np.sort(All_ws, axis=0), np.sort(Test_ws, axis=0))
      np.testing.assert_array_equal(np.sort(All_ids, axis=0), np.sort(Test_ids, axis=0))

  def test_numpy_iterate_batch_size(self):
    solubility_dataset = dc.data.tests.load_solubility_data()
    X, y, _, _ = (solubility_dataset.X, solubility_dataset.y,
@@ -406,3 +511,6 @@ class TestDatasets(unittest.TestCase):
        3, pad_batches=False, deterministic=True):
      batch_sizes.append(len(X))
    self.assertEqual([3, 3, 3, 1], batch_sizes)

if __name__ == "__main__":
  unittest.main()
 No newline at end of file