Commit c85740e2 authored by nd-02110114's avatar nd-02110114
Browse files

👌 fix for review

parent cf797e14
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -339,11 +339,11 @@ class Dataset(object):

  def __repr__(self) -> str:
    """Convert self to REPL print representation."""
    threshold = 10
    threshold = dc.utils.get_print_threshold()
    task_str = np.array2string(
        np.array(self.get_task_names()), threshold=threshold)
    X_shape, y_shape, w_shape, _ = self.get_shape()
    if self.__len__() < 1000:
    if self.__len__() < dc.utils.get_max_print_size():
      id_str = np.array2string(self.ids, threshold=threshold)
      return "<%s X.shape: %s, y.shape: %s, w.shape: %s, ids: %s, task_names: %s>" % (
          self.__class__.__name__, str(X_shape), str(y_shape), str(w_shape),
+10 −0
Original line number Diff line number Diff line
@@ -782,16 +782,26 @@ def test_to_str():
  ref_str = '<NumpyDataset X.shape: (5, 3), y.shape: (5,), w.shape: (5,), ids: [0 1 2 3 4], task_names: [0]>'
  assert str(dataset) == ref_str

  # Test id shrinkage
  dc.utils.set_print_threshold(10)
  dataset = dc.data.NumpyDataset(
      X=np.random.rand(50, 3), y=np.random.rand(50,), ids=np.arange(50))
  ref_str = '<NumpyDataset X.shape: (50, 3), y.shape: (50,), w.shape: (50,), ids: [0 1 2 ... 47 48 49], task_names: [0]>'
  assert str(dataset) == ref_str

  # Test task shrinkage
  dataset = dc.data.NumpyDataset(
      X=np.random.rand(50, 3), y=np.random.rand(50, 20), ids=np.arange(50))
  ref_str = '<NumpyDataset X.shape: (50, 3), y.shape: (50, 20), w.shape: (50, 1), ids: [0 1 2 ... 47 48 49], task_names: [ 0  1  2 ... 17 18 19]>'
  assert str(dataset) == ref_str

  # Test max print size
  dc.utils.set_max_print_size(25)
  dataset = dc.data.NumpyDataset(
      X=np.random.rand(50, 3), y=np.random.rand(50,), ids=np.arange(50))
  ref_str = '<NumpyDataset X.shape: (50, 3), y.shape: (50,), w.shape: (50,), task_names: [0]>'
  assert str(dataset) == ref_str


class TestDatasets(unittest.TestCase):
  """
+88 −0
Original line number Diff line number Diff line
"""
Miscellaneous utility functions.
"""
# flake8: noqa
from deepchem.utils.conformers import ConformerGenerator
from deepchem.utils.evaluate import relative_difference
from deepchem.utils.evaluate import Evaluator
from deepchem.utils.evaluate import GeneratorEvaluator

from deepchem.utils.coordinate_box_utils import CoordinateBox
from deepchem.utils.coordinate_box_utils import intersect_interval
from deepchem.utils.coordinate_box_utils import intersection
from deepchem.utils.coordinate_box_utils import union
from deepchem.utils.coordinate_box_utils import merge_overlapping_boxes
from deepchem.utils.coordinate_box_utils import get_face_boxes

from deepchem.utils.data_utils import pad_array
from deepchem.utils.data_utils import get_data_dir
from deepchem.utils.data_utils import download_url
from deepchem.utils.data_utils import untargz_file
from deepchem.utils.data_utils import unzip_file
from deepchem.utils.data_utils import load_image_files
from deepchem.utils.data_utils import load_sdf_files
from deepchem.utils.data_utils import load_csv_files
from deepchem.utils.data_utils import load_json_files
from deepchem.utils.data_utils import load_pickle_files
from deepchem.utils.data_utils import load_data
from deepchem.utils.data_utils import save_to_disk
from deepchem.utils.data_utils import load_from_disk
from deepchem.utils.data_utils import save_dataset_to_disk
from deepchem.utils.data_utils import load_dataset_from_disk

from deepchem.utils.debug_utils import get_print_threshold
from deepchem.utils.debug_utils import set_print_threshold
from deepchem.utils.debug_utils import get_max_print_size
from deepchem.utils.debug_utils import set_max_print_size

from deepchem.utils.fragment_utils import AtomShim
from deepchem.utils.fragment_utils import MolecularFragment
from deepchem.utils.fragment_utils import get_partial_charge
from deepchem.utils.fragment_utils import merge_molecular_fragments
from deepchem.utils.fragment_utils import get_mol_subset
from deepchem.utils.fragment_utils import strip_hydrogens
from deepchem.utils.fragment_utils import get_contact_atom_indices
from deepchem.utils.fragment_utils import reduce_molecular_complex_to_contacts

from deepchem.utils.genomics_utils import seq_one_hot_encode
from deepchem.utils.genomics_utils import encode_bio_sequence

from deepchem.utils.geometry_utils import unit_vector
from deepchem.utils.geometry_utils import angle_between
from deepchem.utils.geometry_utils import generate_random_unit_vector
from deepchem.utils.geometry_utils import generate_random_rotation_matrix
from deepchem.utils.geometry_utils import is_angle_within_cutoff
from deepchem.utils.geometry_utils import compute_centroid
from deepchem.utils.geometry_utils import subtract_centroid
from deepchem.utils.geometry_utils import compute_protein_range
from deepchem.utils.geometry_utils import compute_pairwise_distances

from deepchem.utils.hash_utils import hash_ecfp
from deepchem.utils.hash_utils import hash_ecfp_pair
from deepchem.utils.hash_utils import vectorize

from deepchem.utils.molecule_feature_utils import one_hot_encode
from deepchem.utils.molecule_feature_utils import get_atom_type_one_hot
from deepchem.utils.molecule_feature_utils import construct_hydrogen_bonding_info
from deepchem.utils.molecule_feature_utils import get_atom_hydrogen_bonding_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_is_in_aromatic_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_hybridization_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_total_num_Hs_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_chirality_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_formal_charge
from deepchem.utils.molecule_feature_utils import get_atom_partial_charge
from deepchem.utils.molecule_feature_utils import get_atom_ring_size_one_hot
from deepchem.utils.molecule_feature_utils import get_atom_total_degree_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_type_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_is_in_same_ring_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_is_conjugated_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_stereo_one_hot
from deepchem.utils.molecule_feature_utils import get_bond_graph_distance_one_hot

from deepchem.utils.pdbqt_utils import pdbqt_to_pdb
from deepchem.utils.pdbqt_utils import convert_protein_to_pdbqt
from deepchem.utils.pdbqt_utils import convert_mol_to_pdbqt

from deepchem.utils.vina_utils import write_vina_conf
from deepchem.utils.vina_utils import load_docked_ligands

from deepchem.utils.voxel_utils import convert_atom_to_voxel
from deepchem.utils.voxel_utils import convert_atom_pair_to_voxel
from deepchem.utils.voxel_utils import voxelize


class ScaffoldGenerator(object):
+124 −119
Original line number Diff line number Diff line
@@ -143,19 +143,23 @@ def unzip_file(file: str,
    zip_ref.extractall(dest_dir)


def load_image_files(image_files: List[str]) -> np.ndarray:
def load_image_files(input_files: List[str]) -> np.ndarray:
  """Loads a set of images from disk.

  Parameters
  ----------
  image_files: List[str]
    List of image filenames to load.
  input_files: List[str]
    List of image filenames.

  Returns
  -------
  np.ndarray
    A numpy array that contains loaded images. The shape is, `(N,...)`.

  Notes
  -----
  This method requires Pillow to be installed.
  The supported file types are PNG and TIF.
  """
  try:
    from PIL import Image
@@ -163,18 +167,18 @@ def load_image_files(image_files: List[str]) -> np.ndarray:
    raise ValueError("This function requires Pillow to be installed.")

  images = []
  for image_file in image_files:
    _, extension = os.path.splitext(image_file)
  for input_file in input_files:
    _, extension = os.path.splitext(input_file)
    extension = extension.lower()
    if extension == ".png":
      image = np.array(Image.open(image_file))
      image = np.array(Image.open(input_file))
      images.append(image)
    elif extension == ".tif":
      im = Image.open(image_file)
      im = Image.open(input_file)
      imarray = np.array(im)
      images.append(imarray)
    else:
      raise ValueError("Unsupported image filetype for %s" % image_file)
      raise ValueError("Unsupported image filetype for %s" % input_file)
  return np.array(images)


@@ -252,13 +256,13 @@ def load_sdf_files(input_files: List[str],
      df_rows = []


def load_csv_files(filenames: List[str],
def load_csv_files(input_files: List[str],
                   shard_size: Optional[int] = None) -> Iterator[pd.DataFrame]:
  """Load data as pandas dataframe from CSV files.

  Parameters
  ----------
  filenames: List[str]
  input_files: List[str]
    List of filenames
  shard_size: int, default None
    The shard size to yield at one time.
@@ -270,12 +274,12 @@ def load_csv_files(filenames: List[str],
  """
  # First line of user-specified CSV *must* be header.
  shard_num = 1
  for filename in filenames:
  for input_file in input_files:
    if shard_size is None:
      yield pd.read_csv(filename)
      yield pd.read_csv(input_file)
    else:
      logger.info("About to start loading CSV from %s" % filename)
      for df in pd.read_csv(filename, chunksize=shard_size):
      logger.info("About to start loading CSV from %s" % input_file)
      for df in pd.read_csv(input_file, chunksize=shard_size):
        logger.info(
            "Loading shard %d of size %s." % (shard_num, str(shard_size)))
        df = df.replace(np.nan, str(""), regex=True)
@@ -283,13 +287,13 @@ def load_csv_files(filenames: List[str],
        yield df


def load_json_files(filenames: List[str],
def load_json_files(input_files: List[str],
                    shard_size: Optional[int] = None) -> Iterator[pd.DataFrame]:
  """Load data as pandas dataframe.

  Parameters
  ----------
  filenames: List[str]
  input_files: List[str]
    List of json filenames.
  shard_size: int, default None
    Chunksize for reading json files.
@@ -305,13 +309,13 @@ def load_json_files(filenames: List[str],
  must be originally saved with ``df.to_json('filename.json', orient='records', lines=True)``
  """
  shard_num = 1
  for filename in filenames:
  for input_file in input_files:
    if shard_size is None:
      yield pd.read_json(filename, orient='records', lines=True)
      yield pd.read_json(input_file, orient='records', lines=True)
    else:
      logger.info("About to start loading json from %s." % filename)
      logger.info("About to start loading json from %s." % input_file)
      for df in pd.read_json(
          filename, orient='records', chunksize=shard_size, lines=True):
          input_file, orient='records', chunksize=shard_size, lines=True):
        logger.info(
            "Loading shard %d of size %s." % (shard_num, str(shard_size)))
        df = df.replace(np.nan, str(""), regex=True)
@@ -319,6 +323,87 @@ def load_json_files(filenames: List[str],
        yield df


def load_pickle_files(input_files: List[str]) -> Iterator[Any]:
  """Load dataset from pickle file.

  Parameters
  ----------
  input_files: List[str]
    The list of filenames of pickle file. This function can load from
    gzipped pickle file like `XXXX.pkl.gz`.

  Returns
  -------
  Iterator[Any]
    Generator which yields the objects which is loaded from each pickle file.
  """
  for input_file in input_files:
    if ".gz" in input_file:
      with gzip.open(input_file, "rb") as f:
        df = pickle.load(f)
    else:
      with open(input_file, "rb") as f:
        df = pickle.load(f)
    yield df


def load_data(input_files: List[str],
              shard_size: Optional[int] = None) -> Iterator[Any]:
  """Loads data from files.

  Parameters
  ----------
  input_files: List[str]
    List of filenames.
  shard_size: int, default None
    Size of shard to yield

  Returns
  -------
  Iterator[Any]
    Iterator which iterates over provided files.

  Notes
  -----
  The supported file types are SDF, CSV and Pickle.
  """
  if len(input_files) == 0:
    raise ValueError("The length of `filenames` must be more than 1.")

  file_type = _get_file_type(input_files[0])
  if file_type == "sdf":
    if shard_size is not None:
      logger.info("Ignoring shard_size for sdf input.")
    for value in load_sdf_files(input_files):
      yield value
  elif file_type == "csv":
    for value in load_csv_files(input_files, shard_size):
      yield value
  elif file_type == "pickle":
    if shard_size is not None:
      logger.info("Ignoring shard_size for pickle input.")
    for value in load_pickle_files(input_files):
      yield value


def _get_file_type(input_file: str) -> str:
  """Get type of input file. Must be csv/pkl/sdf/joblib file."""
  filename, file_extension = os.path.splitext(input_file)
  # If gzipped, need to compute extension again
  if file_extension == ".gz":
    filename, file_extension = os.path.splitext(filename)
  if file_extension == ".csv":
    return "csv"
  elif file_extension == ".pkl":
    return "pickle"
  elif file_extension == ".joblib":
    return "joblib"
  elif file_extension == ".sdf":
    return "sdf"
  else:
    raise ValueError("Unrecognized extension %s" % file_extension)


def save_to_disk(dataset: Any, filename: str, compress: int = 3):
  """Save a dataset to file.

@@ -340,63 +425,23 @@ def save_to_disk(dataset: Any, filename: str, compress: int = 3):


def load_from_disk(filename: str) -> Any:
  """Load a dataset from file."""
  name = filename
  if os.path.splitext(name)[1] == ".gz":
    name = os.path.splitext(name)[0]
  extension = os.path.splitext(name)[1]
  if extension == ".pkl":
    return load_pickle_from_disk(filename)
  elif extension == ".joblib":
    return joblib.load(filename)
  elif extension == ".csv":
    # First line of user-specified CSV *must* be header.
    df = pd.read_csv(filename, header=0)
    df = df.replace(np.nan, str(""), regex=True)
    return df
  elif extension == ".npy":
    return np.load(filename, allow_pickle=True)
  else:
    raise ValueError("Unrecognized filetype for %s" % filename)


def load_sharded_csv(filenames) -> pd.DataFrame:
  """Load a dataset from multiple files. Each file MUST have same column headers"""
  dataframes = []
  for name in filenames:
    placeholder_name = name
    if os.path.splitext(name)[1] == ".gz":
      name = os.path.splitext(name)[0]
    if os.path.splitext(name)[1] == ".csv":
      # First line of user-specified CSV *must* be header.
      df = pd.read_csv(placeholder_name, header=0)
      df = df.replace(np.nan, str(""), regex=True)
      dataframes.append(df)
    else:
      raise ValueError("Unrecognized filetype for %s" % name)

  # combine dataframes
  combined_df = dataframes[0]
  for i in range(0, len(dataframes) - 1):
    combined_df = combined_df.append(dataframes[i + 1])
  combined_df = combined_df.reset_index(drop=True)
  return combined_df


def load_pickle_from_disk(filename: str) -> Any:
  """Load dataset from pickle file.
  """Load a dataset from file.

  Parameters
  ----------
  filename: str
    A filename of pickle file. This function can load from
    gzipped pickle file like `XXXX.pkl.gz`.
    A filename you want to load data.

  Returns
  -------
  Any
    A loaded object from pickle file.
    A loaded object from file.
  """
  name = filename
  if os.path.splitext(name)[1] == ".gz":
    name = os.path.splitext(name)[0]
  extension = os.path.splitext(name)[1]
  if extension == ".pkl":
    if ".gz" in filename:
      with gzip.open(filename, "rb") as f:
        df = pickle.load(f)
@@ -404,6 +449,17 @@ def load_pickle_from_disk(filename: str) -> Any:
      with open(filename, "rb") as f:
        df = pickle.load(f)
    return df
  elif extension == ".joblib":
    return joblib.load(filename)
  elif extension == ".csv":
    # First line of user-specified CSV *must* be header.
    df = pd.read_csv(filename, header=0)
    df = df.replace(np.nan, str(""), regex=True)
    return df
  elif extension == ".npy":
    return np.load(filename, allow_pickle=True)
  else:
    raise ValueError("Unrecognized filetype for %s" % filename)


def load_dataset_from_disk(save_dir: str) -> Tuple[bool, Optional[Tuple[
@@ -503,54 +559,3 @@ def save_dataset_to_disk(
  with open(os.path.join(save_dir, "transformers.pkl"), 'wb') as f:
    pickle.dump(transformers, f)
  return None


def get_input_type(input_file: str) -> str:
  """Get type of input file. Must be csv/pkl.gz/sdf file."""
  filename, file_extension = os.path.splitext(input_file)
  # If gzipped, need to compute extension again
  if file_extension == ".gz":
    filename, file_extension = os.path.splitext(filename)
  if file_extension == ".csv":
    return "csv"
  elif file_extension == ".pkl":
    return "pandas-pickle"
  elif file_extension == ".joblib":
    return "pandas-joblib"
  elif file_extension == ".sdf":
    return "sdf"
  else:
    raise ValueError("Unrecognized extension %s" % file_extension)


def load_data(input_files: List[str],
              shard_size: Optional[int] = None) -> Iterator[Any]:
  """Loads data from disk.

  For CSV files, supports sharded loading for large files.

  Parameters
  ----------
  input_files: List[str]
    List of filenames.
  shard_size: int, default None
    Size of shard to yield

  Returns
  -------
  Iterator which iterates over provided files.
  """
  if not len(input_files):
    return
  input_type = get_input_type(input_files[0])
  if input_type == "sdf":
    if shard_size is not None:
      logger.info("Ignoring shard_size for sdf input.")
    for value in load_sdf_files(input_files):
      yield value
  elif input_type == "csv":
    for value in load_csv_files(input_files, shard_size):
      yield value
  elif input_type == "pandas-pickle":
    for input_file in input_files:
      yield load_pickle_from_disk(input_file)
+69 −0
Original line number Diff line number Diff line
# The number of elements to print for dataset ids/tasks
_print_threshold = 10


def get_print_threshold() -> int:
  """Return the printing threshold for datasets.

  The print threshold is the number of elements from ids/tasks to
  print when printing representations of `Dataset` objects.

  Returns
  ----------
  threshold: int
    Number of elements that will be printed
  """
  return _print_threshold


def set_print_threshold(threshold: int):
  """Set print threshold

  The print threshold is the number of elements from ids/tasks to
  print when printing representations of `Dataset` objects.

  Parameters
  ----------
  threshold: int
    Number of elements to print.
  """
  global _print_threshold
  _print_threshold = threshold


# If a dataset contains more than this number of elements, it won't
# print any dataset ids
_max_print_size = 1000


def get_max_print_size() -> int:
  """Return the max print size for a datset.

  If a dataset is large, printing `self.ids` as part of a string
  representation can be very slow. This field controls the maximum
  size for a dataset before ids are no longer printed.

  Returns
  -------
  max_print_size: int
    Maximum length of a dataset for ids to be printed in string
    representation.
  """
  return _max_print_size


def set_max_print_size(max_print_size: int):
  """Set max_print_size

  If a dataset is large, printing `self.ids` as part of a string
  representation can be very slow. This field controls the maximum
  size for a dataset before ids are no longer printed.

  Parameters
  ----------
  max_print_size: int
    Maximum length of a dataset for ids to be printed in string
    representation.
  """
  global _max_print_size
  _max_print_size = max_print_size
Loading