Commit ffeb4b7c authored by peastman's avatar peastman
Browse files

Fixes to support ImageDataset

parent 9fb18d53
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -52,7 +52,7 @@ def load_images_DR(split='random', seed=None):

  loader = deepchem.data.ImageLoader()
  dat = loader.featurize(
      image_full_paths, labels=labels, weights=weights, read_img=False)
      image_full_paths, labels=labels, weights=weights)
  if split == None:
    return dat

+2 −28
Original line number Diff line number Diff line
@@ -142,32 +142,6 @@ class DRModel(TensorGraph):
    # weighted_loss = WeightDecay(0.1, 'l2', in_layers=[weighted_loss])
    self.set_loss(weighted_loss)

  def default_generator(self,
                        dataset,
                        epochs=1,
                        predict=False,
                        deterministic=True,
                        pad_batches=True):
    for epoch in range(epochs):
      for (X_b, y_b, w_b, ids_b) in dataset.iterbatches(
          batch_size=self.batch_size,
          deterministic=deterministic,
          pad_batches=pad_batches):
        feed_dict = dict()

        if None in X_b:
          # load images on the fly
          feed_dict[self.features[0]] = ImageLoader.load_img(ids_b)
        else:
          feed_dict[self.features[0]] = X_b

        if y_b is not None and not predict:
          feed_dict[self.labels[0]] = y_b
        if w_b is not None and not predict:
          feed_dict[self.task_weights[0]] = w_b

        yield feed_dict


def DRAccuracy(y, y_pred):
  y_pred = np.argmax(y_pred, 1)
+11 −5
Original line number Diff line number Diff line
@@ -1529,11 +1529,17 @@ class ImageDataset(Dataset):
    select_dir: string
      Ignored.
    """
    X = self.X[indices]
    y = self.y[indices]
    w = self.w[indices]
    ids = self.ids[indices]
    return NumpyDataset(X, y, w, ids)
    if isinstance(self._X, np.ndarray):
      X = self._X[indices]
    else:
      X = [self._X[i] for i in indices]
    if isinstance(self._y, np.ndarray):
      y = self._y[indices]
    else:
      y = [self._y[i] for i in indices]
    w = self._w[indices]
    ids = self._ids[indices]
    return ImageDataset(X, y, w, ids)


class Databag(object):
+296 B
Loading image diff...
+275 B
Loading image diff...
Loading