Unverified Commit 4deb33c9 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #914 from proteneer/fix_iter_batches

Fix iterbatches not iterating in batch_size across shards and pad_batches with None y or w.
parents 0404791b 36404369
Loading
Loading
Loading
Loading
+116 −39
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals
import os
import math
import numpy as np
import pandas as pd
import random
@@ -89,15 +90,24 @@ def pad_batch(batch_size, X_b, y_b, w_b, ids_b):
  else:
    X_out = np.zeros((batch_size,), dtype=X_b.dtype)

  num_tasks = y_b.shape[1]
  y_out = np.zeros((batch_size, num_tasks), dtype=y_b.dtype)
  w_out = np.zeros((batch_size, num_tasks), dtype=w_b.dtype)
  if y_b is None:
    y_out = None
  else:
    y_out = np.zeros((batch_size, y_b.shape[1]), dtype=y_b.dtype)

  if w_b is None:
    w_out = None
  else:
    w_out = np.zeros((batch_size, w_b.shape[1]), dtype=w_b.dtype)

  ids_out = np.zeros((batch_size,), dtype=ids_b.dtype)

  # Fill in batch arrays
  start = 0
  # Only the first set of copy will be counted in training loss
  if w_out is not None:
    w_out[start:start + num_samples] = w_b[:]

  while start < batch_size:
    num_left = batch_size - start
    if num_left < num_samples:
@@ -105,7 +115,10 @@ def pad_batch(batch_size, X_b, y_b, w_b, ids_b):
    else:
      increment = num_samples
    X_out[start:start + increment] = X_b[:increment]

    if y_out is not None:
      y_out[start:start + increment] = y_b[:increment]

    ids_out[start:start + increment] = ids_b[:increment]
    start += increment

