Commit 8a9be632 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

More ImageLoader work

parent b9f74b99
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -19,4 +19,5 @@ from deepchem.data.data_loader import CSVLoader
from deepchem.data.data_loader import UserCSVLoader
from deepchem.data.data_loader import SDFLoader
from deepchem.data.data_loader import FASTALoader
from deepchem.data.data_loader import ImageLoader
import deepchem.data.tests
+58 −0
Original line number Diff line number Diff line
@@ -21,6 +21,9 @@ from deepchem.utils.save import load_sdf_files
from deepchem.utils.save import encode_fasta_sequence
from deepchem.feat import UserDefinedFeaturizer
from deepchem.data import DiskDataset
from deepchem.data import NumpyDataset
from scipy import misc
import zipfile


def convert_df_to_numpy(df, tasks, verbose=False):
@@ -325,3 +328,58 @@ class FASTALoader(DataLoader):
        yield X, None, None, ids

    return DiskDataset.create_dataset(shard_generator(), data_dir)

class ImageLoader(DataLoader):
  """
  Handles loading of image files.
  """
  def __init__(self, tasks=None):
    """Initialize image loader."""
    if tasks is None:
      tasks = []
    self.tasks = tasks

  def featurize(self, input_files):
    """Featurizes image files.

    Parameters
    ----------
    input_files: list
      Each file in this list should either be of a supported image format (.png only for now) or of a compressed folder of image files.
    """
    if not isinstance(input_files, list):
      input_files = [input_files]

    images = []
    image_files = []
    print("input_files")
    print(input_files)
    for input_file in input_files:
      print("input_file")
      print(input_file)
      filename, extension = os.path.splitext(input_file)
      print("filename, extension")
      print(filename, extension)
      # TODO(rbharath): Add support for more extensions
      if extension == ".zip":
        zip_dir = tempfile.mkdtemp()
        print("zip_dir")
        print(zip_dir)
        zip_ref = zipfile.ZipFile(input_file, 'r')
        zip_ref.extractall(path=zip_dir)
        zip_ref.close()
        print("os.listdir(zip_dir)")
        print(os.listdir(zip_dir))
        image_files += os.listdir(zip_dir)
      elif extension == ".png":
        image_files.append(input_file)
    print("image_files")
    print(image_files)

    for image_file in image_files:
      print("image_file")
      print(image_file)
      image = misc.imread(image_file)
      images.append(image)
    images = np.array(images)
    return NumpyDataset(images)
+42 −0
Original line number Diff line number Diff line
"""
Tests for ImageLoader.
"""
from __future__ import division
from __future__ import unicode_literals

import os
import unittest
import tempfile
from scipy import misc
import deepchem as dc
import zipfile


class TestImageLoader(unittest.TestCase):
  """
  Test ImageLoader
  """
  def setUp(self):
    super(TestImageLoader, self).setUp()

    # Create image file
    self.data_dir = tempfile.mkdtemp() 
    self.face = misc.face()
    self.face_path = os.path.join(self.data_dir, "face.png")
    misc.imsave(self.face_path, self.face)

    # Create zip of image files
    self.zip_path = os.path.join(self.data_dir, "face.zip")
    zipf = zipfile.ZipFile(self.zip_path, "w", zipfile.ZIP_DEFLATED)
    zipf.write(self.face_path)
    zipf.close()

  def test_simple_load(self):
    loader = dc.data.ImageLoader()
    dataset = loader.featurize(self.face_path)
    # These are the known dimensions of face.png
    assert dataset.X.shape == (1, 768, 1024, 3)

  def test_zip_load(self):
    loader = dc.data.ImageLoader()
    dataset = loader.featurize(self.zip_path)