Unverified Commit 3574d1c3 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1507 from peastman/splitseed

Splitters use random seed correctly
parents 3d926441 ea451712
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -138,6 +138,7 @@ class Splitter(object):
    log("Computing train/valid/test indices", self.verbose)
    train_inds, valid_inds, test_inds = self.split(
        dataset,
        seed=seed,
        frac_train=frac_train,
        frac_test=frac_test,
        frac_valid=frac_valid,
@@ -179,12 +180,14 @@ class Splitter(object):
        frac_train=frac_train,
        frac_test=1 - frac_train,
        frac_valid=0.,
        seed=seed,
        verbose=verbose,
        **kwargs)
    return train_dataset, test_dataset

  def split(self,
            dataset,
            seed=None,
            frac_train=None,
            frac_valid=None,
            frac_test=None,
@@ -789,6 +792,7 @@ class ButinaSplitter(Splitter):

  def split(self,
            dataset,
            seed=None,
            frac_train=None,
            frac_valid=None,
            frac_test=None,
@@ -848,6 +852,7 @@ class ScaffoldSplitter(Splitter):

  def split(self,
            dataset,
            seed=None,
            frac_train=.8,
            frac_valid=.1,
            frac_test=.1,
@@ -896,6 +901,7 @@ class FingerprintSplitter(Splitter):

  def split(self,
            dataset,
            seed=None,
            frac_train=.8,
            frac_valid=.1,
            frac_test=.1,
@@ -987,6 +993,7 @@ class SpecifiedSplitter(Splitter):

  def split(self,
            dataset,
            seed=None,
            frac_train=.8,
            frac_valid=.1,
            frac_test=.1,
@@ -1023,6 +1030,7 @@ class SpecifiedIndexSplitter(Splitter):

  def split(self,
            dataset,
            seed=None,
            frac_train=.8,
            frac_valid=.1,
            frac_test=.1,
+15 −1
Original line number Diff line number Diff line
@@ -16,7 +16,7 @@ from deepchem.data import NumpyDataset
from deepchem.splits import IndexSplitter


class TestSplitters(unittest.TestCase):
class TestSplitter(unittest.TestCase):
  """
  Test some basic splitters.
  """
@@ -539,6 +539,20 @@ class TestSplitters(unittest.TestCase):
    assert len(valid_data) == 1
    assert len(test_data) == 1

  def test_random_seed(self):
    """Test that splitters use the random seed correctly."""
    dataset = dc.data.tests.load_solubility_data()
    splitter = dc.splits.RandomSplitter()
    train1, valid1, test1 = splitter.train_valid_test_split(dataset, seed=1)
    train2, valid2, test2 = splitter.train_valid_test_split(dataset, seed=2)
    train3, valid3, test3 = splitter.train_valid_test_split(dataset, seed=1)
    assert np.array_equal(train1.X, train3.X)
    assert np.array_equal(valid1.X, valid3.X)
    assert np.array_equal(test1.X, test3.X)
    assert not np.array_equal(train1.X, train2.X)
    assert not np.array_equal(valid1.X, valid2.X)
    assert not np.array_equal(test1.X, test2.X)


if __name__ == "__main__":
  import nose