Unverified Commit 82cb9075 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1934 from peastman/epochs

Fixes to fit_on_batch() and iterbatches()
parents c9eaf1b6 c56e0acc
Loading
Loading
Loading
Loading
+141 −140
Original line number Diff line number Diff line
@@ -336,7 +336,7 @@ class Dataset(object):

  def iterbatches(self,
                  batch_size=None,
                  epoch=0,
                  epochs=1,
                  deterministic=False,
                  pad_batches=False):
    """Get an object that iterates over minibatches from the dataset.
@@ -348,7 +348,7 @@ class Dataset(object):
    ----------
    batch_size: int, optional
      Number of elements in each batch
    epoch: int, optional
    epochs: int, optional
      Number of epochs to walk over dataset
    deterministic: bool, optional
      If True, follow deterministic order.
@@ -485,8 +485,7 @@ class Dataset(object):
    # Create a Tensorflow Dataset.

    def gen_data():
      for epoch in range(epochs):
        for X, y, w, ids in self.iterbatches(batch_size, epoch, deterministic,
      for X, y, w, ids in self.iterbatches(batch_size, epochs, deterministic,
                                           pad_batches):
        yield (X, y, w)

@@ -727,7 +726,7 @@ class NumpyDataset(Dataset):

  def iterbatches(self,
                  batch_size=None,
                  epoch=0,
                  epochs=1,
                  deterministic=False,
                  pad_batches=False):
    """Get an object that iterates over minibatches from the dataset.
@@ -739,7 +738,7 @@ class NumpyDataset(Dataset):
    ----------
    batch_size: int, optional
      Number of elements in each batch
    epoch: int, optional
    epochs: int, optional
      Number of epochs to walk over dataset
    deterministic: bool, optional
      If True, follow deterministic order.
@@ -751,14 +750,15 @@ class NumpyDataset(Dataset):
    Generator which yields tuples of four numpy arrays `(X, y, w, ids)`
    """

    def iterate(dataset, batch_size, deterministic, pad_batches):
    def iterate(dataset, batch_size, epochs, deterministic, pad_batches):
      n_samples = dataset._X.shape[0]
      if not deterministic:
        sample_perm = np.random.permutation(n_samples)
      else:
      if deterministic:
        sample_perm = np.arange(n_samples)
      if batch_size is None:
        batch_size = n_samples
      for epoch in range(epochs):
        if not deterministic:
          sample_perm = np.random.permutation(n_samples)
        batch_idx = 0
        num_batches = np.math.ceil(n_samples / batch_size)
        while batch_idx < num_batches:
@@ -776,7 +776,7 @@ class NumpyDataset(Dataset):
          batch_idx += 1
          yield (X_batch, y_batch, w_batch, ids_batch)

    return iterate(self, batch_size, deterministic, pad_batches)
    return iterate(self, batch_size, epochs, deterministic, pad_batches)

  def itersamples(self):
    """Get an object that iterates over the samples in the dataset.
