Commit d82fd707 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #485 from XericZephyr/new_bm_design

Provide prototype to benchmark outside models.
parents 20862551 108afce3
Loading
Loading
Loading
Loading
+90 −2
Original line number Original line Diff line number Diff line
@@ -79,8 +79,8 @@ def run_benchmark(datasets,


    metric_all = {
    metric_all = {
        'auc': deepchem.metrics.Metric(deepchem.metrics.roc_auc_score, np.mean),
        'auc': deepchem.metrics.Metric(deepchem.metrics.roc_auc_score, np.mean),
        'r2':
        'r2': deepchem.metrics.Metric(deepchem.metrics.pearson_r2_score,
        deepchem.metrics.Metric(deepchem.metrics.pearson_r2_score, np.mean)
                                      np.mean)
    }
    }


    if isinstance(metric, str):
    if isinstance(metric, str):
@@ -197,3 +197,91 @@ def run_benchmark(datasets,
        output_line.extend(
        output_line.extend(
            ['time_for_running', time_finish_fitting - time_start_fitting])
            ['time_for_running', time_finish_fitting - time_start_fitting])
        writer.writerow(output_line)
        writer.writerow(output_line)


#
# Note by @XericZephyr. Reason why I spun off this function:
#   1. Some model needs dataset information.
#   2. It offers us possibility to **cache** the dataset
#      if the featurizer runs very slow, e.g., GraphConv.
#   2+. The cache can even happen at Travis CI to accelerate
#       CI testing.
#
def load_dataset(dataset, featurizer, split='random'):
  """
  Load specific dataset for benchmark.
  
  Parameters
  ----------
  dataset: string
      choice of which datasets to use, should be: tox21, muv, sider, 
      toxcast, pcba, delaney, kaggle, nci, clintox, hiv, pdbbind, chembl,
      qm7, qm7b, qm9, sampl
  featurizer: string or dc.feat.Featurizer.
      choice of featurization.
  split: string,  optional (default=None)
      choice of splitter function, None = using the default splitter
  """
  dataset_loading_functions = {
      'tox21': deepchem.molnet.load_tox21,
      'muv': deepchem.molnet.load_muv,
      'pcba': deepchem.molnet.load_pcba,
      'nci': deepchem.molnet.load_nci,
      'sider': deepchem.molnet.load_sider,
      'toxcast': deepchem.molnet.load_toxcast,
      'kaggle': deepchem.molnet.load_kaggle,
      'delaney': deepchem.molnet.load_delaney,
      'pdbbind': deepchem.molnet.load_pdbbind_grid,
      'chembl': deepchem.molnet.load_chembl,
      'qm7': deepchem.molnet.load_qm7_from_mat,
      'qm7b': deepchem.molnet.load_qm7b_from_mat,
      'qm9': deepchem.molnet.load_qm9,
      'sampl': deepchem.molnet.load_sampl,
      'clintox': deepchem.molnet.load_clintox,
      'hiv': deepchem.molnet.load_hiv
  }
  print('-------------------------------------')
  print('Loading dataset: %s' % dataset)
  print('-------------------------------------')
  # loading datasets
  if split is not None:
    print('Splitting function: %s' % split)
  tasks, all_dataset, transformers = dataset_loading_functions[dataset](
      featurizer=featurizer, split=split)
  return tasks, all_dataset, transformers


def benchmark_model(model, all_dataset, transformers, metric, test=False):
  """
  Benchmark custom model.

  model: user-defined model stucture
    For user define model, it should include function: fit, evaluate.

  all_dataset: (train, test, val) data tuple.
    Returned by `load_dataset` function.

  transformers

  metric: string
    choice of evaluation metrics.


  """
  time_start_fitting = time.time()
  train_score = .0
  valid_score = .0
  test_score = .0

  train_dataset, valid_dataset, test_dataset = all_dataset

  model.fit(train_dataset)
  train_score = model.evaluate(train_dataset, metric, transformers)
  valid_score = model.evaluate(valid_dataset, metric, transformers)
  if test:
    test_score = model.evaluate(test_dataset, metric, transformers)

  time_finish_fitting = time.time()
  time_for_running = time_finish_fitting - time_start_fitting

  return train_score, valid_score, test_score, time_for_running