Skip to content

Commit

Permalink
riscv, bpf: Adapt bpf trampoline to optimized riscv ftrace framework
Browse files Browse the repository at this point in the history
Commit 6724a76 ("riscv: ftrace: Reduce the detour code size to
half") optimizes the detour code size of kernel functions to half with
T0 register and the upcoming DYNAMIC_FTRACE_WITH_DIRECT_CALLS of riscv
is based on this optimization, we need to adapt riscv bpf trampoline
based on this. One thing to do is to reduce detour code size of bpf
programs, and the second is to deal with the return address after the
execution of bpf trampoline. Meanwhile, we need to construct the frame
of parent function, otherwise we will miss one layer when unwinding.
The related tests have passed.

Signed-off-by: Pu Lehui <pulehui@huawei.com>
Tested-by: Björn Töpel <bjorn@rivosinc.com>
Link: https://lore.kernel.org/r/20230721100627.2630326-1-pulehui@huaweicloud.com
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
  • Loading branch information
Pu Lehui authored and Alexei Starovoitov committed Aug 2, 2023
1 parent 94e38c9 commit 25ad106
Showing 1 changed file with 82 additions and 71 deletions.
153 changes: 82 additions & 71 deletions arch/riscv/net/bpf_jit_comp64.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include <asm/patch.h>
#include "bpf_jit.h"

#define RV_FENTRY_NINSNS 2

#define RV_REG_TCC RV_REG_A6
#define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */

Expand Down Expand Up @@ -241,7 +243,7 @@ static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
if (!is_tail_call)
emit_mv(RV_REG_A0, RV_REG_A5, ctx);
emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
is_tail_call ? 20 : 0, /* skip reserved nops and TCC init */
is_tail_call ? (RV_FENTRY_NINSNS + 1) * 4 : 0, /* skip reserved nops and TCC init */
ctx);
}

Expand Down Expand Up @@ -618,32 +620,7 @@ static int add_exception_handler(const struct bpf_insn *insn,
return 0;
}

static int gen_call_or_nops(void *target, void *ip, u32 *insns)
{
s64 rvoff;
int i, ret;
struct rv_jit_context ctx;

ctx.ninsns = 0;
ctx.insns = (u16 *)insns;

if (!target) {
for (i = 0; i < 4; i++)
emit(rv_nop(), &ctx);
return 0;
}

rvoff = (s64)(target - (ip + 4));
emit(rv_sd(RV_REG_SP, -8, RV_REG_RA), &ctx);
ret = emit_jump_and_link(RV_REG_RA, rvoff, false, &ctx);
if (ret)
return ret;
emit(rv_ld(RV_REG_RA, -8, RV_REG_SP), &ctx);

return 0;
}

static int gen_jump_or_nops(void *target, void *ip, u32 *insns)
static int gen_jump_or_nops(void *target, void *ip, u32 *insns, bool is_call)
{
s64 rvoff;
struct rv_jit_context ctx;
Expand All @@ -658,38 +635,35 @@ static int gen_jump_or_nops(void *target, void *ip, u32 *insns)
}

rvoff = (s64)(target - ip);
return emit_jump_and_link(RV_REG_ZERO, rvoff, false, &ctx);
return emit_jump_and_link(is_call ? RV_REG_T0 : RV_REG_ZERO, rvoff, false, &ctx);
}