@@ -636,65 +649,129 @@ class DiskDataset(Dataset):
                  epoch=0,
                  deterministic=False,
                  pad_batches=False):
    """Get an object that iterates over minibatches from the dataset.
    """ Get an object that iterates over minibatches from the dataset. It is guaranteed
    that the number of batches returned is math.ceil(len(dataset)/batch_size).
    
    Each minibatch is returned as a tuple of four numpy arrays: (X, y, w, ids).


    Parameters:
    -----------
    batch_size: int
      Number of elements in a batch. If None, then it yields batches with size equal to the size
      of each individual shard.

    epoch: int
      Not used

    deterministic: bool
      Whether or not we should should shuffle each shard before generating the batches.
      Note that this is only local in the sense that it does not ever mix between different
      shards.

    pad_batches: bool
      Whether or not we should pad the last batch, globally, such that it has exactly batch_size
      elements.


    """

    def iterate(dataset):
    def iterate(dataset, batch_size):
      num_shards = dataset.get_number_shards()
      if not deterministic:
        shard_perm = np.random.permutation(num_shards)
      else:
        shard_perm = np.arange(num_shards)
      pool = Pool(1)

      # (ytz): Depending on the application, thread-based pools may be faster
      # than process based pools, since process based pools need to pickle/serialize
      # objects as an extra overhead. Also, as hideously as un-thread safe this looks,
      # we're actually protected by the GIL.
      pool = Pool(1)  # mp.dummy aliases ThreadPool to Pool
      next_shard = pool.apply_async(dataset.get_shard, (shard_perm[0],))
      for i in range(num_shards):

      total_yield = 0

      if batch_size is None:
        num_global_batches = num_shards
      else:
        num_global_batches = math.ceil(dataset.get_shape()[0][0] / batch_size)

      cur_global_batch = 0
      cur_shard = 0
      carry = None

      while cur_global_batch < num_global_batches:

        X, y, w, ids = next_shard.get()
        if i < num_shards - 1:
          next_shard = pool.apply_async(dataset.get_shard, (shard_perm[i + 1],))
        n_samples = X.shape[0]
        # TODO(rbharath): This happens in tests sometimes, but don't understand why?
        # Handle edge case.
        if n_samples == 0:
          continue
        if not deterministic:
          sample_perm = np.random.permutation(n_samples)
        if cur_shard < num_shards - 1:
          next_shard = pool.apply_async(dataset.get_shard,
                                        (shard_perm[cur_shard + 1],))
        else:
          sample_perm = np.arange(n_samples)
          pool.close()

        if carry is not None:
          X = np.concatenate([carry[0], X], axis=0)
          if y is not None:
            y = np.concatenate([carry[1], y], axis=0)
          if w is not None:
            w = np.concatenate([carry[2], w], axis=0)
          ids = np.concatenate([carry[3], ids], axis=0)
          carry = None

        n_shard_samples = X.shape[0]
        cur_local_batch = 0
        if batch_size is None:
          shard_batch_size = n_samples
          shard_batch_size = n_shard_samples
        else:
          shard_batch_size = batch_size

        batch_idx = 0
        num_batches = np.math.ceil(n_samples / shard_batch_size)
        while batch_idx < num_batches:
          start = batch_idx * shard_batch_size
          end = min(n_samples, (batch_idx + 1) * shard_batch_size)
        num_local_batches = math.ceil(n_shard_samples / shard_batch_size)

        if n_shard_samples == 0:
          continue
        if not deterministic:
          sample_perm = np.random.permutation(n_shard_samples)
        else:
          sample_perm = np.arange(n_shard_samples)

        while cur_local_batch < num_local_batches:
          start = cur_local_batch * shard_batch_size
          end = min(n_shard_samples, (cur_local_batch + 1) * shard_batch_size)

          indices = range(start, end)
          perm_indices = sample_perm[indices]
          X_batch = X[perm_indices]
          X_b = X[perm_indices]

          if y is not None:
            y_batch = y[perm_indices]
            y_b = y[perm_indices]
          else:
            y_batch = None
            y_b = None

          if w is not None:
            w_batch = w[perm_indices]
            w_b = w[perm_indices]
          else:
            w_b = None

          ids_b = ids[perm_indices]

          assert len(X_b) <= shard_batch_size
          if len(X_b) < shard_batch_size and cur_shard != num_shards - 1:
            assert carry is None
            carry = [X_b, y_b, w_b, ids_b]
          else:
            w_batch = None

          ids_batch = ids[perm_indices]
            # (ytz): this skips everything except possibly the last shard
            if pad_batches:
            (X_batch, y_batch, w_batch, ids_batch) = pad_batch(
                shard_batch_size, X_batch, y_batch, w_batch, ids_batch)
          batch_idx += 1
          yield (X_batch, y_batch, w_batch, ids_batch)
      pool.close()
              (X_b, y_b, w_b, ids_b) = pad_batch(shard_batch_size, X_b, y_b,
                                                 w_b, ids_b)

    return iterate(self)
            yield X_b, y_b, w_b, ids_b
            cur_global_batch += 1
          cur_local_batch += 1
        cur_shard += 1

    return iterate(self, batch_size)

  def itersamples(self):
    """Get an object that iterates over the samples in the dataset.
+221 −0
Original line number Diff line number Diff line
@@ -9,6 +9,8 @@ __author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "MIT"

