Unverified Commit fd7c0266 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1324 from rbharath/cell_counting

Cell counting Dataset and ImageLoader
parents b0ccc4a5 0a35d6e7
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -19,4 +19,5 @@ from deepchem.data.data_loader import CSVLoader
from deepchem.data.data_loader import UserCSVLoader
from deepchem.data.data_loader import SDFLoader
from deepchem.data.data_loader import FASTALoader
from deepchem.data.data_loader import ImageLoader
import deepchem.data.tests
+55 −2
Original line number Diff line number Diff line
@@ -21,6 +21,9 @@ from deepchem.utils.save import load_sdf_files
from deepchem.utils.save import encode_fasta_sequence
from deepchem.feat import UserDefinedFeaturizer
from deepchem.data import DiskDataset
from deepchem.data import NumpyDataset
from scipy import misc
import zipfile


def convert_df_to_numpy(df, tasks, verbose=False):
@@ -226,7 +229,8 @@ class DataLoader(object):
          assert len(X) == len(ids)

        time2 = time.time()
        log("TIMING: featurizing shard %d took %0.3f s" %
        log(
            "TIMING: featurizing shard %d took %0.3f s" %
            (shard_num, time2 - time1), self.verbose)
        yield X, y, w, ids

@@ -290,7 +294,8 @@ class SDFLoader(DataLoader):

  def featurize_shard(self, shard):
    """Featurizes a shard of an input dataframe."""
    log("Currently featurizing feature_type: %s" %
    log(
        "Currently featurizing feature_type: %s" %
        self.featurizer.__class__.__name__, self.verbose)
    return featurize_mol_df(shard, self.featurizer, field=self.mol_field)

@@ -325,3 +330,51 @@ class FASTALoader(DataLoader):
        yield X, None, None, ids

    return DiskDataset.create_dataset(shard_generator(), data_dir)


class ImageLoader(DataLoader):
  """
  Handles loading of image files.
  """

  def __init__(self, tasks=None):
    """Initialize image loader."""
    if tasks is None:
      tasks = []
    self.tasks = tasks

  def featurize(self, input_files):
    """Featurizes image files.

    Parameters
    ----------
    input_files: list
      Each file in this list should either be of a supported image format (.png
      only for now) or of a compressed folder of image files (only .zip for now).
    """
    if not isinstance(input_files, list):
      input_files = [input_files]

    images = []
    image_files = []
    for input_file in input_files:
      filename, extension = os.path.splitext(input_file)
      # TODO(rbharath): Add support for more extensions
      if extension == ".zip":
        zip_dir = tempfile.mkdtemp()
        zip_ref = zipfile.ZipFile(input_file, 'r')
        zip_ref.extractall(path=zip_dir)
        zip_ref.close()
        image_files += [
            os.path.join(zip_dir, name) for name in zip_ref.namelist()
        ]
      elif extension == ".png":
        image_files.append(input_file)
      else:
        raise ValueError("Unsupported file format")

    for image_file in image_files:
      image = misc.imread(image_file)
      images.append(image)
    images = np.array(images)
    return NumpyDataset(images)
+64 −0
Original line number Diff line number Diff line
"""
Tests for ImageLoader.
"""
from __future__ import division
from __future__ import unicode_literals

import os
import unittest
import tempfile
from scipy import misc
import deepchem as dc
import zipfile


class TestImageLoader(unittest.TestCase):
  """
  Test ImageLoader
  """

  def setUp(self):
    super(TestImageLoader, self).setUp()

    # Create image file
    self.data_dir = tempfile.mkdtemp()
    self.face = misc.face()
    self.face_path = os.path.join(self.data_dir, "face.png")
    misc.imsave(self.face_path, self.face)
    self.face_copy_path = os.path.join(self.data_dir, "face.png")
    misc.imsave(self.face_copy_path, self.face)

    # Create zip of image file
    #self.zip_path = "/home/rbharath/misc/cells.zip"
    self.zip_path = os.path.join(self.data_dir, "face.zip")
    zipf = zipfile.ZipFile(self.zip_path, "w", zipfile.ZIP_DEFLATED)
    zipf.write(self.face_path)
    zipf.close()

    # Create zip of multiple image files
    self.multi_zip_path = os.path.join(self.data_dir, "multi_face.zip")
    zipf = zipfile.ZipFile(self.multi_zip_path, "w", zipfile.ZIP_DEFLATED)
    zipf.write(self.face_path)
    zipf.write(self.face_copy_path)
    zipf.close()

  def test_simple_load(self):
    loader = dc.data.ImageLoader()
    dataset = loader.featurize(self.face_path)
    # These are the known dimensions of face.png
    assert dataset.X.shape == (1, 768, 1024, 3)

  def test_multi_load(self):
    loader = dc.data.ImageLoader()
    dataset = loader.featurize([self.face_path, self.face_copy_path])
    assert dataset.X.shape == (2, 768, 1024, 3)

  def test_zip_load(self):
    loader = dc.data.ImageLoader()
    dataset = loader.featurize(self.zip_path)
    assert dataset.X.shape == (1, 768, 1024, 3)

  def test_multi_zip_load(self):
    loader = dc.data.ImageLoader()
    dataset = loader.featurize(self.multi_zip_path)
    assert dataset.X.shape == (2, 768, 1024, 3)
+1 −0
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@ from __future__ import unicode_literals

from deepchem.molnet.load_function.bace_datasets import load_bace_classification, load_bace_regression
from deepchem.molnet.load_function.bbbp_datasets import load_bbbp
from deepchem.molnet.load_function.cell_counting_datasets import load_cell_counting
from deepchem.molnet.load_function.chembl_datasets import load_chembl
from deepchem.molnet.load_function.clearance_datasets import load_clearance
from deepchem.molnet.load_function.clintox_datasets import load_clintox
+60 −0
Original line number Diff line number Diff line
"""
Cell Counting Dataset.

Loads the cell counting dataset from
http://www.robots.ox.ac.uk/~vgg/research/counting/index_org.html. Labels aren't
available for this dataset, so only raw images are provided.
"""
from __future__ import division
from __future__ import unicode_literals

import os
import logging
import deepchem

logger = logging.getLogger(__name__)


def load_cell_counting(split=None, reload=True):
  """Load Cell Counting dataset.
  
  Loads the cell counting dataset from http://www.robots.ox.ac.uk/~vgg/research/counting/index_org.html.
  """
  data_dir = deepchem.utils.get_data_dir()
  # No tasks since no labels provided.
  cell_counting_tasks = []
  # For now images are loaded directly by ImageLoader
  featurizer = ""
  if reload:
    save_dir = os.path.join(data_dir,
                            "cell_counting/" + featurizer + "/" + str(split))
    loaded, all_dataset, transformers = deepchem.utils.save.load_dataset_from_disk(
        save_dir)
    if loaded:
      return cell_counting_tasks, all_dataset, transformers
  dataset_file = os.path.join(data_dir, "cells.zip")
  if not os.path.exists(dataset_file):
    deepchem.utils.download_url(
        'http://www.robots.ox.ac.uk/~vgg/research/counting/cells.zip')

  loader = deepchem.data.ImageLoader()
  dataset = loader.featurize(dataset_file)

  transformers = []

  if split == None:
    return cell_counting_tasks, (dataset, None, None), transformers

  splitters = {
      'index': deepchem.splits.IndexSplitter(),
      'random': deepchem.splits.RandomSplitter(),
  }
  if split not in splitters:
    raise ValueError("Only index and random splits supported.")

  train, valid, test = splitter.train_valid_test_split(dataset)
  all_dataset = (train, valid, test)
  if reload:
    deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
                                             transformers)
  return cell_counting_tasks, all_dataset, transformers