Unverified Commit 8e551190 authored by Daiki Nishikawa's avatar Daiki Nishikawa Committed by GitHub
Browse files

Merge pull request #2143 from nd-02110114/update-splitter

Refactor splitter
parents 1cbafde0 82c6da8a
Loading
Loading
Loading
Loading
+20 −6
Original line number Diff line number Diff line
@@ -420,6 +420,19 @@ class Dataset(object):
    """
    raise NotImplementedError()

  def select(self, indices: Sequence[int],
             select_dir: Optional[str] = None) -> "Dataset":
    """Creates a new dataset from a selection of indices from self.

    Parameters
    ----------
    indices: Sequence
      List of indices to select.
    select_dir: str, optional (default None)
      Path to new directory that the selected indices will be copied to.
    """
    raise NotImplementedError()

  def get_statistics(self, X_stats: bool = True,
                     y_stats: bool = True) -> Tuple[float, ...]:
    """Compute and return statistics of this dataset.
@@ -1868,13 +1881,13 @@ class DiskDataset(Dataset):
        tasks=tasks)

  @staticmethod
  def merge(datasets: Iterable["DiskDataset"],
  def merge(datasets: Iterable["Dataset"],
            merge_dir: Optional[str] = None) -> "DiskDataset":
    """Merges provided datasets into a merged dataset.

    Parameters
    ----------
    datasets: Iterable[DiskDataset]
    datasets: Iterable[Dataset]
      List of datasets to merge.
    merge_dir: str, optional (default None)
      The new directory path to store the merged DiskDataset.
@@ -1897,7 +1910,7 @@ class DiskDataset(Dataset):
    tasks = []
    for dataset in datasets:
      try:
        tasks.append(dataset.tasks)
        tasks.append(dataset.tasks)  # type: ignore
      except AttributeError:
        pass
    if tasks:
@@ -2033,7 +2046,7 @@ class DiskDataset(Dataset):

  def shuffle_each_shard(self,
                         shard_basenames: Optional[List[str]] = None) -> None:
    """Shuffles elements within each shard of the datset.
    """Shuffles elements within each shard of the dataset.

    Parameters
    ----------
@@ -2282,8 +2295,9 @@ class DiskDataset(Dataset):

    Returns
    -------
    DiskDataset
      A Dataset containing the selected samples
    Dataset
      A dataset containing the selected samples. The default dataset is `DiskDataset`.
      If `output_numpy_dataset` is True, the dataset is `NumpyDataset`.
    """
    if output_numpy_dataset and (select_dir is not None or
                                 select_shard_size is not None):
+1 −2
Original line number Diff line number Diff line
@@ -41,7 +41,6 @@ def test_singletask_sklearn_rf_ECFP_regression_API():

def test_singletask_sklearn_rf_user_specified_regression_API():
  """Test of singletask RF USF regression API."""
  splittype = "specified"
  featurizer = dc.feat.UserDefinedFeaturizer(
      ["user-specified1", "user-specified2"])
  tasks = ["log-solubility"]
@@ -51,7 +50,7 @@ def test_singletask_sklearn_rf_user_specified_regression_API():
      tasks=tasks, smiles_field="smiles", featurizer=featurizer)
  dataset = loader.create_dataset(input_file)

  splitter = dc.splits.SpecifiedSplitter(input_file, "split")
  splitter = dc.splits.RandomSplitter()
  train_dataset, test_dataset = splitter.train_test_split(dataset)

  transformers = [
+17 −5
Original line number Diff line number Diff line
"""
Gathers all splitters in one place for convenient imports
"""
# TODO(rbharath): Get rid of * import
from deepchem.splits.splitters import *
from deepchem.splits.splitters import ScaffoldSplitter
from deepchem.splits.splitters import SpecifiedSplitter
# flake8: noqa

# basic splitter
from deepchem.splits.splitters import Splitter
from deepchem.splits.splitters import RandomSplitter
from deepchem.splits.splitters import RandomStratifiedSplitter
from deepchem.splits.splitters import RandomGroupSplitter
from deepchem.splits.splitters import SingletaskStratifiedSplitter
from deepchem.splits.splitters import IndexSplitter
from deepchem.splits.splitters import IndiceSplitter
from deepchem.splits.splitters import RandomGroupSplitter

# molecule splitter
from deepchem.splits.splitters import ScaffoldSplitter
from deepchem.splits.splitters import MolecularWeightSplitter
from deepchem.splits.splitters import MaxMinSplitter
from deepchem.splits.splitters import FingerprintSplitter
from deepchem.splits.splitters import ButinaSplitter

# other splitter
from deepchem.splits.task_splitter import merge_fold_datasets
from deepchem.splits.task_splitter import TaskSplitter
+720 −496

File changed.

Preview size limit exceeded, changes collapsed.

+0 −4
Original line number Diff line number Diff line
@@ -5,11 +5,8 @@ __author__ = "Bharath Ramsundar"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "MIT"

import tempfile
import numpy as np
from deepchem.utils import ScaffoldGenerator
from deepchem.data import NumpyDataset
from deepchem.utils.save import load_data
from deepchem.splits import Splitter


@@ -74,7 +71,6 @@ class TaskSplitter(Splitter):
    n_tasks = len(dataset.get_task_names())
    n_train = int(np.round(frac_train * n_tasks))
    n_valid = int(np.round(frac_valid * n_tasks))
    n_test = n_tasks - n_train - n_valid

    X, y, w, ids = dataset.X, dataset.y, dataset.w, dataset.ids

Loading