Unverified Commit c9fe0176 authored by Karl Leswing's avatar Karl Leswing Committed by GitHub
Browse files

Merge pull request #1425 from lilleswing/fix-classification-metrics

Fix Binary Classification Metrics Metrics
parents c6986172 09df7d78
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -140,7 +140,7 @@ class TestHyperparamOpt(unittest.TestCase):

    transformers = []
    metric = dc.metrics.Metric(
        dc.metrics.matthews_corrcoef, np.mean, mode="classification")
        dc.metrics.roc_auc_score, np.mean, mode="classification")
    params_dict = {"layer_sizes": [(10,), (100,)]}

    def model_builder(model_params, model_dir):
+64 −57
Original line number Diff line number Diff line
@@ -207,6 +207,11 @@ class Metric(object):
      else:
        raise ValueError("Must specify mode for new metric.")
    assert mode in ["classification", "regression"]
    if self.metric.__name__ in [
        "accuracy_score", "balanced_accuracy_score", "recall_score",
        "matthews_corrcoef"
    ] and threshold is None:
      self.threshold = 0.5
    self.mode = mode
    # The convention used is that the first task is the metric.
    # TODO(rbharath, joegomes): This doesn't seem like it should be hard-coded as
@@ -293,6 +298,7 @@ class Metric(object):
    Raises:
      NotImplementedError: If metric_str is not in METRICS.
    """

    y_true = np.array(np.squeeze(y_true[w != 0]))
    y_pred = np.array(np.squeeze(y_pred[w != 0]))

@@ -304,7 +310,8 @@ class Metric(object):
    if not y_true.size:
      return np.nan
    if self.threshold is not None:
      y_pred = np.greater(y_pred, threshold)
      y_pred = y_pred[:, 1]
      y_pred = np.greater(y_pred, self.threshold)
    if len(y_true.shape) == 0:
      y_true = np.expand_dims(y_true, 0)
    if len(y_pred.shape) == 0: