Commit 89d8cd31 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Changes

parent 5d64e281
Loading
Loading
Loading
Loading
+2 −1
Original line number Original line Diff line number Diff line
@@ -976,7 +976,8 @@ class InMemoryLoader(DataLoader):
    Parameters
    Parameters
    ----------
    ----------
    inputs: Sequence[Any]
    inputs: Sequence[Any]
      List of inputs to process. Entries can be filenames or arbitrary objects.
      List of inputs to process. Entries can be arbitrary objects so long as
      they are understood by `self.featurizer`
    data_dir: str, optional
    data_dir: str, optional
      Directory to store featurized dataset.
      Directory to store featurized dataset.
    shard_size: int, optional
    shard_size: int, optional
+6 −1
Original line number Original line Diff line number Diff line
@@ -1943,7 +1943,12 @@ class ImageDataset(Dataset):
    self._X_shape = self._find_array_shape(X)
    self._X_shape = self._find_array_shape(X)
    self._y_shape = self._find_array_shape(y)
    self._y_shape = self._find_array_shape(y)
    if w is None:
    if w is None:
      if len(self._y_shape) == 1:
      if len(self._y_shape) == 0:
        # Case n_samples should be 1
        if n_samples != 1:
          raise ValueError("y can only be a scalar if n_samples == 1")
        w = np.ones_like(y)
      elif len(self._y_shape) == 1:
        w = np.ones(self._y_shape[0], np.float32)
        w = np.ones(self._y_shape[0], np.float32)
      else:
      else:
        w = np.ones((self._y_shape[0], 1), np.float32)
        w = np.ones((self._y_shape[0], 1), np.float32)
+10 −6
Original line number Original line Diff line number Diff line
@@ -11,7 +11,7 @@ import os
import deepchem
import deepchem
import warnings
import warnings
import logging
import logging
from typing import List, Optional, Iterator
from typing import List, Optional, Iterator, Any
from deepchem.utils.genomics import encode_bio_sequence as encode_sequence, encode_fasta_sequence as fasta_sequence, seq_one_hot_encode as seq_one_hotencode
from deepchem.utils.genomics import encode_bio_sequence as encode_sequence, encode_fasta_sequence as fasta_sequence, seq_one_hot_encode as seq_one_hotencode


logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
@@ -45,7 +45,8 @@ def get_input_type(input_file):
    raise ValueError("Unrecognized extension %s" % file_extension)
    raise ValueError("Unrecognized extension %s" % file_extension)




def load_data(input_files, shard_size=None):
def load_data(input_files: List[str],
              shard_size: Optional[int] = None) -> Iterator[Any]:
  """Loads data from disk.
  """Loads data from disk.


  For CSV files, supports sharded loading for large files.
  For CSV files, supports sharded loading for large files.
@@ -77,7 +78,9 @@ def load_data(input_files, shard_size=None):
      yield load_pickle_from_disk(input_file)
      yield load_pickle_from_disk(input_file)




def load_sdf_files(input_files, clean_mols, tasks=[]):
def load_sdf_files(input_files: List[str],
                   clean_mols: bool = True,
                   tasks: List[str] = []) -> List[pd.DataFrame]:
  """Load SDF file into dataframe.
  """Load SDF file into dataframe.


  Parameters
  Parameters
@@ -99,7 +102,7 @@ def load_sdf_files(input_files, clean_mols, tasks=[]):
  -------
  -------
  dataframes: list
  dataframes: list
    This function returns a list of pandas dataframes. Each dataframe will
    This function returns a list of pandas dataframes. Each dataframe will
    columns `('mol_id', 'smiles', 'mol')`.
    contain columns `('mol_id', 'smiles', 'mol')`.
  """
  """
  from rdkit import Chem
  from rdkit import Chem
  dataframes = []
  dataframes = []
@@ -130,12 +133,13 @@ def load_sdf_files(input_files, clean_mols, tasks=[]):
  return dataframes
  return dataframes




def load_csv_files(filenames, shard_size=None):
def load_csv_files(filenames: List[str],
                   shard_size: Optional[int] = None) -> Iterator[pd.DataFrame]:
  """Load data as pandas dataframe.
  """Load data as pandas dataframe.


  Parameters
  Parameters
  ----------
  ----------
  input_files: list[str]
  filenames: list[str]
    List of filenames
    List of filenames
  shard_size: int, optional (default None) 
  shard_size: int, optional (default None) 
    The shard size to yield at one time.
    The shard size to yield at one time.