int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
void *old_addr, void *new_addr)
{
u32 old_insns[4], new_insns[4];
u32 old_insns[RV_FENTRY_NINSNS], new_insns[RV_FENTRY_NINSNS];
bool is_call = poke_type == BPF_MOD_CALL;
int (*gen_insns)(void *target, void *ip, u32 *insns);
int ninsns = is_call ? 4 : 2;
int ret;

if (!is_bpf_text_address((unsigned long)ip))
if (!is_kernel_text((unsigned long)ip) &&
!is_bpf_text_address((unsigned long)ip))
return -ENOTSUPP;

gen_insns = is_call ? gen_call_or_nops : gen_jump_or_nops;

ret = gen_insns(old_addr, ip, old_insns);
ret = gen_jump_or_nops(old_addr, ip, old_insns, is_call);
if (ret)
return ret;

if (memcmp(ip, old_insns, ninsns * 4))
if (memcmp(ip, old_insns, RV_FENTRY_NINSNS * 4))
return -EFAULT;

ret = gen_insns(new_addr, ip, new_insns);
ret = gen_jump_or_nops(new_addr, ip, new_insns, is_call);
if (ret)
return ret;

cpus_read_lock();
mutex_lock(&text_mutex);
if (memcmp(ip, new_insns, ninsns * 4))
ret = patch_text(ip, new_insns, ninsns);
if (memcmp(ip, new_insns, RV_FENTRY_NINSNS * 4))
ret = patch_text(ip, new_insns, RV_FENTRY_NINSNS);
mutex_unlock(&text_mutex);
cpus_read_unlock();

Expand Down Expand Up @@ -787,22 +761,35 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
int i, ret, offset;
int *branches_off = NULL;
int stack_size = 0, nregs = m->nr_args;
int retaddr_off, fp_off, retval_off, args_off;
int nregs_off, ip_off, run_ctx_off, sreg_off;
int retval_off, args_off, nregs_off, ip_off, run_ctx_off, sreg_off;
struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
void *orig_call = func_addr;
bool save_ret;
u32 insn;

/* Generated trampoline stack layout:
/* Two types of generated trampoline stack layout:
*
* 1. trampoline called from function entry
* --------------------------------------
* FP + 8 [ RA to parent func ] return address to parent
* function
* FP + 0 [ FP of parent func ] frame pointer of parent
* function
* FP - 8 [ T0 to traced func ] return address of traced
* function
* FP - 16 [ FP of traced func ] frame pointer of traced
* function
* --------------------------------------
*
* FP - 8 [ RA of parent func ] return address of parent
* 2. trampoline called directly
* --------------------------------------
* FP - 8 [ RA to caller func ] return address to caller
* function
* FP - retaddr_off [ RA of traced func ] return address of traced
* FP - 16 [ FP of caller func ] frame pointer of caller
* function
* FP - fp_off [ FP of parent func ]
* --------------------------------------
*
* FP - retval_off [ return value ] BPF_TRAMP_F_CALL_ORIG or
* BPF_TRAMP_F_RET_FENTRY_RET
Expand Down Expand Up @@ -833,14 +820,8 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
if (nregs > 8)
return -ENOTSUPP;

/* room for parent function return address */
stack_size += 8;

stack_size += 8;
retaddr_off = stack_size;

stack_size += 8;
fp_off = stack_size;
/* room of trampoline frame to store return address and frame pointer */
stack_size += 16;

save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
if (save_ret) {
Expand All @@ -867,12 +848,29 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,

stack_size = round_up(stack_size, 16);

emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);

emit_sd(RV_REG_SP, stack_size - retaddr_off, RV_REG_RA, ctx);
emit_sd(RV_REG_SP, stack_size - fp_off, RV_REG_FP, ctx);

emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
if (func_addr) {
/* For the trampoline called from function entry,
* the frame of traced function and the frame of
* trampoline need to be considered.
*/
emit_addi(RV_REG_SP, RV_REG_SP, -16, ctx);
emit_sd(RV_REG_SP, 8, RV_REG_RA, ctx);
emit_sd(RV_REG_SP, 0, RV_REG_FP, ctx);
emit_addi(RV_REG_FP, RV_REG_SP, 16, ctx);

emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
emit_sd(RV_REG_SP, stack_size - 8, RV_REG_T0, ctx);
emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
} else {
/* For the trampoline called directly, just handle
* the frame of trampoline.
*/
emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
emit_sd(RV_REG_SP, stack_size - 8, RV_REG_RA, ctx);
emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
}

/* callee saved register S1 to pass start time */
emit_sd(RV_REG_FP, -sreg_off, RV_REG_S1, ctx);
Expand All @@ -890,7 +888,7 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,

/* skip to actual body of traced function */
if (flags & BPF_TRAMP_F_SKIP_FRAME)
orig_call += 16;
orig_call += RV_FENTRY_NINSNS * 4;

if (flags & BPF_TRAMP_F_CALL_ORIG) {
emit_imm(RV_REG_A0, (const s64)im, ctx);
Expand Down Expand Up @@ -967,17 +965,30 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,

emit_ld(RV_REG_S1, -sreg_off, RV_REG_FP, ctx);

if (flags & BPF_TRAMP_F_SKIP_FRAME)
/* return address of parent function */
emit_ld(RV_REG_RA, stack_size - 8, RV_REG_SP, ctx);
else
/* return address of traced function */
emit_ld(RV_REG_RA, stack_size - retaddr_off, RV_REG_SP, ctx);
if (func_addr) {
/* trampoline called from function entry */
emit_ld(RV_REG_T0, stack_size - 8, RV_REG_SP, ctx);
emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);

emit_ld(RV_REG_FP, stack_size - fp_off, RV_REG_SP, ctx);
emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
emit_ld(RV_REG_RA, 8, RV_REG_SP, ctx);
emit_ld(RV_REG_FP, 0, RV_REG_SP, ctx);
emit_addi(RV_REG_SP, RV_REG_SP, 16, ctx);

emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
if (flags & BPF_TRAMP_F_SKIP_FRAME)
/* return to parent function */
emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
else
/* return to traced function */
emit_jalr(RV_REG_ZERO, RV_REG_T0, 0, ctx);
} else {
/* trampoline called directly */
emit_ld(RV_REG_RA, stack_size - 8, RV_REG_SP, ctx);
emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);

emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
}

ret = ctx->ninsns;
out:
Expand Down Expand Up @@ -1691,8 +1702,8 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx)

store_offset = stack_adjust - 8;

/* reserve 4 nop insns */
for (i = 0; i < 4; i++)
/* nops reserved for auipc+jalr pair */
for (i = 0; i < RV_FENTRY_NINSNS; i++)
emit(rv_nop(), ctx);

/* First instruction is always setting the tail-call-counter
Expand Down

0 comments on commit 25ad106

Please sign in to comment.