Commit cf7eef84 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #400 from joegomes/sss

SingletaskStratifiedSplit + QM7b Example
parents e5a1f49d c4ccffd0
Loading
Loading
Loading
Loading
+128 −0
Original line number Diff line number Diff line
@@ -255,6 +255,134 @@ class RandomStratifiedSplitter(Splitter):
      fold_datasets.append(fold_dataset)
    return fold_datasets

class SingletaskStratifiedSplitter(Splitter):
  """ 
  Class for doing data splits by stratification on a single task.

  Example:

  >>> n_samples = 100
  >>> n_features = 10
  >>> n_tasks = 10
  >>> X = np.random.rand(n_samples, n_features)
  >>> y = np.random.rang(n_samples, n_tasks)
  >>> w = np.ones_like(y)
  >>> dataset = dc.data.NumpyDataset(X, y, w, ids=None)
  >>> splitter = SingletaskStratifiedSplitter(task_number=5)
  >>> train_dataset, test_dataset = splitter.train_valid_split()

  """

  def __init__(self, task_number=0, verbose=False):
    """
    Creates splitter object.

    Parameters
    ----------
    task_number: int (Optional, Default 0)
      Task number for stratification.
    verbose: bool (Optional, Default False)
      Controls logging frequency.
    """
    self.task_number = task_number
    self.verbose = verbose

  def k_fold_split(self, dataset, k, seed=None, log_every_n=None):
    """
    Splits compounds into k-folds using stratified sampling.
    Overriding base class k_fold_split.

    Parameters
    ----------
    dataset: dc.data.Dataset object
      Dataset.
    k: int
      Number of folds.
    seed: int (Optional, Default None)
      Random seed.
    log_every_n: int (Optional, Default None)
      Log every n examples (not currently used).

    Returns
    -------
    fold_datasets: List
      List containing dc.data.Dataset objects
    """
    log("Computing K-fold split", self.verbose)
    if directories is None:
      directories = [tempfile.mkdtemp() for _ in range(k)]
    else:
      assert len(directories) == k

    y_s = dataset.y[:, self.task_number]
    sortidx = np.argsort(y_s)
    sortidx_list = np.array_split(sortidx, k)

    fold_datasets = []
    for fold in range(k):
      fold_dir = directories[fold]
      fold_ind = sortidx_list[fold]
      fold_dataset = dataset.select(fold_ind, fold_dir)
      fold_datasets.append(fold_dataset)
    return fold_datasets

  def split(self, dataset, seed=None, frac_train=.8, frac_valid=.1,
            frac_test=.1, log_every_n=None):
    """
    Splits compounds into train/validation/test using stratified sampling.

    Parameters
    ----------
    dataset: dc.data.Dataset object
      Dataset.
    seed: int (Optional, Default None)
      Random seed.
    frac_train: float (Optional, Default .8)
      Fraction of dataset put into training data.
    frac_valid: float (Optional, Default .1)
      Fraction of dataset put into validation data.
    frac_test: float (Optional, Default .1)
      Fraction of dataset put into test data.
    log_every_n: int (Optional, Default None)
      Log every n examples (not currently used).

    Returns
    -------
    retval: Tuple
      Tuple containing train indices, valid indices, and test indices    
    """
    # JSG Assert that split fractions can be written as proper fractions over 10.
    # This can be generalized in the future with some common demoninator determination.
    # This will work for 80/20 train/test or 80/10/10 train/valid/test (most use cases).
    np.testing.assert_equal(frac_train + frac_valid + frac_test, 1.)
    np.testing.assert_equal(10*frac_train + 10*frac_valid + 10*frac_test, 10.)
    
    if not seed is None:
      np.random.seed(seed)

    y_s = dataset.y[:,self.task_number]
    sortidx = np.argsort(y_s)

    split_cd = 10
    train_cutoff = int(frac_train * split_cd)
    valid_cutoff = int(frac_valid * split_cd) + train_cutoff
    test_cutoff = int(frac_test * split_cd) + valid_cutoff

    train_idx = np.array([])
    valid_idx = np.array([])
    test_idx = np.array([])

    while sortidx.shape[0] >= split_cd:
      sortidx_split, sortidx = np.split(sortidx, [split_cd])
      shuffled = np.random.permutation(range(split_cd))
      train_idx = np.hstack([train_idx, sortidx_split[shuffled[:train_cutoff]]])
      valid_idx = np.hstack([valid_idx, sortidx_split[shuffled[train_cutoff:valid_cutoff]]])
      test_idx = np.hstack([test_idx, sortidx_split[shuffled[valid_cutoff:]]])

    # Append remaining examples to train
    if sortidx.shape[0] > 0: np.hstack([train_idx, sortidx]) 

    return (train_idx, valid_idx, test_idx)     

