Commit 431b143b authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #237 from rbharath/k_fold_splits

Adds support for K-fold splits to Splitter classes
parents c32a4d83 641d6bbd
Loading
Loading
Loading
Loading
+20 −7
Original line number Diff line number Diff line
@@ -602,9 +602,21 @@ class Dataset(object):
  # TODO(rbharath): This change for general object types seems a little
  # kludgey.  Is there a more principled approach to support general objects?
  def select(self, select_dir, indices, compute_feature_statistics=False):
    """Creates a new dataset from a selection of indices from self."""
    """Creates a new dataset from a selection of indices from self.

    Parameters
    ----------
    select_dir: string
      Path to new directory that the selected indices will be copied to.
    indices: list
      List of indices to select.
    compute_feature_statistics: bool
      Whether or not to compute moments of features. Only meaningful if features
      are np.ndarrays. Not meaningful for other featurizations.
    """
    if not os.path.exists(select_dir):
      os.makedirs(select_dir)
    # Handle edge case with empty indices
    if not len(indices):
      return Dataset(
          data_dir=select_dir, metadata_rows=[], verbosity=self.verbosity)
@@ -622,12 +634,11 @@ class Dataset(object):
        if indices_count + num_shard_elts >= len(indices):
          break
      # Need to offset indices to fit within shard_size
      shard_indices = (
          indices[indices_count:indices_count+num_shard_elts] - count)
      X_sel = X[shard_indices]
      y_sel = y[shard_indices]
      w_sel = w[shard_indices]
      ids_sel = ids[shard_indices]
      shard_inds =  indices[indices_count:indices_count+num_shard_elts] - count
      X_sel = X[shard_inds]
      y_sel = y[shard_inds]
      w_sel = w[shard_inds]
      ids_sel = ids[shard_inds]
      basename = "dataset-%d" % shard_num
      metadata_rows.append(
          Dataset.write_data_to_disk(
@@ -696,6 +707,8 @@ class Dataset(object):
    """
    Returns all molecule-ids for this dataset.
    """
    if len(self) == 0:
      return np.array([])
    ids = []
    for (_, _, _, ids_b) in self.itershards():
      ids.append(np.atleast_1d(np.squeeze(ids_b)))
+94 −35
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ __author__ = "Bharath Ramsundar, Aneesh Pappu "
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "GPL"

import tempfile
import numpy as np
from rdkit import Chem
from deepchem.utils import ScaffoldGenerator
@@ -16,7 +17,6 @@ from deepchem.utils.save import log
from deepchem.datasets import Dataset
from deepchem.featurizers.featurize import load_data


def generate_scaffold(smiles, include_chirality=False):
  """Compute the Bemis-Murcko scaffold for a SMILES string."""
  mol = Chem.MolFromSmiles(smiles)
@@ -24,16 +24,49 @@ def generate_scaffold(smiles, include_chirality=False):
  scaffold = engine.get_scaffold(mol)
  return scaffold

def randomize_arrays(array_list):
  # assumes that every array is of the same dimension
  num_rows = array_list[0].shape[0]
  perm = np.random.permutation(num_rows)
  for array in array_list:
    array = array[perm]
  return array_list

class Splitter(object):
  """
  Abstract base class for chemically aware splits..
  """

  def __init__(self, verbosity=None):
    """Creates splitter object."""
    self.verbosity = verbosity

  def k_fold_split(self, dataset, directories, compute_feature_statistics=True):
    """Does K-fold split of dataset."""
    log("Computing K-fold split", self.verbosity)
    k = len(directories)
    fold_datasets = []
    # rem_dataset is remaining portion of dataset
    rem_dataset = dataset
    for fold in range(k):
      # Note starts as 1/k since fold starts at 0. Ends at 1 since fold goes up
      # to k-1.
      frac_fold = 1./(k-fold)
      fold_dir = directories[fold]
      fold_inds, rem_inds, _ = self.split(
          rem_dataset,
          frac_train=frac_fold, frac_valid=1-frac_fold, frac_test=0)
      fold_dataset = rem_dataset.select( 
          fold_dir, fold_inds,
          compute_feature_statistics=compute_feature_statistics)
      # TODO(rbharath): Is making a tempfile the best way to handle remainders?
      # Would be  nice to be able to do in memory dataset construction...
      rem_dir = tempfile.mkdtemp()
      rem_dataset = rem_dataset.select( 
          rem_dir, rem_inds,
          compute_feature_statistics=compute_feature_statistics)
      fold_datasets.append(fold_dataset)
    return fold_datasets

  def train_valid_test_split(self, dataset, train_dir,
                             valid_dir, test_dir, frac_train=.8,
                             frac_valid=.1, frac_test=.1, seed=None,
@@ -84,21 +117,28 @@ class Splitter(object):
    raise NotImplementedError

  

class StratifiedSplitter(Splitter):
  """
  Class for doing stratified splits -- where data is too sparse to do regular splits
  """
  Stratified Splitter class.

  def __randomize_arrays(self, array_list):
    # assumes that every array is of the same dimension
    num_rows = array_list[0].shape[0]
    perm = np.random.permutation(num_rows)
    for array in array_list:
      array = array[perm]
    return array_list
  For sparse multitask datasets, a standard split offers no guarantees that the
  splits will have any activate compounds. This class guarantees that each task
  will have a proportional split of the activates in a split. TO do this, a
  ragged split is performed with different numbers of compounds taken from each
  task. Thus, the length of the split arrays may exceed the split of the
  original array. That said, no datapoint is copied to more than one split, so
  correctness is still ensured.

  Note that this splitter is only valid for boolean label data.

  TODO(rbharath): This splitter should be refactored to match style of other
  splitter classes.
  """

  def __generate_required_hits(self, w, frac_split):
    required_hits = (w != 0).sum(0)  # returns list of per column sum of non zero elements
    # returns list of per column sum of non zero elements
    required_hits = (w != 0).sum(axis=0)  
    for col_hits in required_hits:
      col_hits = int(frac_split * col_hits)
    return required_hits
@@ -106,7 +146,8 @@ class StratifiedSplitter(Splitter):
  def __generate_required_index(self, w, required_hit_list):
    col_index = 0
    index_hits = []
    # loop through each column and obtain index required to splice out for required fraction of hits
    # loop through each column and obtain index required to splice out for
    # required fraction of hits
    for col in w.T:
      num_hit = 0
      num_required = required_hit_list[col_index]
@@ -121,7 +162,7 @@ class StratifiedSplitter(Splitter):

  def __split(self, X, y, w, ids, frac_split):
    """
    Method that does bulk of splitting dataset appropriately based on desired split percentage
    Method that does bulk of splitting dataset.
    """
    # find the total number of hits for each task and calculate the required
    # number of hits for split based on frac_split
@@ -135,48 +176,48 @@ class StratifiedSplitter(Splitter):
    for col_index, index in enumerate(index_list):
      # copy over up to required index for weight first_split
      w_1[:index, col_index] = w[:index, col_index]
      w_1[index:, col_index] = np.zeros(w_1[index:, col_index].shape)
      w_2[:index, col_index] = np.zeros(w_2[:index, col_index].shape)
      w_2[index:, col_index] = w[index:, col_index]

    # check out if any rows in either w_1 or w_2 are just zeros
    rows_to_keep_1 = w_1.any(axis=1)
    rows_to_keep_2 = w_2.any(axis=1)
    rows_1 = w_1.any(axis=1)
    rows_2 = w_2.any(axis=1)

    # prune first set
    w_1 = w_1[rows_to_keep_1]
    X_1 = X[rows_to_keep_1]
    y_1 = y[rows_to_keep_1]
    ids_1 = ids[rows_to_keep_1]
    w_1, X_1, y_1, ids_1 = w_1[rows_1], X[rows_1], y[rows_1], ids[rows_1]

    # prune second sets
    w_2 = w_2[rows_to_keep_2]
    X_2 = X[rows_to_keep_2]
    y_2 = y[rows_to_keep_2]
    ids_2 = ids[rows_to_keep_2]
    w_2, X_2, y_2, ids_2 = w_2[rows_2], X[rows_2], y[rows_2], ids[rows_2]

    return X_1, y_1, w_1, ids_1, X_2, \
           y_2, w_2, ids_2
    return ((X_1, y_1, w_1, ids_1), (X_2, y_2, w_2, ids_2))

  def train_valid_test_split(self, dataset, train_dir,
                             valid_dir, test_dir, frac_train=.8,
                             frac_valid=.1, frac_test=.1, seed=None,
                             log_every_n=1000):
    """Custom split due to raggedness in original split.
    """

    # Obtain original x, y, and w arrays and shuffle
    X, y, w, ids = self.__randomize_arrays(dataset.to_numpy())
    X_train, y_train, w_train, ids_train, X_test, y_test, w_test, ids_test = self.__split(X, y, w, ids, frac_train)
    X, y, w, ids = randomize_arrays(dataset.to_numpy())
    train_arrays, rem_arrays = self.__split(X, y, w, ids, frac_train)
    (X_train, y_train, w_train, ids_train) = train_arrays
    (X_rem, y_rem, w_rem, ids_rem) = rem_arrays 

    # calculate percent split for valid (out of test and valid)
    valid_percentage = frac_valid / (frac_valid + frac_test)
    # split test data into valid and test, treating sub test set also as sparse
    X_valid, y_valid, w_valid, ids_valid, X_test, y_test, w_test, ids_test = self.__split(X_test, y_test, w_test,
                                                                                          ids_test, valid_percentage)
    valid_arrays, test_arrays = self.__split(
        X_rem, y_rem, w_rem, ids_rem, valid_percentage)
    (X_valid, y_valid, w_valid, ids_valid) = valid_arrays
    (X_test, y_test, w_test, ids_test) = test_arrays

    # turn back into dataset objects
    train_data = Dataset.from_numpy(train_dir, X_train, y_train, w_train, ids_train)
    valid_data = Dataset.from_numpy(valid_dir, X_valid, y_valid, w_valid, ids_valid)
    test_data = Dataset.from_numpy(test_dir, X_test, y_test, w_test, ids_test)
    train_data = Dataset.from_numpy(
        train_dir, X_train, y_train, w_train, ids_train)
    valid_data = Dataset.from_numpy(
        valid_dir, X_valid, y_valid, w_valid, ids_valid)
    test_data = Dataset.from_numpy(
        test_dir, X_test, y_test, w_test, ids_test)
    return train_data, valid_data, test_data


@@ -231,6 +272,24 @@ class RandomSplitter(Splitter):
    return (shuffled[:train_cutoff], shuffled[train_cutoff:valid_cutoff],
            shuffled[valid_cutoff:])

class IndexSplitter(Splitter):
  """
  Class for simple order based splits. 
  """

  def split(self, dataset, seed=None, frac_train=.8, frac_valid=.1,
            frac_test=.1, log_every_n=None):
    """
    Splits internal compounds into train/validation/test in provided order.
    """
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
    num_datapoints = len(dataset)
    train_cutoff = int(frac_train * num_datapoints)
    valid_cutoff = int((frac_train + frac_valid) * num_datapoints)
    indices = range(num_datapoints)
    return (indices[:train_cutoff], indices[train_cutoff:valid_cutoff],
            indices[valid_cutoff:])


class ScaffoldSplitter(Splitter):
  """
+212 −110
Original line number Diff line number Diff line
@@ -9,11 +9,14 @@ __author__ = "Bharath Ramsundar, Aneesh Pappu"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "GPL"

import tempfile
import numpy as np
from deepchem.datasets import Dataset
from deepchem.splits import RandomSplitter
from deepchem.splits import IndexSplitter
from deepchem.splits import ScaffoldSplitter
from deepchem.splits import StratifiedSplitter
from deepchem.datasets.tests import TestDatasetAPI
import numpy as np


class TestSplitters(TestDatasetAPI):
@@ -36,6 +39,33 @@ class TestSplitters(TestDatasetAPI):
    assert len(valid_data) == 1
    assert len(test_data) == 1

    merge_dir = tempfile.mkdtemp()
    merged_dataset = Dataset.merge(
        merge_dir, [train_data, valid_data, test_data])
    assert sorted(merged_dataset.get_ids()) == (
           sorted(solubility_dataset.get_ids()))

  def test_singletask_index_split(self):
    """
    Test singletask RandomSplitter class.
    """
    solubility_dataset = self.load_solubility_data()
    random_splitter = IndexSplitter()
    train_data, valid_data, test_data = \
        random_splitter.train_valid_test_split(
            solubility_dataset,
            self.train_dir, self.valid_dir, self.test_dir,
            frac_train=0.8, frac_valid=0.1, frac_test=0.1)
    assert len(train_data) == 8
    assert len(valid_data) == 1
    assert len(test_data) == 1

    merge_dir = tempfile.mkdtemp()
    merged_dataset = Dataset.merge(
        merge_dir, [train_data, valid_data, test_data])
    assert sorted(merged_dataset.get_ids()) == (
           sorted(solubility_dataset.get_ids()))

  def test_singletask_scaffold_split(self):
    """
    Test singletask ScaffoldSplitter class.
@@ -51,6 +81,105 @@ class TestSplitters(TestDatasetAPI):
    assert len(valid_data) == 1
    assert len(test_data) == 1

  def test_singletask_random_k_fold_split(self):
    """
    Test singletask RandomSplitter class.
    """
    solubility_dataset = self.load_solubility_data()
    random_splitter = RandomSplitter()
    ids_set = set(solubility_dataset.get_ids())

    K = 5
    fold_dirs = [tempfile.mkdtemp() for i in range(K)]
    fold_datasets = random_splitter.k_fold_split(solubility_dataset, fold_dirs)
    for fold in range(K):
      fold_dataset = fold_datasets[fold]
      # Verify lengths is 10/k == 2
      assert len(fold_dataset) == 2
      # Verify that compounds in this fold are subset of original compounds
      fold_ids_set = set(fold_dataset.get_ids())
      assert fold_ids_set.issubset(ids_set)
      # Verify that no two folds have overlapping compounds.
      for other_fold in range(K):
        if fold == other_fold:
          continue
        other_fold_dataset = fold_datasets[other_fold]
        other_fold_ids_set = set(other_fold_dataset.get_ids())
        assert fold_ids_set.isdisjoint(other_fold_ids_set)

    merge_dir = tempfile.mkdtemp()
    merged_dataset = Dataset.merge(merge_dir, fold_datasets)
    assert len(merged_dataset) == len(solubility_dataset)
    assert sorted(merged_dataset.get_ids()) == (
           sorted(solubility_dataset.get_ids()))

  def test_singletask_index_k_fold_split(self):
    """
    Test singletask IndexSplitter class.
    """
    solubility_dataset = self.load_solubility_data()
    index_splitter = IndexSplitter()
    ids_set = set(solubility_dataset.get_ids())

    K = 5
    fold_dirs = [tempfile.mkdtemp() for i in range(K)]
    fold_datasets = index_splitter.k_fold_split(solubility_dataset, fold_dirs)

    for fold in range(K):
      fold_dataset = fold_datasets[fold]
      # Verify lengths is 10/k == 2
      assert len(fold_dataset) == 2
      # Verify that compounds in this fold are subset of original compounds
      fold_ids_set = set(fold_dataset.get_ids())
      assert fold_ids_set.issubset(ids_set)
      # Verify that no two folds have overlapping compounds.
      for other_fold in range(K):
        if fold == other_fold:
          continue
        other_fold_dataset = fold_datasets[other_fold]
        other_fold_ids_set = set(other_fold_dataset.get_ids())
        assert fold_ids_set.isdisjoint(other_fold_ids_set)

    merge_dir = tempfile.mkdtemp()
    merged_dataset = Dataset.merge(merge_dir, fold_datasets)
    assert len(merged_dataset) == len(solubility_dataset)
    assert sorted(merged_dataset.get_ids()) == (
           sorted(solubility_dataset.get_ids()))
    
  def test_singletask_scaffold_k_fold_split(self):
    """
    Test singletask ScaffoldSplitter class.
    """
    solubility_dataset = self.load_solubility_data()
    scaffold_splitter = ScaffoldSplitter()
    ids_set = set(solubility_dataset.get_ids())

    K = 5
    fold_dirs = [tempfile.mkdtemp() for i in range(K)]
    fold_datasets = scaffold_splitter.k_fold_split(
        solubility_dataset, fold_dirs)

    for fold in range(K):
      fold_dataset = fold_datasets[fold]
      # Verify lengths is 10/k == 2
      assert len(fold_dataset) == 2
      # Verify that compounds in this fold are subset of original compounds
      fold_ids_set = set(fold_dataset.get_ids())
      assert fold_ids_set.issubset(ids_set)
      # Verify that no two folds have overlapping compounds.
      for other_fold in range(K):
        if fold == other_fold:
          continue
        other_fold_dataset = fold_datasets[other_fold]
        other_fold_ids_set = set(other_fold_dataset.get_ids())
        assert fold_ids_set.isdisjoint(other_fold_ids_set)

    merge_dir = tempfile.mkdtemp()
    merged_dataset = Dataset.merge(merge_dir, fold_datasets)
    assert len(merged_dataset) == len(solubility_dataset)
    assert sorted(merged_dataset.get_ids()) == (
           sorted(solubility_dataset.get_ids()))

  def test_multitask_random_split(self):
    """
    Test multitask RandomSplitter class.
@@ -66,6 +195,21 @@ class TestSplitters(TestDatasetAPI):
    assert len(valid_data) == 1
    assert len(test_data) == 1

  def test_multitask_index_split(self):
    """
    Test multitask IndexSplitter class.
    """
    multitask_dataset = self.load_multitask_data()
    index_splitter = IndexSplitter()
    train_data, valid_data, test_data = \
        index_splitter.train_valid_test_split(
            multitask_dataset,
            self.train_dir, self.valid_dir, self.test_dir,
            frac_train=0.8, frac_valid=0.1, frac_test=0.1)
    assert len(train_data) == 8
    assert len(valid_data) == 1
    assert len(test_data) == 1

  def test_multitask_scaffold_split(self):
    """
    Test multitask ScaffoldSplitter class.
@@ -85,63 +229,21 @@ class TestSplitters(TestDatasetAPI):
    """
    Test multitask StratifiedSplitter class
    """
      # ensure sparse dataset is actually sparse

    # sparsity is determined by number of w weights that are 0 for a given
    # task structure of w np array is such that each row corresponds to a
    # sample. The loaded sparse dataset has many rows with only zeros
    sparse_dataset = self.load_sparse_multitask_dataset()

    X, y, w, ids = sparse_dataset.to_numpy()
    
      """
      sparsity is determined by number of w weights that are 0 for a given task
      structure of w np array is such that each row corresponds to a sample -- e.g., analyze third column for third
      sparse task
      """
      frac_train = 0.5
      cutoff = int(frac_train * w.shape[0])
      w = w[:cutoff, :]
      sparse_flag = False

      col_index = 0
      for col in w.T:
        if not np.any(col): #check to see if any columns are all zero
          sparse_flag = True
          break
        col_index+=1
      if not sparse_flag:
        print("Test dataset isn't sparse -- test failed")
      else:
        print("Column %d is sparse -- expected" % col_index)
      assert sparse_flag

    stratified_splitter = StratifiedSplitter()
      train_data, valid_data, test_data = \
          stratified_splitter.train_valid_test_split(
    datasets = stratified_splitter.train_valid_test_split(
        sparse_dataset,
        self.train_dir, self.valid_dir, self.test_dir,
              frac_train=0.8, frac_valid=0.1, frac_test=0.1
          )
        frac_train=0.8, frac_valid=0.1, frac_test=0.1)
    train_data, valid_data, test_data = datasets

      datasets = [train_data, valid_data, test_data]
      dataset_index = 0
      for dataset in datasets:
    for dataset_index, dataset in enumerate(datasets):
      X, y, w, ids = dataset.to_numpy()
        # verify that each task in the train dataset has some hits
        for col in w.T:
            if not np.any(col):
                print("Fail -- one column doesn't have results")
                if dataset_index == 0:
                    print("train_data failed")
                elif dataset_index == 1:
                    print("valid_data failed")
                elif dataset_index == 2:
                    print("test_data failed")
                assert np.any(col)
        if dataset_index == 0:
            print("train_data passed")
        elif dataset_index == 1:
            print("valid_data passed")
        elif dataset_index == 2:
            print("test_data passed")
        dataset_index+=1
      print("end of stratified test")
      assert 1 == 1
      # verify that there are no rows (samples) in weights matrix w
      # that have no hits.
      assert len(np.where(~w.any(axis=1))[0]) == 0