Commit afc299c3 authored by calebgeniesse's avatar calebgeniesse
Browse files

register clintox dataset in benchmark.py

parent 68d00f4e
Loading
Loading
Loading
Loading
+7 −6
Original line number Diff line number Diff line
@@ -53,6 +53,7 @@ from pdbbind.pdbbind_datasets import load_pdbbind_grid
from chembl.chembl_datasets import load_chembl
from gdb7.gdb7_datasets import load_gdb7
from sampl.sampl_datasets import load_sampl
from clintox.clintox_datasets import load_clintox

def benchmark_loading_datasets(hyper_parameters, 
                               dataset='tox21', model='tf', split=None,
@@ -66,7 +67,7 @@ def benchmark_loading_datasets(hyper_parameters,
      hyper parameters including dropout rate, learning rate, etc.
  dataset: string, optional (default='tox21')
      choice of which dataset to use, should be: tox21, muv, sider, 
      toxcast, pcba, delaney, kaggle, nci
      toxcast, pcba, delaney, kaggle, nci, clintox
  model: string,  optional (default='tf')
      choice of which model to use, should be: rf, tf, tf_robust, logreg,
      graphconv, tf_regression, graphconvreg
@@ -76,7 +77,7 @@ def benchmark_loading_datasets(hyper_parameters,
      path of result file
  """
  
  if dataset in ['muv', 'pcba', 'tox21', 'sider', 'toxcast']:
  if dataset in ['muv', 'pcba', 'tox21', 'sider', 'toxcast', 'clintox']:
    mode = 'classification'
  elif dataset in ['kaggle', 'delaney', 'nci', 'pdbbind', 'chembl', 
                   'gdb7', 'sampl']:
@@ -133,7 +134,7 @@ def benchmark_loading_datasets(hyper_parameters,
                       'kaggle': load_kaggle, 'delaney': load_delaney,
                       'pdbbind': load_pdbbind_grid,
                       'chembl': load_chembl, 'gdb7': load_gdb7,
                       'sampl': load_sampl}
                       'sampl': load_sampl, 'clintox': load_clintox}
  
  print('-------------------------------------')
  print('Benchmark %s on dataset: %s' % (model, dataset))
@@ -545,7 +546,7 @@ if __name__ == '__main__':
           'tf_regression, graphconvreg')
  parser.add_argument('-d', action='append', dest='dataset_args', default=[], 
      help='Choice of dataset: tox21, sider, muv, toxcast, pcba, ' + 
           'kaggle, delaney, nci, pdbbindi, chembl, gdb7')
           'kaggle, delaney, nci, pdbbindi, chembl, gdb7, clintox')
  args = parser.parse_args()
  #Datasets and models used in the benchmark test
  splitters = args.splitter_args
@@ -558,7 +559,7 @@ if __name__ == '__main__':
    models = ['tf', 'tf_robust', 'logreg', 'graphconv', 
              'tf_regression', 'graphconvreg']
  if len(datasets) == 0:
    datasets = ['tox21', 'sider', 'muv', 'toxcast', 'pcba', 
    datasets = ['tox21', 'sider', 'muv', 'toxcast', 'pcba', 'clintox',
                'delaney', 'nci', 'kaggle', 'pdbbind', 'chembl', 'gdb7']

  #input hyperparameters
@@ -604,7 +605,7 @@ if __name__ == '__main__':

  for split in splitters:
    for dataset in datasets:
      if dataset in ['tox21', 'sider', 'muv', 'toxcast', 'pcba']:
      if dataset in ['tox21', 'sider', 'muv', 'toxcast', 'pcba', 'clintox']:
        for model in models:
          if model in ['tf', 'tf_robust', 'logreg', 'graphconv']:
            benchmark_loading_datasets(