Commit 6604d60b authored by nd-02110114's avatar nd-02110114
Browse files

♻️ move ImageLoader.load_img to utils

parent 7c87e8e0
Loading
Loading
Loading
Loading
+12 −15
Original line number Diff line number Diff line
@@ -4,22 +4,21 @@ Contains wrapper class for datasets.
import json
import os
import math
import deepchem as dc
import numpy as np
import pandas as pd
import random
import logging
import tempfile
import time
import shutil
import warnings
import multiprocessing
from deepchem.utils.save import save_to_disk
from deepchem.utils.save import load_from_disk
from ast import literal_eval as make_tuple

from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd

import deepchem as dc
from deepchem.utils.typing import OneOrMany, Shape
from deepchem.utils.save import save_to_disk, load_from_disk, load_image_files

Batch = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]

@@ -2374,7 +2373,7 @@ class ImageDataset(Dataset):
  def _find_array_shape(self, array: Sequence) -> Shape:
    if isinstance(array, np.ndarray):
      return array.shape
    image_shape = dc.data.ImageLoader.load_img([array[0]]).shape[1:]
    image_shape = load_image_files([array[0]]).shape[1:]
    return np.concatenate([[len(array)], image_shape])

  def __len__(self) -> int:
@@ -2402,14 +2401,14 @@ class ImageDataset(Dataset):
    """Get the X vector for this dataset as a single numpy array."""
    if isinstance(self._X, np.ndarray):
      return self._X
    return dc.data.ImageLoader.load_img(self._X)
    return load_image_files(self._X)

  @property
  def y(self) -> np.ndarray:
    """Get the y vector for this dataset as a single numpy array."""
    if isinstance(self._y, np.ndarray):
      return self._y
    return dc.data.ImageLoader.load_img(self._y)
    return load_image_files(self._y)

  @property
  def ids(self) -> np.ndarray:
@@ -2451,13 +2450,11 @@ class ImageDataset(Dataset):
          if isinstance(dataset._X, np.ndarray):
            X_batch = dataset._X[perm_indices]
          else:
            X_batch = dc.data.ImageLoader.load_img(
                [dataset._X[i] for i in perm_indices])
            X_batch = load_image_files([dataset._X[i] for i in perm_indices])
          if isinstance(dataset._y, np.ndarray):
            y_batch = dataset._y[perm_indices]
          else:
            y_batch = dc.data.ImageLoader.load_img(
                [dataset._y[i] for i in perm_indices])
            y_batch = load_image_files([dataset._y[i] for i in perm_indices])
          w_batch = dataset._w[perm_indices]
          ids_batch = dataset._ids[perm_indices]
          if pad_batches:
@@ -2483,7 +2480,7 @@ class ImageDataset(Dataset):
    def get_image(array, index):
      if isinstance(array, np.ndarray):
        return array[index]
      return dc.data.ImageLoader.load_img([array[index]])[0]
      return load_image_files([array[index]])[0]

    n_samples = self._X_shape[0]
    return ((get_image(self._X, i), get_image(self._y, i), self._w[i],
+2 −2
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@ from typing import List, Union
import numpy as np
import torch

from deepchem.data.data_loader import ImageLoader
from deepchem.utils.save import load_image_files
from deepchem.data.datasets import NumpyDataset, DiskDataset, ImageDataset


@@ -139,4 +139,4 @@ class _TorchImageDataset(torch.utils.data.IterableDataset): # type: ignore
    """
    if isinstance(array, np.ndarray):
      return array[index]
    return ImageLoader.load_img([array[index]])[0]
    return load_image_files([array[index]])[0]
+38 −0
Original line number Diff line number Diff line
@@ -80,6 +80,44 @@ def load_data(input_files: List[str],
      yield load_pickle_from_disk(input_file)


def load_image_files(image_files: List[str]) -> np.ndarray:
  """Loads a set of images from disk.

  Parameters
  ----------
  image_files: List[str]
    List of image filenames to load.

  Returns
  -------
  np.ndarray
    A numpy array that contains loaded images. The shape is, `(N,...)`.

  Notes
  -----
  This method requires Pillow to be installed.
  """
  try:
    from PIL import Image
  except ModuleNotFoundError:
    raise ValueError("This function requires Pillow to be installed.")

  images = []
  for image_file in image_files:
    _, extension = os.path.splitext(image_file)
    extension = extension.lower()
    if extension == ".png":
      image = np.array(Image.open(image_file))
      images.append(image)
    elif extension == ".tif":
      im = Image.open(image_file)
      imarray = np.array(im)
      images.append(imarray)
    else:
      raise ValueError("Unsupported image filetype for %s" % image_file)
  return np.array(images)


def load_sdf_files(input_files: List[str],
                   clean_mols: bool = True,
                   tasks: List[str] = [],