Commit 72d36b32 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changes

parent b67fb466
Loading
Loading
Loading
Loading
+2 −11
Original line number Diff line number Diff line
@@ -367,13 +367,11 @@ class ImageLoader(DataLoader):
    image_files = []
    # Sometimes zip files contain directories within. Traverse directories
    while len(input_files) > 0:
      print("ITERATION!!")
      remainder = []
      for input_file in input_files:
        filename, extension = os.path.splitext(input_file)
        # TODO(rbharath): Add support for more extensions
        if os.path.isdir(input_file):
          print("DIRECTORY!")
          dirfiles = [os.path.join(input_file, subfile) for subfile in os.listdir(input_file)]
          remainder += dirfiles
        elif extension == ".zip":
@@ -383,23 +381,16 @@ class ImageLoader(DataLoader):
          zip_ref.close()
          zip_files = [os.path.join(zip_dir, name) for name in zip_ref.namelist()]
          for zip_file in zip_files:
            if os.path.isdir(zip_file):
              remainder.append(zip_file)
            else:
            _, extension = os.path.splitext(zip_file)
            if extension in [".png", ".tif"]:
              image_files.append(zip_file)
        elif extension in [".png", ".tif"]:
          image_files.append(input_file)
        else:
          raise ValueError("Unsupported file format")
      input_files = remainder
      print("remainder")
      print(remainder)

    images = []
    print("image_files")
    print(image_files)
    print("len(image_files)")
    print(len(image_files))
    for image_file in image_files:
      _, extension = os.path.splitext(image_file) 
      if extension == ".png":
+13 −3
Original line number Diff line number Diff line
@@ -27,7 +27,7 @@ class TestImageLoader(unittest.TestCase):
    self.face = misc.face()
    self.face_path = os.path.join(self.data_dir, "face.png")
    misc.imsave(self.face_path, self.face)
    self.face_copy_path = os.path.join(self.data_dir, "face.png")
    self.face_copy_path = os.path.join(self.data_dir, "face_copy.png")
    misc.imsave(self.face_copy_path, self.face)

    # Create zip of image file
@@ -51,6 +51,13 @@ class TestImageLoader(unittest.TestCase):
    zipf.write(self.tif_image_path)
    zipf.close()

    # Create image directory 
    self.image_dir = tempfile.mkdtemp()
    face_path = os.path.join(self.image_dir, "face.png")
    misc.imsave(face_path, self.face)
    face_copy_path = os.path.join(self.image_dir, "face_copy.png")
    misc.imsave(face_copy_path, self.face)

  def test_png_simple_load(self):
    loader = dc.data.ImageLoader()
    dataset = loader.featurize(self.face_path)
@@ -76,8 +83,6 @@ class TestImageLoader(unittest.TestCase):
  def test_png_multi_zip_load(self):
    loader = dc.data.ImageLoader()
    dataset = loader.featurize(self.multi_zip_path)
    print("dataset.X.shape")
    print(dataset.X.shape)
    assert dataset.X.shape == (2, 768, 1024, 3)

  def test_multitype_zip_load(self):
@@ -85,3 +90,8 @@ class TestImageLoader(unittest.TestCase):
    dataset = loader.featurize(self.multitype_zip_path)
    # Since the different files have different shapes, makes an object array
    assert dataset.X.shape == (2,)

  def test_directory_load(self):
    loader = dc.data.ImageLoader()
    dataset = loader.featurize(self.image_dir)
    assert dataset.X.shape == (2, 768, 1024, 3)