Commit e3e36f18 authored by nd-02110114's avatar nd-02110114
Browse files

💚 fix ci error

parent d28431bc
Loading
Loading
Loading
Loading
+10 −11
Original line number Diff line number Diff line
@@ -939,6 +939,15 @@ class NumpyDataset(Dataset):
    return NumpyDataset(X, y, w, ids, n_tasks=y.shape[1])


class Shard(object):

  def __init__(self, X, y, w, ids):
    self.X = X
    self.y = y
    self.w = w
    self.ids = ids


class DiskDataset(Dataset):
  """
  A Dataset that is stored as a set of files on disk.
@@ -1461,9 +1470,7 @@ class DiskDataset(Dataset):
      raise ValueError("This method requires PyTorch to be installed.")

    pytorch_ds = TorchDiskDataset(
        disk_dataset=self,
        epochs=epochs,
        deterministic=deterministic)
        disk_dataset=self, epochs=epochs, deterministic=deterministic)
    return pytorch_ds

  @staticmethod
@@ -1711,14 +1718,6 @@ class DiskDataset(Dataset):
  def get_shard(self, i: int) -> Batch:
    """Retrieves data for the i-th shard from disk."""

    class Shard(object):

      def __init__(self, X, y, w, ids):
        self.X = X
        self.y = y
        self.w = w
        self.ids = ids

    # See if we have a cached copy of this shard.
    if self._cached_shards is None:
      self._cached_shards = [None] * self.get_number_shards()
+15 −15
Original line number Diff line number Diff line
@@ -4,10 +4,11 @@ import multiprocessing
import numpy as np
import torch

import deepchem as dc
from deepchem.data.datasets import pad_batch
from deepchem.data.data_loader import ImageLoader


class TorchNumpyDataset(torch.utils.data.IterableDataset):
class TorchNumpyDataset(torch.utils.data.IterableDataset):  # type: ignore

  def __init__(self, X, y, w, ids, n_samples, epochs, deterministic):
    self._X = X
@@ -36,7 +37,7 @@ class TorchNumpyDataset(torch.utils.data.IterableDataset):
        yield (self._X[i], self._y[i], self._w[i], self._ids[i])


class TorchDiskDataset(torch.utils.data.IterableDataset):
class TorchDiskDataset(torch.utils.data.IterableDataset):  # type: ignore

  def __init__(self, disk_dataset, epochs, deterministic):
    self.disk_dataset = disk_dataset
@@ -161,8 +162,8 @@ class TorchDiskDataset(torch.utils.data.IterableDataset):

              # (ytz): this skips everything except possibly the last shard
              if pad_batches:
                (X_b, y_b, w_b, ids_b) = dc.data.datasets.pad_batch(
                    shard_batch_size, X_b, y_b, w_b, ids_b)
                (X_b, y_b, w_b, ids_b) = pad_batch(shard_batch_size, X_b, y_b,
                                                   w_b, ids_b)

              yield X_b, y_b, w_b, ids_b
              cur_global_batch += 1
@@ -172,13 +173,7 @@ class TorchDiskDataset(torch.utils.data.IterableDataset):
    return iterate(self.disk_dataset, batch_size, epochs)


def get_image(array, index):
  if isinstance(array, np.ndarray):
    return array[index]
  return dc.data.ImageLoader.load_img([array[index]])[0]


class TorchImageDataset(torch.utils.data.IterableDataset):
class TorchImageDataset(torch.utils.data.IterableDataset):  # type: ignore

  def __init__(self, X, y, w, ids, n_samples, epochs, deterministic):
    self._X = X
@@ -204,5 +199,10 @@ class TorchImageDataset(torch.utils.data.IterableDataset):
      else:
        order = first_sample + np.random.permutation(last_sample - first_sample)
      for i in order:
        yield (get_image(self._X, i), get_image(self._y, i), self._w[i],
               self._ids[i])
        yield (self._get_image(self._X, i), self._get_image(self._y, i),
               self._w[i], self._ids[i])

  def _get_image(self, array, index):
    if isinstance(array, np.ndarray):
      return array[index]
    return ImageLoader.load_img([array[index]])[0]
+2 −10
Original line number Diff line number Diff line
@@ -8,14 +8,9 @@ __license__ = "MIT"
import random
import math
import unittest
import tempfile
import os
import shutil
import numpy as np
import deepchem as dc
import tensorflow as tf
import pandas as pd
from tensorflow.python.framework import test_util

try:
  import torch
@@ -29,7 +24,6 @@ def load_solubility_data():
  current_dir = os.path.dirname(os.path.abspath(__file__))
  featurizer = dc.feat.CircularFingerprint(size=1024)
  tasks = ["log-solubility"]
  task_type = "regression"
  input_file = os.path.join(current_dir, "../../models/tests/example.csv")
  loader = dc.data.CSVLoader(
      tasks=tasks, smiles_field="smiles", featurizer=featurizer)
@@ -111,7 +105,6 @@ def test_pad_features():
  """Test that pad_features pads features correctly."""
  batch_size = 100
  num_features = 10
  num_tasks = 5

  # Test cases where n_samples < 2*n_samples < batch_size
  n_samples = 29
@@ -306,7 +299,6 @@ def test_select():

def test_complete_shuffle():
  shard_sizes = [1, 2, 3, 4, 5]
  batch_size = 10

  all_Xs, all_ys, all_ws, all_ids = [], [], [], []

@@ -550,7 +542,7 @@ def test_disk_iterate_y_w_None():
  shard_sizes = [21, 11, 41, 21, 51]
  batch_size = 10

  all_Xs, all_ys, all_ws, all_ids = [], [], [], []
  all_Xs, all_ids = [], []

  def shard_generator():
    for sz in shard_sizes:
@@ -839,7 +831,7 @@ def test_to_str():
  assert str(dataset) == ref_str


class TestDatasets(test_util.TensorFlowTestCase):
class TestDatasets(unittest.TestCase):
  """
  Test basic top-level API for dataset objects.
  """