Commit 2b220df7 authored by miaecle's avatar miaecle
Browse files

fix data loader

parent 45c27e86
Loading
Loading
Loading
Loading
+21 −9
Original line number Diff line number Diff line
@@ -349,7 +349,11 @@ class ImageLoader(DataLoader):
      tasks = []
    self.tasks = tasks

  def featurize(self, input_files, labels=None, in_memory=True):
  def featurize(self, 
                input_files, 
                labels=None, 
                read_img=True,
                in_memory=True):
    """Featurizes image files.

    Parameters
@@ -395,6 +399,21 @@ class ImageLoader(DataLoader):
          raise ValueError("Unsupported file format")
      input_files = remainder

    if read_img:
      X = self.load_img(image_files)
    else:
      X = [None] * len(image_files)
    if in_memory:
      return NumpyDataset(X, y=labels, ids=image_files)
      
    else:
      # from_numpy currently requires labels. Make dummy labels
      if labels is None:
        labels = np.zeros((len(image_files), 1))
      return DiskDataset.from_numpy(X, labels, ids=image_files)
  
  @staticmethod
  def load_img(image_files):
    images = []
    for image_file in image_files:
      _, extension = os.path.splitext(image_file)
@@ -407,11 +426,4 @@ class ImageLoader(DataLoader):
        images.append(imarray)
      else:
        raise ValueError("Unsupported image filetype for %s" % image_file)
    images = np.array(images)
    if in_memory:
      return NumpyDataset(images, y=labels)
    else:
      # from_numpy currently requires labels. Make dummy labels
      if labels is None:
        labels = np.zeros((len(images), 1))
      return DiskDataset.from_numpy(images, labels)
    return np.array(images)
 No newline at end of file