Commit 88138d3c authored by miaecle's avatar miaecle
Browse files

update graph generating scripts

parent fd922115
Loading
Loading
Loading
Loading
+9 −5
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@ import os
import numpy as np
import matplotlib.pyplot as plt
import time
plt.switch_backend('agg')

TODO = {
  ('tox21', 'random'): ['weave', 'graphconv', 'tf', 'tf_robust', 'irv', 'xgb', 'logreg', 'textcnn'],
@@ -16,7 +17,6 @@ TODO = {
  ('sampl', 'random'): ['weave_regression', 'graphconvreg', 'tf_regression', 'xgb_regression', 'krr', 'textcnn_regression', 'dag_regression', 'mpnn'],
  ('lipo', 'random'): ['weave_regression', 'graphconvreg', 'tf_regression', 'xgb_regression', 'krr', 'textcnn_regression', 'dag_regression', 'mpnn'],
  ('qm7', 'stratified'): ['dtnn', 'graphconvreg', 'tf_regression_ft', 'krr_ft'],
  ('qm7b', 'random'): ['dtnn', 'tf_regression_ft', 'krr_ft'],
  ('qm8', 'random'): ['dtnn', 'graphconvreg', 'weave_regression', 'textcnn_regression', 'mpnn', 'tf_regression', 'tf_regression_ft'],
}

@@ -114,8 +114,8 @@ def plot(dataset, split, path, out_path):
    ax.set_xlabel('ROC-AUC')
    ax.set_xlim(left=0.4, right=1.)
  t = time.localtime(time.time())
  
  ax.set_title("Performance on %s (%s split), %i-%i-%i" % (dataset, split, t.tm_year, t.tm_mon, t.tm_mday))
  plt.tight_layout()
  for i in range(len(colors)):
    ax.get_children()[i].set_color(colors[i])
    ax.text(values[i]-0.1, y_pos[i]+0.1, str("%.3f" % values[i]), color='white')
@@ -127,7 +127,11 @@ if __name__ == '__main__':
  current_dir = os.path.dirname(os.path.realpath(__file__))
  DEEPCHEM_DIR = os.path.split(os.path.split(current_dir)[0])[0]
  FILE = os.path.join(os.path.join(DEEPCHEM_DIR, 'examples'), 'results.csv')
  run_benchmark(FILE, DEEPCHEM_DIR)
  #run_benchmark(FILE, DEEPCHEM_DIR)
  save_dir = os.path.join(DEEPCHEM_DIR, 'datasets/MolNet_pic')
  if not os.path.exists(save_dir):
    os.mkdir(save_dir)
  for pair in TODO.keys():
    plot(pair[0], pair[1], FILE, os.environ['DEEPCHEM_DATA_DIR'])
    plot(pair[0], pair[1], FILE, save_dir)
  os.system('aws s3 sync '+save_dir+' s3://deepchem.io/trained_models/MolNet_pic')