import random
import math
import unittest
import tempfile
import os
@@ -395,6 +397,221 @@ class TestDatasets(unittest.TestCase):
      batch_sizes.append(len(X))
    self.assertEqual([3, 3, 3, 1], batch_sizes)

  def test_disk_pad_batches(self):
    shard_sizes = [21, 11, 41, 21, 51]
    batch_size = 10

    all_Xs, all_ys, all_ws, all_ids = [], [], [], []

    def shard_generator():
      for sz in shard_sizes:
        X_b = np.random.rand(sz, 1)
        y_b = np.random.rand(sz, 1)
        w_b = np.random.rand(sz, 1)
        ids_b = np.random.rand(sz)

        all_Xs.append(X_b)
        all_ys.append(y_b)
        all_ws.append(w_b)
        all_ids.append(ids_b)

        yield X_b, y_b, w_b, ids_b

    dataset = dc.data.DiskDataset.create_dataset(shard_generator())

    all_Xs = np.concatenate(all_Xs, axis=0)
    all_ys = np.concatenate(all_ys, axis=0)
    all_ws = np.concatenate(all_ws, axis=0)
    all_ids = np.concatenate(all_ids, axis=0)

    test_Xs, test_ys, test_ws, test_ids = [], [], [], []
    for bidx, (a, b, c, d) in enumerate(
        dataset.iterbatches(
            batch_size=batch_size, pad_batches=True, deterministic=True)):

      test_Xs.append(a)
      test_ys.append(b)
      test_ws.append(c)
      test_ids.append(d)

    test_Xs = np.concatenate(test_Xs, axis=0)
    test_ys = np.concatenate(test_ys, axis=0)
    test_ws = np.concatenate(test_ws, axis=0)
    test_ids = np.concatenate(test_ids, axis=0)

    total_size = sum(shard_sizes)

    assert bidx == math.ceil(total_size / batch_size) - 1

    expected_batches = math.ceil(total_size / batch_size) * batch_size

    assert len(test_Xs) == expected_batches
    assert len(test_ys) == expected_batches
    assert len(test_ws) == expected_batches
    assert len(test_ids) == expected_batches

    np.testing.assert_array_equal(all_Xs, test_Xs[:total_size, :])
    np.testing.assert_array_equal(all_ys, test_ys[:total_size, :])
    np.testing.assert_array_equal(all_ws, test_ws[:total_size, :])
    np.testing.assert_array_equal(all_ids, test_ids[:total_size])

  def test_disk_iterate_y_w_None(self):
    shard_sizes = [21, 11, 41, 21, 51]
    batch_size = 10

    all_Xs, all_ys, all_ws, all_ids = [], [], [], []

    def shard_generator():
      for sz in shard_sizes:
        X_b = np.random.rand(sz, 1)
        ids_b = np.random.rand(sz)

        all_Xs.append(X_b)
        all_ids.append(ids_b)

        yield X_b, None, None, ids_b

    dataset = dc.data.DiskDataset.create_dataset(shard_generator())

    all_Xs = np.concatenate(all_Xs, axis=0)
    all_ids = np.concatenate(all_ids, axis=0)

    test_Xs, test_ids = [], []
    for bidx, (a, _, _, d) in enumerate(
        dataset.iterbatches(
            batch_size=batch_size, pad_batches=True, deterministic=True)):

      test_Xs.append(a)
      test_ids.append(d)

    test_Xs = np.concatenate(test_Xs, axis=0)
    test_ids = np.concatenate(test_ids, axis=0)

    total_size = sum(shard_sizes)

    assert bidx == math.ceil(total_size / batch_size) - 1

    expected_batches = math.ceil(total_size / batch_size) * batch_size

    assert len(test_Xs) == expected_batches
    assert len(test_ids) == expected_batches

    np.testing.assert_array_equal(all_Xs, test_Xs[:total_size, :])
    np.testing.assert_array_equal(all_ids, test_ids[:total_size])

  def test_disk_iterate_batch(self):

    all_batch_sizes = [None, 32, 17, 11]
    all_shard_sizes = [[7, 3, 12, 4, 5], [1, 1, 1, 1, 1], [31, 31, 31, 31, 31],
                       [21, 11, 41, 21, 51]]

    for idx in range(25):
      shard_length = random.randint(1, 32)
      shard_sizes = []
      for _ in range(shard_length):
        shard_sizes.append(random.randint(1, 128))
      all_shard_sizes.append(shard_sizes)
      if idx == 0:
        # special case to test
        all_batch_sizes.append(None)
      else:
        all_batch_sizes.append(random.randint(1, 256))

    for shard_sizes, batch_size in zip(all_shard_sizes, all_batch_sizes):

      all_Xs, all_ys, all_ws, all_ids = [], [], [], []

      def shard_generator():
        for sz in shard_sizes:
          X_b = np.random.rand(sz, 1)
          y_b = np.random.rand(sz, 1)
          w_b = np.random.rand(sz, 1)
          ids_b = np.random.rand(sz)

          all_Xs.append(X_b)
          all_ys.append(y_b)
          all_ws.append(w_b)
          all_ids.append(ids_b)

          yield X_b, y_b, w_b, ids_b

      dataset = dc.data.DiskDataset.create_dataset(shard_generator())

      all_Xs = np.concatenate(all_Xs, axis=0)
      all_ys = np.concatenate(all_ys, axis=0)
      all_ws = np.concatenate(all_ws, axis=0)
      all_ids = np.concatenate(all_ids, axis=0)

      total_size = sum(shard_sizes)

      assert dataset.X.shape[0] == total_size

      # deterministic
      test_Xs, test_ys, test_ws, test_ids = [], [], [], []
      for bidx, (a, b, c, d) in enumerate(
          dataset.iterbatches(
              batch_size=batch_size, pad_batches=False, deterministic=True)):

        test_Xs.append(a)
        test_ys.append(b)
        test_ws.append(c)
        test_ids.append(d)

      if batch_size is None:
        for idx, (tx, ty, tw,
                  tids) in enumerate(zip(test_Xs, test_ys, test_ws, test_ids)):
          assert len(tx) == shard_sizes[idx]
          assert len(ty) == shard_sizes[idx]
          assert len(tw) == shard_sizes[idx]
          assert len(tids) == shard_sizes[idx]

      test_Xs = np.concatenate(test_Xs, axis=0)
      test_ys = np.concatenate(test_ys, axis=0)
      test_ws = np.concatenate(test_ws, axis=0)
      test_ids = np.concatenate(test_ids, axis=0)

      if batch_size is None:
        assert bidx == len(shard_sizes) - 1
      else:
        assert bidx == math.ceil(total_size / batch_size) - 1

      np.testing.assert_array_equal(all_Xs, test_Xs)
      np.testing.assert_array_equal(all_ys, test_ys)
      np.testing.assert_array_equal(all_ws, test_ws)
      np.testing.assert_array_equal(all_ids, test_ids)

      # non-deterministic
      test_Xs, test_ys, test_ws, test_ids = [], [], [], []

      for bidx, (a, b, c, d) in enumerate(
          dataset.iterbatches(
              batch_size=batch_size, pad_batches=False, deterministic=False)):

        test_Xs.append(a)
        test_ys.append(b)
        test_ws.append(c)
        test_ids.append(d)

      # we don't know the order in which the shards are iterated in.
      test_Xs = np.concatenate(test_Xs, axis=0)
      test_ys = np.concatenate(test_ys, axis=0)
      test_ws = np.concatenate(test_ws, axis=0)
      test_ids = np.concatenate(test_ids, axis=0)

      if batch_size is None:
        assert bidx == len(shard_sizes) - 1
      else:
        assert bidx == math.ceil(total_size / batch_size) - 1

      np.testing.assert_array_equal(
          np.sort(all_Xs, axis=0), np.sort(test_Xs, axis=0))
      np.testing.assert_array_equal(
          np.sort(all_ys, axis=0), np.sort(test_ys, axis=0))
      np.testing.assert_array_equal(
          np.sort(all_ws, axis=0), np.sort(test_ws, axis=0))
      np.testing.assert_array_equal(
          np.sort(all_ids, axis=0), np.sort(test_ids, axis=0))

  def test_numpy_iterate_batch_size(self):
    solubility_dataset = dc.data.tests.load_solubility_data()
    X, y, _, _ = (solubility_dataset.X, solubility_dataset.y,
@@ -406,3 +623,7 @@ class TestDatasets(unittest.TestCase):
        3, pad_batches=False, deterministic=True):
      batch_sizes.append(len(X))
    self.assertEqual([3, 3, 3, 1], batch_sizes)


if __name__ == "__main__":
  unittest.main()