Commit d5fb7d1a authored by Bharath's avatar Bharath
Browse files

Fixes to metrics/datasets

parent 925d223e
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -332,6 +332,12 @@ class Dataset(object):
        os.path.join(self.data_dir, row['ids'])), dtype=object)
    return (X, y, w, ids)

  def set_shard(self, shard_num, X, y, w, ids):
    """Writes data shard to disk"""
    basename = "shard-%d" % shard_num 
    tasks = self.get_task_names()
    Dataset.write_data_to_disk(self.data_dir, basename, tasks, X, y, w, ids)

  def set_verbosity(self, new_verbosity):
    """Sets verbosity."""
    self.verbosity = new_verbosity
+6 −1
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@ from sklearn.metrics import accuracy_score
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from scipy.stats import pearsonr

def to_one_hot(y):
  """Transforms label vector into one-hot encoding.
@@ -53,6 +54,10 @@ def compute_roc_auc_scores(y, y_pred):
    score = 0.5
  return score

def pearson_r2_score(y, y_pred):
  """Computes Pearson R^2 (square of Pearson correlation)."""
  return pearsonr(y, y_pred)[0]**2

def rms_score(y_true, y_pred):
  """Computes RMS error."""
  return np.sqrt(mean_squared_error(y_true, y_pred))
@@ -123,7 +128,7 @@ class Metric(object):
      if self.name in ["roc_auc_score", "matthews_corrcoef", "recall_score",
                       "accuracy_score", "kappa_score"]:
        mode = "classification"
      elif self.name in ["r2_score", "mean_squared_error",
      elif self.name in ["pearson_r2_score", "r2_score", "mean_squared_error",
                         "mean_absolute_error", "rms_score",
                         "mae_score"]:
        mode = "regression"