Commit 56365623 authored by KacperKubara's avatar KacperKubara
Browse files

Added tests for the new scaffoldsplitter

parent 03f1709e
Loading
Loading
Loading
Loading
+30 −20
Original line number Diff line number Diff line
@@ -852,7 +852,6 @@ class ScaffoldSplitter(Splitter):

  def split(self,
            dataset,
            seed=None,
            frac_train=.8,
            frac_valid=.1,
            frac_test=.1,
@@ -861,36 +860,47 @@ class ScaffoldSplitter(Splitter):
    Splits internal compounds into train/validation/test by scaffold.
    """
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
    scaffold_sets = self.generate_scaffolds(dataset)

    train_cutoff = frac_train * len(dataset)
    valid_cutoff = (frac_train + frac_valid) * len(dataset)
    train_inds, valid_inds, test_inds = [], [], []

    log("About to sort in scaffold sets", self.verbose)
    for scaffold_set in scaffold_sets:
      if len(train_inds) + len(scaffold_set) > train_cutoff:
        if len(train_inds) + len(valid_inds) + len(scaffold_set) > valid_cutoff:
          test_inds += scaffold_set
        else:
          valid_inds += scaffold_set
      else:
        train_inds += scaffold_set
    return train_inds, valid_inds, test_inds

  def generate_scaffolds(self, dataset, log_every_n=1000):
    """
    Returns all scaffolds from the dataset
    """
    scaffolds = {}
    log("About to generate scaffolds", self.verbose)
    data_len = len(dataset)

    log("About to generate scaffolds", self.verbose)
    for ind, smiles in enumerate(dataset.ids):
      if ind % log_every_n == 0:
        log("Generating scaffold %d/%d" % (ind, data_len), self.verbose)
        log(f"Generating scaffold {ind} {data_len}", self.verbose)
      scaffold = generate_scaffold(smiles)
      if scaffold not in scaffolds:
        scaffolds[scaffold] = [ind]
      else:
        scaffolds[scaffold].append(ind)

    # Sort from largest to smallest scaffold sets
    scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
    scaffold_sets = [
        scaffold_set for (scaffold, scaffold_set) in sorted(
            scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
    ]
    train_cutoff = frac_train * len(dataset)
    valid_cutoff = (frac_train + frac_valid) * len(dataset)
    train_inds, valid_inds, test_inds = [], [], []
    log("About to sort in scaffold sets", self.verbose)
    for scaffold_set in scaffold_sets:
      if len(train_inds) + len(scaffold_set) > train_cutoff:
        if len(train_inds) + len(valid_inds) + len(scaffold_set) > valid_cutoff:
          test_inds += scaffold_set
        else:
          valid_inds += scaffold_set
      else:
        train_inds += scaffold_set
    return train_inds, valid_inds, test_inds
    return scaffold_sets


class FingerprintSplitter(Splitter):
+134 −0
Original line number Diff line number Diff line
import unittest
from unittest import TestCase

import numpy as np
import deepchem as dc
from deepchem.splits import Splitter
from deepchem.utils import ScaffoldGenerator
from deepchem.utils.save import log
from rdkit import Chem


def generate_scaffold(smiles, include_chirality=False):
  """Compute the Bemis-Murcko scaffold for a SMILES string."""
  mol = Chem.MolFromSmiles(smiles)
  engine = ScaffoldGenerator(include_chirality=include_chirality)
  scaffold = engine.get_scaffold(mol)
  return scaffold


class ScaffoldSplitter(Splitter):
  """
  Class for doing data splits based on the scaffold of small molecules.
  """

  def split(self,
            dataset,
            frac_train=.8,
            frac_valid=.1,
            frac_test=.1,
            log_every_n=1000):
    """
    Splits internal compounds into train/validation/test by scaffold.
    """
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
    scaffold_sets = self.generate_scaffolds(dataset)

    train_cutoff = frac_train * len(dataset)
    valid_cutoff = (frac_train + frac_valid) * len(dataset)
    train_inds, valid_inds, test_inds = [], [], []

    log("About to sort in scaffold sets", self.verbose)
    for scaffold_set in scaffold_sets:
      if len(train_inds) + len(scaffold_set) > train_cutoff:
        if len(train_inds) + len(valid_inds) + len(scaffold_set) > valid_cutoff:
          test_inds += scaffold_set
        else:
          valid_inds += scaffold_set
      else:
        train_inds += scaffold_set
    return train_inds, valid_inds, test_inds

  def generate_scaffolds(self, dataset, log_every_n=1000):
    """
    Returns all scaffolds from the dataset
    """
    scaffolds = {}
    data_len = len(dataset)

    log("About to generate scaffolds", self.verbose)
    for ind, smiles in enumerate(dataset.ids):
      if ind % log_every_n == 0:
        log(f"Generating scaffold {ind} {data_len}", self.verbose)
      scaffold = generate_scaffold(smiles)
      if scaffold not in scaffolds:
        scaffolds[scaffold] = [ind]
      else:
        scaffolds[scaffold].append(ind)

    # Sort from largest to smallest scaffold sets
    scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
    scaffold_sets = [
        scaffold_set for (scaffold, scaffold_set) in sorted(
            scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
    ]
    return scaffold_sets


class TestScaffoldSplitter(TestCase):

  def create_dataset(self):
    # Load Tox21 dataset
    tox21_tasks, tox21_datasets, transformers = dc.molnet.load_tox21(
        featurizer='GraphConv')
    return tox21_datasets

  def test_scaffolds(self):
    splitter = ScaffoldSplitter()
    train_dataset, valid_dataset, test_dataset = self.create_dataset()

    scaffolds_separate = splitter.generate_scaffolds(train_dataset)
    scaffolds_train, scaffolds_valid, _ = splitter.split(train_dataset)

    data_cnt = sum([len(sfd) for sfd in scaffolds_separate])

    # The amount of datapoints has to be the same
    print(f"\nDatapoints count from generate_scaffolds: {data_cnt}")
    print(
        f"Datapoints count from train_dataset.X.shape[0]: {train_dataset.X.shape[0]}"
    )
    self.assertTrue(data_cnt == train_dataset.X.shape[0])

    # The number of scaffolds generated by the splitter
    # has to be smaller or equal than number of total molecules
    scaffolds_separate_cnt = len(scaffolds_separate)
    print(f"\nNumber of molecules: {train_dataset.X.shape[0]}")
    print(
        f"Number of scaffolds from generate_scaffolds() method: {scaffolds_separate_cnt}"
    )
    self.assertTrue(scaffolds_separate_cnt <= train_dataset.X.shape[0])


if __name__ == "__main__":
  unittest.main()
"""class TestSpecifiedIndexSplitter(TestCase):

  def create_dataset(self):
    n_samples, n_features = 20, 10
    X = np.random.random(size=(n_samples, n_features))
    y = np.random.random(size=(n_samples, 1))
    return deepchem.data.NumpyDataset(X, y)

  def test_split(self):
    ds = self.create_dataset()
    indexes = list(range(len(ds)))
    train, test = train_test_split(indexes)
    train, valid = train_test_split(train)

    splitter = SpecifiedIndexSplitter(train, valid, test)
    train_ds, valid_ds, test_ds = splitter.train_valid_test_split(ds)

    self.assertTrue(np.all(train_ds.X == ds.X[train]))
    self.assertTrue(np.all(valid_ds.X == ds.X[valid]))
    self.assertTrue(np.all(test_ds.X == ds.X[test]))
"""