Commit 07cd2631 authored by John Fastabend's avatar John Fastabend Committed by Alexei Starovoitov
Browse files

bpf: Verifer, refactor adjust_scalar_min_max_vals



Pull per op ALU logic into individual functions. We are about to add
u32 versions of each of these by pull them out the code gets a bit
more readable here and nicer in the next patch.

Signed-off-by: default avatarJohn Fastabend <john.fastabend@gmail.com>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Link: https://lore.kernel.org/bpf/158507149518.15666.15672349629329072411.stgit@john-Precision-5820-Tower
parent 8395f320
Loading
Loading
Loading
Loading
+239 −164
Original line number Diff line number Diff line
@@ -4843,62 +4843,14 @@ static int adjust_ptr_min_max_vals(struct bpf_verifier_env *env,
	return 0;
}

/* WARNING: This function does calculations on 64-bit values, but the actual
 * execution may occur on 32-bit values. Therefore, things like bitshifts
 * need extra checks in the 32-bit case.
 */
static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
				      struct bpf_insn *insn,
				      struct bpf_reg_state *dst_reg,
				      struct bpf_reg_state src_reg)
static void scalar_min_max_add(struct bpf_reg_state *dst_reg,
			       struct bpf_reg_state *src_reg)
{
	struct bpf_reg_state *regs = cur_regs(env);
	u8 opcode = BPF_OP(insn->code);
	bool src_known, dst_known;
	s64 smin_val, smax_val;
	u64 umin_val, umax_val;
	u64 insn_bitness = (BPF_CLASS(insn->code) == BPF_ALU64) ? 64 : 32;
	u32 dst = insn->dst_reg;
	int ret;

	if (insn_bitness == 32) {
		/* Relevant for 32-bit RSH: Information can propagate towards
		 * LSB, so it isn't sufficient to only truncate the output to
		 * 32 bits.
		 */
		coerce_reg_to_size(dst_reg, 4);
		coerce_reg_to_size(&src_reg, 4);
	}

	smin_val = src_reg.smin_value;
	smax_val = src_reg.smax_value;
	umin_val = src_reg.umin_value;
	umax_val = src_reg.umax_value;
	src_known = tnum_is_const(src_reg.var_off);
	dst_known = tnum_is_const(dst_reg->var_off);
	s64 smin_val = src_reg->smin_value;
	s64 smax_val = src_reg->smax_value;
	u64 umin_val = src_reg->umin_value;
	u64 umax_val = src_reg->umax_value;

	if ((src_known && (smin_val != smax_val || umin_val != umax_val)) ||
	    smin_val > smax_val || umin_val > umax_val) {
		/* Taint dst register if offset had invalid bounds derived from
		 * e.g. dead branches.
		 */
		__mark_reg_unknown(env, dst_reg);
		return 0;
	}

	if (!src_known &&
	    opcode != BPF_ADD && opcode != BPF_SUB && opcode != BPF_AND) {
		__mark_reg_unknown(env, dst_reg);
		return 0;
	}

	switch (opcode) {
	case BPF_ADD:
		ret = sanitize_val_alu(env, insn);
		if (ret < 0) {
			verbose(env, "R%d tried to add from different pointers or scalars\n", dst);
			return ret;
		}
	if (signed_add_overflows(dst_reg->smin_value, smin_val) ||
	    signed_add_overflows(dst_reg->smax_value, smax_val)) {
		dst_reg->smin_value = S64_MIN;
@@ -4915,14 +4867,17 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
		dst_reg->umin_value += umin_val;
		dst_reg->umax_value += umax_val;
	}
		dst_reg->var_off = tnum_add(dst_reg->var_off, src_reg.var_off);
		break;
	case BPF_SUB:
		ret = sanitize_val_alu(env, insn);
		if (ret < 0) {
			verbose(env, "R%d tried to sub from different pointers or scalars\n", dst);
			return ret;
	dst_reg->var_off = tnum_add(dst_reg->var_off, src_reg->var_off);
}

static void scalar_min_max_sub(struct bpf_reg_state *dst_reg,
			       struct bpf_reg_state *src_reg)
{
	s64 smin_val = src_reg->smin_value;
	s64 smax_val = src_reg->smax_value;
	u64 umin_val = src_reg->umin_value;
	u64 umax_val = src_reg->umax_value;

	if (signed_sub_overflows(dst_reg->smin_value, smax_val) ||
	    signed_sub_overflows(dst_reg->smax_value, smin_val)) {
		/* Overflow possible, we know nothing */
@@ -4941,15 +4896,22 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
		dst_reg->umin_value -= umax_val;
		dst_reg->umax_value -= umin_val;
	}
		dst_reg->var_off = tnum_sub(dst_reg->var_off, src_reg.var_off);
		break;
	case BPF_MUL:
		dst_reg->var_off = tnum_mul(dst_reg->var_off, src_reg.var_off);
	dst_reg->var_off = tnum_sub(dst_reg->var_off, src_reg->var_off);
}

static void scalar_min_max_mul(struct bpf_reg_state *dst_reg,
			       struct bpf_reg_state *src_reg)
{
	s64 smin_val = src_reg->smin_value;
	u64 umin_val = src_reg->umin_value;
	u64 umax_val = src_reg->umax_value;

	dst_reg->var_off = tnum_mul(dst_reg->var_off, src_reg->var_off);
	if (smin_val < 0 || dst_reg->smin_value < 0) {
		/* Ain't nobody got time to multiply that sign */
		__mark_reg_unbounded(dst_reg);
		__update_reg_bounds(dst_reg);
			break;
		return;
	}
	/* Both values are positive, so we can work with unsigned and
	 * copy the result to signed (unless it exceeds S64_MAX).
@@ -4959,7 +4921,7 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
		__mark_reg_unbounded(dst_reg);
		/* (except what we can learn from the var_off) */
		__update_reg_bounds(dst_reg);
			break;
		return;
	}
	dst_reg->umin_value *= umin_val;
	dst_reg->umax_value *= umax_val;
@@ -4971,17 +4933,18 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
		dst_reg->smin_value = dst_reg->umin_value;
		dst_reg->smax_value = dst_reg->umax_value;
	}
		break;
	case BPF_AND:
		if (src_known && dst_known) {
			__mark_reg_known(dst_reg, dst_reg->var_off.value &
						  src_reg.var_off.value);
			break;
}

static void scalar_min_max_and(struct bpf_reg_state *dst_reg,
			       struct bpf_reg_state *src_reg)
{
	s64 smin_val = src_reg->smin_value;
	u64 umax_val = src_reg->umax_value;

	/* We get our minimum from the var_off, since that's inherently
	 * bitwise.  Our maximum is the minimum of the operands' maxima.
	 */
		dst_reg->var_off = tnum_and(dst_reg->var_off, src_reg.var_off);
	dst_reg->var_off = tnum_and(dst_reg->var_off, src_reg->var_off);
	dst_reg->umin_value = dst_reg->var_off.value;
	dst_reg->umax_value = min(dst_reg->umax_value, umax_val);
	if (dst_reg->smin_value < 0 || smin_val < 0) {
@@ -4999,20 +4962,20 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
	}
	/* We may learn something more from the var_off */
	__update_reg_bounds(dst_reg);
		break;
	case BPF_OR:
		if (src_known && dst_known) {
			__mark_reg_known(dst_reg, dst_reg->var_off.value |
						  src_reg.var_off.value);
			break;
}

static void scalar_min_max_or(struct bpf_reg_state *dst_reg,
			      struct bpf_reg_state *src_reg)
{
	s64 smin_val = src_reg->smin_value;
	u64 umin_val = src_reg->umin_value;

	/* We get our maximum from the var_off, and our minimum is the
	 * maximum of the operands' minima
	 */
		dst_reg->var_off = tnum_or(dst_reg->var_off, src_reg.var_off);
	dst_reg->var_off = tnum_or(dst_reg->var_off, src_reg->var_off);
	dst_reg->umin_value = max(dst_reg->umin_value, umin_val);
		dst_reg->umax_value = dst_reg->var_off.value |
				      dst_reg->var_off.mask;
	dst_reg->umax_value = dst_reg->var_off.value | dst_reg->var_off.mask;
	if (dst_reg->smin_value < 0 || smin_val < 0) {
		/* Lose signed bounds when ORing negative numbers,
		 * ain't nobody got time for that.
@@ -5028,15 +4991,14 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
	}
	/* We may learn something more from the var_off */
	__update_reg_bounds(dst_reg);
		break;
	case BPF_LSH:
		if (umax_val >= insn_bitness) {
			/* Shifts greater than 31 or 63 are undefined.
			 * This includes shifts by a negative number.
			 */
			mark_reg_unknown(env, regs, insn->dst_reg);
			break;
}

static void scalar_min_max_lsh(struct bpf_reg_state *dst_reg,
			       struct bpf_reg_state *src_reg)
{
	u64 umax_val = src_reg->umax_value;
	u64 umin_val = src_reg->umin_value;

	/* We lose all sign bit information (except what we can pick
	 * up from var_off)
	 */
@@ -5053,15 +5015,14 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
	dst_reg->var_off = tnum_lshift(dst_reg->var_off, umin_val);
	/* We may learn something more from the var_off */
	__update_reg_bounds(dst_reg);
		break;
	case BPF_RSH:
		if (umax_val >= insn_bitness) {
			/* Shifts greater than 31 or 63 are undefined.
			 * This includes shifts by a negative number.
			 */
			mark_reg_unknown(env, regs, insn->dst_reg);
			break;
}

static void scalar_min_max_rsh(struct bpf_reg_state *dst_reg,
			       struct bpf_reg_state *src_reg)
{
	u64 umax_val = src_reg->umax_value;
	u64 umin_val = src_reg->umin_value;

	/* BPF_RSH is an unsigned shift.  If the value in dst_reg might
	 * be negative, then either:
	 * 1) src_reg might be zero, so the sign bit of the result is
@@ -5083,16 +5044,14 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
	dst_reg->umax_value >>= umin_val;
	/* We may learn something more from the var_off */
	__update_reg_bounds(dst_reg);
		break;
	case BPF_ARSH:
		if (umax_val >= insn_bitness) {
			/* Shifts greater than 31 or 63 are undefined.
			 * This includes shifts by a negative number.
			 */
			mark_reg_unknown(env, regs, insn->dst_reg);
			break;
}

static void scalar_min_max_arsh(struct bpf_reg_state *dst_reg,
			        struct bpf_reg_state *src_reg,
				u64 insn_bitness)
{
	u64 umin_val = src_reg->umin_value;

	/* Upon reaching here, src_known is true and
	 * umax_val is equal to umin_val.
	 */
@@ -5113,6 +5072,122 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
	dst_reg->umin_value = 0;
	dst_reg->umax_value = U64_MAX;
	__update_reg_bounds(dst_reg);
}

/* WARNING: This function does calculations on 64-bit values, but the actual
 * execution may occur on 32-bit values. Therefore, things like bitshifts
 * need extra checks in the 32-bit case.
 */
static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
				      struct bpf_insn *insn,
				      struct bpf_reg_state *dst_reg,
				      struct bpf_reg_state src_reg)
{
	struct bpf_reg_state *regs = cur_regs(env);
	u8 opcode = BPF_OP(insn->code);
	bool src_known, dst_known;
	s64 smin_val, smax_val;
	u64 umin_val, umax_val;
	u64 insn_bitness = (BPF_CLASS(insn->code) == BPF_ALU64) ? 64 : 32;
	u32 dst = insn->dst_reg;
	int ret;

	if (insn_bitness == 32) {
		/* Relevant for 32-bit RSH: Information can propagate towards
		 * LSB, so it isn't sufficient to only truncate the output to
		 * 32 bits.
		 */
		coerce_reg_to_size(dst_reg, 4);
		coerce_reg_to_size(&src_reg, 4);
	}

	smin_val = src_reg.smin_value;
	smax_val = src_reg.smax_value;
	umin_val = src_reg.umin_value;
	umax_val = src_reg.umax_value;
	src_known = tnum_is_const(src_reg.var_off);
	dst_known = tnum_is_const(dst_reg->var_off);

	if ((src_known && (smin_val != smax_val || umin_val != umax_val)) ||
	    smin_val > smax_val || umin_val > umax_val) {
		/* Taint dst register if offset had invalid bounds derived from
		 * e.g. dead branches.
		 */
		__mark_reg_unknown(env, dst_reg);
		return 0;
	}

	if (!src_known &&
	    opcode != BPF_ADD && opcode != BPF_SUB && opcode != BPF_AND) {
		__mark_reg_unknown(env, dst_reg);
		return 0;
	}

	switch (opcode) {
	case BPF_ADD:
		ret = sanitize_val_alu(env, insn);
		if (ret < 0) {
			verbose(env, "R%d tried to add from different pointers or scalars\n", dst);
			return ret;
		}
		scalar_min_max_add(dst_reg, &src_reg);
		break;
	case BPF_SUB:
		ret = sanitize_val_alu(env, insn);
		if (ret < 0) {
			verbose(env, "R%d tried to sub from different pointers or scalars\n", dst);
			return ret;
		}
		scalar_min_max_sub(dst_reg, &src_reg);
		break;
	case BPF_MUL:
		scalar_min_max_mul(dst_reg, &src_reg);
		break;
	case BPF_AND:
		if (src_known && dst_known) {
			__mark_reg_known(dst_reg, dst_reg->var_off.value &
						  src_reg.var_off.value);
			break;
		}
		scalar_min_max_and(dst_reg, &src_reg);
		break;
	case BPF_OR:
		if (src_known && dst_known) {
			__mark_reg_known(dst_reg, dst_reg->var_off.value |
						  src_reg.var_off.value);
			break;
		}
		scalar_min_max_or(dst_reg, &src_reg);
		break;
	case BPF_LSH:
		if (umax_val >= insn_bitness) {
			/* Shifts greater than 31 or 63 are undefined.
			 * This includes shifts by a negative number.
			 */
			mark_reg_unknown(env, regs, insn->dst_reg);
			break;
		}
		scalar_min_max_lsh(dst_reg, &src_reg);
		break;
	case BPF_RSH:
		if (umax_val >= insn_bitness) {
			/* Shifts greater than 31 or 63 are undefined.
			 * This includes shifts by a negative number.
			 */
			mark_reg_unknown(env, regs, insn->dst_reg);
			break;
		}
		scalar_min_max_rsh(dst_reg, &src_reg);
		break;
	case BPF_ARSH:
		if (umax_val >= insn_bitness) {
			/* Shifts greater than 31 or 63 are undefined.
			 * This includes shifts by a negative number.
			 */
			mark_reg_unknown(env, regs, insn->dst_reg);
			break;
		}
		scalar_min_max_arsh(dst_reg, &src_reg, insn_bitness);
		break;
	default:
		mark_reg_unknown(env, regs, insn->dst_reg);