class MolecularWeightSplitter(Splitter):
  """
+18 −0
Original line number Diff line number Diff line
@@ -72,6 +72,24 @@ class TestSplitters(unittest.TestCase):
    assert len(valid_data) == 1
    assert len(test_data) == 1

  def test_singletask_stratified_split(self):
    """
    Test singletask SingletaskStratifiedSplitter class.
    """
    solubility_dataset = dc.data.tests.load_solubility_data()
    stratified_splitter = dc.splits.ScaffoldSplitter()
    train_data, valid_data, test_data = \
        stratified_splitter.train_valid_test_split(
            solubility_dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1)
    assert len(train_data) == 8
    assert len(valid_data) == 1
    assert len(test_data) == 1  

    merged_dataset = dc.data.DiskDataset.merge(
        [train_data, valid_data, test_data])
    assert sorted(merged_dataset.ids) == (
           sorted(solubility_dataset.ids))

  def test_singletask_random_k_fold_split(self):
    """
    Test singletask RandomSplitter class.

examples/gdb7/gdb7_datasets.py

deleted100644 → 0
+0 −79
Original line number Diff line number Diff line
"""
gdb7 dataset loader.
"""
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

import os
import numpy as np
import shutil
import deepchem as dc
import scipy.io
import csv

def load_gdb7_from_mat(split=0):

  if not os.path.exists('qm7.mat'): os.system('wget http://www.quantum-machine.org/data/qm7.mat')
  dataset = scipy.io.loadmat('qm7.mat')
  
  P = dataset['P'][list(range(0,split))+list(range(split+1,5))].flatten()
  X = dataset['X'][P]
  y = dataset['T'][0,P]
  w = np.ones_like(y)
  train_dataset = dc.data.NumpyDataset(X, y, w, ids=None)
  
  Ptest = dataset['P'][split]
  X = dataset['X'][Ptest]
  y = dataset['T'][0,Ptest]
  w = np.ones_like(y)
  test_dataset = dc.data.NumpyDataset(X, y, w, ids=None)

  transformers = [dc.trans.NormalizationTransformer(transform_y=True, dataset=train_dataset)]

  for transformer in transformers:
    train_dataset = transformer.transform(train_dataset)
    test_dataset = transformer.transform(test_dataset)

  gdb7_tasks = ["atomization_energy"]
  return gdb7_tasks, (train_dataset, test_dataset), transformers

def load_gdb7(featurizer=None, split='random'):
  """Load gdb7 datasets."""
  # Featurize gdb7 dataset
  print("About to featurize gdb7 dataset.")
  current_dir = os.path.dirname(os.path.realpath(__file__))
  dataset_file = os.path.join(
      current_dir, "./gdb7.sdf")
  gdb7_tasks = ["u0_atom"]
  if featurizer is None:
    featurizer = dc.feat.CoulombMatrixEig(23)
  loader = dc.data.SDFLoader(tasks=gdb7_tasks, smiles_field="smiles", 
                             mol_field="mol", featurizer=featurizer)
  dataset = loader.featurize(dataset_file)
 
  # Initialize transformers 
  transformers = [
      dc.trans.NormalizationTransformer(transform_y=True, dataset=dataset)]

  print("About to transform data")
  for transformer in transformers:
    dataset = transformer.transform(dataset)
  
  split_file = os.path.join(
      current_dir, "./gdb7_splits.csv")

  split_indices = []
  with open(split_file, 'r') as f:
    reader = csv.reader(f)
    for row in reader:
      row_int = (np.asarray(list(map(int, row)))).tolist()
      split_indices.append(row_int)
  
  
  splitters = {'index': dc.splits.IndexSplitter(),
               'random': dc.splits.RandomSplitter(),
               'indice': dc.splits.IndiceSplitter(valid_indices=split_indices[1])}
  splitter = splitters[split]
  train, valid, test = splitter.train_valid_test_split(dataset)
  return gdb7_tasks, (train, valid, test), transformers
+0 −0

File moved.

+0 −0

File moved.

Loading