Commit fa9c0b8b authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #516 from rbharath/balanced_accuracy

Adds balanced accuracy score
parents 82658299 28978f77
Loading
Loading
Loading
Loading
+14 −3
Original line number Diff line number Diff line
@@ -55,6 +55,16 @@ def compute_roc_auc_scores(y, y_pred):
  return score


def balanced_accuracy_score(y, y_pred):
  """Computes balanced accuracy score."""
  num_positive = float(np.count_nonzero(y))
  num_negative = float(len(y) - num_positive)
  pos_weight = num_negative / num_positive
  weights = np.ones_like(y)
  weights[y != 0] = pos_weight
  return accuracy_score(y, y_pred, sample_weight=weights)


def pearson_r2_score(y, y_pred):
  """Computes Pearson R^2 (square of Pearson correlation)."""
  return pearsonr(y, y_pred)[0]**2
@@ -137,7 +147,8 @@ class Metric(object):
    if mode is None:
      if self.metric.__name__ in [
          "roc_auc_score", "matthews_corrcoef", "recall_score",
          "accuracy_score", "kappa_score", "precision_score"
          "accuracy_score", "kappa_score", "precision_score",
          "balanced_accuracy_score"
      ]:
        mode = "classification"
      elif self.metric.__name__ in [
@@ -218,8 +229,8 @@ class Metric(object):
      if self.compute_energy_metric:
        # TODO(rbharath, joegomes): What is this magic number?
        force_error = self.task_averager(computed_metrics[1:]) * 4961.47596096
        print("Force error (metric: np.mean(%s)): %f kJ/mol/A" %
              (self.name, force_error))
        print("Force error (metric: np.mean(%s)): %f kJ/mol/A" % (self.name,
                                                                  force_error))
        return computed_metrics[0]
      elif not per_task_metrics:
        return self.task_averager(computed_metrics)