Unverified Commit 17e58ef6 authored by Karl Leswing's avatar Karl Leswing Committed by GitHub
Browse files

Merge pull request #1429 from VIGS25/bedroc_metric

[WIP] #412: BEDROC metric
parents 1124ab6e 87f62e8d
Loading
Loading
Loading
Loading
+41 −0
Original line number Diff line number Diff line
@@ -161,6 +161,47 @@ def kappa_score(y_true, y_pred):
  return kappa


def bedroc_score(y_true, y_pred, alpha=20.0):
  """BEDROC metric implemented according to Truchon and Bayley that modifies
  the ROC score by allowing for a factor of early recognition

    References:
      The original paper by Truchon et al. is located at
      https://pubs.acs.org/doi/pdf/10.1021/ci600426e

    Args:
      y_true (array_like):
        Binary class labels. 1 for positive class, 0 otherwise
      y_pred (array_like):
        Predicted labels
      alpha (float), default 20.0:
        Early recognition parameter

    Returns:
      float: Value in [0, 1] that indicates the degree of early recognition

  """

  assert len(y_true) == len(y_pred), 'Number of examples do not match'

  assert np.array_equal(
      np.unique(y_true).astype(int),
      [0, 1]), ('Class labels must be binary: %s' % np.unique(y_true))

  from rdkit.ML.Scoring.Scoring import CalcBEDROC

  yt = np.asarray(y_true)
  yp = np.asarray(y_pred)

  yt = yt.flatten()
  yp = yp[:, 1].flatten()  # Index 1 because one_hot predictions

  scores = list(zip(yt, yp))
  scores = sorted(scores, key=lambda pair: pair[1], reverse=True)

  return CalcBEDROC(scores, 0, alpha)


class Metric(object):
  """Wrapper class for computing user-defined metrics."""

+24 −1
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ class MetricsTest(googletest.TestCase):
    expected_agreement = ((2 * 1) + (2 * 3)) / 4.0**2
    expected_kappa = np.true_divide(observed_agreement - expected_agreement,
                                    1.0 - expected_agreement)
    self.assertAlmostEquals(kappa, expected_kappa)
    self.assertAlmostEqual(kappa, expected_kappa)

  def test_r2_score(self):
    """Test that R^2 metric passes basic sanity tests"""
@@ -46,6 +46,29 @@ class MetricsTest(googletest.TestCase):
    assert np.array_equal(expected, y_hot)
    assert np.array_equal(y, yp)

  def test_bedroc_score(self):

    num_actives = 20
    num_total = 400

    y_true_actives = np.ones(num_actives)
    y_true_inactives = np.zeros(num_total - num_actives)
    y_true = np.concatenate([y_true_actives, y_true_inactives])

    # Best score case
    y_pred_best = dc.metrics.to_one_hot(
        np.concatenate([y_true_actives, y_true_inactives]))
    best_score = dc.metrics.bedroc_score(y_true, y_pred_best)
    self.assertAlmostEqual(best_score, 1.0)

    # Worst score case
    worst_pred_actives = np.zeros(num_actives)
    worst_pred_inactives = np.ones(num_total - num_actives)
    y_pred_worst = dc.metrics.to_one_hot(
        np.concatenate([worst_pred_actives, worst_pred_inactives]))
    worst_score = dc.metrics.bedroc_score(y_true, y_pred_worst)
    self.assertAlmostEqual(worst_score, 0.0, 4)


if __name__ == '__main__':
  googletest.main()