Commit 0a58a65c authored by Daniel Borkmann's avatar Daniel Borkmann
Browse files

Merge branch 'bpf-ptrs-beyond-pkt-end'



Alexei Starovoitov says:

====================
v1->v2:
- removed set-but-unused variable.
- added Jiri's Tested-by.

In some cases LLVM uses the knowledge that branch is taken to optimze the code
which causes the verifier to reject valid programs.
Teach the verifier to recognize that
r1 = skb->data;
r1 += 10;
r2 = skb->data_end;
if (r1 > r2) {
  here r1 points beyond packet_end and subsequent
  if (r1 > r2) // always evaluates to "true".
}
====================

Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
parents c3653879 cb62d340
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -45,7 +45,7 @@ struct bpf_reg_state {
	enum bpf_reg_type type;
	union {
		/* valid when type == PTR_TO_PACKET */
		u16 range;
		int range;

		/* valid when type == CONST_PTR_TO_MAP | PTR_TO_MAP_VALUE |
		 *   PTR_TO_MAP_VALUE_OR_NULL
+107 −22
Original line number Diff line number Diff line
@@ -2739,7 +2739,9 @@ static int check_packet_access(struct bpf_verifier_env *env, u32 regno, int off,
			regno);
		return -EACCES;
	}
	err = __check_mem_access(env, regno, off, size, reg->range,

	err = reg->range < 0 ? -EINVAL :
	      __check_mem_access(env, regno, off, size, reg->range,
				 zero_size_allowed);
	if (err) {
		verbose(env, "R%d offset is outside of the packet\n", regno);
@@ -4697,6 +4699,32 @@ static void clear_all_pkt_pointers(struct bpf_verifier_env *env)
		__clear_all_pkt_pointers(env, vstate->frame[i]);
}

enum {
	AT_PKT_END = -1,
	BEYOND_PKT_END = -2,
};

static void mark_pkt_end(struct bpf_verifier_state *vstate, int regn, bool range_open)
{
	struct bpf_func_state *state = vstate->frame[vstate->curframe];
	struct bpf_reg_state *reg = &state->regs[regn];

	if (reg->type != PTR_TO_PACKET)
		/* PTR_TO_PACKET_META is not supported yet */
		return;

	/* The 'reg' is pkt > pkt_end or pkt >= pkt_end.
	 * How far beyond pkt_end it goes is unknown.
	 * if (!range_open) it's the case of pkt >= pkt_end
	 * if (range_open) it's the case of pkt > pkt_end
	 * hence this pointer is at least 1 byte bigger than pkt_end
	 */
	if (range_open)
		reg->range = BEYOND_PKT_END;
	else
		reg->range = AT_PKT_END;
}

static void release_reg_references(struct bpf_verifier_env *env,
				   struct bpf_func_state *state,
				   int ref_obj_id)
@@ -6708,7 +6736,7 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)

