Commit 3f50197a authored by Lucas Romero's avatar Lucas Romero Committed by Anas Nashif
Browse files

lib: bitarray: add method to find nth bit set in region



This is part one of several changes to add more methods to the bitarray api
so that it can be used for broader usecases, specifically LoRaWAN forward
error correction.

Signed-off-by: default avatarLucas Romero <luqasn@gmail.com>
parent 752d3c52
Loading
Loading
Loading
Loading
+22 −0
Original line number Diff line number Diff line
@@ -184,6 +184,28 @@ int sys_bitarray_alloc(sys_bitarray_t *bitarray, size_t num_bits,
 */
int sys_bitarray_xor(sys_bitarray_t *dst, sys_bitarray_t *other, size_t num_bits, size_t offset);

/**
 * Find nth bit set in region
 *
 * This counts the number of bits set (@p count) in a
 * region (@p offset, @p num_bits) and returns the index (@p found_at)
 * of the nth set bit, if it exists, as long with a zero return value.
 *
 * If it does not exist, @p found_at is not updated and the method returns
 *
 * @param[in]  bitarray Bitarray struct
 * @param[in]  n        Nth bit set to look for
 * @param[in]  num_bits Number of bits to check, must be larger than 0
 * @param[in]  offset   Starting bit position
 * @param[out] found_at Index of the nth bit set, if found
 *
 * @retval 0       Operation successful
 * @retval 1       Nth bit set was not found in region
 * @retval -EINVAL Invalid argument (e.g. out-of-bounds access, trying to count 0 bits, etc.)
 */
int sys_bitarray_find_nth_set(sys_bitarray_t *bitarray, size_t n, size_t num_bits, size_t offset,
			      size_t *found_at);

/**
 * Count bits set in a bit array region
 *
+75 −0
Original line number Diff line number Diff line
@@ -559,6 +559,81 @@ out:
	return ret;
}

int sys_bitarray_find_nth_set(sys_bitarray_t *bitarray, size_t n, size_t num_bits, size_t offset,
			      size_t *found_at)
{
	k_spinlock_key_t key;
	size_t count, idx;
	uint32_t mask;
	struct bundle_data bd;
	int ret;

	__ASSERT_NO_MSG(bitarray != NULL);
	__ASSERT_NO_MSG(bitarray->num_bits > 0);

	key = k_spin_lock(&bitarray->lock);

	if (n == 0 || num_bits == 0 || offset + num_bits > bitarray->num_bits) {
		ret = -EINVAL;
		goto out;
	}

	ret = 1;
	mask = 0;
	setup_bundle_data(bitarray, &bd, offset, num_bits);

	count = POPCOUNT(bitarray->bundles[bd.sidx] & bd.smask);
	/* If we already found more bits set than n, we found the target bundle */
	if (count >= n) {
		idx = bd.sidx;
		mask = bd.smask;
		goto found;
	}
	/* Keep looking if there are more bundles */
	if (bd.sidx != bd.eidx) {
		/* We are now only looking for the remaining bits */
		n -= count;
		/* First bundle was already checked, keep looking in middle (complete)
		 * bundles.
		 */
		for (idx = bd.sidx + 1; idx < bd.eidx; idx++) {
			count = POPCOUNT(bitarray->bundles[idx]);
			if (count >= n) {
				mask = ~(mask & 0);
				goto found;
			}
			n -= count;
		}
		/* Continue searching in last bundle */
		count = POPCOUNT(bitarray->bundles[bd.eidx] & bd.emask);
		if (count >= n) {
			idx = bd.eidx;
			mask = bd.emask;
			goto found;
		}
	}

	goto out;

found:
	/* The bit we are looking for must be in the current bundle idx.
	 * Find out the exact index of the bit.
	 */
	for (int j = 0; j <= bundle_bitness(bitarray) - 1; j++) {
		if (bitarray->bundles[idx] & mask & BIT(j)) {
			if (--n <= 0) {
				*found_at = idx * bundle_bitness(bitarray) + j;
				ret = 0;
				break;
			}
		}
	}

out:
	k_spin_unlock(&bitarray->lock, key);
	return ret;
}

int sys_bitarray_free(sys_bitarray_t *bitarray, size_t num_bits,
		      size_t offset)
{
+107 −0
Original line number Diff line number Diff line
@@ -760,6 +760,113 @@ ZTEST(bitarray, test_bitarray_xor)
	zassert_equal(ret, -EINVAL, "sys_bitarray_xor() returned unexpected value: %d", ret);
}

