Commit b8749616 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Steps towards cached shape

parent 7691cc59
Loading
Loading
Loading
Loading
+34 −19
Original line number Diff line number Diff line
@@ -1073,34 +1073,45 @@ class DiskDataset(Dataset):

    Returns
    -------
    List with values `[out_ids, out_X, out_y, out_w]` with filenames of locations to disk which these respective arrays were written.
    List with values `[out_ids, out_X, out_y, out_w, out_ids_shape, out_X_shape, out_y_shape, out_w_shape]` with filenames of locations to disk which these respective arrays were written.
    """
    if X is not None:
      out_X: Optional[str] = "%s-X.npy" % basename
      save_to_disk(X, os.path.join(data_dir, out_X))  # type: ignore
      out_X_shape = X.shape
    else:
      out_X = None
      out_X_shape = None

    if y is not None:
      out_y: Optional[str] = "%s-y.npy" % basename
      save_to_disk(y, os.path.join(data_dir, out_y))  # type: ignore
      out_y_shape = y.shape
    else:
      out_y = None
      out_y_shape = None

    if w is not None:
      out_w: Optional[str] = "%s-w.npy" % basename
      save_to_disk(w, os.path.join(data_dir, out_w))  # type: ignore
      out_w_shape = w.shape
    else:
      out_w = None
      out_w_shape = None

    if ids is not None:
      out_ids: Optional[str] = "%s-ids.npy" % basename
      save_to_disk(ids, os.path.join(data_dir, out_ids))  # type: ignore
      out_ids_shape = ids.shape
    else:
      out_ids = None
      out_ids_shape = None

    # note that this corresponds to the _construct_metadata column order
    return [out_ids, out_X, out_y, out_w]
    return [
        out_ids, out_X, out_y, out_w, out_ids_shape, out_X_shape, out_y_shape,
        out_w_shape
    ]

  def save_to_disk(self) -> None:
    """Save dataset to disk."""
@@ -1674,6 +1685,10 @@ class DiskDataset(Dataset):
      A DiskDataset with a single shard.

    """
    # Create temp directory to store shuffled version
    shuffle_dir = tempfile.mkdtemp()
    n_shards = self.get_number_shards()

    all_X = []
    all_y = []
    all_w = []
@@ -2003,23 +2018,23 @@ class DiskDataset(Dataset):
  def get_shape(self) -> Tuple[Shape, Shape, Shape, Shape]:
    """Finds shape of dataset."""
    n_tasks = len(self.get_task_names())
    for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
      if shard_num == 0:
        X_shape = np.array(X.shape)
        if n_tasks > 0:
          y_shape = np.array(y.shape)
          w_shape = np.array(w.shape)
        else:
          y_shape = tuple()
          w_shape = tuple()
        ids_shape = np.array(ids.shape)
      else:
        X_shape[0] += np.array(X.shape)[0]
        if n_tasks > 0:
          y_shape[0] += np.array(y.shape)[0]
          w_shape[0] += np.array(w.shape)[0]
        ids_shape[0] += np.array(ids.shape)[0]
    return tuple(X_shape), tuple(y_shape), tuple(w_shape), tuple(ids_shape)
    #for shard_num, (X, y, w, ids) in enumerate(self.itershards()):
    #  if shard_num == 0:
    #    X_shape = np.array(X.shape)
    #    if n_tasks > 0:
    #      y_shape = np.array(y.shape)
    #      w_shape = np.array(w.shape)
    #    else:
    #      y_shape = tuple()
    #      w_shape = tuple()
    #    ids_shape = np.array(ids.shape)
    #  else:
    #    X_shape[0] += np.array(X.shape)[0]
    #    if n_tasks > 0:
    #      y_shape[0] += np.array(y.shape)[0]
    #      w_shape[0] += np.array(w.shape)[0]
    #    ids_shape[0] += np.array(ids.shape)[0]
    #return tuple(X_shape), tuple(y_shape), tuple(w_shape), tuple(ids_shape)

  def get_label_means(self) -> pd.DataFrame:
    """Return pandas series of label means."""