Commit 84efd4e7 authored by KacperKubara's avatar KacperKubara
Browse files

Deleted Scaffoldsplitter class from the test file, code refactoring

parent 56365623
Loading
Loading
Loading
Loading
+3 −98
Original line number Diff line number Diff line
@@ -2,88 +2,14 @@ import unittest
from unittest import TestCase

import numpy as np
import deepchem as dc
from deepchem.splits import Splitter
from deepchem.utils import ScaffoldGenerator
from deepchem.utils.save import log
from rdkit import Chem


def generate_scaffold(smiles, include_chirality=False):
  """Compute the Bemis-Murcko scaffold for a SMILES string."""
  mol = Chem.MolFromSmiles(smiles)
  engine = ScaffoldGenerator(include_chirality=include_chirality)
  scaffold = engine.get_scaffold(mol)
  return scaffold


class ScaffoldSplitter(Splitter):
  """
  Class for doing data splits based on the scaffold of small molecules.
  """

  def split(self,
            dataset,
            frac_train=.8,
            frac_valid=.1,
            frac_test=.1,
            log_every_n=1000):
    """
    Splits internal compounds into train/validation/test by scaffold.
    """
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
    scaffold_sets = self.generate_scaffolds(dataset)

    train_cutoff = frac_train * len(dataset)
    valid_cutoff = (frac_train + frac_valid) * len(dataset)
    train_inds, valid_inds, test_inds = [], [], []

    log("About to sort in scaffold sets", self.verbose)
    for scaffold_set in scaffold_sets:
      if len(train_inds) + len(scaffold_set) > train_cutoff:
        if len(train_inds) + len(valid_inds) + len(scaffold_set) > valid_cutoff:
          test_inds += scaffold_set
        else:
          valid_inds += scaffold_set
      else:
        train_inds += scaffold_set
    return train_inds, valid_inds, test_inds

  def generate_scaffolds(self, dataset, log_every_n=1000):
    """
    Returns all scaffolds from the dataset
    """
    scaffolds = {}
    data_len = len(dataset)

    log("About to generate scaffolds", self.verbose)
    for ind, smiles in enumerate(dataset.ids):
      if ind % log_every_n == 0:
        log(f"Generating scaffold {ind} {data_len}", self.verbose)
      scaffold = generate_scaffold(smiles)
      if scaffold not in scaffolds:
        scaffolds[scaffold] = [ind]
      else:
        scaffolds[scaffold].append(ind)

    # Sort from largest to smallest scaffold sets
    scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
    scaffold_sets = [
        scaffold_set for (scaffold, scaffold_set) in sorted(
            scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
    ]
    return scaffold_sets
from deepchem.splits.splitters import ScaffoldSplitter


class TestScaffoldSplitter(TestCase):

  def create_dataset(self):
    # Load Tox21 dataset
    tox21_tasks, tox21_datasets, transformers = dc.molnet.load_tox21(
        featurizer='GraphConv')
    return tox21_datasets

  def test_scaffolds(self):
    tox21_tasks, tox21_datasets, transformers = \
      dc.molnet.load_tox21(featurizer='GraphConv')
    splitter = ScaffoldSplitter()
    train_dataset, valid_dataset, test_dataset = self.create_dataset()

@@ -111,24 +37,3 @@ class TestScaffoldSplitter(TestCase):

if __name__ == "__main__":
  unittest.main()
"""class TestSpecifiedIndexSplitter(TestCase):

  def create_dataset(self):
    n_samples, n_features = 20, 10
    X = np.random.random(size=(n_samples, n_features))
    y = np.random.random(size=(n_samples, 1))
    return deepchem.data.NumpyDataset(X, y)

  def test_split(self):
    ds = self.create_dataset()
    indexes = list(range(len(ds)))
    train, test = train_test_split(indexes)
    train, valid = train_test_split(train)

    splitter = SpecifiedIndexSplitter(train, valid, test)
    train_ds, valid_ds, test_ds = splitter.train_valid_test_split(ds)

    self.assertTrue(np.all(train_ds.X == ds.X[train]))
    self.assertTrue(np.all(valid_ds.X == ds.X[valid]))
    self.assertTrue(np.all(test_ds.X == ds.X[test]))
"""