Commit 545b3701 authored by peastman's avatar peastman
Browse files

Fixed more assumptions about array shapes

parent b75985db
Loading
Loading
Loading
Loading
+5 −2
Original line number Diff line number Diff line
@@ -56,7 +56,9 @@ def _ensure_class_labels(y):

def roc_auc_score(y, y_pred):
  """Area under the receiver operating characteristic curve."""
  return sklearn.metrics.roc_auc_score(_ensure_one_hot(y), y_pred)
  if y.shape != y_pred.shape:
    y = _ensure_one_hot(y)
  return sklearn.metrics.roc_auc_score(y, y_pred)


def accuracy_score(y, y_pred):
@@ -106,6 +108,7 @@ def pixel_error(y, y_pred):

def prc_auc_score(y, y_pred):
  """Compute area under precision-recall curve"""
  if y.shape != y_pred.shape:
    y = _ensure_one_hot(y)
  assert y_pred.shape == y.shape
  assert y_pred.shape[1] == 2