Skip to content

Commit

Permalink
Merge branch 'optimize-bpf_tail_call'
Browse files Browse the repository at this point in the history
Daniel Borkmann says:

====================
This gets rid of indirect jumps for BPF tail calls whenever possible.
The series adds emission for *direct* jumps for tail call maps in order
to avoid the retpoline overhead from a493a87 ("bpf, x64: implement
retpoline for tail call") for situations that allow for it, meaning,
for known constant keys at verification time which are used as index
into the tail call map. See patch 7/8 for more general details.

Thanks!

v1  -> v2:
  - added more test cases
  - u8 ip_stable -> bool (Andrii)
  - removed bpf_map_poke_{un,}lock and simplified the code (Andrii)
  - added break into prog_array_map_poke_untrack since there's just
    one prog (Andrii)
  - fixed typo: for for in commit msg (Andrii)
  - reworked __bpf_arch_text_poke (Andrii)
  - added subtests, and comment on tests themselves, NULL-NULL
    transistion (Andrii)
  - in constant map key tracking I've moved the map_poke_track callback
    to once we've finished creating the poke tab as otherwise concurrent
    access from tail call map would blow up (since we realloc the table)
rfc -> v1:
  - Applied Alexei's and Andrii's feeback from
    https://lore.kernel.org/bpf/cover.1573779287.git.daniel@iogearbox.net/T/#t
====================

Signed-off-by: Alexei Starovoitov <ast@kernel.org>
  • Loading branch information
Alexei Starovoitov committed Nov 25, 2019
2 parents c4781e3 + 79d49ba commit 6dbae03
Show file tree
Hide file tree
Showing 15 changed files with 1,364 additions and 140 deletions.
264 changes: 192 additions & 72 deletions arch/x86/net/bpf_jit_comp.c
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,9 @@ struct jit_context {
/* Maximum number of bytes emitted while JITing one eBPF insn */
#define BPF_MAX_INSN_SIZE 128
#define BPF_INSN_SAFETY 64
/* number of bytes emit_call() needs to generate call instruction */
#define X86_CALL_SIZE 5

/* Number of bytes emit_patch() needs to generate instructions */
#define X86_PATCH_SIZE 5

#define PROLOGUE_SIZE 25

Expand All @@ -215,7 +216,7 @@ struct jit_context {
static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
{
u8 *prog = *pprog;
int cnt = X86_CALL_SIZE;
int cnt = X86_PATCH_SIZE;

/* BPF trampoline can be made to work without these nops,
* but let's waste 5 bytes for now and optimize later
Expand All @@ -238,6 +239,123 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
*pprog = prog;
}

static int emit_patch(u8 **pprog, void *func, void *ip, u8 opcode)
{
u8 *prog = *pprog;
int cnt = 0;
s64 offset;

offset = func - (ip + X86_PATCH_SIZE);
if (!is_simm32(offset)) {
pr_err("Target call %p is out of range\n", func);
return -ERANGE;
}
EMIT1_off32(opcode, offset);
*pprog = prog;
return 0;
}

static int emit_call(u8 **pprog, void *func, void *ip)
{
return emit_patch(pprog, func, ip, 0xE8);
}

static int emit_jump(u8 **pprog, void *func, void *ip)
{
return emit_patch(pprog, func, ip, 0xE9);
}

static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
void *old_addr, void *new_addr,
const bool text_live)
{
int (*emit_patch_fn)(u8 **pprog, void *func, void *ip);
const u8 *nop_insn = ideal_nops[NOP_ATOMIC5];
u8 old_insn[X86_PATCH_SIZE] = {};
u8 new_insn[X86_PATCH_SIZE] = {};
u8 *prog;
int ret;

switch (t) {
case BPF_MOD_NOP_TO_CALL ... BPF_MOD_CALL_TO_NOP:
emit_patch_fn = emit_call;
break;
case BPF_MOD_NOP_TO_JUMP ... BPF_MOD_JUMP_TO_NOP:
emit_patch_fn = emit_jump;
break;
default:
return -ENOTSUPP;
}

switch (t) {
case BPF_MOD_NOP_TO_CALL:
case BPF_MOD_NOP_TO_JUMP:
if (!old_addr && new_addr) {
memcpy(old_insn, nop_insn, X86_PATCH_SIZE);

prog = new_insn;
ret = emit_patch_fn(&prog, new_addr, ip);
if (ret)
return ret;
break;
}
return -ENXIO;
case BPF_MOD_CALL_TO_CALL:
case BPF_MOD_JUMP_TO_JUMP:
if (old_addr && new_addr) {
prog = old_insn;
ret = emit_patch_fn(&prog, old_addr, ip);
if (ret)
return ret;

prog = new_insn;
ret = emit_patch_fn(&prog, new_addr, ip);
if (ret)
return ret;
break;
}
return -ENXIO;
case BPF_MOD_CALL_TO_NOP:
case BPF_MOD_JUMP_TO_NOP:
if (old_addr && !new_addr) {
memcpy(new_insn, nop_insn, X86_PATCH_SIZE);

prog = old_insn;
ret = emit_patch_fn(&prog, old_addr, ip);
if (ret)
return ret;
break;
}
return -ENXIO;
default:
return -ENOTSUPP;
}

ret = -EBUSY;
mutex_lock(&text_mutex);
if (memcmp(ip, old_insn, X86_PATCH_SIZE))
goto out;
if (text_live)
text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL);
else
memcpy(ip, new_insn, X86_PATCH_SIZE);
ret = 0;
out:
mutex_unlock(&text_mutex);
return ret;
}

int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
void *old_addr, void *new_addr)
{
if (!is_kernel_text((long)ip) &&
!is_bpf_text_address((long)ip))
/* BPF poking in modules is not supported */
return -EINVAL;

return __bpf_arch_text_poke(ip, t, old_addr, new_addr, true);
}

/*
* Generate the following code:
*
Expand All @@ -252,7 +370,7 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
* goto *(prog->bpf_func + prologue_size);
* out:
*/
static void emit_bpf_tail_call(u8 **pprog)
static void emit_bpf_tail_call_indirect(u8 **pprog)
{
u8 *prog = *pprog;
int label1, label2, label3;
Expand Down Expand Up @@ -319,6 +437,69 @@ static void emit_bpf_tail_call(u8 **pprog)
*pprog = prog;
}

static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke,
u8 **pprog, int addr, u8 *image)
{
u8 *prog = *pprog;
int cnt = 0;

/*
* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
* goto out;
*/
EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */
EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */
EMIT2(X86_JA, 14); /* ja out */
EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */
EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */

poke->ip = image + (addr - X86_PATCH_SIZE);
poke->adj_off = PROLOGUE_SIZE;

memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE);
prog += X86_PATCH_SIZE;
/* out: */

*pprog = prog;
}

static void bpf_tail_call_direct_fixup(struct bpf_prog *prog)
{
static const enum bpf_text_poke_type type = BPF_MOD_NOP_TO_JUMP;
struct bpf_jit_poke_descriptor *poke;
struct bpf_array *array;
struct bpf_prog *target;
int i, ret;

for (i = 0; i < prog->aux->size_poke_tab; i++) {
poke = &prog->aux->poke_tab[i];
WARN_ON_ONCE(READ_ONCE(poke->ip_stable));

if (poke->reason != BPF_POKE_REASON_TAIL_CALL)
continue;

array = container_of(poke->tail_call.map, struct bpf_array, map);
mutex_lock(&array->aux->poke_mutex);
target = array->ptrs[poke->tail_call.key];
if (target) {
/* Plain memcpy is used when image is not live yet
* and still not locked as read-only. Once poke
* location is active (poke->ip_stable), any parallel
* bpf_arch_text_poke() might occur still on the
* read-write image until we finally locked it as
* read-only. Both modifications on the given image
* are under text_mutex to avoid interference.
*/
ret = __bpf_arch_text_poke(poke->ip, type, NULL,
(u8 *)target->bpf_func +
poke->adj_off, false);
BUG_ON(ret < 0);
}
WRITE_ONCE(poke->ip_stable, true);
mutex_unlock(&array->aux->poke_mutex);
}
}

static void emit_mov_imm32(u8 **pprog, bool sign_propagate,
u32 dst_reg, const u32 imm32)
{
Expand Down Expand Up @@ -480,72 +661,6 @@ static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
*pprog = prog;
}

static int emit_call(u8 **pprog, void *func, void *ip)
{
u8 *prog = *pprog;
int cnt = 0;
s64 offset;

offset = func - (ip + X86_CALL_SIZE);
if (!is_simm32(offset)) {
pr_err("Target call %p is out of range\n", func);
return -EINVAL;
}
EMIT1_off32(0xE8, offset);
*pprog = prog;
return 0;
}

int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
void *old_addr, void *new_addr)
{
u8 old_insn[X86_CALL_SIZE] = {};
u8 new_insn[X86_CALL_SIZE] = {};
u8 *prog;
int ret;

if (!is_kernel_text((long)ip) &&
!is_bpf_text_address((long)ip))
/* BPF trampoline in modules is not supported */
return -EINVAL;

if (old_addr) {
prog = old_insn;
ret = emit_call(&prog, old_addr, (void *)ip);
if (ret)
return ret;
}
if (new_addr) {
prog = new_insn;
ret = emit_call(&prog, new_addr, (void *)ip);
if (ret)
return ret;
}
ret = -EBUSY;
mutex_lock(&text_mutex);
switch (t) {
case BPF_MOD_NOP_TO_CALL:
if (memcmp(ip, ideal_nops[NOP_ATOMIC5], X86_CALL_SIZE))
goto out;
text_poke_bp(ip, new_insn, X86_CALL_SIZE, NULL);
break;
case BPF_MOD_CALL_TO_CALL:
if (memcmp(ip, old_insn, X86_CALL_SIZE))
goto out;
text_poke_bp(ip, new_insn, X86_CALL_SIZE, NULL);
break;
case BPF_MOD_CALL_TO_NOP:
if (memcmp(ip, old_insn, X86_CALL_SIZE))
goto out;
text_poke_bp(ip, ideal_nops[NOP_ATOMIC5], X86_CALL_SIZE, NULL);
break;
}
ret = 0;
out:
mutex_unlock(&text_mutex);
return ret;
}

static bool ex_handler_bpf(const struct exception_table_entry *x,
struct pt_regs *regs, int trapnr,
unsigned long error_code, unsigned long fault_addr)
Expand Down Expand Up @@ -1013,7 +1128,11 @@ xadd: if (is_imm8(insn->off))
break;

case BPF_JMP | BPF_TAIL_CALL:
emit_bpf_tail_call(&prog);
if (imm32)
emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1],
&prog, addrs[i], image);
else
emit_bpf_tail_call_indirect(&prog);
break;

/* cond jump */
Expand Down Expand Up @@ -1394,7 +1513,7 @@ int arch_prepare_bpf_trampoline(void *image, struct btf_func_model *m, u32 flags
/* skip patched call instruction and point orig_call to actual
* body of the kernel function.
*/
orig_call += X86_CALL_SIZE;
orig_call += X86_PATCH_SIZE;

prog = image;

Expand Down Expand Up @@ -1571,6 +1690,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)

if (image) {
if (!prog->is_func || extra_pass) {
bpf_tail_call_direct_fixup(prog);
bpf_jit_binary_lock_ro(header);
} else {
jit_data->addrs = addrs;
Expand Down
Loading

0 comments on commit 6dbae03

Please sign in to comment.