Commit 2c3b296d authored by nd-02110114's avatar nd-02110114
Browse files

♻️ pass a reference to ImageDataset and NumpyDataset

parent 7b1c1612
Loading
Loading
Loading
Loading
+61 −24
Original line number Diff line number Diff line
import math
import multiprocessing

from typing import List, Union
import numpy as np
import torch

from deepchem.data.datasets import pad_batch
from deepchem.data.data_loader import ImageLoader
from deepchem.data.datasets import NumpyDataset, DiskDataset, ImageDataset


class TorchNumpyDataset(torch.utils.data.IterableDataset):  # type: ignore

  def __init__(self, X, y, w, ids, n_samples, epochs, deterministic):
    self._X = X
    self._y = y
    self._w = w
    self._ids = ids
    self.n_samples = n_samples
  def __init__(self, numpy_dataset: NumpyDataset, epochs: int, deterministic: bool):
    """
    Parameters
    ----------
    numpy_dataset: NumpyDataset
      The original NumpyDataset which you want to convert to PyTorch
    epochs: int
      the number of times to iterate over the Dataset
    deterministic: bool
      if True, the data is produced in order.  If False, a different random
      permutation of the data is used for each epoch.
    """
    self.numpy_dataset = numpy_dataset
    self.epochs = epochs
    self.deterministic = deterministic

  def __iter__(self):
    n_samples = self.n_samples
    n_samples = self.numpy_dataset._X.shape[0]
    worker_info = torch.utils.data.get_worker_info()
    if worker_info is None:
      first_sample = 0
@@ -34,12 +39,23 @@ class TorchNumpyDataset(torch.utils.data.IterableDataset): # type: ignore
      else:
        order = first_sample + np.random.permutation(last_sample - first_sample)
      for i in order:
        yield (self._X[i], self._y[i], self._w[i], self._ids[i])
        yield (self.numpy_dataset._X[i], self.numpy_dataset._y[i], self.numpy_dataset._w[i], self.numpy_dataset._ids[i])


class TorchDiskDataset(torch.utils.data.IterableDataset):  # type: ignore

  def __init__(self, disk_dataset, epochs, deterministic):
  def __init__(self, disk_dataset: DiskDataset, epochs: int, deterministic: bool):
    """
    Parameters
    ----------
    disk_dataset: DiskDataset
      The original DiskDataset which you want to convert to PyTorch
    epochs: int
      the number of times to iterate over the Dataset
    deterministic: bool
      if True, the data is produced in order.  If False, a different random
      permutation of the data is used for each epoch.
    """
    self.disk_dataset = disk_dataset
    self.epochs = epochs
    self.deterministic = deterministic
@@ -66,17 +82,24 @@ class TorchDiskDataset(torch.utils.data.IterableDataset): # type: ignore

class TorchImageDataset(torch.utils.data.IterableDataset):  # type: ignore

  def __init__(self, X, y, w, ids, n_samples, epochs, deterministic):
    self._X = X
    self._y = y
    self._w = w
    self._ids = ids
    self.n_samples = n_samples
  def __init__(self, image_dataset: ImageDataset, epochs: int, deterministic: bool):
    """
    Parameters
    ----------
    image_dataset: ImageDataset
      The original ImageDataset which you want to convert to PyTorch
    epochs: int
      the number of times to iterate over the Dataset
    deterministic: bool
      if True, the data is produced in order.  If False, a different random
      permutation of the data is used for each epoch.
    """
    self.image_dataset = image_dataset
    self.epochs = epochs
    self.deterministic = deterministic

  def __iter__(self):
    n_samples = self.n_samples
    n_samples = self.image_dataset._X.shape[0]
    worker_info = torch.utils.data.get_worker_info()
    if worker_info is None:
      first_sample = 0
@@ -90,10 +113,24 @@ class TorchImageDataset(torch.utils.data.IterableDataset): # type: ignore
      else:
        order = first_sample + np.random.permutation(last_sample - first_sample)
      for i in order:
        yield (self._get_image(self._X, i), self._get_image(self._y, i),
               self._w[i], self._ids[i])

  def _get_image(self, array, index):
        yield (self._get_image(self.image_dataset._X, i), self._get_image(self.image_dataset._y, i),
               self.image_dataset._w[i], self.image_dataset._ids[i])

  def _get_image(self, array: Union[np.ndarray, List[str]], index: int) -> np.ndarray:
    """Function for loading an image

    Parameters
    ----------
    array: Union[np.ndarray, List[str]]
      A numpy array which contains all images or List of image filenames
    index: int
      Index you want to get the images

    Returns
    -------
    np.ndarray
      Loaded image
    """
    if isinstance(array, np.ndarray):
      return array[index]
    return ImageLoader.load_img([array[index]])[0]