Commit 0ae17d9f authored by leswing's avatar leswing
Browse files

distances not similarity

parent 129f70d4
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -714,7 +714,7 @@ class FingerprintSplitter(Splitter):
    distances = np.ones(shape=(data_len, data_len))
    for i in range(data_len):
      for j in range(data_len):
        distances[i][j] = DataStructs.FingerprintSimilarity(fingerprints[i],
        distances[i][j] = 1 - DataStructs.FingerprintSimilarity(fingerprints[i],
                                                                fingerprints[j])

    train_cutoff = int(frac_train * len(dataset))
+20 −0
Original line number Diff line number Diff line
@@ -5,6 +5,8 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from rdkit.Chem.Fingerprints import FingerprintMols

__author__ = "Bharath Ramsundar, Aneesh Pappu"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "MIT"
@@ -13,6 +15,7 @@ import tempfile
import unittest
import numpy as np
import deepchem as dc
from rdkit import Chem, DataStructs


class TestSplitters(unittest.TestCase):
@@ -71,6 +74,23 @@ class TestSplitters(unittest.TestCase):
    assert len(valid_data) == 1
    assert len(test_data) == 1

  def test_singletask_fingerprint_split(self):
    """
    Test singletask Fingerprint class.
    """
    solubility_dataset = dc.data.tests.load_solubility_data()
    assert (len(solubility_dataset.X) == 10)
    scaffold_splitter = dc.splits.FingerprintSplitter()
    train_data, valid_data, test_data = \
        scaffold_splitter.train_valid_test_split(
            solubility_dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1)
    assert len(train_data) == 8
    assert len(valid_data) == 1
    assert len(test_data) == 1
    s1 = set(train_data.ids)
    assert valid_data.ids[0] not in s1
    assert test_data.ids[0] not in s1

  def test_singletask_stratified_split(self):
    """
    Test singletask SingletaskStratifiedSplitter class.