Commit 624e77c7 authored by Yutong Zhao's avatar Yutong Zhao Committed by leswing
Browse files

Add ButinaSplitter class

Add support for benchmarking tox21 and butina. Cleanup print statements.

cleanup

WIP Butina

Fix typo

Update documentation

Fix tests

Add example
parent fa3daf45
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -28,6 +28,18 @@ def load_solubility_data():
  
  return loader.featurize(input_file)

def load_butina_data():
  """Loads solubility dataset"""
  current_dir = os.path.dirname(os.path.abspath(__file__))
  featurizer = dc.feat.CircularFingerprint(size=1024)
  tasks = ["task"]
  # task_type = "regression"
  input_file = os.path.join(current_dir, "../../models/tests/butina_example.csv")
  loader = dc.data.CSVLoader(
      tasks=tasks, smiles_field="smiles", featurizer=featurizer)
  
  return loader.featurize(input_file)

def load_multitask_data():
  """Load example multitask data."""
  current_dir = os.path.dirname(os.path.abspath(__file__))
+11 −0
Original line number Diff line number Diff line
task,smiles
0,OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)C(O)C3O 
0,Cc1occc1C(=O)Nc2ccccc2
0,CCCCCCCCCCC
0,c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43
0,CCCCCCCCCCCC
1,CCCCCCCCCCCCC
0,Clc1cc(Cl)c(c(Cl)c1)c2c(Cl)cccc2Cl
0,C1CCNCC1
0,ClC4=C(Cl)C5(Cl)C3C1CC(C2OC12)C3C4(Cl)C5(Cl)Cl
1,COc5cc4OCC3Oc2c1CC(Oc1ccc2C(=O)C3c4cc5OC)C(C)=C 
+66 −0
Original line number Diff line number Diff line
@@ -11,7 +11,11 @@ __license__ = "GPL"

import tempfile
import numpy as np
import itertools
from rdkit import Chem
from rdkit import DataStructs
from rdkit.Chem import AllChem
from rdkit.ML.Cluster import Butina
from deepchem.utils import ScaffoldGenerator
from deepchem.utils.save import log
from deepchem.data import NumpyDataset
@@ -492,6 +496,68 @@ class IndiceSplitter(Splitter):
    
    return (train_indices, self.valid_indices, self.test_indices)

def ClusterFps(fps,cutoff=0.2):
    # (ytz): this is directly copypasta'd from Greg Landrum's clustering example.
    dists = []
    nfps = len(fps)
    for i in range(1,nfps):
        sims = DataStructs.BulkTanimotoSimilarity(fps[i],fps[:i])
        dists.extend([1-x for x in sims])
    cs = Butina.ClusterData(dists,nfps,cutoff,isDistData=True)
    return cs

class ButinaSplitter(Splitter):
  """
  Class for doing data splits based on the butina clustering of a bulk tanimoto
  fingerprint matrix.
  """

  def split(self, dataset, frac_train=None, frac_valid=None, frac_test=None,
            log_every_n=1000, cutoff=0.18):
    """
    Splits internal compounds into train and validation based on the butina
    clustering algorithm. This splitting algorithm has an O(N^2) run time, where N
    is the number of elements in the dataset. The dataset is expected to be a classification
    dataset.

    This algorithm is designed to generate validation data that are novel chemotypes.
    
    Note that this function entirely disregards the ratios for frac_train, frac_valid,
    and frac_test. Furthermore, it does not generate a test set, only a train and valid set.
  
    Setting a small cutoff value will generate smaller, finer clusters of high similarity,
    whereas setting a large cutoff value will generate larger, coarser clusters of low similarity.
    """
    print("Performing butina clustering with cutoff of", cutoff)
    mols = []
    for ind, smiles in enumerate(dataset.ids):
      mols.append(Chem.MolFromSmiles(smiles))
    n_mols = len(mols)
    fps = [AllChem.GetMorganFingerprintAsBitVect(x,2,1024) for x in mols]

    scaffold_sets = ClusterFps(fps, cutoff=cutoff)
    scaffold_sets = sorted(scaffold_sets, key=lambda x: -len(x))

    ys = dataset.y
    valid_inds = []
    for c_idx, cluster in enumerate(scaffold_sets):
      # for m_idx in cluster:
      valid_inds.extend(cluster)
      # continue until we find an active in all the tasks, otherwise we can't
      # compute a meaningful AUC
      # TODO (ytz): really, we want at least one active and inactive in both scenarios.
      # TODO (Ytz): for regression tasks we'd stop after only one cluster.
      active_populations = np.sum(ys[valid_inds], axis=0)
      if np.all(active_populations):
        print("# of actives per task in valid:", active_populations)
        print("Total # of validation points:", len(valid_inds))
        break

    train_inds = list(itertools.chain.from_iterable(scaffold_sets[c_idx+1:]))
    test_inds = []

    return train_inds, valid_inds, []


class ScaffoldSplitter(Splitter):
  """
+19 −0
Original line number Diff line number Diff line
@@ -90,6 +90,20 @@ class TestSplitters(unittest.TestCase):
    assert sorted(merged_dataset.ids) == (
           sorted(solubility_dataset.ids))

  def test_singletask_butina_split(self):
    """
    Test singletask ScaffoldSplitter class.
    """
    solubility_dataset = dc.data.tests.load_butina_data()
    scaffold_splitter = dc.splits.ButinaSplitter()
    train_data, valid_data, test_data = \
        scaffold_splitter.train_valid_test_split(
            solubility_dataset)
    print(len(train_data), len(valid_data))
    assert len(train_data) == 7
    assert len(valid_data) == 3
    assert len(test_data) == 0

  def test_singletask_random_k_fold_split(self):
    """
    Test singletask RandomSplitter class.
@@ -422,3 +436,8 @@ class TestSplitters(unittest.TestCase):
      # verify that there are no rows (samples) in weights matrix w
      # that have no hits.
      assert len(np.where(~w.any(axis=1))[0]) == 0


if __name__ == "__main__":
  import nose
  nose.run(defaultTest=__name__)
 No newline at end of file
+1 −1
Original line number Diff line number Diff line
@@ -125,7 +125,7 @@ def benchmark_loading_datasets(hyper_parameters,
  elif split in ['indice']:
    if not dataset in ['gdb7']:
      return
  elif not split in [None, 'index','random','scaffold']:
  elif not split in [None, 'index','random','scaffold', 'butina']:
    raise ValueError('Splitter function not supported')
  
  loading_functions = {'tox21': load_tox21, 'muv': load_muv,
Loading