Commit b5a4d7c3 authored by peastman's avatar peastman
Browse files

Prevent DiskDataset from resizing arrays

parent 2c04394f
Loading
Loading
Loading
Loading
+33 −25
Original line number Diff line number Diff line
@@ -94,11 +94,15 @@ def pad_batch(batch_size, X_b, y_b, w_b, ids_b):

  if y_b is None:
    y_out = None
  elif len(y_b.shape) < 2:
    y_out = np.zeros(batch_size, dtype=y_b.dtype)
  else:
    y_out = np.zeros((batch_size, y_b.shape[1]), dtype=y_b.dtype)

  if w_b is None:
    w_out = None
  elif len(w_b.shape) < 2:
    w_out = np.zeros(batch_size, dtype=w_b.dtype)
  else:
    w_out = np.zeros((batch_size, w_b.shape[1]), dtype=w_b.dtype)

@@ -343,6 +347,8 @@ class NumpyDataset(Dataset):

  def get_task_names(self):
    """Get the names of the tasks associated with this dataset."""
    if len(self._y.shape) < 2:
      return np.array([0])
    return np.arange(self._y.shape[1])

  @property
@@ -956,20 +962,16 @@ class DiskDataset(Dataset):
                 data_dir=None,
                 verbose=True):
    """Creates a DiskDataset object from specified Numpy arrays."""
    # if data_dir is None:
    #  data_dir = tempfile.mkdtemp()
    n_samples = len(X)
    # The -1 indicates that y will be reshaped to have length -1
    if n_samples > 0:
      y = np.reshape(y, (n_samples, -1))
      if w is not None:
        w = np.reshape(w, (n_samples, -1))
    if ids is None:
      ids = np.arange(n_samples)
    if w is None:
      w = np.ones_like(y)
    if tasks is None:
      if len(y.shape) > 1:
        n_tasks = y.shape[1]
      else:
        n_tasks = 1
      tasks = np.arange(n_tasks)
    # raw_data = (X, y, w, ids)
    return DiskDataset.create_dataset(
@@ -1205,8 +1207,8 @@ class DiskDataset(Dataset):
          if indices_count + num_shard_elts >= len(indices):
            break
        # Need to offset indices to fit within shard_size
        shard_inds = indices[indices_count:
                             indices_count + num_shard_elts] - count
        shard_inds = indices[indices_count:indices_count +
                             num_shard_elts] - count
        X_sel = X[shard_inds]
        # Handle the case of datasets with y/w missing
        if y is not None:
@@ -1257,17 +1259,29 @@ class DiskDataset(Dataset):
  def y(self):
    """Get the y vector for this dataset as a single numpy array."""
    ys = []
    one_dimensional = False
    for (_, y_b, _, _) in self.itershards():
      ys.append(y_b)
      if len(y_b.shape) == 1:
        one_dimensional = True
    if not one_dimensional:
      return np.vstack(ys)
    else:
      return np.concatenate(ys)

  @property
  def w(self):
    """Get the weight vector for this dataset as a single numpy array."""
    ws = []
    one_dimensional = False
    for (_, _, w_b, _) in self.itershards():
      ws.append(np.array(w_b))
      if len(w_b.shape) == 1:
        one_dimensional = True
    if not one_dimensional:
      return np.vstack(ws)
    else:
      return np.concatenate(ws)

  def __len__(self):
    """
@@ -1282,22 +1296,16 @@ class DiskDataset(Dataset):
  def get_shape(self):
    """Finds shape of dataset."""
    n_tasks = len(self.get_task_names())
    X_shape = np.array((0,) + (0,) * len(self.get_data_shape()))
    ids_shape = np.array((0,))
    for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
      if shard_num == 0:
        X_shape = np.array(X.shape)
        if n_tasks > 0:
      y_shape = np.array((0,) + (0,))
      w_shape = np.array((0,) + (0,))
          y_shape = np.array(y.shape)
          w_shape = np.array(w.shape)
        else:
          y_shape = tuple()
          w_shape = tuple()

    for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
      if shard_num == 0:
        X_shape += np.array(X.shape)
        if n_tasks > 0:
          y_shape += np.array(y.shape)
          w_shape += np.array(w.shape)
        ids_shape += np.array(ids.shape)
        ids_shape = np.array(ids.shape)
      else:
        X_shape[0] += np.array(X.shape)[0]
        if n_tasks > 0: