Commit 145ba9a1 authored by miaecle's avatar miaecle
Browse files

molnet metrics change

parent dc687afd
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -72,7 +72,7 @@ def pearson_r2_score(y, y_pred):
  return pearsonr(y, y_pred)[0]**2


def auPR_score(y, y_pred):
def prc_auc_score(y, y_pred):
  """Compute area under precision-recall curve"""
  assert y_pred.shape == y.shape
  assert y_pred.shape[1] == 2
@@ -158,7 +158,7 @@ class Metric(object):
      if self.metric.__name__ in [
          "roc_auc_score", "matthews_corrcoef", "recall_score",
          "accuracy_score", "kappa_score", "precision_score",
          "balanced_accuracy_score", "auPR_score"
          "balanced_accuracy_score", "prc_auc_score"
      ]:
        mode = "classification"
      elif self.metric.__name__ in [
@@ -277,7 +277,7 @@ class Metric(object):
      # TODO(rbharath): This has been a major source of bugs. Is there a more
      # robust characterization of which metrics require class-probs and which
      # don't?
      if "roc_auc_score" in self.name or "auPR_score" in self.name:
      if "roc_auc_score" in self.name or "prc_auc_score" in self.name:
        y_true = to_one_hot(y_true).astype(int)
        y_pred = np.reshape(y_pred, (n_samples, n_classes))
      else:
+12 −17
Original line number Diff line number Diff line
@@ -69,26 +69,22 @@ def run_benchmark(datasets,
    ]:
      mode = 'classification'
      if metric == None:
        metric = str('auc')
        metric = [
            deepchem.metrics.Metric(deepchem.metrics.roc_auc_score, np.mean),
            deepchem.metrics.Metric(deepchem.metrics.prc_auc_score, np.mean)
        ]
    elif dataset in [
        'bace_r', 'chembl', 'clearance', 'delaney', 'hopv', 'kaggle', 'lipo',
        'nci', 'pdbbind', 'ppb', 'qm7', 'qm7b', 'qm8', 'qm9', 'sampl'
    ]:
      mode = 'regression'
      if metric == None:
        metric = str('r2')
        metric = [
            deepchem.metrics.Metric(deepchem.metrics.pearson_r2_score, np.mean)
        ]
    else:
      raise ValueError('Dataset not supported')

    metric_all = {
        'auc': deepchem.metrics.Metric(deepchem.metrics.roc_auc_score, np.mean),
        'r2': deepchem.metrics.Metric(deepchem.metrics.pearson_r2_score,
                                      np.mean)
    }

    if isinstance(metric, str):
      metric = [metric_all[metric]]

    if featurizer == None and isinstance(model, str):
      # Assigning featurizer if not user defined
      pair = (dataset, model)
@@ -188,15 +184,14 @@ def run_benchmark(datasets,

    with open(os.path.join(out_path, 'results.csv'), 'a') as f:
      writer = csv.writer(f)
      for i in train_score:
      model_name = list(train_score.keys())[0]
      for i in train_score[model_name]:
        output_line = [
            dataset, str(split), mode, 'train', i,
            train_score[i][list(train_score[i].keys())[0]], 'valid', i,
            valid_score[i][list(valid_score[i].keys())[0]]
            dataset, str(split), mode, model_name, i, 'train',
            train_score[model_name][i], 'valid', valid_score[model_name][i]
        ]
        if test:
          output_line.extend(
              ['test', i, test_score[i][list(test_score[i].keys())[0]]])
          output_line.extend(['test', test_score[model_name][i]])
        output_line.extend(
            ['time_for_running', time_finish_fitting - time_start_fitting])
        writer.writerow(output_line)
+19 −8
Original line number Diff line number Diff line
@@ -29,14 +29,19 @@ class TestMolnet(unittest.TestCase):
    model = 'graphconvreg'
    split = 'random'
    out_path = tempfile.mkdtemp()
    metric = [dc.metrics.Metric(dc.metrics.pearson_r2_score, np.mean)]
    dc.molnet.run_benchmark(
        datasets, str(model), split=split, out_path=out_path, reload=False)
        datasets,
        str(model),
        metric=metric,
        split=split,
        out_path=out_path,
        reload=False)
    with open(os.path.join(out_path, 'results.csv'), 'r') as f:
      reader = csv.reader(f)
      for lastrow in reader:
        pass
      assert lastrow[-4] == model
      assert lastrow[-5] == 'valid'
      assert lastrow[-4] == 'valid'
      assert float(lastrow[-3]) > 0.75
    os.remove(os.path.join(out_path, 'results.csv'))

@@ -46,14 +51,19 @@ class TestMolnet(unittest.TestCase):
    model = 'tf_regression_ft'
    split = 'random'
    out_path = tempfile.mkdtemp()
    metric = [dc.metrics.Metric(dc.metrics.pearson_r2_score, np.mean)]
    dc.molnet.run_benchmark(
        datasets, str(model), split=split, out_path=out_path, reload=False)
        datasets,
        str(model),
        metric=metric,
        split=split,
        out_path=out_path,
        reload=False)
    with open(os.path.join(out_path, 'results.csv'), 'r') as f:
      reader = csv.reader(f)
      for lastrow in reader:
        pass
      assert lastrow[-4] == model
      assert lastrow[-5] == 'valid'
      assert lastrow[-4] == 'valid'
      assert float(lastrow[-3]) > 0.95
    os.remove(os.path.join(out_path, 'results.csv'))

@@ -63,9 +73,11 @@ class TestMolnet(unittest.TestCase):
    model = 'tf'
    split = 'random'
    out_path = tempfile.mkdtemp()
    metric = [dc.metrics.Metric(dc.metrics.roc_auc_score, np.mean)]
    dc.molnet.run_benchmark(
        datasets,
        str(model),
        metric=metric,
        split=split,
        out_path=out_path,
        test=True,
@@ -74,7 +86,6 @@ class TestMolnet(unittest.TestCase):
      reader = csv.reader(f)
      for lastrow in reader:
        pass
      assert lastrow[-4] == model
      assert lastrow[-5] == 'test'
      assert lastrow[-4] == 'test'
      assert float(lastrow[-3]) > 0.7
    os.remove(os.path.join(out_path, 'results.csv'))