Commit 0fc31b10 authored by Jann Horn's avatar Jann Horn Committed by Alexei Starovoitov
Browse files

bpf: Simplify reg_set_min_max_inv handling



reg_set_min_max_inv() contains exactly the same logic as reg_set_min_max(),
just flipped around. While this makes sense in a cBPF verifier (where ALU
operations are not symmetric), it does not make sense for eBPF.

Replace reg_set_min_max_inv() with a helper that flips the opcode around,
then lets reg_set_min_max() do the complicated work.

Signed-off-by: default avatarJann Horn <jannh@google.com>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Link: https://lore.kernel.org/bpf/20200330160324.15259-4-daniel@iogearbox.net
parent 604dca5e
Loading
Loading
Loading
Loading
+22 −86
Original line number Diff line number Diff line
@@ -5836,7 +5836,7 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg,
		break;
	}
	default:
		break;
		return;
	}

	__reg_deduce_bounds(false_reg);
@@ -5859,92 +5859,28 @@ static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
				struct bpf_reg_state *false_reg, u64 val,
				u8 opcode, bool is_jmp32)
{
	s64 sval;

	if (__is_pointer_value(false, false_reg))
		return;

	val = is_jmp32 ? (u32)val : val;
	sval = is_jmp32 ? (s64)(s32)val : (s64)val;

	switch (opcode) {
	case BPF_JEQ:
	case BPF_JNE:
	{
		struct bpf_reg_state *reg =
			opcode == BPF_JEQ ? true_reg : false_reg;

		if (is_jmp32) {
			u64 old_v = reg->var_off.value;
			u64 hi_mask = ~0xffffffffULL;

			reg->var_off.value = (old_v & hi_mask) | val;
			reg->var_off.mask &= hi_mask;
		} else {
			__mark_reg_known(reg, val);
		}
		break;
	}
	case BPF_JSET:
		false_reg->var_off = tnum_and(false_reg->var_off,
					      tnum_const(~val));
		if (is_power_of_2(val))
			true_reg->var_off = tnum_or(true_reg->var_off,
						    tnum_const(val));
		break;
	case BPF_JGE:
	case BPF_JGT:
	{
		set_lower_bound(false_reg, val, is_jmp32, opcode == BPF_JGE);
		set_upper_bound(true_reg, val, is_jmp32, opcode == BPF_JGT);
		break;
	}
	case BPF_JSGE:
	case BPF_JSGT:
	{
		s64 false_smin = opcode == BPF_JSGT ? sval    : sval + 1;
		s64 true_smax = opcode == BPF_JSGT ? sval - 1 : sval;

		if (is_jmp32 && !cmp_val_with_extended_s64(sval, false_reg))
			break;
		false_reg->smin_value = max(false_reg->smin_value, false_smin);
		true_reg->smax_value = min(true_reg->smax_value, true_smax);
		break;
	}
	case BPF_JLE:
	case BPF_JLT:
	{
		set_upper_bound(false_reg, val, is_jmp32, opcode == BPF_JLE);
		set_lower_bound(true_reg, val, is_jmp32, opcode == BPF_JLT);
		break;
	}
	case BPF_JSLE:
	case BPF_JSLT:
	{
		s64 false_smax = opcode == BPF_JSLT ? sval    : sval - 1;
		s64 true_smin = opcode == BPF_JSLT ? sval + 1 : sval;

		if (is_jmp32 && !cmp_val_with_extended_s64(sval, false_reg))
			break;
		false_reg->smax_value = min(false_reg->smax_value, false_smax);
		true_reg->smin_value = max(true_reg->smin_value, true_smin);
		break;
	}
	default:
		break;
	}

	__reg_deduce_bounds(false_reg);
	__reg_deduce_bounds(true_reg);
	/* We might have learned some bits from the bounds. */
	__reg_bound_offset(false_reg);
	__reg_bound_offset(true_reg);
	/* Intersecting with the old var_off might have improved our bounds
	 * slightly.  e.g. if umax was 0x7f...f and var_off was (0; 0xf...fc),
	 * then new var_off is (0; 0x7f...fc) which improves our umax.
	/* How can we transform "a <op> b" into "b <op> a"? */
	static const u8 opcode_flip[16] = {
		/* these stay the same */
		[BPF_JEQ  >> 4] = BPF_JEQ,
		[BPF_JNE  >> 4] = BPF_JNE,
		[BPF_JSET >> 4] = BPF_JSET,
		/* these swap "lesser" and "greater" (L and G in the opcodes) */
		[BPF_JGE  >> 4] = BPF_JLE,
		[BPF_JGT  >> 4] = BPF_JLT,
		[BPF_JLE  >> 4] = BPF_JGE,
		[BPF_JLT  >> 4] = BPF_JGT,
		[BPF_JSGE >> 4] = BPF_JSLE,
		[BPF_JSGT >> 4] = BPF_JSLT,
		[BPF_JSLE >> 4] = BPF_JSGE,
		[BPF_JSLT >> 4] = BPF_JSGT
	};
	opcode = opcode_flip[opcode >> 4];
	/* This uses zero as "not present in table"; luckily the zero opcode,
	 * BPF_JA, can't get here.
	 */
	__update_reg_bounds(false_reg);
	__update_reg_bounds(true_reg);
	if (opcode)
		reg_set_min_max(true_reg, false_reg, val, opcode, is_jmp32);
}

/* Regs are known to be equal, so intersect their min/max/var_off */