Commit d0b90f4f authored by nd-02110114's avatar nd-02110114
Browse files

🚨 fix yapf and mypy error

parent 3a9e6dc9
Loading
Loading
Loading
Loading
+2 −14
Original line number Diff line number Diff line
@@ -871,13 +871,7 @@ class NumpyDataset(Dataset):
      raise ValueError("This method requires PyTorch to be installed.")

    pytorch_ds = TorchNumpyDataset(
        X=self._X,
        y=self._y,
        w=self._w,
        ids=self._ids,
        n_samples=self._X.shape[0],
        epochs=epochs,
        deterministic=deterministic)
        numpy_dataset=self, epochs=epochs, deterministic=deterministic)
    return pytorch_ds

  @staticmethod
@@ -2243,13 +2237,7 @@ class ImageDataset(Dataset):
      raise ValueError("This method requires PyTorch to be installed.")

    pytorch_ds = TorchImageDataset(
        X=self.X,
        y=self.y,
        w=self.w,
        ids=self._ids,
        n_samples=self._X_shape[0],
        epochs=epochs,
        deterministic=deterministic)
        image_dataset=self, epochs=epochs, deterministic=deterministic)
    return pytorch_ds


+15 −9
Original line number Diff line number Diff line
@@ -8,7 +8,8 @@ from deepchem.data.datasets import NumpyDataset, DiskDataset, ImageDataset

class TorchNumpyDataset(torch.utils.data.IterableDataset):  # type: ignore

  def __init__(self, numpy_dataset: NumpyDataset, epochs: int, deterministic: bool):
  def __init__(self, numpy_dataset: NumpyDataset, epochs: int,
               deterministic: bool):
    """
    Parameters
    ----------
@@ -39,12 +40,14 @@ class TorchNumpyDataset(torch.utils.data.IterableDataset): # type: ignore
      else:
        order = first_sample + np.random.permutation(last_sample - first_sample)
      for i in order:
        yield (self.numpy_dataset._X[i], self.numpy_dataset._y[i], self.numpy_dataset._w[i], self.numpy_dataset._ids[i])
        yield (self.numpy_dataset._X[i], self.numpy_dataset._y[i],
               self.numpy_dataset._w[i], self.numpy_dataset._ids[i])


class TorchDiskDataset(torch.utils.data.IterableDataset):  # type: ignore

  def __init__(self, disk_dataset: DiskDataset, epochs: int, deterministic: bool):
  def __init__(self, disk_dataset: DiskDataset, epochs: int,
               deterministic: bool):
    """
    Parameters
    ----------
@@ -82,7 +85,8 @@ class TorchDiskDataset(torch.utils.data.IterableDataset): # type: ignore

class TorchImageDataset(torch.utils.data.IterableDataset):  # type: ignore

  def __init__(self, image_dataset: ImageDataset, epochs: int, deterministic: bool):
  def __init__(self, image_dataset: ImageDataset, epochs: int,
               deterministic: bool):
    """
    Parameters
    ----------
@@ -99,7 +103,7 @@ class TorchImageDataset(torch.utils.data.IterableDataset): # type: ignore
    self.deterministic = deterministic

  def __iter__(self):
    n_samples = self.image_dataset._X.shape[0]
    n_samples = self.image_dataset._X_shape[0]
    worker_info = torch.utils.data.get_worker_info()
    if worker_info is None:
      first_sample = 0
@@ -113,18 +117,20 @@ class TorchImageDataset(torch.utils.data.IterableDataset): # type: ignore
      else:
        order = first_sample + np.random.permutation(last_sample - first_sample)
      for i in order:
        yield (self._get_image(self.image_dataset._X, i), self._get_image(self.image_dataset._y, i),
        yield (self._get_image(self.image_dataset._X, i),
               self._get_image(self.image_dataset._y, i),
               self.image_dataset._w[i], self.image_dataset._ids[i])

  def _get_image(self, array: Union[np.ndarray, List[str]], index: int) -> np.ndarray:
  def _get_image(self, array: Union[np.ndarray, List[str]],
                 index: int) -> np.ndarray:
    """Function for loading an image

    Parameters
    ----------
    array: Union[np.ndarray, List[str]]
      A numpy array which contains all images or List of image filenames
      A numpy array which contains images or List of image filenames
    index: int
      Index you want to get the images
      Index you want to get the image

    Returns
    -------