Commit 0a35d6e7 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Updates

parent 8a9be632
Loading
Loading
Loading
Loading
+13 −18
Original line number Diff line number Diff line
@@ -229,7 +229,8 @@ class DataLoader(object):
          assert len(X) == len(ids)

        time2 = time.time()
        log("TIMING: featurizing shard %d took %0.3f s" %
        log(
            "TIMING: featurizing shard %d took %0.3f s" %
            (shard_num, time2 - time1), self.verbose)
        yield X, y, w, ids

@@ -293,7 +294,8 @@ class SDFLoader(DataLoader):

  def featurize_shard(self, shard):
    """Featurizes a shard of an input dataframe."""
    log("Currently featurizing feature_type: %s" %
    log(
        "Currently featurizing feature_type: %s" %
        self.featurizer.__class__.__name__, self.verbose)
    return featurize_mol_df(shard, self.featurizer, field=self.mol_field)

@@ -329,10 +331,12 @@ class FASTALoader(DataLoader):

    return DiskDataset.create_dataset(shard_generator(), data_dir)


class ImageLoader(DataLoader):
  """
  Handles loading of image files.
  """

  def __init__(self, tasks=None):
    """Initialize image loader."""
    if tasks is None:
@@ -345,40 +349,31 @@ class ImageLoader(DataLoader):
    Parameters
    ----------
    input_files: list
      Each file in this list should either be of a supported image format (.png only for now) or of a compressed folder of image files.
      Each file in this list should either be of a supported image format (.png
      only for now) or of a compressed folder of image files (only .zip for now).
    """
    if not isinstance(input_files, list):
      input_files = [input_files]

    images = []
    image_files = []
    print("input_files")
    print(input_files)
    for input_file in input_files:
      print("input_file")
      print(input_file)
      filename, extension = os.path.splitext(input_file)
      print("filename, extension")
      print(filename, extension)
      # TODO(rbharath): Add support for more extensions
      if extension == ".zip":
        zip_dir = tempfile.mkdtemp()
        print("zip_dir")
        print(zip_dir)
        zip_ref = zipfile.ZipFile(input_file, 'r')
        zip_ref.extractall(path=zip_dir)
        zip_ref.close()
        print("os.listdir(zip_dir)")
        print(os.listdir(zip_dir))
        image_files += os.listdir(zip_dir)
        image_files += [
            os.path.join(zip_dir, name) for name in zip_ref.namelist()
        ]
      elif extension == ".png":
        image_files.append(input_file)
    print("image_files")
    print(image_files)
      else:
        raise ValueError("Unsupported file format")

    for image_file in image_files:
      print("image_file")
      print(image_file)
      image = misc.imread(image_file)
      images.append(image)
    images = np.array(images)
+24 −2
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ class TestImageLoader(unittest.TestCase):
  """
  Test ImageLoader
  """

  def setUp(self):
    super(TestImageLoader, self).setUp()

@@ -24,19 +25,40 @@ class TestImageLoader(unittest.TestCase):
    self.face = misc.face()
    self.face_path = os.path.join(self.data_dir, "face.png")
    misc.imsave(self.face_path, self.face)
    self.face_copy_path = os.path.join(self.data_dir, "face.png")
    misc.imsave(self.face_copy_path, self.face)

    # Create zip of image files
    # Create zip of image file
    #self.zip_path = "/home/rbharath/misc/cells.zip"
    self.zip_path = os.path.join(self.data_dir, "face.zip")
    zipf = zipfile.ZipFile(self.zip_path, "w", zipfile.ZIP_DEFLATED)
    zipf.write(self.face_path)
    zipf.close()

    # Create zip of multiple image files
    self.multi_zip_path = os.path.join(self.data_dir, "multi_face.zip")
    zipf = zipfile.ZipFile(self.multi_zip_path, "w", zipfile.ZIP_DEFLATED)
    zipf.write(self.face_path)
    zipf.write(self.face_copy_path)
    zipf.close()

  def test_simple_load(self):
    loader = dc.data.ImageLoader()
    dataset = loader.featurize(self.face_path)
    # These are the known dimensions of face.png
    assert dataset.X.shape == (1, 768, 1024, 3)

  def test_multi_load(self):
    loader = dc.data.ImageLoader()
    dataset = loader.featurize([self.face_path, self.face_copy_path])
    assert dataset.X.shape == (2, 768, 1024, 3)

  def test_zip_load(self):
    loader = dc.data.ImageLoader()
    dataset = loader.featurize(self.zip_path)
    assert dataset.X.shape == (1, 768, 1024, 3)

  def test_multi_zip_load(self):
    loader = dc.data.ImageLoader()
    dataset = loader.featurize(self.multi_zip_path)
    assert dataset.X.shape == (2, 768, 1024, 3)
+1 −0
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@ from __future__ import unicode_literals

from deepchem.molnet.load_function.bace_datasets import load_bace_classification, load_bace_regression
from deepchem.molnet.load_function.bbbp_datasets import load_bbbp
from deepchem.molnet.load_function.cell_counting_datasets import load_cell_counting
from deepchem.molnet.load_function.chembl_datasets import load_chembl
from deepchem.molnet.load_function.clearance_datasets import load_clearance
from deepchem.molnet.load_function.clintox_datasets import load_clintox
+19 −10
Original line number Diff line number Diff line
@@ -5,8 +5,17 @@ Loads the cell counting dataset from
http://www.robots.ox.ac.uk/~vgg/research/counting/index_org.html. Labels aren't
available for this dataset, so only raw images are provided.
"""
from __future__ import division
from __future__ import unicode_literals

def load_cell_counting(split=None):
import os
import logging
import deepchem

logger = logging.getLogger(__name__)


def load_cell_counting(split=None, reload=True):
  """Load Cell Counting dataset.
  
  Loads the cell counting dataset from http://www.robots.ox.ac.uk/~vgg/research/counting/index_org.html.
@@ -14,8 +23,11 @@ def load_cell_counting(split=None):
  data_dir = deepchem.utils.get_data_dir()
  # No tasks since no labels provided.
  cell_counting_tasks = []
  # For now images are loaded directly by ImageLoader
  featurizer = ""
  if reload:
    save_dir = os.path.join(data_dir, "cell_counting/" + featurizer + "/" + str(split))
    save_dir = os.path.join(data_dir,
                            "cell_counting/" + featurizer + "/" + str(split))
    loaded, all_dataset, transformers = deepchem.utils.save.load_dataset_from_disk(
        save_dir)
    if loaded:
@@ -23,17 +35,15 @@ def load_cell_counting(split=None):
  dataset_file = os.path.join(data_dir, "cells.zip")
  if not os.path.exists(dataset_file):
    deepchem.utils.download_url(
        'http://www.robots.ox.ac.uk/~vgg/research/counting/cells.zip'
    )
        'http://www.robots.ox.ac.uk/~vgg/research/counting/cells.zip')

  loader = deepchem.data.ImageLoader(
      tasks=cell_counting_tasks)
  dataset = loader.featurize(dataset_file, shard_size=8192)
  loader = deepchem.data.ImageLoader()
  dataset = loader.featurize(dataset_file)

  transformers = []

  if split == None:
    return tox21_tasks, (dataset, None, None), transformers
    return cell_counting_tasks, (dataset, None, None), transformers

  splitters = {
      'index': deepchem.splits.IndexSplitter(),
@@ -48,4 +58,3 @@ def load_cell_counting(split=None):
    deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
                                             transformers)
  return cell_counting_tasks, all_dataset, transformers