Commit 1fadab6e authored by peastman's avatar peastman
Browse files

More fixes to metric code

parent 1eebea74
Loading
Loading
Loading
Loading
+24 −34
Original line number Diff line number Diff line
@@ -3,10 +3,9 @@
import numpy as np
import warnings
from deepchem.utils.save import log
from sklearn.metrics import roc_auc_score
import sklearn.metrics
from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import recall_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
@@ -39,22 +38,29 @@ def from_one_hot(y, axis=1):
  return np.argmax(y, axis=axis)


def compute_roc_auc_scores(y, y_pred):
  """Transforms the results dict into roc-auc-scores and prints scores.
def _ensure_one_hot(y):
  """If neceessary, convert class labels to one-hot encoding."""
  if len(y.shape) == 1:
    return to_one_hot(y)
  return y

  Parameters
  ----------
  results: dict
  task_types: dict
    dict mapping task names to output type. Each output type must be either
    "classification" or "regression".
  """
  try:
    score = roc_auc_score(y, y_pred)
  except ValueError:
    warnings.warn("ROC AUC score calculation failed.")
    score = 0.5
  return score

def _ensure_class_labels(y):
  """If necessary, convert one-hot encoding to class labels."""
  if len(y.shape) == 2:
    return from_one_hot(y)
  return y


def roc_auc_score(y, y_pred):
  """Area under the receiver operating characteristic curve."""
  return sklearn.metrics.roc_auc_score(_ensure_one_hot(y), y_pred)


def accuracy_score(y, y_pred):
  y = _ensure_class_labels(y)
  y_pred = _ensure_class_labels(y_pred)
  return sklearn.metrics.accuracy_score(y, y_pred)


def balanced_accuracy_score(y, y_pred):
@@ -74,6 +80,7 @@ def pearson_r2_score(y, y_pred):

def prc_auc_score(y, y_pred):
  """Compute area under precision-recall curve"""
  y = _ensure_one_hot(y)
  assert y_pred.shape == y.shape
  assert y_pred.shape[1] == 2
  precision, recall, _ = precision_recall_curve(y[:, 1], y_pred[:, 1])
@@ -271,23 +278,6 @@ class Metric(object):
    # If there are no nonzero examples, metric is ill-defined.
    if not y_true.size:
      return np.nan

    if self.mode == "classification":
      n_classes = y_pred.shape[-1]
      # TODO(rbharath): This has been a major source of bugs. Is there a more
      # robust characterization of which metrics require class-probs and which
      # don't?
      if "roc_auc_score" in self.name or "prc_auc_score" in self.name:
        y_true = to_one_hot(y_true).astype(int)
        y_pred = np.reshape(y_pred, (n_samples, n_classes))
      else:
        y_true = y_true.astype(int)
        # Reshape to handle 1-d edge cases
        y_pred = np.reshape(y_pred, (n_samples, n_classes))
        y_pred = from_one_hot(y_pred)
    else:
      y_pred = np.reshape(y_pred, (n_samples,))

    if self.threshold is not None:
      y_pred = np.greater(y_pred, threshold)
    try: