diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h index e83ef6f6bf43a..306869d4743bc 100644 --- a/include/linux/bpf_verifier.h +++ b/include/linux/bpf_verifier.h @@ -45,7 +45,7 @@ struct bpf_reg_state { enum bpf_reg_type type; union { /* valid when type == PTR_TO_PACKET */ - u16 range; + int range; /* valid when type == CONST_PTR_TO_MAP | PTR_TO_MAP_VALUE | * PTR_TO_MAP_VALUE_OR_NULL diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index 10da26e55130a..7b1f85aa9741b 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -2739,7 +2739,9 @@ static int check_packet_access(struct bpf_verifier_env *env, u32 regno, int off, regno); return -EACCES; } - err = __check_mem_access(env, regno, off, size, reg->range, + + err = reg->range < 0 ? -EINVAL : + __check_mem_access(env, regno, off, size, reg->range, zero_size_allowed); if (err) { verbose(env, "R%d offset is outside of the packet\n", regno); @@ -4697,6 +4699,32 @@ static void clear_all_pkt_pointers(struct bpf_verifier_env *env) __clear_all_pkt_pointers(env, vstate->frame[i]); } +enum { + AT_PKT_END = -1, + BEYOND_PKT_END = -2, +}; + +static void mark_pkt_end(struct bpf_verifier_state *vstate, int regn, bool range_open) +{ + struct bpf_func_state *state = vstate->frame[vstate->curframe]; + struct bpf_reg_state *reg = &state->regs[regn]; + + if (reg->type != PTR_TO_PACKET) + /* PTR_TO_PACKET_META is not supported yet */ + return; + + /* The 'reg' is pkt > pkt_end or pkt >= pkt_end. + * How far beyond pkt_end it goes is unknown. + * if (!range_open) it's the case of pkt >= pkt_end + * if (range_open) it's the case of pkt > pkt_end + * hence this pointer is at least 1 byte bigger than pkt_end + */ + if (range_open) + reg->range = BEYOND_PKT_END; + else + reg->range = AT_PKT_END; +} + static void release_reg_references(struct bpf_verifier_env *env, struct bpf_func_state *state, int ref_obj_id) @@ -6708,7 +6736,7 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn) static void __find_good_pkt_pointers(struct bpf_func_state *state, struct bpf_reg_state *dst_reg, - enum bpf_reg_type type, u16 new_range) + enum bpf_reg_type type, int new_range) { struct bpf_reg_state *reg; int i; @@ -6733,8 +6761,7 @@ static void find_good_pkt_pointers(struct bpf_verifier_state *vstate, enum bpf_reg_type type, bool range_right_open) { - u16 new_range; - int i; + int new_range, i; if (dst_reg->off < 0 || (dst_reg->off == 0 && range_right_open)) @@ -6985,6 +7012,67 @@ static int is_branch_taken(struct bpf_reg_state *reg, u64 val, u8 opcode, return is_branch64_taken(reg, val, opcode); } +static int flip_opcode(u32 opcode) +{ + /* How can we transform "a b" into "b a"? */ + static const u8 opcode_flip[16] = { + /* these stay the same */ + [BPF_JEQ >> 4] = BPF_JEQ, + [BPF_JNE >> 4] = BPF_JNE, + [BPF_JSET >> 4] = BPF_JSET, + /* these swap "lesser" and "greater" (L and G in the opcodes) */ + [BPF_JGE >> 4] = BPF_JLE, + [BPF_JGT >> 4] = BPF_JLT, + [BPF_JLE >> 4] = BPF_JGE, + [BPF_JLT >> 4] = BPF_JGT, + [BPF_JSGE >> 4] = BPF_JSLE, + [BPF_JSGT >> 4] = BPF_JSLT, + [BPF_JSLE >> 4] = BPF_JSGE, + [BPF_JSLT >> 4] = BPF_JSGT + }; + return opcode_flip[opcode >> 4]; +} + +static int is_pkt_ptr_branch_taken(struct bpf_reg_state *dst_reg, + struct bpf_reg_state *src_reg, + u8 opcode) +{ + struct bpf_reg_state *pkt; + + if (src_reg->type == PTR_TO_PACKET_END) { + pkt = dst_reg; + } else if (dst_reg->type == PTR_TO_PACKET_END) { + pkt = src_reg; + opcode = flip_opcode(opcode); + } else { + return -1; + } + + if (pkt->range >= 0) + return -1; + + switch (opcode) { + case BPF_JLE: + /* pkt <= pkt_end */ + fallthrough; + case BPF_JGT: + /* pkt > pkt_end */ + if (pkt->range == BEYOND_PKT_END) + /* pkt has at last one extra byte beyond pkt_end */ + return opcode == BPF_JGT; + break; + case BPF_JLT: + /* pkt < pkt_end */ + fallthrough; + case BPF_JGE: + /* pkt >= pkt_end */ + if (pkt->range == BEYOND_PKT_END || pkt->range == AT_PKT_END) + return opcode == BPF_JGE; + break; + } + return -1; +} + /* Adjusts the register min/max values in the case that the dst_reg is the * variable register that we are working on, and src_reg is a constant or we're * simply doing a BPF_K check. @@ -7148,23 +7236,7 @@ static void reg_set_min_max_inv(struct bpf_reg_state *true_reg, u64 val, u32 val32, u8 opcode, bool is_jmp32) { - /* How can we transform "a b" into "b a"? */ - static const u8 opcode_flip[16] = { - /* these stay the same */ - [BPF_JEQ >> 4] = BPF_JEQ, - [BPF_JNE >> 4] = BPF_JNE, - [BPF_JSET >> 4] = BPF_JSET, - /* these swap "lesser" and "greater" (L and G in the opcodes) */ - [BPF_JGE >> 4] = BPF_JLE, - [BPF_JGT >> 4] = BPF_JLT, - [BPF_JLE >> 4] = BPF_JGE, - [BPF_JLT >> 4] = BPF_JGT, - [BPF_JSGE >> 4] = BPF_JSLE, - [BPF_JSGT >> 4] = BPF_JSLT, - [BPF_JSLE >> 4] = BPF_JSGE, - [BPF_JSLT >> 4] = BPF_JSGT - }; - opcode = opcode_flip[opcode >> 4]; + opcode = flip_opcode(opcode); /* This uses zero as "not present in table"; luckily the zero opcode, * BPF_JA, can't get here. */ @@ -7346,6 +7418,7 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn, /* pkt_data' > pkt_end, pkt_meta' > pkt_data */ find_good_pkt_pointers(this_branch, dst_reg, dst_reg->type, false); + mark_pkt_end(other_branch, insn->dst_reg, true); } else if ((dst_reg->type == PTR_TO_PACKET_END && src_reg->type == PTR_TO_PACKET) || (reg_is_init_pkt_pointer(dst_reg, PTR_TO_PACKET) && @@ -7353,6 +7426,7 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn, /* pkt_end > pkt_data', pkt_data > pkt_meta' */ find_good_pkt_pointers(other_branch, src_reg, src_reg->type, true); + mark_pkt_end(this_branch, insn->src_reg, false); } else { return false; } @@ -7365,6 +7439,7 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn, /* pkt_data' < pkt_end, pkt_meta' < pkt_data */ find_good_pkt_pointers(other_branch, dst_reg, dst_reg->type, true); + mark_pkt_end(this_branch, insn->dst_reg, false); } else if ((dst_reg->type == PTR_TO_PACKET_END && src_reg->type == PTR_TO_PACKET) || (reg_is_init_pkt_pointer(dst_reg, PTR_TO_PACKET) && @@ -7372,6 +7447,7 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn, /* pkt_end < pkt_data', pkt_data > pkt_meta' */ find_good_pkt_pointers(this_branch, src_reg, src_reg->type, false); + mark_pkt_end(other_branch, insn->src_reg, true); } else { return false; } @@ -7384,6 +7460,7 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn, /* pkt_data' >= pkt_end, pkt_meta' >= pkt_data */ find_good_pkt_pointers(this_branch, dst_reg, dst_reg->type, true); + mark_pkt_end(other_branch, insn->dst_reg, false); } else if ((dst_reg->type == PTR_TO_PACKET_END && src_reg->type == PTR_TO_PACKET) || (reg_is_init_pkt_pointer(dst_reg, PTR_TO_PACKET) && @@ -7391,6 +7468,7 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn, /* pkt_end >= pkt_data', pkt_data >= pkt_meta' */ find_good_pkt_pointers(other_branch, src_reg, src_reg->type, false); + mark_pkt_end(this_branch, insn->src_reg, true); } else { return false; } @@ -7403,6 +7481,7 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn, /* pkt_data' <= pkt_end, pkt_meta' <= pkt_data */ find_good_pkt_pointers(other_branch, dst_reg, dst_reg->type, false); + mark_pkt_end(this_branch, insn->dst_reg, true); } else if ((dst_reg->type == PTR_TO_PACKET_END && src_reg->type == PTR_TO_PACKET) || (reg_is_init_pkt_pointer(dst_reg, PTR_TO_PACKET) && @@ -7410,6 +7489,7 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn, /* pkt_end <= pkt_data', pkt_data <= pkt_meta' */ find_good_pkt_pointers(this_branch, src_reg, src_reg->type, true); + mark_pkt_end(other_branch, insn->src_reg, false); } else { return false; } @@ -7509,6 +7589,10 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env, src_reg->var_off.value, opcode, is_jmp32); + } else if (reg_is_pkt_pointer_any(dst_reg) && + reg_is_pkt_pointer_any(src_reg) && + !is_jmp32) { + pred = is_pkt_ptr_branch_taken(dst_reg, src_reg, opcode); } if (pred >= 0) { @@ -7517,7 +7601,8 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env, */ if (!__is_pointer_value(false, dst_reg)) err = mark_chain_precision(env, insn->dst_reg); - if (BPF_SRC(insn->code) == BPF_X && !err) + if (BPF_SRC(insn->code) == BPF_X && !err && + !__is_pointer_value(false, src_reg)) err = mark_chain_precision(env, insn->src_reg); if (err) return err; diff --git a/tools/testing/selftests/bpf/prog_tests/test_skb_pkt_end.c b/tools/testing/selftests/bpf/prog_tests/test_skb_pkt_end.c new file mode 100644 index 0000000000000..cf1215531920f --- /dev/null +++ b/tools/testing/selftests/bpf/prog_tests/test_skb_pkt_end.c @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2020 Facebook */ +#include +#include +#include "skb_pkt_end.skel.h" + +static int sanity_run(struct bpf_program *prog) +{ + __u32 duration, retval; + int err, prog_fd; + + prog_fd = bpf_program__fd(prog); + err = bpf_prog_test_run(prog_fd, 1, &pkt_v4, sizeof(pkt_v4), + NULL, NULL, &retval, &duration); + if (CHECK(err || retval != 123, "test_run", + "err %d errno %d retval %d duration %d\n", + err, errno, retval, duration)) + return -1; + return 0; +} + +void test_test_skb_pkt_end(void) +{ + struct skb_pkt_end *skb_pkt_end_skel = NULL; + __u32 duration = 0; + int err; + + skb_pkt_end_skel = skb_pkt_end__open_and_load(); + if (CHECK(!skb_pkt_end_skel, "skb_pkt_end_skel_load", "skb_pkt_end skeleton failed\n")) + goto cleanup; + + err = skb_pkt_end__attach(skb_pkt_end_skel); + if (CHECK(err, "skb_pkt_end_attach", "skb_pkt_end attach failed: %d\n", err)) + goto cleanup; + + if (sanity_run(skb_pkt_end_skel->progs.main_prog)) + goto cleanup; + +cleanup: + skb_pkt_end__destroy(skb_pkt_end_skel); +} diff --git a/tools/testing/selftests/bpf/progs/skb_pkt_end.c b/tools/testing/selftests/bpf/progs/skb_pkt_end.c new file mode 100644 index 0000000000000..cf6823f42e80b --- /dev/null +++ b/tools/testing/selftests/bpf/progs/skb_pkt_end.c @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: GPL-2.0 +#define BPF_NO_PRESERVE_ACCESS_INDEX +#include +#include +#include + +#define NULL 0 +#define INLINE __always_inline + +#define skb_shorter(skb, len) ((void *)(long)(skb)->data + (len) > (void *)(long)skb->data_end) + +#define ETH_IPV4_TCP_SIZE (14 + sizeof(struct iphdr) + sizeof(struct tcphdr)) + +static INLINE struct iphdr *get_iphdr(struct __sk_buff *skb) +{ + struct iphdr *ip = NULL; + struct ethhdr *eth; + + if (skb_shorter(skb, ETH_IPV4_TCP_SIZE)) + goto out; + + eth = (void *)(long)skb->data; + ip = (void *)(eth + 1); + +out: + return ip; +} + +SEC("classifier/cls") +int main_prog(struct __sk_buff *skb) +{ + struct iphdr *ip = NULL; + struct tcphdr *tcp; + __u8 proto = 0; + + if (!(ip = get_iphdr(skb))) + goto out; + + proto = ip->protocol; + + if (proto != IPPROTO_TCP) + goto out; + + tcp = (void*)(ip + 1); + if (tcp->dest != 0) + goto out; + if (!tcp) + goto out; + + return tcp->urg_ptr; +out: + return -1; +} +char _license[] SEC("license") = "GPL"; diff --git a/tools/testing/selftests/bpf/verifier/ctx_skb.c b/tools/testing/selftests/bpf/verifier/ctx_skb.c index 2e16b8e268f2f..2022c0f2cd759 100644 --- a/tools/testing/selftests/bpf/verifier/ctx_skb.c +++ b/tools/testing/selftests/bpf/verifier/ctx_skb.c @@ -1089,3 +1089,45 @@ .errstr_unpriv = "R1 leaks addr", .result = REJECT, }, +{ + "pkt > pkt_end taken check", + .insns = { + BPF_LDX_MEM(BPF_W, BPF_REG_2, BPF_REG_1, // 0. r2 = *(u32 *)(r1 + data_end) + offsetof(struct __sk_buff, data_end)), + BPF_LDX_MEM(BPF_W, BPF_REG_4, BPF_REG_1, // 1. r4 = *(u32 *)(r1 + data) + offsetof(struct __sk_buff, data)), + BPF_MOV64_REG(BPF_REG_3, BPF_REG_4), // 2. r3 = r4 + BPF_ALU64_IMM(BPF_ADD, BPF_REG_3, 42), // 3. r3 += 42 + BPF_MOV64_IMM(BPF_REG_1, 0), // 4. r1 = 0 + BPF_JMP_REG(BPF_JGT, BPF_REG_3, BPF_REG_2, 2), // 5. if r3 > r2 goto 8 + BPF_ALU64_IMM(BPF_ADD, BPF_REG_4, 14), // 6. r4 += 14 + BPF_MOV64_REG(BPF_REG_1, BPF_REG_4), // 7. r1 = r4 + BPF_JMP_REG(BPF_JGT, BPF_REG_3, BPF_REG_2, 1), // 8. if r3 > r2 goto 10 + BPF_LDX_MEM(BPF_H, BPF_REG_2, BPF_REG_1, 9), // 9. r2 = *(u8 *)(r1 + 9) + BPF_MOV64_IMM(BPF_REG_0, 0), // 10. r0 = 0 + BPF_EXIT_INSN(), // 11. exit + }, + .result = ACCEPT, + .prog_type = BPF_PROG_TYPE_SK_SKB, +}, +{ + "pkt_end < pkt taken check", + .insns = { + BPF_LDX_MEM(BPF_W, BPF_REG_2, BPF_REG_1, // 0. r2 = *(u32 *)(r1 + data_end) + offsetof(struct __sk_buff, data_end)), + BPF_LDX_MEM(BPF_W, BPF_REG_4, BPF_REG_1, // 1. r4 = *(u32 *)(r1 + data) + offsetof(struct __sk_buff, data)), + BPF_MOV64_REG(BPF_REG_3, BPF_REG_4), // 2. r3 = r4 + BPF_ALU64_IMM(BPF_ADD, BPF_REG_3, 42), // 3. r3 += 42 + BPF_MOV64_IMM(BPF_REG_1, 0), // 4. r1 = 0 + BPF_JMP_REG(BPF_JGT, BPF_REG_3, BPF_REG_2, 2), // 5. if r3 > r2 goto 8 + BPF_ALU64_IMM(BPF_ADD, BPF_REG_4, 14), // 6. r4 += 14 + BPF_MOV64_REG(BPF_REG_1, BPF_REG_4), // 7. r1 = r4 + BPF_JMP_REG(BPF_JLT, BPF_REG_2, BPF_REG_3, 1), // 8. if r2 < r3 goto 10 + BPF_LDX_MEM(BPF_H, BPF_REG_2, BPF_REG_1, 9), // 9. r2 = *(u8 *)(r1 + 9) + BPF_MOV64_IMM(BPF_REG_0, 0), // 10. r0 = 0 + BPF_EXIT_INSN(), // 11. exit + }, + .result = ACCEPT, + .prog_type = BPF_PROG_TYPE_SK_SKB, +},