Commit be1fff2e authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Add a unit test

parent e0369766
Loading
Loading
Loading
Loading
+19 −3
Original line number Diff line number Diff line
@@ -794,6 +794,22 @@ class TestDatasets(test_util.TensorFlowTestCase):
    np.testing.assert_array_equal(
        np.stack([dataset.y[:, 0], dataset.X[:, 0]], axis=1), dataset3.w)


if __name__ == "__main__":
  unittest.main()
  def test_to_str(self):
    """Tests to string representation of Dataset."""
    dataset = dc.data.NumpyDataset(
        X=np.random.rand(5, 3), y=np.random.rand(5,), ids=np.arange(5))
    ref_str = '<NumpyDataset X.shape: (5, 3), y.shape: (5,), w.shape: (5,), ids: [0 1 2 3 4], task_names: [0]>'
    assert str(dataset) == ref_str

    # Test id shrinkage
    dc.utils.set_print_threshold(10)
    dataset = dc.data.NumpyDataset(
        X=np.random.rand(50, 3), y=np.random.rand(50,), ids=np.arange(50))
    ref_str = '<NumpyDataset X.shape: (50, 3), y.shape: (50,), w.shape: (50,), ids: [0 1 2 ... 47 48 49], task_names: [0]>'
    assert str(dataset) == ref_str

    # Test task shrinkage
    dataset = dc.data.NumpyDataset(
        X=np.random.rand(50, 3), y=np.random.rand(50, 20), ids=np.arange(50))
    ref_str = '<NumpyDataset X.shape: (50, 3), y.shape: (50, 20), w.shape: (50, 1), ids: [0 1 2 ... 47 48 49], task_names: [ 0  1  2 ... 17 18 19]>'
    assert str(dataset) == ref_str