Commit b0cae084 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #606 from lilleswing/fingerprint-splitter

Fingerprint Splitter
parents d257e334 fbe19a20
Loading
Loading
Loading
Loading
+222 −137
Original line number Diff line number Diff line
@@ -16,6 +16,8 @@ from rdkit import Chem
from rdkit import DataStructs
from rdkit.Chem import AllChem
from rdkit.ML.Cluster import Butina
from rdkit import DataStructs
from rdkit.Chem.Fingerprints import FingerprintMols
import deepchem as dc
from deepchem.data import DiskDataset
from deepchem.utils import ScaffoldGenerator
@@ -684,6 +686,89 @@ class ScaffoldSplitter(Splitter):
    return train_inds, valid_inds, test_inds


class FingerprintSplitter(Splitter):
  """
    Class for doing data splits based on the fingerprints of small molecules
    O(N**2) algorithm
  """

  def split(self,
            dataset,
            frac_train=.8,
            frac_valid=.1,
            frac_test=.1,
            log_every_n=1000):
    """
        Splits internal compounds into train/validation/test by fingerprint.
    """
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
    data_len = len(dataset)
    mols, fingerprints = [], []
    train_inds, valid_inds, test_inds = [], [], []
    for ind, smiles in enumerate(dataset.ids):
      mol = Chem.MolFromSmiles(smiles, sanitize=False)
      mols.append(mol)
      fp = FingerprintMols.FingerprintMol(mol)
      fingerprints.append(fp)

    distances = np.ones(shape=(data_len, data_len))
    for i in range(data_len):
      for j in range(data_len):
        distances[i][j] = 1 - DataStructs.FingerprintSimilarity(fingerprints[i],
                                                                fingerprints[j])

    train_cutoff = int(frac_train * len(dataset))
    valid_cutoff = int(frac_valid * len(dataset))

    # Pick the mol closest to everything as the first element of training
    closest_ligand = np.argmin(np.sum(distances, axis=1))
    train_inds.append(closest_ligand)
    cur_distances = [float('inf')] * data_len
    self.update_distances(closest_ligand, cur_distances, distances, train_inds)
    for i in range(1, train_cutoff):
      closest_ligand = np.argmin(cur_distances)
      train_inds.append(closest_ligand)
      self.update_distances(closest_ligand, cur_distances, distances,
                            train_inds)

    # Pick the closest mol from what is left
    index, best_dist = 0, float('inf')
    for i in range(data_len):
      if i in train_inds:
        continue
      dist = np.sum(distances[i])
      if dist < best_dist:
        index, best_dist = i, dist
    valid_inds.append(index)

    leave_out_indexes = train_inds + valid_inds
    cur_distances = [float('inf')] * data_len
    self.update_distances(index, cur_distances, distances, leave_out_indexes)
    for i in range(1, valid_cutoff):
      closest_ligand = np.argmin(cur_distances)
      valid_inds.append(closest_ligand)
      leave_out_indexes.append(closest_ligand)
      self.update_distances(closest_ligand, cur_distances, distances,
                            leave_out_indexes)

    # Test is everything else
    for i in range(data_len):
      if i in leave_out_indexes:
        continue
      test_inds.append(i)
    return train_inds, valid_inds, test_inds

  def update_distances(self, last_selected, cur_distances, distance_matrix,
                       dont_update):
    for i in range(len(cur_distances)):
      if i in dont_update:
        cur_distances[i] = float('inf')
        continue
      new_dist = distance_matrix[i][last_selected]
      if new_dist < cur_distances[i]:
        cur_distances[i] = new_dist


class SpecifiedSplitter(Splitter):
  """
    Class that splits data according to user specification.
+20 −0
Original line number Diff line number Diff line
@@ -5,6 +5,8 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from rdkit.Chem.Fingerprints import FingerprintMols

__author__ = "Bharath Ramsundar, Aneesh Pappu"
__copyright__ = "Copyright 2016, Stanford University"
__license__ = "MIT"
@@ -13,6 +15,7 @@ import tempfile
import unittest
import numpy as np
import deepchem as dc
from rdkit import Chem, DataStructs


class TestSplitters(unittest.TestCase):
@@ -71,6 +74,23 @@ class TestSplitters(unittest.TestCase):
    assert len(valid_data) == 1
    assert len(test_data) == 1

  def test_singletask_fingerprint_split(self):
    """
    Test singletask Fingerprint class.
    """
    solubility_dataset = dc.data.tests.load_solubility_data()
    assert (len(solubility_dataset.X) == 10)
    scaffold_splitter = dc.splits.FingerprintSplitter()
    train_data, valid_data, test_data = \
        scaffold_splitter.train_valid_test_split(
            solubility_dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1)
    assert len(train_data) == 8
    assert len(valid_data) == 1
    assert len(test_data) == 1
    s1 = set(train_data.ids)
    assert valid_data.ids[0] not in s1
    assert test_data.ids[0] not in s1

  def test_singletask_stratified_split(self):
    """
    Test singletask SingletaskStratifiedSplitter class.