Unverified Commit b4cdf474 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #1857 from zealseeker/master

Fix bug in metric computing
parents 3467db1e a44decc0
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -355,6 +355,8 @@ class Metric(object):
    # If there are no nonzero examples, metric is ill-defined.
    if not y_true.size:
      return np.nan
    if self.threshold is not None and len(y_pred.shape) == 1:
      y_pred = np.expand_dims(y_pred, 0)
    if self.threshold is not None:
      y_pred = y_pred[:, 1]
      y_pred = np.greater(y_pred, self.threshold)
+18 −0
Original line number Diff line number Diff line
@@ -23,6 +23,24 @@ class MetricsTest(googletest.TestCase):
                                    1.0 - expected_agreement)
    self.assertAlmostEqual(kappa, expected_kappa)

  def test_one_sample(self):
    """Test that the metrics won't raise error even in an extreme condition
    where there is only one sample with w > 0.
    """
    np.random.seed(123)
    n_samples = 2
    y_true = np.array([0, 0])
    y_pred = np.random.rand(n_samples, 2)
    w = np.array([0, 1])
    all_metrics = [
        dc.metrics.Metric(dc.metrics.recall_score),
        dc.metrics.Metric(dc.metrics.matthews_corrcoef),
        dc.metrics.Metric(dc.metrics.roc_auc_score)
    ]
    for metric in all_metrics:
      score = metric.compute_singletask_metric(y_true, y_pred, w)
      self.assertTrue(np.isnan(score) or score == 0)

  def test_r2_score(self):
    """Test that R^2 metric passes basic sanity tests"""
    np.random.seed(123)