Commit 4344e3e3 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Removing det_iterbatches

parent ff4bcb0b
Loading
Loading
Loading
Loading
+29 −23
Original line number Diff line number Diff line
@@ -214,34 +214,40 @@ class Dataset(object):
          os.path.join(self.data_dir, row['ids'])), dtype=object)
      yield (X, y, w, ids)

  def det_iterbatches(self, batch_size=None, epoch=0):
    """
    Returns minibatches from dataset.
    """
    for i, (X, y, w, ids) in enumerate(self.itershards()):
      nb_sample = np.shape(X)[0]
      if batch_size is None:
        shard_batch_size = nb_sample
      else:
        shard_batch_size = batch_size 
      interval_points = np.linspace(
          0, nb_sample, np.ceil(float(nb_sample)/shard_batch_size)+1, dtype=int)
      for j in range(len(interval_points)-1):
        indices = range(interval_points[j], interval_points[j+1])
        X_batch = X[indices, :]
        y_batch = y[indices]
        w_batch = w[indices]
        ids_batch = ids[indices]
        yield (X_batch, y_batch, w_batch, ids_batch)

  def iterbatches(self, batch_size=None, epoch=0):
  #def det_iterbatches(self, batch_size=None, epoch=0):
  #  """
  #  Returns minibatches from dataset.
  #  """
  #  for i, (X, y, w, ids) in enumerate(self.itershards()):
  #    nb_sample = np.shape(X)[0]
  #    if batch_size is None:
  #      shard_batch_size = nb_sample
  #    else:
  #      shard_batch_size = batch_size 
  #    interval_points = np.linspace(
  #        0, nb_sample, np.ceil(float(nb_sample)/shard_batch_size)+1, dtype=int)
  #    for j in range(len(interval_points)-1):
  #      indices = range(interval_points[j], interval_points[j+1])
  #      X_batch = X[indices, :]
  #      y_batch = y[indices]
  #      w_batch = w[indices]
  #      ids_batch = ids[indices]
  #      yield (X_batch, y_batch, w_batch, ids_batch)

  def iterbatches(self, batch_size=None, epoch=0, deterministic=False):
    """Returns minibatches from dataset randomly."""
    num_shards = self.get_number_shards()
    if not deterministic:
      shard_perm = np.random.permutation(num_shards)
    else:
      shard_perm = np.arange(num_shards)
    for i in range(num_shards):
      X, y, w, ids = self.get_shard(shard_perm[i])
      n_samples = X.shape[0]
      if not deterministic:
        sample_perm = np.random.permutation(n_samples)
      else:
        sample_perm = np.arange(n_samples)
      if batch_size is None:
        shard_batch_size = n_samples
      else:
+4 −2
Original line number Diff line number Diff line
@@ -137,7 +137,8 @@ class Model(object):
    """
    y_preds = []
    batch_size = self.model_params["batch_size"]
    for (X_batch, y_batch, w_batch, ids_batch) in dataset.det_iterbatches(batch_size):
    for (X_batch, y_batch, w_batch, ids_batch) in dataset.iterbatches(
        batch_size, deterministic=True):
      y_pred_batch = np.reshape(self.predict_on_batch(X_batch), y_batch.shape)
      y_pred_batch = undo_transforms(y_pred_batch, transformers)
      y_preds.append(y_pred_batch)
@@ -159,7 +160,8 @@ class Model(object):
    y_preds = []
    batch_size = self.model_params["batch_size"]
    n_tasks = len(self.tasks)
    for (X_batch, y_batch, w_batch, ids_batch) in dataset.det_iterbatches(batch_size):
    for (X_batch, y_batch, w_batch, ids_batch) in dataset.iterbatches(
        batch_size, deterministic=True):
      y_pred_batch = self.predict_proba_on_batch(X_batch)
      batch_size = len(y_batch)
      y_pred_batch = np.reshape(y_pred_batch, (batch_size, n_tasks, n_classes))