Commit 81fda40d authored by miaecle's avatar miaecle
Browse files

little modification to low data benchmark

parent 9f6a0048
Loading
Loading
Loading
Loading
+11 −10
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@ import deepchem as dc
import tensorflow as tf
import argparse
from keras import backend as K
import csv

from low_data.datasets import load_tox21_convmol
from low_data.datasets import load_muv_convmol
@@ -107,16 +108,16 @@ def low_data_benchmark_loading_datasets(hyper_parameters, cross_valid=False,
                         train_dataset, valid_dataset, hp, n_feat,
                         model=model)
      time_finish_fitting = time.time() 
      with open(os.path.join(out_path, 'results.csv'),'a') as f:
        f.write('\n'+str(count_hp)+','+str(count_iter)+',')
        f.write(dataset+','+model+',')
        f.write('valid,')
      with open(os.path.join(out_path, 'results.csv'),'ab') as f:
        writer = csv.writer(f)
        output_line = [count_hp, count_iter, dataset, model, 'valid']
        for i in valid_scores:
          f.write(str(i)+',')
          output_line.append(i)
          for count in valid_scores[i]:
            f.write(str(valid_scores[i][count])+',')
        f.write('time_for_running,'+
              str(time_finish_fitting-time_start_fitting)+',')
            output_line.append(valid_scores[i][count])
        output_line.append('time_for_running')
        output_line.append(time_finish_fitting-time_start_fitting)
        writer.writerow(output_line)

  return None

@@ -211,7 +212,7 @@ def low_data_benchmark_classification(train_dataset, valid_dataset,
          support_batch_size=support_batch_size, learning_rate=learning_rate)
        
      print('-------------------------------------')
      print('Start fitting by graph convolution')
      print('Start fitting by low data model: ' + model)
      # Fit trained model
      model_low_data.fit(train_dataset, nb_epochs=nb_epochs,
            n_episodes_per_epoch=n_train_trials,
@@ -259,7 +260,7 @@ if __name__ == '__main__':
  #    batch_size
  hps = {}
  hps = {}
  hps['siamese'] = [{'K': 4, 'n_feat': 71, 'n_pos': 1, 'n_neg': 1,
  hps['siamese'] = [{'K': 4, 'n_feat': 75, 'n_pos': 1, 'n_neg': 1,
                     'test_batch_size': 128, 'n_filters': [64, 128, 64],
                     'n_fully_connected_nodes': [128], 'max_depth': 3,
                     'nb_epochs': 1, 'n_train_trials': 2000, 
+3 −0
Original line number Diff line number Diff line
@@ -20,6 +20,7 @@
0,pcba,index,classification,train,logreg,0.8087514041,valid,logreg,0.7757913202,time_for_running,11075.2340579
0,pcba,index,classification,train,graphconv,0.87647472,valid,graphconv,0.8523348204,time_for_running,14497.7339029
0,delaney,index,regression,train,tf_regression,0.7830983671,valid,tf_regression,0.5789729655,time_for_running,41.1367759705
0,delaney,index,regression,train,graphconvreg,0.9911206824,valid,graphconvreg,0.7892057714,time_for_running,101.8902909756
0,kaggle,None,regression,train,tf_regression,0.7480423542,valid,tf_regression,0.4516795145,time_for_running,3238.91535401
0,tox21,random,classification,train,tf,0.8565178786,valid,tf,0.7834036936,time_for_running,53.8197240829
0,tox21,random,classification,train,tf_robust,0.8549658589,valid,tf_robust,0.7735497329,time_for_running,88.9351768494
@@ -42,6 +43,7 @@
0,pcba,random,classification,train,logreg,0.8075221378,valid,logreg,0.7758407198,time_for_running,8754.5542872
0,pcba,random,classification,train,graphconv,0.872172184,valid,graphconv,0.8435271472,time_for_running,11502.8221002
0,delaney,random,regression,train,tf_regression,0.7791066217,valid,tf_regression,0.6164873014,time_for_running,35.6433098316
0,delaney,random,regression,train,graphconvreg,0.9951851944,valid,graphconvreg,0.8397307618,time_for_running,102.9403319359
0,tox21,scaffold,classification,train,tf,0.8626085326,valid,tf,0.7030201614,time_for_running,63.5685660839
0,tox21,scaffold,classification,train,tf_robust,0.8608722489,valid,tf_robust,0.7100530015,time_for_running,101.614424944
0,tox21,scaffold,classification,train,logreg,0.9004137009,valid,logreg,0.650190286,time_for_running,60.018599987
@@ -63,3 +65,4 @@
0,pcba,scaffold,classification,train,logreg,0.8099796593,valid,logreg,0.7423270057,time_for_running,9959.83747697
0,pcba,scaffold,classification,train,graphconv,0.8743221913,valid,graphconv,0.8166550236,time_for_running,14184.1512611
0,delaney,scaffold,regression,train,tf_regression,0.7893516465,valid,tf_regression,0.4218847009,time_for_running,35.2720739841
0,delaney,scaffold,regression,train,graphconvreg,0.992822139,valid,graphconvreg,0.5578625785,time_for_running,100.2594189644