Commit 89d8cd31 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changes

parent 5d64e281
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -976,7 +976,8 @@ class InMemoryLoader(DataLoader):
    Parameters
    ----------
    inputs: Sequence[Any]
      List of inputs to process. Entries can be filenames or arbitrary objects.
      List of inputs to process. Entries can be arbitrary objects so long as
      they are understood by `self.featurizer`
    data_dir: str, optional
      Directory to store featurized dataset.
    shard_size: int, optional
+6 −1
Original line number Diff line number Diff line
@@ -1943,7 +1943,12 @@ class ImageDataset(Dataset):
    self._X_shape = self._find_array_shape(X)
    self._y_shape = self._find_array_shape(y)
    if w is None:
      if len(self._y_shape) == 1:
      if len(self._y_shape) == 0:
        # Case n_samples should be 1
        if n_samples != 1:
          raise ValueError("y can only be a scalar if n_samples == 1")
        w = np.ones_like(y)
      elif len(self._y_shape) == 1:
        w = np.ones(self._y_shape[0], np.float32)
      else:
        w = np.ones((self._y_shape[0], 1), np.float32)
+10 −6
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ import os
import deepchem
import warnings
import logging
from typing import List, Optional, Iterator
from typing import List, Optional, Iterator, Any
from deepchem.utils.genomics import encode_bio_sequence as encode_sequence, encode_fasta_sequence as fasta_sequence, seq_one_hot_encode as seq_one_hotencode

logger = logging.getLogger(__name__)
@@ -45,7 +45,8 @@ def get_input_type(input_file):
    raise ValueError("Unrecognized extension %s" % file_extension)


def load_data(input_files, shard_size=None):
def load_data(input_files: List[str],
              shard_size: Optional[int] = None) -> Iterator[Any]:
  """Loads data from disk.

  For CSV files, supports sharded loading for large files.
@@ -77,7 +78,9 @@ def load_data(input_files, shard_size=None):
      yield load_pickle_from_disk(input_file)


def load_sdf_files(input_files, clean_mols, tasks=[]):
def load_sdf_files(input_files: List[str],
                   clean_mols: bool = True,
                   tasks: List[str] = []) -> List[pd.DataFrame]:
  """Load SDF file into dataframe.

  Parameters
@@ -99,7 +102,7 @@ def load_sdf_files(input_files, clean_mols, tasks=[]):
  -------
  dataframes: list
    This function returns a list of pandas dataframes. Each dataframe will
    columns `('mol_id', 'smiles', 'mol')`.
    contain columns `('mol_id', 'smiles', 'mol')`.
  """
  from rdkit import Chem
  dataframes = []
@@ -130,12 +133,13 @@ def load_sdf_files(input_files, clean_mols, tasks=[]):
  return dataframes


def load_csv_files(filenames, shard_size=None):
def load_csv_files(filenames: List[str],
                   shard_size: Optional[int] = None) -> Iterator[pd.DataFrame]:
  """Load data as pandas dataframe.

  Parameters
  ----------
  input_files: list[str]
  filenames: list[str]
    List of filenames
  shard_size: int, optional (default None) 
    The shard size to yield at one time.