static void __find_good_pkt_pointers(struct bpf_func_state *state,
				     struct bpf_reg_state *dst_reg,
				     enum bpf_reg_type type, u16 new_range)
				     enum bpf_reg_type type, int new_range)
{
	struct bpf_reg_state *reg;
	int i;
@@ -6733,8 +6761,7 @@ static void find_good_pkt_pointers(struct bpf_verifier_state *vstate,
				   enum bpf_reg_type type,
				   bool range_right_open)
{
	u16 new_range;
	int i;
	int new_range, i;

	if (dst_reg->off < 0 ||
	    (dst_reg->off == 0 && range_right_open))
@@ -6985,6 +7012,67 @@ static int is_branch_taken(struct bpf_reg_state *reg, u64 val, u8 opcode,
	return is_branch64_taken(reg, val, opcode);
}

static int flip_opcode(u32 opcode)
{
	/* How can we transform "a <op> b" into "b <op> a"? */
	static const u8 opcode_flip[16] = {
		/* these stay the same */
		[BPF_JEQ  >> 4] = BPF_JEQ,
		[BPF_JNE  >> 4] = BPF_JNE,
		[BPF_JSET >> 4] = BPF_JSET,
		/* these swap "lesser" and "greater" (L and G in the opcodes) */
		[BPF_JGE  >> 4] = BPF_JLE,
		[BPF_JGT  >> 4] = BPF_JLT,
		[BPF_JLE  >> 4] = BPF_JGE,
		[BPF_JLT  >> 4] = BPF_JGT,
		[BPF_JSGE >> 4] = BPF_JSLE,
		[BPF_JSGT >> 4] = BPF_JSLT,
		[BPF_JSLE >> 4] = BPF_JSGE,
		[BPF_JSLT >> 4] = BPF_JSGT
	};
	return opcode_flip[opcode >> 4];
}

static int is_pkt_ptr_branch_taken(struct bpf_reg_state *dst_reg,
				   struct bpf_reg_state *src_reg,
				   u8 opcode)
{
	struct bpf_reg_state *pkt;

	if (src_reg->type == PTR_TO_PACKET_END) {
		pkt = dst_reg;
	} else if (dst_reg->type == PTR_TO_PACKET_END) {
		pkt = src_reg;
		opcode = flip_opcode(opcode);
	} else {
		return -1;
	}

	if (pkt->range >= 0)
		return -1;

	switch (opcode) {
	case BPF_JLE:
		/* pkt <= pkt_end */
		fallthrough;
	case BPF_JGT:
		/* pkt > pkt_end */
		if (pkt->range == BEYOND_PKT_END)
			/* pkt has at last one extra byte beyond pkt_end */
			return opcode == BPF_JGT;
		break;
	case BPF_JLT:
		/* pkt < pkt_end */
		fallthrough;
	case BPF_JGE:
		/* pkt >= pkt_end */
		if (pkt->range == BEYOND_PKT_END || pkt->range == AT_PKT_END)
			return opcode == BPF_JGE;
		break;
	}
	return -1;
}

/* Adjusts the register min/max values in the case that the dst_reg is the
 * variable register that we are working on, and src_reg is a constant or we're
 * simply doing a BPF_K check.
@@ -7148,23 +7236,7 @@ static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
				u64 val, u32 val32,
				u8 opcode, bool is_jmp32)
{
	/* How can we transform "a <op> b" into "b <op> a"? */
	static const u8 opcode_flip[16] = {
		/* these stay the same */
		[BPF_JEQ  >> 4] = BPF_JEQ,
		[BPF_JNE  >> 4] = BPF_JNE,
		[BPF_JSET >> 4] = BPF_JSET,
		/* these swap "lesser" and "greater" (L and G in the opcodes) */
		[BPF_JGE  >> 4] = BPF_JLE,
		[BPF_JGT  >> 4] = BPF_JLT,
		[BPF_JLE  >> 4] = BPF_JGE,
		[BPF_JLT  >> 4] = BPF_JGT,
		[BPF_JSGE >> 4] = BPF_JSLE,
		[BPF_JSGT >> 4] = BPF_JSLT,
		[BPF_JSLE >> 4] = BPF_JSGE,
		[BPF_JSLT >> 4] = BPF_JSGT
	};
	opcode = opcode_flip[opcode >> 4];
	opcode = flip_opcode(opcode);
	/* This uses zero as "not present in table"; luckily the zero opcode,
	 * BPF_JA, can't get here.
	 */
@@ -7346,6 +7418,7 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn,
			/* pkt_data' > pkt_end, pkt_meta' > pkt_data */
			find_good_pkt_pointers(this_branch, dst_reg,
					       dst_reg->type, false);
			mark_pkt_end(other_branch, insn->dst_reg, true);
		} else if ((dst_reg->type == PTR_TO_PACKET_END &&
			    src_reg->type == PTR_TO_PACKET) ||
			   (reg_is_init_pkt_pointer(dst_reg, PTR_TO_PACKET) &&
@@ -7353,6 +7426,7 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn,
			/* pkt_end > pkt_data', pkt_data > pkt_meta' */
			find_good_pkt_pointers(other_branch, src_reg,
					       src_reg->type, true);
			mark_pkt_end(this_branch, insn->src_reg, false);
		} else {
			return false;
		}
@@ -7365,6 +7439,7 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn,
			/* pkt_data' < pkt_end, pkt_meta' < pkt_data */
			find_good_pkt_pointers(other_branch, dst_reg,
					       dst_reg->type, true);
			mark_pkt_end(this_branch, insn->dst_reg, false);
		} else if ((dst_reg->type == PTR_TO_PACKET_END &&
			    src_reg->type == PTR_TO_PACKET) ||
			   (reg_is_init_pkt_pointer(dst_reg, PTR_TO_PACKET) &&
@@ -7372,6 +7447,7 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn,
			/* pkt_end < pkt_data', pkt_data > pkt_meta' */
			find_good_pkt_pointers(this_branch, src_reg,
					       src_reg->type, false);
			mark_pkt_end(other_branch, insn->src_reg, true);
		} else {
			return false;
		}
@@ -7384,6 +7460,7 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn,
			/* pkt_data' >= pkt_end, pkt_meta' >= pkt_data */
			find_good_pkt_pointers(this_branch, dst_reg,
					       dst_reg->type, true);
			mark_pkt_end(other_branch, insn->dst_reg, false);
		} else if ((dst_reg->type == PTR_TO_PACKET_END &&
			    src_reg->type == PTR_TO_PACKET) ||
			   (reg_is_init_pkt_pointer(dst_reg, PTR_TO_PACKET) &&
@@ -7391,6 +7468,7 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn,
			/* pkt_end >= pkt_data', pkt_data >= pkt_meta' */
			find_good_pkt_pointers(other_branch, src_reg,
					       src_reg->type, false);
			mark_pkt_end(this_branch, insn->src_reg, true);
		} else {
			return false;
		}
@@ -7403,6 +7481,7 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn,
			/* pkt_data' <= pkt_end, pkt_meta' <= pkt_data */
			find_good_pkt_pointers(other_branch, dst_reg,
					       dst_reg->type, false);
			mark_pkt_end(this_branch, insn->dst_reg, true);
		} else if ((dst_reg->type == PTR_TO_PACKET_END &&
			    src_reg->type == PTR_TO_PACKET) ||
			   (reg_is_init_pkt_pointer(dst_reg, PTR_TO_PACKET) &&
@@ -7410,6 +7489,7 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn,
			/* pkt_end <= pkt_data', pkt_data <= pkt_meta' */
			find_good_pkt_pointers(this_branch, src_reg,
					       src_reg->type, true);
			mark_pkt_end(other_branch, insn->src_reg, false);
		} else {
			return false;
		}
@@ -7509,6 +7589,10 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
				       src_reg->var_off.value,
				       opcode,
				       is_jmp32);
	} else if (reg_is_pkt_pointer_any(dst_reg) &&
		   reg_is_pkt_pointer_any(src_reg) &&
		   !is_jmp32) {
		pred = is_pkt_ptr_branch_taken(dst_reg, src_reg, opcode);
	}

	if (pred >= 0) {
@@ -7517,7 +7601,8 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
		 */
		if (!__is_pointer_value(false, dst_reg))
			err = mark_chain_precision(env, insn->dst_reg);
		if (BPF_SRC(insn->code) == BPF_X && !err)
		if (BPF_SRC(insn->code) == BPF_X && !err &&
		    !__is_pointer_value(false, src_reg))
			err = mark_chain_precision(env, insn->src_reg);
		if (err)
			return err;
+41 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2020 Facebook */
#include <test_progs.h>
#include <network_helpers.h>
#include "skb_pkt_end.skel.h"

static int sanity_run(struct bpf_program *prog)
{
	__u32 duration, retval;
	int err, prog_fd;

	prog_fd = bpf_program__fd(prog);
	err = bpf_prog_test_run(prog_fd, 1, &pkt_v4, sizeof(pkt_v4),
				NULL, NULL, &retval, &duration);
	if (CHECK(err || retval != 123, "test_run",
		  "err %d errno %d retval %d duration %d\n",
		  err, errno, retval, duration))
		return -1;
	return 0;
}

void test_test_skb_pkt_end(void)
{
	struct skb_pkt_end *skb_pkt_end_skel = NULL;
	__u32 duration = 0;
	int err;

	skb_pkt_end_skel = skb_pkt_end__open_and_load();
	if (CHECK(!skb_pkt_end_skel, "skb_pkt_end_skel_load", "skb_pkt_end skeleton failed\n"))
		goto cleanup;

	err = skb_pkt_end__attach(skb_pkt_end_skel);
	if (CHECK(err, "skb_pkt_end_attach", "skb_pkt_end attach failed: %d\n", err))
		goto cleanup;

	if (sanity_run(skb_pkt_end_skel->progs.main_prog))
		goto cleanup;

cleanup:
	skb_pkt_end__destroy(skb_pkt_end_skel);
}
+54 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: GPL-2.0
#define BPF_NO_PRESERVE_ACCESS_INDEX
#include <vmlinux.h>
#include <bpf/bpf_core_read.h>
#include <bpf/bpf_helpers.h>

#define NULL 0
#define INLINE __always_inline

#define skb_shorter(skb, len) ((void *)(long)(skb)->data + (len) > (void *)(long)skb->data_end)

#define ETH_IPV4_TCP_SIZE (14 + sizeof(struct iphdr) + sizeof(struct tcphdr))

static INLINE struct iphdr *get_iphdr(struct __sk_buff *skb)
{
	struct iphdr *ip = NULL;
	struct ethhdr *eth;

	if (skb_shorter(skb, ETH_IPV4_TCP_SIZE))
		goto out;

	eth = (void *)(long)skb->data;
	ip = (void *)(eth + 1);

out:
	return ip;
}

SEC("classifier/cls")
int main_prog(struct __sk_buff *skb)
{
	struct iphdr *ip = NULL;
	struct tcphdr *tcp;
	__u8 proto = 0;

	if (!(ip = get_iphdr(skb)))
		goto out;

	proto = ip->protocol;

	if (proto != IPPROTO_TCP)
		goto out;

	tcp = (void*)(ip + 1);
	if (tcp->dest != 0)
		goto out;
	if (!tcp)
		goto out;

	return tcp->urg_ptr;
out:
	return -1;
}
char _license[] SEC("license") = "GPL";
+42 −0
Original line number Diff line number Diff line
@@ -1089,3 +1089,45 @@
	.errstr_unpriv = "R1 leaks addr",
	.result = REJECT,
},
{
       "pkt > pkt_end taken check",
       .insns = {
       BPF_LDX_MEM(BPF_W, BPF_REG_2, BPF_REG_1,                //  0. r2 = *(u32 *)(r1 + data_end)
                   offsetof(struct __sk_buff, data_end)),
       BPF_LDX_MEM(BPF_W, BPF_REG_4, BPF_REG_1,                //  1. r4 = *(u32 *)(r1 + data)
                   offsetof(struct __sk_buff, data)),
       BPF_MOV64_REG(BPF_REG_3, BPF_REG_4),                    //  2. r3 = r4
       BPF_ALU64_IMM(BPF_ADD, BPF_REG_3, 42),                  //  3. r3 += 42
       BPF_MOV64_IMM(BPF_REG_1, 0),                            //  4. r1 = 0
       BPF_JMP_REG(BPF_JGT, BPF_REG_3, BPF_REG_2, 2),          //  5. if r3 > r2 goto 8
       BPF_ALU64_IMM(BPF_ADD, BPF_REG_4, 14),                  //  6. r4 += 14
       BPF_MOV64_REG(BPF_REG_1, BPF_REG_4),                    //  7. r1 = r4
       BPF_JMP_REG(BPF_JGT, BPF_REG_3, BPF_REG_2, 1),          //  8. if r3 > r2 goto 10
       BPF_LDX_MEM(BPF_H, BPF_REG_2, BPF_REG_1, 9),            //  9. r2 = *(u8 *)(r1 + 9)
       BPF_MOV64_IMM(BPF_REG_0, 0),                            // 10. r0 = 0
       BPF_EXIT_INSN(),                                        // 11. exit
       },
       .result = ACCEPT,
       .prog_type = BPF_PROG_TYPE_SK_SKB,
},
{
       "pkt_end < pkt taken check",
       .insns = {
       BPF_LDX_MEM(BPF_W, BPF_REG_2, BPF_REG_1,                //  0. r2 = *(u32 *)(r1 + data_end)
                   offsetof(struct __sk_buff, data_end)),
       BPF_LDX_MEM(BPF_W, BPF_REG_4, BPF_REG_1,                //  1. r4 = *(u32 *)(r1 + data)
                   offsetof(struct __sk_buff, data)),
       BPF_MOV64_REG(BPF_REG_3, BPF_REG_4),                    //  2. r3 = r4
       BPF_ALU64_IMM(BPF_ADD, BPF_REG_3, 42),                  //  3. r3 += 42
       BPF_MOV64_IMM(BPF_REG_1, 0),                            //  4. r1 = 0
       BPF_JMP_REG(BPF_JGT, BPF_REG_3, BPF_REG_2, 2),          //  5. if r3 > r2 goto 8
       BPF_ALU64_IMM(BPF_ADD, BPF_REG_4, 14),                  //  6. r4 += 14
       BPF_MOV64_REG(BPF_REG_1, BPF_REG_4),                    //  7. r1 = r4
       BPF_JMP_REG(BPF_JLT, BPF_REG_2, BPF_REG_3, 1),          //  8. if r2 < r3 goto 10
       BPF_LDX_MEM(BPF_H, BPF_REG_2, BPF_REG_1, 9),            //  9. r2 = *(u8 *)(r1 + 9)
       BPF_MOV64_IMM(BPF_REG_0, 0),                            // 10. r0 = 0
       BPF_EXIT_INSN(),                                        // 11. exit
       },
       .result = ACCEPT,
       .prog_type = BPF_PROG_TYPE_SK_SKB,
},