Unverified Commit e9226da3 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2105 from nd-02110114/update-data-2

Update dataset docstrings
parents b8277808 2a8abbb3
Loading
Loading
Loading
Loading
+524 −308

File changed.

Preview size limit exceeded, changes collapsed.

+5 −29
Original line number Diff line number Diff line
from typing import List, Union
import numpy as np
import torch

from deepchem.utils.save import load_image_files
from deepchem.data.datasets import NumpyDataset, DiskDataset, ImageDataset


@@ -155,35 +153,13 @@ class _TorchImageDataset(torch.utils.data.IterableDataset): # type: ignore
        order = random.permutation(n_samples)[first_sample:last_sample]
      if self.batch_size is None:
        for i in order:
          yield (self._get_image(self.image_dataset._X, i),
                 self._get_image(self.image_dataset._y, i),
          yield (self.image_dataset._get_image(self.image_dataset._X, i),
                 self.image_dataset._get_image(self.image_dataset._y, i),
                 self.image_dataset._w[i], self.image_dataset._ids[i])
      else:
        for i in range(0, len(order), self.batch_size):
          indices = order[i:i + self.batch_size]
          yield (self._get_image(self.image_dataset._X, indices),
                 self._get_image(self.image_dataset._y,
                                 indices), self.image_dataset._w[indices],
          yield (self.image_dataset._get_image(self.image_dataset._X, indices),
                 self.image_dataset._get_image(self.image_dataset._y, indices),
                 self.image_dataset._w[indices],
                 self.image_dataset._ids[indices])

  def _get_image(self, array: Union[np.ndarray, List[str]],
                 indices: int) -> np.ndarray:
    """Method for loading an image

    Parameters
    ----------
    array: Union[np.ndarray, List[str]]
      A numpy array which contains images or List of image filenames
    index: int
      Index you want to get the image

    Returns
    -------
    np.ndarray
      Loaded image
    """
    if isinstance(array, np.ndarray):
      return array[indices]
    if isinstance(indices, np.ndarray):
      return load_image_files([array[i] for i in indices])
    return load_image_files([array[indices]])[0]
+2 −2
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ import numpy as np
import deepchem as dc

try:
  import torch
  import torch  # noqa
  PYTORCH_IMPORT_FAILED = False
except ImportError:
  PYTORCH_IMPORT_FAILED = True
@@ -739,7 +739,7 @@ def _validate_pytorch_dataset(dataset):

  # Test iterating with multiple workers.

  import torch
  import torch  # noqa
  ds = dataset.make_pytorch_dataset(epochs=2, deterministic=False)
  loader = torch.utils.data.DataLoader(ds, num_workers=3)
  id_count = dict((id, 0) for id in ids)
+0 −8
Original line number Diff line number Diff line
import os
import shutil
import logging
import unittest
import tempfile
import deepchem as dc
import numpy as np
from sklearn.ensemble import RandomForestClassifier

logger = logging.getLogger(__name__)

@@ -19,10 +15,6 @@ class TestDrop(unittest.TestCase):

  def test_drop(self):
    """Test on dataset where RDKit fails on some strings."""
    # Set some global variables up top
    reload = True
    len_full = 25

    current_dir = os.path.dirname(os.path.realpath(__file__))
    logger.info("About to load emols dataset.")
    dataset_file = os.path.join(current_dir, "mini_emols.csv")
+1 −12
Original line number Diff line number Diff line
@@ -6,14 +6,8 @@ import tempfile

def test_make_legacy_dataset_from_numpy():
  """Test that legacy DiskDataset objects can be constructed."""
  # This is the shape of legacy_data
  num_datapoints = 100
  num_features = 10
  num_tasks = 10

  current_dir = os.path.dirname(os.path.abspath(__file__))
  # legacy_dataset is a dataset in the legacy format kept around for testing
  # purposes.
  # legacy_dataset is a dataset in the legacy format kept around for testing purposes.
  data_dir = os.path.join(current_dir, "legacy_dataset")
  dataset = dc.data.DiskDataset(data_dir)
  assert dataset.legacy_metadata
@@ -29,11 +23,6 @@ def test_make_legacy_dataset_from_numpy():

def test_reshard():
  """Test that resharding updates legacy datasets."""
  # This is the shape of legacy_data_reshard
  num_datapoints = 100
  num_features = 10
  num_tasks = 10

  # legacy_dataset_reshard is a sharded dataset in the legacy format kept
  # around for testing resharding.
  current_dir = os.path.dirname(os.path.abspath(__file__))
Loading