Commit 6ff63122 authored by ZHENQIN WU's avatar ZHENQIN WU
Browse files

benchmark modified

parent c825292a
Loading
Loading
Loading
Loading
+7 −7
Original line number Diff line number Diff line
@@ -81,7 +81,7 @@ def benchmark_loading_datasets(base_dir_o, hyper_parameters,
  if model in ['graphconv']:
    featurizer = 'GraphConv'
    n_features = 71
  elif model in ['rf', 'tf', 'tf_robust', 'logreg']:
  elif model in ['tf', 'tf_robust', 'logreg', 'rf']:
    featurizer = 'ECFP'
    n_features = 1024
  else:
@@ -99,13 +99,14 @@ def benchmark_loading_datasets(base_dir_o, hyper_parameters,
  
  for dname in dataset_name:
    print('-------------------------------------')
    print('Benchmark test on dataset: '+dname)
    print('Benchmark %s on dataset: %s' % (model, dname))
    print('-------------------------------------')
    base_dir = os.path.join(base_dir_o, dname)
    
    time_start = time.time()
    #loading datasets     
    tasks,datasets,transformers = loading_functions[dname](featurizer=featurizer)
    tasks,datasets,transformers = loading_functions[dname](
        featurizer=featurizer)
    train_dataset, valid_dataset, test_dataset = datasets
    time_finish_loading = time.time()
    #time_finish_loading-time_start is the time(s) used for dataset loading
@@ -140,7 +141,7 @@ def benchmark_loading_datasets(base_dir_o, hyper_parameters,

def benchmark_train_and_valid(base_dir, train_dataset, valid_dataset, tasks,
                              transformers, hyper_parameters, 
                              n_features, model='tf',
                              n_features, model='tf', seed=123, 
                              verbosity='high'):
  """
  Calculate performance of different models on the specific dataset & tasks
@@ -305,7 +306,6 @@ def benchmark_train_and_valid(base_dir, train_dataset, valid_dataset, tasks,
    K.set_session(sess)
    # Building graph convolution model
    with g.as_default():
      if seed is not None:
      tf.set_random_seed(seed)
      graph_model = dc.nn.SequentialGraph(n_features)
      graph_model.add(dc.nn.GraphConv(int(n_filters), activation='relu'))