Commit 8b5708d4 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changes

parent da54ecdc
Loading
Loading
Loading
Loading
+18 −9
Original line number Diff line number Diff line
@@ -1000,7 +1000,7 @@ class DiskDataset(Dataset):

  The basic structure of `DiskDataset` is quite robust and will likely serve
  you will for datasets up to about 100 GB or larger. However note that
  `DiskDataset` has note been tested for very large datasets at the terabyte
  `DiskDataset` has not been tested for very large datasets at the terabyte
  range and beyond. You may be better served by implementing a custom
  `Dataset` class for those use cases.

@@ -1037,7 +1037,7 @@ class DiskDataset(Dataset):
  projects.
  """

  def __init__(self, data_dir: str, legacy_metadata: bool = False) -> None:
  def __init__(self, data_dir: str) -> None:
    """Load a constructed DiskDataset from disk

    Note that this method cannot construct a new disk dataset. Instead use
@@ -1049,18 +1049,25 @@ class DiskDataset(Dataset):
    ----------
    data_dir: str
      Location on disk of an existing `DiskDataset`.
    legacy_metadata: bool, optional (default False)
      If `True` use the legacy format for metadata without shape information
      in metadata.
    """
    self.data_dir = data_dir
    self.legacy_metadata = legacy_metadata

    logger.info("Loading dataset from disk.")
    self.tasks, self.metadata_df = self.load_metadata()
    if len(self.metadata_df.columns) == 4:
    if len(self.metadata_df.columns) == 4 and list(
        self.metadata_df.columns) == ['ids', 'X', 'y', 'w']:
      logger.info("Detected legacy metatadata on disk.")
      self.legacy_metadata = True
    elif len(self.metadata_df.columns) == 8 and list(
        self.metadata_df.columns) == [
            'ids', 'X', 'y', 'w', 'ids_shape', 'X_shape', 'y_shape', 'w_shape'
        ]:
      self.legacy_metadata = False
    else:
      raise ValueError(
          "Malformed metadata on disk. Metadata must have columns 'ids', 'X', 'y', 'w', 'ids_shape', 'X_shape', 'y_shape', 'w_shape' (or if in legacy metadata format, columns 'ids', 'X', 'y', 'w')"
      )
    self._cached_shards: Optional[List] = None
    self._memory_cache_size = 20 * (1 << 20)  # 20 MB
    self._cache_used = 0
@@ -1083,7 +1090,8 @@ class DiskDataset(Dataset):
      List of tasks for this dataset.
    legacy_metadata: bool, optional (default False)
      If `True` use the legacy format for metadata without shape information
      in metadata.
      in metadata. This option is not recommended since the legacy metadata
      format will have worse performance.

    Returns
    -------
@@ -1108,7 +1116,7 @@ class DiskDataset(Dataset):
    logger.info("TIMING: dataset construction took %0.3f s" % (time2 - time1))
    return DiskDataset(data_dir, legacy_metadata)

  def load_metadata(self):
  def load_metadata(self) -> Tuple[List[str], pd.DataFrame]:
    """Helper method that loads metadata from disk."""
    try:
      tasks_filename, metadata_filename = self._get_metadata_filename()
@@ -1205,7 +1213,8 @@ class DiskDataset(Dataset):
      The identifiers array 
    legacy_metadata: bool, optional (default False)
      If `True` use the legacy format for metadata without shape information
      in metadata.
      in metadata. Setting this option is not recommended since legacy
      metadata will have worse performance.

    Returns
    -------