Commit af87706b authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Preliminary splits implementation. Still buggy

parent c32a4d83
Loading
Loading
Loading
Loading
+51 −9
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ __author__ = "Bharath Ramsundar, Aneesh Pappu "
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "GPL"

import tempfile
import numpy as np
from rdkit import Chem
from deepchem.utils import ScaffoldGenerator
@@ -34,6 +35,37 @@ class Splitter(object):
    """Creates splitter object."""
    self.verbosity = verbosity

  def k_fold_split(self, dataset, directories, compute_feature_statistics=True):
    """Does K-fold split of dataset."""
    log("Computing K-fold split", self.verbosity)
    k = len(directories)
    fold_datasets = []
    # rem_dataset is remaining portion of dataset
    rem_dataset = dataset
    for fold in range(k):
      # Note starts as 1/k since fold starts at 0. Ends at 1 since fold goes up
      # to k-1.
      frac_fold = 1./(k-fold)
      fold_dir = directories[fold]
      fold_inds, rem_inds, _ = self.split(
          rem_dataset,
          frac_train=frac_fold, frac_valid=1-frac_fold, frac_test=0)
      fold_dataset = dataset.select( 
          fold_dir, fold_inds,
          compute_feature_statistics=compute_feature_statistics)
      # TODO(rbharath): Is making a tempfile the best way to handle remainders?
      # Would be  nice to be able to do in memory dataset construction...
      rem_dir = tempfile.mkdtemp()
      rem_dataset = dataset.select( 
          rem_dir, rem_inds,
          compute_feature_statistics=compute_feature_statistics)
      ####################################################################### DEBUG
      print("frac_fold, fold, len(fold_dataset), len(rem_dataset)")
      print(frac_fold, fold, len(fold_dataset), len(rem_dataset))
      ####################################################################### DEBUG
      fold_datasets.append(fold_dataset)
    return fold_datasets

  def train_valid_test_split(self, dataset, train_dir,
                             valid_dir, test_dir, frac_train=.8,
                             frac_valid=.1, frac_test=.1, seed=None,
@@ -98,7 +130,8 @@ class StratifiedSplitter(Splitter):
    return array_list
  
  def __generate_required_hits(self, w, frac_split):
    required_hits = (w != 0).sum(0)  # returns list of per column sum of non zero elements
    # returns list of per column sum of non zero elements
    required_hits = (w != 0).sum(0)  
    for col_hits in required_hits:
      col_hits = int(frac_split * col_hits)
    return required_hits
@@ -106,7 +139,8 @@ class StratifiedSplitter(Splitter):
  def __generate_required_index(self, w, required_hit_list):
    col_index = 0
    index_hits = []
    # loop through each column and obtain index required to splice out for required fraction of hits
    # loop through each column and obtain index required to splice out for
    # required fraction of hits
    for col in w.T:
      num_hit = 0
      num_required = required_hit_list[col_index]
@@ -121,7 +155,7 @@ class StratifiedSplitter(Splitter):

  def __split(self, X, y, w, ids, frac_split):
    """
    Method that does bulk of splitting dataset appropriately based on desired split percentage
    Method that does bulk of splitting dataset.
    """
    # find the total number of hits for each task and calculate the required
    # number of hits for split based on frac_split
@@ -165,18 +199,26 @@ class StratifiedSplitter(Splitter):

    # Obtain original x, y, and w arrays and shuffle
    X, y, w, ids = self.__randomize_arrays(dataset.to_numpy())
    X_train, y_train, w_train, ids_train, X_test, y_test, w_test, ids_test = self.__split(X, y, w, ids, frac_train)
    arrays = self.__split(X, y, w, ids, frac_train)
    train_arrays, rem_arrays = arrays[:4], arrays[4:]
    (X_train, y_train, w_train, ids_train) = train_arrays
    (X_rem, y_rem, w_rem, ids_rem) = rem_arrays 

    # calculate percent split for valid (out of test and valid)
    valid_percentage = frac_valid / (frac_valid + frac_test)
    # split test data into valid and test, treating sub test set also as sparse
    X_valid, y_valid, w_valid, ids_valid, X_test, y_test, w_test, ids_test = self.__split(X_test, y_test, w_test,
                                                                                          ids_test, valid_percentage)
    arrays = self.__split(X_rem, y_rem, w_rem, ids_rem, valid_percentage)
    (valid_arrays, test_arrays) = arrays[:4], arrays[4:]
    (X_valid, y_valid, w_valid, ids_valid) = valid_arrays
    (X_test, y_test, w_test, ids_test) = test_arrays

    # turn back into dataset objects
    train_data = Dataset.from_numpy(train_dir, X_train, y_train, w_train, ids_train)
    valid_data = Dataset.from_numpy(valid_dir, X_valid, y_valid, w_valid, ids_valid)
    test_data = Dataset.from_numpy(test_dir, X_test, y_test, w_test, ids_test)
    train_data = Dataset.from_numpy(
        train_dir, X_train, y_train, w_train, ids_train)
    valid_data = Dataset.from_numpy(
        valid_dir, X_valid, y_valid, w_valid, ids_valid)
    test_data = Dataset.from_numpy(
        test_dir, X_test, y_test, w_test, ids_test)
    return train_data, valid_data, test_data


+163 −109
Original line number Diff line number Diff line
@@ -9,11 +9,13 @@ __author__ = "Bharath Ramsundar, Aneesh Pappu"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "GPL"

import tempfile
import numpy as np
from deepchem.datasets import Dataset
from deepchem.splits import RandomSplitter
from deepchem.splits import ScaffoldSplitter
from deepchem.splits import StratifiedSplitter
from deepchem.datasets.tests import TestDatasetAPI
import numpy as np


class TestSplitters(TestDatasetAPI):
@@ -36,6 +38,58 @@ class TestSplitters(TestDatasetAPI):
    assert len(valid_data) == 1
    assert len(test_data) == 1

  def test_singletask_random_k_fold_split(self):
    """
    Test singletask RandomSplitter class.
    """
    solubility_dataset = self.load_solubility_data()
    random_splitter = RandomSplitter()
    ids_set = set(solubility_dataset.get_ids())
    #################################################### DEBUG
    print("ids_set")
    print(ids_set)
    #################################################### DEBUG

    K = 5
    fold_dirs = [tempfile.mkdtemp() for i in range(K)]
    fold_datasets = random_splitter.k_fold_split(solubility_dataset, fold_dirs)
    #################################################### DEBUG
    for fold in range(K):
      fold_dataset = fold_datasets[fold]
      fold_ids_set = set(fold_dataset.get_ids())
      print("fold")
      print(fold)
      print("fold_ids_set")
      print(fold_ids_set)
    #################################################### DEBUG
    for fold in range(K):
      fold_dataset = fold_datasets[fold]
      # Verify lengths is 10/k == 2
      assert len(fold_dataset) == 2
      # Verify that compounds in this fold are subset of original compounds
      fold_ids_set = set(fold_dataset.get_ids())
      assert fold_ids_set.issubset(ids_set)
      # Verify that no two folds have overlapping compounds.
      for other_fold in range(K):
        if fold == other_fold:
          continue
        other_fold_dataset = fold_datasets[other_fold]
        other_fold_ids_set = set(other_fold_dataset.get_ids())
        #################################################### DEBUG
        print("fold, other_fold")
        print(fold, other_fold)
        print("fold_ids_set, other_fold_ids_set")
        print(fold_ids_set, other_fold_ids_set)
        #################################################### DEBUG
        assert fold_ids_set.isdisjoint(other_fold_ids_set)

    merge_dir = tempfile.mkdtemp()
    merged_dataset = Dataset.merge(merge_dir, fold_datasets)
    assert len(merged_dataset) == len(solubility_dataset)
    assert sorted(merged_dataset.get_ids()) == (
           sorted(solubility_dataset.get_ids()))
    

  def test_singletask_scaffold_split(self):
    """
    Test singletask ScaffoldSplitter class.
@@ -91,11 +145,11 @@ class TestSplitters(TestDatasetAPI):

    X, y, w, ids = sparse_dataset.to_numpy()

      """
      sparsity is determined by number of w weights that are 0 for a given task
      structure of w np array is such that each row corresponds to a sample -- e.g., analyze third column for third
      sparse task
      """
    
    # sparsity is determined by number of w weights that are 0 for a given
    # task structure of w np array is such that each row corresponds to a
    # sample -- e.g., analyze third column for third sparse task
  
    frac_train = 0.5
    cutoff = int(frac_train * w.shape[0])
    w = w[:cutoff, :]