Commit 03bca48f authored by Peter Eastman's avatar Peter Eastman
Browse files

Optimized one hot encoding

parent d77ccf9b
Loading
Loading
Loading
Loading
+1 −5
Original line number Diff line number Diff line
@@ -23,11 +23,7 @@ def to_one_hot(y):
  """
  n_samples = np.shape(y)[0]
  y_hot = np.zeros((n_samples, 2))
  for index, val in enumerate(y):
    if val == 0:
      y_hot[index] = np.array([1, 0])
    elif val == 1:
      y_hot[index] = np.array([0, 1])
  y_hot[np.arange(n_samples), y.astype(np.int64)] = 1
  return y_hot

def from_one_hot(y, axis=1):
+9 −1
Original line number Diff line number Diff line
@@ -38,6 +38,14 @@ class MetricsTest(googletest.TestCase):
    assert np.isclose(dc.metrics.r2_score(y_true, y_pred),
                      regression_metric.compute_metric(y_true, y_pred))

  def test_one_hot(self):
    y = np.array([0, 0, 1, 0, 1, 1, 0])
    y_hot = metrics.to_one_hot(y)
    expected = np.array([[1,0], [1,0], [0,1], [1,0], [0,1], [0,1], [1,0]])
    yp = metrics.from_one_hot(y_hot)
    assert np.array_equal(expected, y_hot)
    assert np.array_equal(y, yp)


if __name__ == '__main__':
  googletest.main()