Commit c56e0acc authored by peastman's avatar peastman
Browse files

Fixed error

parent 58095f0f
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -1937,7 +1937,7 @@ class ImageDataset(Dataset):
        sample_perm = np.arange(n_samples)
      if batch_size is None:
        batch_size = n_samples
      for epoch in epochs:
      for epoch in range(epochs):
        if not deterministic:
          sample_perm = np.random.permutation(n_samples)
        batch_idx = 0
+6 −3
Original line number Diff line number Diff line
@@ -74,15 +74,18 @@ class TestImageDataset(test_util.TensorFlowTestCase):
    ds = dc.data.ImageDataset(files, np.random.random(10))
    X = ds.X
    iterated_ids = set()
    for x, y, w, ids in ds.iterbatches(2):
    for x, y, w, ids in ds.iterbatches(2, epochs=2):
      np.testing.assert_array_equal([2, 28, 28], x.shape)
      np.testing.assert_array_equal([2], y.shape)
      np.testing.assert_array_equal([2], w.shape)
      np.testing.assert_array_equal([2], ids.shape)
      for i in (0, 1):
        assert ids[i] in files
        if len(iterated_ids) < 10:
          assert ids[i] not in iterated_ids
          iterated_ids.add(ids[i])
        else:
          assert ids[i] in iterated_ids
        index = files.index(ids[i])
        np.testing.assert_array_equal(x[i], X[index])
    assert len(iterated_ids) == 10