Unverified Commit 8b52d26f authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1751 from KacperKubara/return_scaffolds

Added functionality to ScaffoldSplitter to return all scaffolds
parents 713b8128 72ad0b4d
Loading
Loading
Loading
Loading
+30 −20
Original line number Diff line number Diff line
@@ -852,7 +852,6 @@ class ScaffoldSplitter(Splitter):

  def split(self,
            dataset,
            seed=None,
            frac_train=.8,
            frac_valid=.1,
            frac_test=.1,
@@ -861,36 +860,47 @@ class ScaffoldSplitter(Splitter):
    Splits internal compounds into train/validation/test by scaffold.
    """
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
    scaffold_sets = self.generate_scaffolds(dataset)

    train_cutoff = frac_train * len(dataset)
    valid_cutoff = (frac_train + frac_valid) * len(dataset)
    train_inds, valid_inds, test_inds = [], [], []

    log("About to sort in scaffold sets", self.verbose)
    for scaffold_set in scaffold_sets:
      if len(train_inds) + len(scaffold_set) > train_cutoff:
        if len(train_inds) + len(valid_inds) + len(scaffold_set) > valid_cutoff:
          test_inds += scaffold_set
        else:
          valid_inds += scaffold_set
      else:
        train_inds += scaffold_set
    return train_inds, valid_inds, test_inds

  def generate_scaffolds(self, dataset, log_every_n=1000):
    """
    Returns all scaffolds from the dataset
    """
    scaffolds = {}
    log("About to generate scaffolds", self.verbose)
    data_len = len(dataset)

    log("About to generate scaffolds", self.verbose)
    for ind, smiles in enumerate(dataset.ids):
      if ind % log_every_n == 0:
        log("Generating scaffold %d/%d" % (ind, data_len), self.verbose)
        log(f"Generating scaffold {ind} {data_len}", self.verbose)
      scaffold = generate_scaffold(smiles)
      if scaffold not in scaffolds:
        scaffolds[scaffold] = [ind]
      else:
        scaffolds[scaffold].append(ind)

    # Sort from largest to smallest scaffold sets
    scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
    scaffold_sets = [
        scaffold_set for (scaffold, scaffold_set) in sorted(
            scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
    ]
    train_cutoff = frac_train * len(dataset)
    valid_cutoff = (frac_train + frac_valid) * len(dataset)
    train_inds, valid_inds, test_inds = [], [], []
    log("About to sort in scaffold sets", self.verbose)
    for scaffold_set in scaffold_sets:
      if len(train_inds) + len(scaffold_set) > train_cutoff:
        if len(train_inds) + len(valid_inds) + len(scaffold_set) > valid_cutoff:
          test_inds += scaffold_set
        else:
          valid_inds += scaffold_set
      else:
        train_inds += scaffold_set
    return train_inds, valid_inds, test_inds
    return scaffold_sets


class FingerprintSplitter(Splitter):
+27 −0
Original line number Diff line number Diff line
import unittest
from unittest import TestCase

import numpy as np
import deepchem as dc
from deepchem.splits.splitters import ScaffoldSplitter


class TestScaffoldSplitter(TestCase):

  def test_scaffolds(self):
    tox21_tasks, tox21_datasets, transformers = \
      dc.molnet.load_tox21(featurizer='GraphConv')
    train_dataset, valid_dataset, test_dataset = tox21_datasets

    splitter = ScaffoldSplitter()
    scaffolds_separate = splitter.generate_scaffolds(train_dataset)
    scaffolds_train, scaffolds_valid, _ = splitter.split(train_dataset)

    # The amount of datapoints has to be the same
    data_cnt = sum([len(sfd) for sfd in scaffolds_separate])
    self.assertTrue(data_cnt == train_dataset.X.shape[0])

    # The number of scaffolds generated by the splitter
    # has to be smaller or equal than number of total molecules
    scaffolds_separate_cnt = len(scaffolds_separate)
    self.assertTrue(scaffolds_separate_cnt <= train_dataset.X.shape[0])