Commit 277fc976 authored by nd-02110114's avatar nd-02110114
Browse files

👌 update codes by review

parent 3c8f339f
Loading
Loading
Loading
Loading
+37 −18
Original line number Diff line number Diff line
@@ -317,7 +317,6 @@ class Dataset(object):
    `iterbatches()` or `itersamples()` may be more efficient for
    larger datasets.
    """

    raise NotImplementedError()

  @property
@@ -387,7 +386,16 @@ class Dataset(object):
    raise NotImplementedError()

  def itersamples(self) -> Iterator[Batch]:
    """Get an object that iterates over the samples in the dataset."""
    """Get an object that iterates over the samples in the dataset.

    Examples
    --------
    >>> dataset = NumpyDataset(np.ones((2,2)))
    >>> for x, y, w, id in dataset.itersamples():
    ...   print(x.tolist(), y.tolist(), w.tolist(), id)
    [1.0, 1.0] [0.0] [0.0] 0
    [1.0, 1.0] [0.0] [0.0] 1
    """
    raise NotImplementedError()

  def transform(self, transformer: "dc.trans.Transformer", **args) -> "Dataset":
@@ -733,7 +741,10 @@ class NumpyDataset(Dataset):
    return len(self._y)

  def get_shape(self) -> Tuple[Shape, Shape, Shape, Shape]:
    """Get the shape of the dataset."""
    """Get the shape of the dataset.

    Returns four tuples, giving the shape of the X, y, w, and ids arrays.
    """
    return self._X.shape, self._y.shape, self._w.shape, self._ids.shape

  def get_task_names(self) -> np.ndarray:
@@ -2018,7 +2029,7 @@ class DiskDataset(Dataset):
    N = len(self)
    perm = np.random.permutation(N)
    shard_size = self.get_shard_size()
    return self.select(perm, data_dir, self.get_shard_size())
    return self.select(perm, data_dir, shard_size)

  def shuffle_each_shard(self,
                         shard_basenames: Optional[List[str]] = None) -> None:
@@ -2272,7 +2283,7 @@ class DiskDataset(Dataset):
    Returns
    -------
    DiskDataset
      A DiskDataset contains selected samples.
      A Dataset containing the selected samples
    """
    if output_numpy_dataset and (select_dir is not None or
                                 select_shard_size is not None):
@@ -2507,7 +2518,10 @@ class DiskDataset(Dataset):
    return X_shape, y_shape, w_shape, ids_shape

  def get_shape(self) -> Tuple[Shape, Shape, Shape, Shape]:
    """Finds shape of dataset."""
    """Finds shape of dataset.

    Returns four tuples, giving the shape of the X, y, w, and ids arrays.
    """
    n_tasks = len(self.get_task_names())
    n_rows = len(self.metadata_df.index)
    # If shape metadata is available use it to directly compute shape from
@@ -2621,7 +2635,10 @@ class ImageDataset(Dataset):
    return self._X_shape[0]

  def get_shape(self) -> Tuple[Shape, Shape, Shape, Shape]:
    """Get the shape of the dataset."""
    """Get the shape of the dataset.

    Returns four tuples, giving the shape of the X, y, w, and ids arrays.
    """
    return self._X_shape, self._y_shape, self._w.shape, self._ids.shape

  def get_task_names(self) -> np.ndarray:
@@ -2716,24 +2733,26 @@ class ImageDataset(Dataset):
    return iterate(self, batch_size, epochs, deterministic, pad_batches)

  def _get_image(self, array: Union[np.ndarray, List[str]],
                 index: int) -> np.ndarray:
                 indices: Union[int, np.ndarray]) -> np.ndarray:
    """Method for loading an image

    Parameters
    ----------
    array: Union[np.ndarray, List[str]]
      A numpy array which contains images or List of image filenames
    index: int
      Index you want to get the image
    indices: Union[int, np.ndarray]
      Index you want to get the images

    Returns
    -------
    np.ndarray
      Loaded image
      Loaded images
    """
    if isinstance(array, np.ndarray):
      return array[index]
    return load_image_files([array[index]])[0]
      return array[indices]
    if isinstance(indices, np.ndarray):
      return load_image_files([array[i] for i in indices])
    return load_image_files([array[indices]])[0]

  def itersamples(self) -> Iterator[Batch]:
    """Get an object that iterates over the samples in the dataset.
@@ -2751,7 +2770,7 @@ class ImageDataset(Dataset):
      self,
      transformer: "dc.trans.Transformer",
      **args,
  ) -> "ImageDataset":
  ) -> "NumpyDataset":
    """Construct a new dataset by applying a transformation to every sample in this dataset.

    The argument is a function that can be called as follows:
@@ -2769,12 +2788,12 @@ class ImageDataset(Dataset):

    Returns
    -------
    ImageDataset
      A newly constructed ImageDataset object
    NumpyDataset
      A newly constructed NumpyDataset object
    """
    newx, newy, neww, newids = transformer.transform_array(
        self.X, self.y, self.w, self.ids)
    return ImageDataset(newx, newy, neww, newids)
    return NumpyDataset(newx, newy, neww, newids)

  def select(self, indices: Sequence[int],
             select_dir: Optional[str] = None) -> "ImageDataset":
+2 −2
Original line number Diff line number Diff line
@@ -160,6 +160,6 @@ class _TorchImageDataset(torch.utils.data.IterableDataset): # type: ignore
        for i in range(0, len(order), self.batch_size):
          indices = order[i:i + self.batch_size]
          yield (self.image_dataset._get_image(self.image_dataset._X, indices),
                 self.image_dataset._get_image(self.image_dataset._y,
                                 indices), self.image_dataset._w[indices],
                 self.image_dataset._get_image(self.image_dataset._y, indices),
                 self.image_dataset._w[indices],
                 self.image_dataset._ids[indices])