Commit 641d6bbd authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Splitter tests now passing

parent af87706b
Loading
Loading
Loading
Loading
+20 −7
Original line number Diff line number Diff line
@@ -602,9 +602,21 @@ class Dataset(object):
  # TODO(rbharath): This change for general object types seems a little
  # kludgey.  Is there a more principled approach to support general objects?
  def select(self, select_dir, indices, compute_feature_statistics=False):
    """Creates a new dataset from a selection of indices from self."""
    """Creates a new dataset from a selection of indices from self.

    Parameters
    ----------
    select_dir: string
      Path to new directory that the selected indices will be copied to.
    indices: list
      List of indices to select.
    compute_feature_statistics: bool
      Whether or not to compute moments of features. Only meaningful if features
      are np.ndarrays. Not meaningful for other featurizations.
    """
    if not os.path.exists(select_dir):
      os.makedirs(select_dir)
    # Handle edge case with empty indices
    if not len(indices):
      return Dataset(
          data_dir=select_dir, metadata_rows=[], verbosity=self.verbosity)
@@ -622,12 +634,11 @@ class Dataset(object):
        if indices_count + num_shard_elts >= len(indices):
          break
      # Need to offset indices to fit within shard_size
      shard_indices = (
          indices[indices_count:indices_count+num_shard_elts] - count)
      X_sel = X[shard_indices]
      y_sel = y[shard_indices]
      w_sel = w[shard_indices]
      ids_sel = ids[shard_indices]
      shard_inds =  indices[indices_count:indices_count+num_shard_elts] - count
      X_sel = X[shard_inds]
      y_sel = y[shard_inds]
      w_sel = w[shard_inds]
      ids_sel = ids[shard_inds]
      basename = "dataset-%d" % shard_num
      metadata_rows.append(
          Dataset.write_data_to_disk(
@@ -696,6 +707,8 @@ class Dataset(object):
    """
    Returns all molecule-ids for this dataset.
    """
    if len(self) == 0:
      return np.array([])
    ids = []
    for (_, _, _, ids_b) in self.itershards():
      ids.append(np.atleast_1d(np.squeeze(ids_b)))
+54 −37
Original line number Diff line number Diff line
@@ -17,7 +17,6 @@ from deepchem.utils.save import log
from deepchem.datasets import Dataset
from deepchem.featurizers.featurize import load_data


def generate_scaffold(smiles, include_chirality=False):
  """Compute the Bemis-Murcko scaffold for a SMILES string."""
  mol = Chem.MolFromSmiles(smiles)
@@ -25,12 +24,18 @@ def generate_scaffold(smiles, include_chirality=False):
  scaffold = engine.get_scaffold(mol)
  return scaffold

def randomize_arrays(array_list):
  # assumes that every array is of the same dimension
  num_rows = array_list[0].shape[0]
  perm = np.random.permutation(num_rows)
  for array in array_list:
    array = array[perm]
  return array_list

class Splitter(object):
  """
  Abstract base class for chemically aware splits..
  """

  def __init__(self, verbosity=None):
    """Creates splitter object."""
    self.verbosity = verbosity
@@ -50,19 +55,15 @@ class Splitter(object):
      fold_inds, rem_inds, _ = self.split(
          rem_dataset,
          frac_train=frac_fold, frac_valid=1-frac_fold, frac_test=0)
      fold_dataset = dataset.select( 
      fold_dataset = rem_dataset.select( 
          fold_dir, fold_inds,
          compute_feature_statistics=compute_feature_statistics)
      # TODO(rbharath): Is making a tempfile the best way to handle remainders?
      # Would be  nice to be able to do in memory dataset construction...
      rem_dir = tempfile.mkdtemp()
      rem_dataset = dataset.select( 
      rem_dataset = rem_dataset.select( 
          rem_dir, rem_inds,
          compute_feature_statistics=compute_feature_statistics)
      ####################################################################### DEBUG
      print("frac_fold, fold, len(fold_dataset), len(rem_dataset)")
      print(frac_fold, fold, len(fold_dataset), len(rem_dataset))
      ####################################################################### DEBUG
      fold_datasets.append(fold_dataset)
    return fold_datasets

@@ -116,22 +117,28 @@ class Splitter(object):
    raise NotImplementedError

  

class StratifiedSplitter(Splitter):
  """
  Class for doing stratified splits -- where data is too sparse to do regular splits
  """
  Stratified Splitter class.

  def __randomize_arrays(self, array_list):
    # assumes that every array is of the same dimension
    num_rows = array_list[0].shape[0]
    perm = np.random.permutation(num_rows)
    for array in array_list:
      array = array[perm]
    return array_list
  For sparse multitask datasets, a standard split offers no guarantees that the
  splits will have any activate compounds. This class guarantees that each task
  will have a proportional split of the activates in a split. TO do this, a
  ragged split is performed with different numbers of compounds taken from each
  task. Thus, the length of the split arrays may exceed the split of the
  original array. That said, no datapoint is copied to more than one split, so
  correctness is still ensured.

  Note that this splitter is only valid for boolean label data.

  TODO(rbharath): This splitter should be refactored to match style of other
  splitter classes.
  """

  def __generate_required_hits(self, w, frac_split):
    # returns list of per column sum of non zero elements
    required_hits = (w != 0).sum(0)  
    required_hits = (w != 0).sum(axis=0)  
    for col_hits in required_hits:
      col_hits = int(frac_split * col_hits)
    return required_hits
@@ -169,46 +176,38 @@ class StratifiedSplitter(Splitter):
    for col_index, index in enumerate(index_list):
      # copy over up to required index for weight first_split
      w_1[:index, col_index] = w[:index, col_index]
      w_1[index:, col_index] = np.zeros(w_1[index:, col_index].shape)
      w_2[:index, col_index] = np.zeros(w_2[:index, col_index].shape)
      w_2[index:, col_index] = w[index:, col_index]

    # check out if any rows in either w_1 or w_2 are just zeros
    rows_to_keep_1 = w_1.any(axis=1)
    rows_to_keep_2 = w_2.any(axis=1)
    rows_1 = w_1.any(axis=1)
    rows_2 = w_2.any(axis=1)

    # prune first set
    w_1 = w_1[rows_to_keep_1]
    X_1 = X[rows_to_keep_1]
    y_1 = y[rows_to_keep_1]
    ids_1 = ids[rows_to_keep_1]
    w_1, X_1, y_1, ids_1 = w_1[rows_1], X[rows_1], y[rows_1], ids[rows_1]

    # prune second sets
    w_2 = w_2[rows_to_keep_2]
    X_2 = X[rows_to_keep_2]
    y_2 = y[rows_to_keep_2]
    ids_2 = ids[rows_to_keep_2]
    w_2, X_2, y_2, ids_2 = w_2[rows_2], X[rows_2], y[rows_2], ids[rows_2]

    return X_1, y_1, w_1, ids_1, X_2, \
           y_2, w_2, ids_2
    return ((X_1, y_1, w_1, ids_1), (X_2, y_2, w_2, ids_2))

  def train_valid_test_split(self, dataset, train_dir,
                             valid_dir, test_dir, frac_train=.8,
                             frac_valid=.1, frac_test=.1, seed=None,
                             log_every_n=1000):
    """Custom split due to raggedness in original split.
    """

    # Obtain original x, y, and w arrays and shuffle
    X, y, w, ids = self.__randomize_arrays(dataset.to_numpy())
    arrays = self.__split(X, y, w, ids, frac_train)
    train_arrays, rem_arrays = arrays[:4], arrays[4:]
    X, y, w, ids = randomize_arrays(dataset.to_numpy())
    train_arrays, rem_arrays = self.__split(X, y, w, ids, frac_train)
    (X_train, y_train, w_train, ids_train) = train_arrays
    (X_rem, y_rem, w_rem, ids_rem) = rem_arrays 

    # calculate percent split for valid (out of test and valid)
    valid_percentage = frac_valid / (frac_valid + frac_test)
    # split test data into valid and test, treating sub test set also as sparse
    arrays = self.__split(X_rem, y_rem, w_rem, ids_rem, valid_percentage)
    (valid_arrays, test_arrays) = arrays[:4], arrays[4:]
    valid_arrays, test_arrays = self.__split(
        X_rem, y_rem, w_rem, ids_rem, valid_percentage)
    (X_valid, y_valid, w_valid, ids_valid) = valid_arrays
    (X_test, y_test, w_test, ids_test) = test_arrays

@@ -273,6 +272,24 @@ class RandomSplitter(Splitter):
    return (shuffled[:train_cutoff], shuffled[train_cutoff:valid_cutoff],
            shuffled[valid_cutoff:])

class IndexSplitter(Splitter):
  """
  Class for simple order based splits. 
  """

  def split(self, dataset, seed=None, frac_train=.8, frac_valid=.1,
            frac_test=.1, log_every_n=None):
    """
    Splits internal compounds into train/validation/test in provided order.
    """
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
    num_datapoints = len(dataset)
    train_cutoff = int(frac_train * num_datapoints)
    valid_cutoff = int((frac_train + frac_valid) * num_datapoints)
    indices = range(num_datapoints)
    return (indices[:train_cutoff], indices[train_cutoff:valid_cutoff],
            indices[valid_cutoff:])


class ScaffoldSplitter(Splitter):
  """
+126 −78
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ import tempfile
import numpy as np
from deepchem.datasets import Dataset
from deepchem.splits import RandomSplitter
from deepchem.splits import IndexSplitter
from deepchem.splits import ScaffoldSplitter
from deepchem.splits import StratifiedSplitter
from deepchem.datasets.tests import TestDatasetAPI
@@ -38,6 +39,48 @@ class TestSplitters(TestDatasetAPI):
    assert len(valid_data) == 1
    assert len(test_data) == 1

    merge_dir = tempfile.mkdtemp()
    merged_dataset = Dataset.merge(
        merge_dir, [train_data, valid_data, test_data])
    assert sorted(merged_dataset.get_ids()) == (
           sorted(solubility_dataset.get_ids()))

  def test_singletask_index_split(self):
    """
    Test singletask RandomSplitter class.
    """
    solubility_dataset = self.load_solubility_data()
    random_splitter = IndexSplitter()
    train_data, valid_data, test_data = \
        random_splitter.train_valid_test_split(
            solubility_dataset,
            self.train_dir, self.valid_dir, self.test_dir,
            frac_train=0.8, frac_valid=0.1, frac_test=0.1)
    assert len(train_data) == 8
    assert len(valid_data) == 1
    assert len(test_data) == 1

    merge_dir = tempfile.mkdtemp()
    merged_dataset = Dataset.merge(
        merge_dir, [train_data, valid_data, test_data])
    assert sorted(merged_dataset.get_ids()) == (
           sorted(solubility_dataset.get_ids()))

  def test_singletask_scaffold_split(self):
    """
    Test singletask ScaffoldSplitter class.
    """
    solubility_dataset = self.load_solubility_data()
    scaffold_splitter = ScaffoldSplitter()
    train_data, valid_data, test_data = \
        scaffold_splitter.train_valid_test_split(
            solubility_dataset,
            self.train_dir, self.valid_dir, self.test_dir,
            frac_train=0.8, frac_valid=0.1, frac_test=0.1)
    assert len(train_data) == 8
    assert len(valid_data) == 1
    assert len(test_data) == 1

  def test_singletask_random_k_fold_split(self):
    """
    Test singletask RandomSplitter class.
@@ -45,23 +88,43 @@ class TestSplitters(TestDatasetAPI):
    solubility_dataset = self.load_solubility_data()
    random_splitter = RandomSplitter()
    ids_set = set(solubility_dataset.get_ids())
    #################################################### DEBUG
    print("ids_set")
    print(ids_set)
    #################################################### DEBUG

    K = 5
    fold_dirs = [tempfile.mkdtemp() for i in range(K)]
    fold_datasets = random_splitter.k_fold_split(solubility_dataset, fold_dirs)
    #################################################### DEBUG
    for fold in range(K):
      fold_dataset = fold_datasets[fold]
      # Verify lengths is 10/k == 2
      assert len(fold_dataset) == 2
      # Verify that compounds in this fold are subset of original compounds
      fold_ids_set = set(fold_dataset.get_ids())
      print("fold")
      print(fold)
      print("fold_ids_set")
      print(fold_ids_set)
    #################################################### DEBUG
      assert fold_ids_set.issubset(ids_set)
      # Verify that no two folds have overlapping compounds.
      for other_fold in range(K):
        if fold == other_fold:
          continue
        other_fold_dataset = fold_datasets[other_fold]
        other_fold_ids_set = set(other_fold_dataset.get_ids())
        assert fold_ids_set.isdisjoint(other_fold_ids_set)

    merge_dir = tempfile.mkdtemp()
    merged_dataset = Dataset.merge(merge_dir, fold_datasets)
    assert len(merged_dataset) == len(solubility_dataset)
    assert sorted(merged_dataset.get_ids()) == (
           sorted(solubility_dataset.get_ids()))

  def test_singletask_index_k_fold_split(self):
    """
    Test singletask IndexSplitter class.
    """
    solubility_dataset = self.load_solubility_data()
    index_splitter = IndexSplitter()
    ids_set = set(solubility_dataset.get_ids())

    K = 5
    fold_dirs = [tempfile.mkdtemp() for i in range(K)]
    fold_datasets = index_splitter.k_fold_split(solubility_dataset, fold_dirs)

    for fold in range(K):
      fold_dataset = fold_datasets[fold]
      # Verify lengths is 10/k == 2
@@ -75,12 +138,6 @@ class TestSplitters(TestDatasetAPI):
          continue
        other_fold_dataset = fold_datasets[other_fold]
        other_fold_ids_set = set(other_fold_dataset.get_ids())
        #################################################### DEBUG
        print("fold, other_fold")
        print(fold, other_fold)
        print("fold_ids_set, other_fold_ids_set")
        print(fold_ids_set, other_fold_ids_set)
        #################################################### DEBUG
        assert fold_ids_set.isdisjoint(other_fold_ids_set)

    merge_dir = tempfile.mkdtemp()
@@ -89,30 +146,63 @@ class TestSplitters(TestDatasetAPI):
    assert sorted(merged_dataset.get_ids()) == (
           sorted(solubility_dataset.get_ids()))
    

  def test_singletask_scaffold_split(self):
  def test_singletask_scaffold_k_fold_split(self):
    """
    Test singletask ScaffoldSplitter class.
    """
    solubility_dataset = self.load_solubility_data()
    scaffold_splitter = ScaffoldSplitter()
    ids_set = set(solubility_dataset.get_ids())

    K = 5
    fold_dirs = [tempfile.mkdtemp() for i in range(K)]
    fold_datasets = scaffold_splitter.k_fold_split(
        solubility_dataset, fold_dirs)

    for fold in range(K):
      fold_dataset = fold_datasets[fold]
      # Verify lengths is 10/k == 2
      assert len(fold_dataset) == 2
      # Verify that compounds in this fold are subset of original compounds
      fold_ids_set = set(fold_dataset.get_ids())
      assert fold_ids_set.issubset(ids_set)
      # Verify that no two folds have overlapping compounds.
      for other_fold in range(K):
        if fold == other_fold:
          continue
        other_fold_dataset = fold_datasets[other_fold]
        other_fold_ids_set = set(other_fold_dataset.get_ids())
        assert fold_ids_set.isdisjoint(other_fold_ids_set)

    merge_dir = tempfile.mkdtemp()
    merged_dataset = Dataset.merge(merge_dir, fold_datasets)
    assert len(merged_dataset) == len(solubility_dataset)
    assert sorted(merged_dataset.get_ids()) == (
           sorted(solubility_dataset.get_ids()))

  def test_multitask_random_split(self):
    """
    Test multitask RandomSplitter class.
    """
    multitask_dataset = self.load_multitask_data()
    random_splitter = RandomSplitter()
    train_data, valid_data, test_data = \
        scaffold_splitter.train_valid_test_split(
            solubility_dataset,
        random_splitter.train_valid_test_split(
            multitask_dataset,
            self.train_dir, self.valid_dir, self.test_dir,
            frac_train=0.8, frac_valid=0.1, frac_test=0.1)
    assert len(train_data) == 8
    assert len(valid_data) == 1
    assert len(test_data) == 1

  def test_multitask_random_split(self):
  def test_multitask_index_split(self):
    """
    Test multitask RandomSplitter class.
    Test multitask IndexSplitter class.
    """
    multitask_dataset = self.load_multitask_data()
    random_splitter = RandomSplitter()
    index_splitter = IndexSplitter()
    train_data, valid_data, test_data = \
        random_splitter.train_valid_test_split(
        index_splitter.train_valid_test_split(
            multitask_dataset,
            self.train_dir, self.valid_dir, self.test_dir,
            frac_train=0.8, frac_valid=0.1, frac_test=0.1)
@@ -139,63 +229,21 @@ class TestSplitters(TestDatasetAPI):
    """
    Test multitask StratifiedSplitter class
    """
    # ensure sparse dataset is actually sparse

    sparse_dataset = self.load_sparse_multitask_dataset()

    X, y, w, ids = sparse_dataset.to_numpy()

    
    # sparsity is determined by number of w weights that are 0 for a given
    # task structure of w np array is such that each row corresponds to a
    # sample -- e.g., analyze third column for third sparse task
  
    frac_train = 0.5
    cutoff = int(frac_train * w.shape[0])
    w = w[:cutoff, :]
    sparse_flag = False

    col_index = 0
    for col in w.T:
      if not np.any(col): #check to see if any columns are all zero
        sparse_flag = True
        break
      col_index+=1
    if not sparse_flag:
      print("Test dataset isn't sparse -- test failed")
    else:
      print("Column %d is sparse -- expected" % col_index)
    assert sparse_flag
    # sample. The loaded sparse dataset has many rows with only zeros
    sparse_dataset = self.load_sparse_multitask_dataset()
    X, y, w, ids = sparse_dataset.to_numpy()
    
    stratified_splitter = StratifiedSplitter()
    train_data, valid_data, test_data = \
        stratified_splitter.train_valid_test_split(
    datasets = stratified_splitter.train_valid_test_split(
        sparse_dataset,
        self.train_dir, self.valid_dir, self.test_dir,
            frac_train=0.8, frac_valid=0.1, frac_test=0.1
        )
        frac_train=0.8, frac_valid=0.1, frac_test=0.1)
    train_data, valid_data, test_data = datasets

    datasets = [train_data, valid_data, test_data]
    dataset_index = 0
    for dataset in datasets:
    for dataset_index, dataset in enumerate(datasets):
      X, y, w, ids = dataset.to_numpy()
      # verify that each task in the train dataset has some hits
      for col in w.T:
          if not np.any(col):
              print("Fail -- one column doesn't have results")
              if dataset_index == 0:
                  print("train_data failed")
              elif dataset_index == 1:
                  print("valid_data failed")
              elif dataset_index == 2:
                  print("test_data failed")
              assert np.any(col)
      if dataset_index == 0:
          print("train_data passed")
      elif dataset_index == 1:
          print("valid_data passed")
      elif dataset_index == 2:
          print("test_data passed")
      dataset_index+=1
    print("end of stratified test")
    assert 1 == 1
      # verify that there are no rows (samples) in weights matrix w
      # that have no hits.
      assert len(np.where(~w.any(axis=1))[0]) == 0