Commit 9ee46f4e authored by Hongbin Yang's avatar Hongbin Yang
Browse files

Fix bug in metric

When sample number is 0, `y_pred` will be squeezed and `y_pred[:, 1]` will raise error.
parent 05e7b07c
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -355,6 +355,8 @@ class Metric(object):
    # If there are no nonzero examples, metric is ill-defined.
    if not y_true.size:
      return np.nan
    if self.threshold is not None and len(y_pred.shape) == 1:
      y_pred = np.expand_dims(y_pred, 0)
    if self.threshold is not None:
      y_pred = y_pred[:, 1]
      y_pred = np.greater(y_pred, self.threshold)