Commit 72ad0b4d authored by KacperKubara's avatar KacperKubara
Browse files

Refactored the ScaffoldSplitter test

parent 84efd4e7
Loading
Loading
Loading
Loading
+4 −16
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@ import unittest
from unittest import TestCase

import numpy as np
import deepchem as dc
from deepchem.splits.splitters import ScaffoldSplitter


@@ -10,30 +11,17 @@ class TestScaffoldSplitter(TestCase):
  def test_scaffolds(self):
    tox21_tasks, tox21_datasets, transformers = \
      dc.molnet.load_tox21(featurizer='GraphConv')
    splitter = ScaffoldSplitter()
    train_dataset, valid_dataset, test_dataset = self.create_dataset()
    train_dataset, valid_dataset, test_dataset = tox21_datasets

    splitter = ScaffoldSplitter()
    scaffolds_separate = splitter.generate_scaffolds(train_dataset)
    scaffolds_train, scaffolds_valid, _ = splitter.split(train_dataset)

    data_cnt = sum([len(sfd) for sfd in scaffolds_separate])

    # The amount of datapoints has to be the same
    print(f"\nDatapoints count from generate_scaffolds: {data_cnt}")
    print(
        f"Datapoints count from train_dataset.X.shape[0]: {train_dataset.X.shape[0]}"
    )
    data_cnt = sum([len(sfd) for sfd in scaffolds_separate])
    self.assertTrue(data_cnt == train_dataset.X.shape[0])

    # The number of scaffolds generated by the splitter
    # has to be smaller or equal than number of total molecules
    scaffolds_separate_cnt = len(scaffolds_separate)
    print(f"\nNumber of molecules: {train_dataset.X.shape[0]}")
    print(
        f"Number of scaffolds from generate_scaffolds() method: {scaffolds_separate_cnt}"
    )
    self.assertTrue(scaffolds_separate_cnt <= train_dataset.X.shape[0])


if __name__ == "__main__":
  unittest.main()