Commit dc687afd authored by unknown's avatar unknown
Browse files

fix auPR

parent 35d9e3f1
Loading
Loading
Loading
Loading
+3 −6
Original line number Diff line number Diff line
@@ -75,12 +75,9 @@ def pearson_r2_score(y, y_pred):
def auPR_score(y, y_pred):
  """Compute area under precision-recall curve"""
  assert y_pred.shape == y.shape
  n_classes = y_pred.shape[1]
  scores = []
  for i in range(n_classes):
    precision, recall, _ = precision_recall_curve(y[:, i], y_pred[:, i])
    scores.append(auc(recall, precision))
  return np.mean(scores)
  assert y_pred.shape[1] == 2
  precision, recall, _ = precision_recall_curve(y[:, 1], y_pred[:, 1])
  return auc(recall, precision)


def rms_score(y_true, y_pred):