Commit 326a23b5 authored by peastman's avatar peastman
Browse files

Added test case for fix to array shape handling

parent 14e7e8d3
Loading
Loading
Loading
Loading
+14 −0
Original line number Diff line number Diff line
@@ -172,3 +172,17 @@ def test_transform_to_directory():
                             np.zeros_like(w_task[w_orig_task == 0]))
  # Check that sum of 0s equals sum of 1s in transformed for each task
  assert np.isclose(np.sum(w_task[y_task == 0]), np.sum(w_task[y_task == 1]))


def test_array_shapes():
  """Test BalancingTransformer when y and w have different shapes."""
  n_samples = 20
  X = np.random.rand(n_samples, 5)
  y = np.random.randint(2, size=n_samples)
  w = np.ones((n_samples, 1))
  dataset = dc.data.NumpyDataset(X, y, w)
  transformer = dc.trans.BalancingTransformer(dataset)
  Xt, yt, wt, ids = transformer.transform_array(X, y, w, dataset.ids)
  sum0 = np.sum(wt[np.where(y == 0)])
  sum1 = np.sum(wt[np.where(y == 1)])
  assert np.isclose(sum0, sum1)