Commit a3f1fd67 authored by nd-02110114's avatar nd-02110114
Browse files

👌 fix for review

parent fb077a27
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -420,7 +420,7 @@ class Dataset(object):
    """
    raise NotImplementedError()

  def select(self, indices: Sequence[int], select_dir: Optional[str] = None):
  def select(self, indices: Sequence[int], select_dir: Optional[str] = None) -> "Dataset":
    """Creates a new dataset from a selection of indices from self.

    Parameters
@@ -2294,8 +2294,9 @@ class DiskDataset(Dataset):

    Returns
    -------
    DiskDataset
      A Dataset containing the selected samples
    Dataset
      A dataset containing the selected samples. The default dataset is `DiskDataset`.
      If `output_numpy_dataset` is True, the datset is `NumpyDataset`.
    """
    if output_numpy_dataset and (select_dir is not None or
                                 select_shard_size is not None):
+14 −13
Original line number Diff line number Diff line
@@ -103,7 +103,8 @@ class Splitter(object):
      train_ds_base = DiskDataset.merge(update_train_base_merge)
    return list(zip(train_datasets, cv_datasets))

  def train_valid_test_split(self,
  def train_valid_test_split(
      self,
      dataset: Dataset,
      train_dir: Optional[str] = None,
      valid_dir: Optional[str] = None,
@@ -113,7 +114,7 @@ class Splitter(object):
      frac_test: float = 0.1,
      seed: Optional[int] = None,
      log_every_n: int = 1000,
                             **kwargs) -> Tuple[Dataset, Dataset, Dataset]:
      **kwargs) -> Tuple[Dataset, Optional[Dataset], Dataset]:
    """ Splits self into train/validation/test sets.

    Returns Dataset objects for train, valid, test.
@@ -150,7 +151,7 @@ class Splitter(object):

    Returns
    -------
    Tuple[Dataset, Dataset, Dataset]
    Tuple[Dataset, Optional[Dataset], Dataset]
      A tuple of train, valid and test datasets as dc.data.Dataset objects.
    """
    logger.info("Computing train/valid/test indices")
@@ -169,7 +170,7 @@ class Splitter(object):
      test_dir = tempfile.mkdtemp()
    train_dataset = dataset.select(train_inds, train_dir)
    if frac_valid != 0:
      valid_dataset = dataset.select(valid_inds, valid_dir)
      valid_dataset: Optional[Dataset] = dataset.select(valid_inds, valid_dir)
    else:
      valid_dataset = None
    test_dataset = dataset.select(test_inds, test_dir)