Commit 10b6cd8e authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #435 from miaecle/qm9

QM9 benchmark
parents cf486d7f ea070aeb
Loading
Loading
Loading
Loading
+5 −1
Original line number Diff line number Diff line
@@ -353,6 +353,9 @@ Scaffold splitting
|qm7b            |MT-NN regression    |Index       |0.931         |0.803         |
|                |MT-NN regression    |Random      |0.923         |0.884         |
|                |MT-NN regression    |Stratified  |0.934         |0.884         | 
|qm9             |MT-NN regression    |Index       |0.733         |0.791         |
|                |MT-NN regression    |Random      |0.811         |0.823         |
|                |MT-NN regression    |Stratified  |0.843         |0.818         | 
|kaggle          |MT-NN regression    |User-defined|0.748         |0.452         |

|Dataset         |Model            |Splitting   |Train score/MAE(kcal/mol)|Valid score/MAE(kcal/mol)|
@@ -385,7 +388,7 @@ Number of tasks and examples in the datasets
|chembl(5thresh) |691        |23871      |
|qm7             |1          |7165       |
|qm7b            |14         |7211       |

|qm9             |15         |133885     |


Time needed for benchmark test(~20h in total)
@@ -442,6 +445,7 @@ Time needed for benchmark test(~20h in total)
|chembl          |MT-NN regression    |200             |9000           |
|qm7             |MT-NN regression    |10              |400            |
|qm7b            |MT-NN regression    |10              |600            |
|qm9             |MT-NN regression    |220             |10000          |
|kaggle          |MT-NN regression    |2200            |3200           |


+16 −12
Original line number Diff line number Diff line
@@ -19,7 +19,7 @@ Giving regression performances of:
    Random forest(rf_regression),
    Graph convolution regression(graphconvreg)
on datasets: delaney(ESOL), nci, kaggle, pdbbind, 
             qm7, qm7b, chembl, sampl(FreeSolv)
             qm7, qm7b, qm9, chembl, sampl(FreeSolv)

time estimation listed in README file

@@ -52,6 +52,7 @@ from delaney.delaney_datasets import load_delaney
from pdbbind.pdbbind_datasets import load_pdbbind_grid
from chembl.chembl_datasets import load_chembl
from qm7.qm7_datasets import load_qm7_from_mat, load_qm7b_from_mat
from qm9.qm9_datasets import load_qm9
from sampl.sampl_datasets import load_sampl
from clintox.clintox_datasets import load_clintox
from hiv.hiv_datasets import load_hiv
@@ -74,7 +75,7 @@ def benchmark_loading_datasets(hyper_parameters,
  dataset: string, optional (default='tox21')
      choice of which dataset to use, should be: tox21, muv, sider, 
      toxcast, pcba, delaney, kaggle, nci, clintox, hiv, pdbbind, chembl,
      qm7, qm7b, sampl
      qm7, qm7b, qm9, sampl
  model: string,  optional (default='tf')
      choice of which model to use, should be: rf, tf, tf_robust, logreg,
      irv, graphconv, tf_regression, rf_regression, graphconvreg
@@ -87,7 +88,8 @@ def benchmark_loading_datasets(hyper_parameters,
  if dataset in ['muv', 'pcba', 'tox21', 'sider', 'toxcast', 'clintox', 'hiv']:
    mode = 'classification'
  elif dataset in [
      'kaggle', 'delaney', 'nci', 'pdbbind', 'chembl', 'qm7', 'qm7b', 'sampl'
      'kaggle', 'delaney', 'nci', 'pdbbind', 'chembl', 'qm7', 'qm7b', 'qm9',
      'sampl'
  ]:
    mode = 'regression'
  else:
@@ -122,18 +124,18 @@ def benchmark_loading_datasets(hyper_parameters,
    if not model in ['tf_regression', 'rf_regression']:
      return

  if dataset in ['qm7']:
    featurizer = None  # qm7 is already featurized
  if dataset in ['qm7', 'qm7b', 'qm9']:
    featurizer = None  # qm7, qm7b, qm9 is already featurized
    if split in ['scaffold', 'butina']:
      return  # qm7 accept index and random splitter
      return  # qm7, qm7b, qm9 accept index and random splitter
    if not model in ['tf_regression']:
      return

  if split in ['year']:
    if not dataset in ['chembl']:
      return
  if split in ['indice', 'stratified']:
    if not dataset in ['qm7', 'qm7b']:
  if split in ['stratified']:
    if not dataset in ['qm7', 'qm7b', 'qm9']:
      return
  elif not split in [None, 'index', 'random', 'scaffold', 'butina']:
    raise ValueError('Splitter function not supported')
@@ -151,6 +153,7 @@ def benchmark_loading_datasets(hyper_parameters,
      'chembl': load_chembl,
      'qm7': load_qm7_from_mat,
      'qm7b': load_qm7b_from_mat,
      'qm9': load_qm9,
      'sampl': load_sampl,
      'clintox': load_clintox,
      'hiv': load_hiv
@@ -171,14 +174,14 @@ def benchmark_loading_datasets(hyper_parameters,

  train_dataset, valid_dataset, test_dataset = all_dataset
  fit_transformers = []
  if dataset in ['qm7', 'qm7b']:
  if dataset in ['qm7', 'qm7b', 'qm9']:
    fit_transformers = [dc.trans.CoulombFitTransformer(train_dataset)]

  time_finish_loading = time.time()
  # time_finish_loading-time_start is the time(s) used for dataset loading
  if dataset in ['kaggle', 'pdbbind']:
    n_features = train_dataset.get_data_shape()[0]
  elif dataset in ['qm7', 'qm7b']:
  elif dataset in ['qm7', 'qm7b', 'qm9']:
    n_features = list(train_dataset.get_data_shape())
    # dataset has customized features

@@ -788,7 +791,8 @@ if __name__ == '__main__':
      dest='dataset_args',
      default=[],
      help='Choice of dataset: tox21, sider, muv, toxcast, pcba, ' +
      'kaggle, delaney, nci, pdbbind, chembl, sampl, qm7, qm7b, clintox, hiv')
      'kaggle, delaney, nci, pdbbind, chembl, sampl, qm7, qm7b, qm9, clintox, hiv'
  )
  parser.add_argument(
      '-t',
      action='store_true',
@@ -814,7 +818,7 @@ if __name__ == '__main__':
  if len(datasets) == 0:
    datasets = [
        'tox21', 'sider', 'muv', 'toxcast', 'pcba', 'clintox', 'hiv', 'sampl',
        'delaney', 'nci', 'kaggle', 'pdbbind', 'chembl', 'qm7b'
        'delaney', 'nci', 'kaggle', 'pdbbind', 'chembl', 'qm7b', 'qm9'
    ]

  #input hyperparameters
+0 −0

Empty file added.

+3 −0
Original line number Diff line number Diff line
@@ -17,6 +17,9 @@ def load_qm9(featurizer=None, split='random'):
  print("About to featurize qm9 dataset.")
  current_dir = os.path.dirname(os.path.realpath(__file__))
  dataset_file = os.path.join(current_dir, "./gdb9.sdf")
  if not os.path.exists(dataset_file):
    os.system('sh ' + current_dir + '/get_qm9.sh')

  qm9_tasks = [
      "A", "B", "C", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "cv",
      "u0_atom", "u298_atom", "h298_atom", "g298_atom"