Unverified Commit 361207f8 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Merge pull request #2463 from peastman/pearson

Fixed error in computing Pearson correlation coefficient
parents a3c3874c 80ba9072
Loading
Loading
Loading
Loading
+20 −2
Original line number Diff line number Diff line
"""Evaluation metrics."""

import numpy as np
import scipy.stats
from sklearn.metrics import matthews_corrcoef  # noqa
from sklearn.metrics import recall_score  # noqa
from sklearn.metrics import cohen_kappa_score
@@ -15,12 +16,29 @@ from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score  # noqa
from sklearn.metrics import accuracy_score  # noqa
from sklearn.metrics import balanced_accuracy_score  # noqa
from scipy.stats import pearsonr

# kappa_score is an alias for `sklearn.metrics.cohen_kappa_score`
kappa_score = cohen_kappa_score


def pearsonr(y: np.ndarray, y_pred: np.ndarray) -> float:
  """Computes Pearson correlation coefficient.

  Parameters
  ----------
  y: np.ndarray
    ground truth array
  y_pred: np.ndarray
    predicted array

  Returns
  -------
  float
    The Pearson correlation coefficient.
  """
  return scipy.stats.pearsonr(y, y_pred)[0]


def pearson_r2_score(y: np.ndarray, y_pred: np.ndarray) -> float:
  """Computes Pearson R^2 (square of Pearson correlation).

@@ -36,7 +54,7 @@ def pearson_r2_score(y: np.ndarray, y_pred: np.ndarray) -> float:
  float
    The Pearson-R^2 score.
  """
  return pearsonr(y, y_pred)[0]**2
  return scipy.stats.pearsonr(y, y_pred)[0]**2


def jaccard_index(y: np.ndarray, y_pred: np.ndarray) -> float:
+14 −0
Original line number Diff line number Diff line
@@ -34,6 +34,20 @@ def test_one_sample():
    _ = metric.compute_singletask_metric(y_true, y_pred, w)


def test_pearsonr():
  """Test the Pearson correlation coefficient is correct."""
  metric = dc.metrics.Metric(dc.metrics.pearsonr)
  r = metric.compute_metric(
      np.array([1.0, 2.0, 3.0]), np.array([2.0, 3.0, 4.0]))
  np.testing.assert_almost_equal(1.0, r)
  r = metric.compute_metric(
      np.array([1.0, 2.0, 3.0]), np.array([-2.0, -3.0, -4.0]))
  np.testing.assert_almost_equal(-1.0, r)
  r = metric.compute_metric(
      np.array([1.0, 2.0, 3.0, 4.0]), np.array([1.0, 2.0, 2.0, 1.0]))
  np.testing.assert_almost_equal(0.0, r)


def test_r2_score():
  """Test that R^2 metric passes basic sanity tests"""
  np.random.seed(123)