@@ -1151,7 +1151,7 @@ class DiskDataset(Dataset):

  def iterbatches(self,
                  batch_size=None,
                  epoch=0,
                  epochs=1,
                  deterministic=False,
                  pad_batches=False):
    """ Get an object that iterates over minibatches from the dataset.
@@ -1166,7 +1166,7 @@ class DiskDataset(Dataset):
      Number of elements in a batch. If None, then it yields batches
      with size equal to the size of each individual shard.
    epoch: int
      Not used
      Number of epochs to walk over dataset
    deterministic: bool
      Whether or not we should should shuffle each shard before
      generating the batches.  Note that this is only local in the
@@ -1176,21 +1176,20 @@ class DiskDataset(Dataset):
      it has exactly batch_size elements.
    """
    shard_indices = list(range(self.get_number_shards()))
    return self._iterbatches_from_shards(shard_indices, batch_size,
    return self._iterbatches_from_shards(shard_indices, batch_size, epochs,
                                         deterministic, pad_batches)

  def _iterbatches_from_shards(self,
                               shard_indices,
                               batch_size=None,
                               epochs=1,
                               deterministic=False,
                               pad_batches=False):
    """Get an object that iterates over batches from a restricted set of shards."""

    def iterate(dataset, batch_size):
    def iterate(dataset, batch_size, epochs):
      num_shards = len(shard_indices)
      if not deterministic:
        shard_perm = np.random.permutation(num_shards)
      else:
      if deterministic:
        shard_perm = np.arange(num_shards)

      # (ytz): Depending on the application, thread-based pools may be faster
@@ -1198,16 +1197,17 @@ class DiskDataset(Dataset):
      # objects as an extra overhead. Also, as hideously as un-thread safe this looks,
      # we're actually protected by the GIL.
      pool = Pool(1)  # mp.dummy aliases ThreadPool to Pool
      next_shard = pool.apply_async(dataset.get_shard,
                                    (shard_indices[shard_perm[0]],))

      total_yield = 0

      if batch_size is None:
        num_global_batches = num_shards
      else:
        num_global_batches = math.ceil(dataset.get_shape()[0][0] / batch_size)

      for epoch in range(epochs):
        if not deterministic:
          shard_perm = np.random.permutation(num_shards)
        next_shard = pool.apply_async(dataset.get_shard,
                                      (shard_indices[shard_perm[0]],))
        cur_global_batch = 0
        cur_shard = 0
        carry = None
@@ -1218,7 +1218,7 @@ class DiskDataset(Dataset):
          if cur_shard < num_shards - 1:
            next_shard = pool.apply_async(
                dataset.get_shard, (shard_indices[shard_perm[cur_shard + 1]],))
        else:
          elif epoch == epochs - 1:
            pool.close()

          if carry is not None:
@@ -1285,7 +1285,7 @@ class DiskDataset(Dataset):
            cur_local_batch += 1
          cur_shard += 1

    return iterate(self, batch_size)
    return iterate(self, batch_size, epochs)

  def itersamples(self):
    """Get an object that iterates over the samples in the dataset.
@@ -1922,7 +1922,7 @@ class ImageDataset(Dataset):

  def iterbatches(self,
                  batch_size=None,
                  epoch=0,
                  epochs=1,
                  deterministic=False,
                  pad_batches=False):
    """Get an object that iterates over minibatches from the dataset.
@@ -1931,14 +1931,15 @@ class ImageDataset(Dataset):
    w, ids).
    """

    def iterate(dataset, batch_size, deterministic, pad_batches):
    def iterate(dataset, batch_size, epochs, deterministic, pad_batches):
      n_samples = dataset._X_shape[0]
      if not deterministic:
        sample_perm = np.random.permutation(n_samples)
      else:
      if deterministic:
        sample_perm = np.arange(n_samples)
      if batch_size is None:
        batch_size = n_samples
      for epoch in range(epochs):
        if not deterministic:
          sample_perm = np.random.permutation(n_samples)
        batch_idx = 0
        num_batches = np.math.ceil(n_samples / batch_size)
        while batch_idx < num_batches:
@@ -1964,7 +1965,7 @@ class ImageDataset(Dataset):
          batch_idx += 1
          yield (X_batch, y_batch, w_batch, ids_batch)

    return iterate(self, batch_size, deterministic, pad_batches)
    return iterate(self, batch_size, epochs, deterministic, pad_batches)

  def itersamples(self):
    """Get an object that iterates over the samples in the dataset.
@@ -2143,7 +2144,7 @@ class Databag(object):
    ----------
    batch_size: int
      Number of samples from each dataset to return
    epoch: int
    epochs: int
      Number of times to loop through the datasets
    pad_batches: boolean
      Should all batches==batch_size
+4 −4
Original line number Diff line number Diff line
@@ -436,9 +436,9 @@ class TestDatasets(test_util.TensorFlowTestCase):
                  solubility_dataset.w, solubility_dataset.ids)
    batch_sizes = []
    for X, y, _, _ in solubility_dataset.iterbatches(
        3, pad_batches=False, deterministic=True):
        3, epochs=2, pad_batches=False, deterministic=True):
      batch_sizes.append(len(X))
    self.assertEqual([3, 3, 3, 1], batch_sizes)
    self.assertEqual([3, 3, 3, 1, 3, 3, 3, 1], batch_sizes)

  def test_disk_pad_batches(self):
    shard_sizes = [21, 11, 41, 21, 51]