ZTEST(bitarray, test_bitarray_find_nth_set)
{
	int ret;
	size_t found_at;

	/* Bitarrays have embedded spinlocks and can't on the stack. */
	if (IS_ENABLED(CONFIG_KERNEL_COHERENCE)) {
		ztest_test_skip();
	}

	SYS_BITARRAY_DEFINE(ba, 128);

	printk("Testing bit array nth bit set finding spanning single bundle\n");

	/* Pre-populate the bits */
	ba.bundles[0] = 0x80000001;
	ba.bundles[1] = 0x80000001;
	ba.bundles[2] = 0x80000001;
	ba.bundles[3] = 0x80000001;

	ret = sys_bitarray_find_nth_set(&ba, 1, 1, 0, &found_at);
	zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
	zassert_equal(found_at, 0, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
		      found_at);

	ret = sys_bitarray_find_nth_set(&ba, 1, 32, 0, &found_at);
	zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
	zassert_equal(found_at, 0, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
		      found_at);

	ret = sys_bitarray_find_nth_set(&ba, 2, 32, 0, &found_at);
	zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
	zassert_equal(found_at, 31, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
		      found_at);

	ret = sys_bitarray_find_nth_set(&ba, 1, 31, 1, &found_at);
	zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
	zassert_equal(found_at, 31, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
		      found_at);

	ret = sys_bitarray_find_nth_set(&ba, 2, 31, 1, &found_at);
	zassert_equal(ret, 1, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);

	printk("Testing bit array nth bit set finding spanning multiple bundles\n");

	ret = sys_bitarray_find_nth_set(&ba, 1, 128, 0, &found_at);
	zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
	zassert_equal(found_at, 0, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
		      found_at);

	ret = sys_bitarray_find_nth_set(&ba, 8, 128, 0, &found_at);
	zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
	zassert_equal(found_at, 127, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
		      found_at);

	ret = sys_bitarray_find_nth_set(&ba, 8, 128, 1, &found_at);
	zassert_equal(ret, -EINVAL, "sys_bitarray_find_nth_set() returned unexpected value: %d",
		      ret);

	ret = sys_bitarray_find_nth_set(&ba, 7, 127, 1, &found_at);
	zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
	zassert_equal(found_at, 127, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
		      found_at);

	ret = sys_bitarray_find_nth_set(&ba, 7, 127, 0, &found_at);
	zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
	zassert_equal(found_at, 96, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
		      found_at);

	ret = sys_bitarray_find_nth_set(&ba, 6, 127, 1, &found_at);
	zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
	zassert_equal(found_at, 96, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
		      found_at);

	ret = sys_bitarray_find_nth_set(&ba, 6, 127, 1, &found_at);
	zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
	zassert_equal(found_at, 96, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
		      found_at);

	ret = sys_bitarray_find_nth_set(&ba, 1, 32, 48, &found_at);
	zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
	zassert_equal(found_at, 63, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
		      found_at);

	ret = sys_bitarray_find_nth_set(&ba, 2, 32, 48, &found_at);
	zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);
	zassert_equal(found_at, 64, "sys_bitarray_find_nth_set() returned unexpected found_at: %d",
		      found_at);

	printk("Testing error cases\n");

	ret = sys_bitarray_find_nth_set(&ba, 1, 128, 0, &found_at);
	zassert_equal(ret, 0, "sys_bitarray_find_nth_set() returned unexpected value: %d", ret);

	ret = sys_bitarray_find_nth_set(&ba, 1, 128, 1, &found_at);
	zassert_equal(ret, -EINVAL, "sys_bitarray_find_nth_set() returned unexpected value: %d",
		      ret);

	ret = sys_bitarray_find_nth_set(&ba, 1, 129, 0, &found_at);
	zassert_equal(ret, -EINVAL, "sys_bitarray_find_nth_set() returned unexpected value: %d",
		      ret);

	ret = sys_bitarray_find_nth_set(&ba, 0, 128, 0, &found_at);
	zassert_equal(ret, -EINVAL, "sys_bitarray_find_nth_set() returned unexpected value: %d",
		      ret);
}

ZTEST(bitarray, test_bitarray_region_set_clear)
{
	int ret;