Commit da7cbcd5 authored by Peter Eastman's avatar Peter Eastman
Browse files

Change to work on Python 2

parent 5cb7f9af
Loading
Loading
Loading
Loading
+42 −42
Original line number Diff line number Diff line
@@ -644,13 +644,12 @@ class DiskDataset(Dataset):
        shard_perm = np.random.permutation(num_shards)
      else:
        shard_perm = np.arange(num_shards)
      with Pool(1) as pool:
      pool = Pool(1)
      next_shard = pool.apply_async(dataset.get_shard, (shard_perm[0],))
      for i in range(num_shards):
        X, y, w, ids = next_shard.get()
        if i < num_shards - 1:
            next_shard = pool.apply_async(dataset.get_shard,
                                          (shard_perm[i + 1],))
          next_shard = pool.apply_async(dataset.get_shard, (shard_perm[i + 1],))
        n_samples = X.shape[0]
        # TODO(rbharath): This happens in tests sometimes, but don't understand why?
        # Handle edge case.
@@ -689,6 +688,7 @@ class DiskDataset(Dataset):
            (X_batch, y_batch, w_batch, ids_batch) = pad_batch(
                shard_batch_size, X_batch, y_batch, w_batch, ids_batch)
          yield (X_batch, y_batch, w_batch, ids_batch)
      pool.close()

    return iterate(self)