@@ -663,9 +663,9 @@ class TestDatasets(test_util.TensorFlowTestCase):
        solubility_dataset)
    batch_sizes = []
    for X, y, _, _ in solubility_dataset.iterbatches(
        3, pad_batches=False, deterministic=True):
        3, epochs=2, pad_batches=False, deterministic=True):
      batch_sizes.append(len(X))
    self.assertEqual([3, 3, 3, 1], batch_sizes)
    self.assertEqual([3, 3, 3, 1, 3, 3, 3, 1], batch_sizes)

  def test_merge(self):
    """Test that dataset merge works."""
+6 −3
Original line number Diff line number Diff line
@@ -74,15 +74,18 @@ class TestImageDataset(test_util.TensorFlowTestCase):
    ds = dc.data.ImageDataset(files, np.random.random(10))
    X = ds.X
    iterated_ids = set()
    for x, y, w, ids in ds.iterbatches(2):
    for x, y, w, ids in ds.iterbatches(2, epochs=2):
      np.testing.assert_array_equal([2, 28, 28], x.shape)
      np.testing.assert_array_equal([2], y.shape)
      np.testing.assert_array_equal([2], w.shape)
      np.testing.assert_array_equal([2], ids.shape)
      for i in (0, 1):
        assert ids[i] in files
        if len(iterated_ids) < 10:
          assert ids[i] not in iterated_ids
          iterated_ids.add(ids[i])
        else:
          assert ids[i] in iterated_ids
        index = files.index(ids[i])
        np.testing.assert_array_equal(x[i], X[index])
    assert len(iterated_ids) == 10
+19 −6
Original line number Diff line number Diff line
@@ -415,7 +415,15 @@ class KerasModel(Model):

    return apply_gradient_for_batch

  def fit_on_batch(self, X, y, w, variables=None, loss=None, callbacks=[]):
  def fit_on_batch(self,
                   X,
                   y,
                   w,
                   variables=None,
                   loss=None,
                   callbacks=[],
                   checkpoint=True,
                   max_checkpoints_to_keep=5):
    """Perform a single step of training.

    Parameters
@@ -436,13 +444,18 @@ class KerasModel(Model):
    callbacks: function or list of functions
      one or more functions of the form f(model, step) that will be invoked after
      every step.  This can be used to perform validation, logging, etc.
    checkpoint: bool
      if true, save a checkpoint after performing the training step
    max_checkpoints_to_keep: int
      the maximum number of checkpoints to keep.  Older checkpoints are discarded.
    """
    if not self.built:
      self.build()
    self._ensure_built()
    dataset = NumpyDataset(X, y, w)
    return self.fit(
        dataset,
        nb_epoch=1,
        max_checkpoints_to_keep=max_checkpoints_to_keep,
        checkpoint_interval=self._global_step.numpy() + 2 if checkpoint else 0,
        variables=variables,
        loss=loss,
        callbacks=callbacks)
+24 −0
Original line number Diff line number Diff line
@@ -58,6 +58,30 @@ class TestKerasModel(unittest.TestCase):
    scores = model.evaluate_generator(generator, [metric])
    assert scores[metric.name] > 0.9

  def test_fit_on_batch(self):
    """Test fitting a KerasModel to individual batches."""
    n_data_points = 10
    n_features = 2
    X = np.random.rand(n_data_points, n_features)
    y = (X[:, 0] > X[:, 1]).astype(np.float32)
    dataset = dc.data.NumpyDataset(X, y)
    keras_model = tf.keras.Sequential([
        tf.keras.layers.Dense(10, activation='relu'),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    model = dc.models.KerasModel(
        keras_model, dc.models.losses.BinaryCrossEntropy(), learning_rate=0.005)
    i = 0
    for X, y, w, ids in dataset.iterbatches(model.batch_size, 500):
      i += 1
      model.fit_on_batch(X, y, w, checkpoint=False)
    prediction = np.squeeze(model.predict_on_batch(X))
    assert np.array_equal(y, np.round(prediction))
    metric = dc.metrics.Metric(dc.metrics.roc_auc_score)
    generator = model.default_generator(dataset, pad_batches=False)
    scores = model.evaluate_generator(generator, [metric])
    assert scores[metric.name] > 0.9

  def test_checkpointing(self):
    """Test loading and saving checkpoints with KerasModel."""
    # Create two models using the same model directory.