diff --git a/arch/riscv/net/bpf_jit_comp32.c b/arch/riscv/net/bpf_jit_comp32.c
index 11083d4d5f2d5..b198eaa74456b 100644
--- a/arch/riscv/net/bpf_jit_comp32.c
+++ b/arch/riscv/net/bpf_jit_comp32.c
@@ -13,8 +13,35 @@
 #include <linux/filter.h>
 #include "bpf_jit.h"
 
+/*
+ * Stack layout during BPF program execution:
+ *
+ *                     high
+ *     RV32 fp =>  +----------+
+ *                 | saved ra |
+ *                 | saved fp | RV32 callee-saved registers
+ *                 |   ...    |
+ *                 +----------+ <= (fp - 4 * NR_SAVED_REGISTERS)
+ *                 |  hi(R6)  |
+ *                 |  lo(R6)  |
+ *                 |  hi(R7)  | JIT scratch space for BPF registers
+ *                 |  lo(R7)  |
+ *                 |   ...    |
+ *  BPF_REG_FP =>  +----------+ <= (fp - 4 * NR_SAVED_REGISTERS
+ *                 |          |        - 4 * BPF_JIT_SCRATCH_REGS)
+ *                 |          |
+ *                 |   ...    | BPF program stack
+ *                 |          |
+ *     RV32 sp =>  +----------+
+ *                 |          |
+ *                 |   ...    | Function call stack
+ *                 |          |
+ *                 +----------+
+ *                     low
+ */
+
 enum {
-	/* Stack layout - these are offsets from (top of stack - 4). */
+	/* Stack layout - these are offsets from top of JIT scratch space. */
 	BPF_R6_HI,
 	BPF_R6_LO,
 	BPF_R7_HI,
@@ -29,7 +56,11 @@ enum {
 	BPF_JIT_SCRATCH_REGS,
 };
 
-#define STACK_OFFSET(k) (-4 - ((k) * 4))
+/* Number of callee-saved registers stored to stack: ra, fp, s1--s7. */
+#define NR_SAVED_REGISTERS	9
+
+/* Offset from fp for BPF registers stored on stack. */
+#define STACK_OFFSET(k)	(-4 - (4 * NR_SAVED_REGISTERS) - (4 * (k)))
 
 #define TMP_REG_1	(MAX_BPF_JIT_REG + 0)
 #define TMP_REG_2	(MAX_BPF_JIT_REG + 1)
@@ -111,11 +142,9 @@ static void emit_imm64(const s8 *rd, s32 imm_hi, s32 imm_lo,
 
 static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
 {
-	int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 4;
+	int stack_adjust = ctx->stack_size;
 	const s8 *r0 = bpf2rv32[BPF_REG_0];
 
-	store_offset -= 4 * BPF_JIT_SCRATCH_REGS;
-
 	/* Set return value if not tail call. */
 	if (!is_tail_call) {
 		emit(rv_addi(RV_REG_A0, lo(r0), 0), ctx);
@@ -123,15 +152,15 @@ static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
 	}
 
 	/* Restore callee-saved registers. */
-	emit(rv_lw(RV_REG_RA, store_offset - 0, RV_REG_SP), ctx);
-	emit(rv_lw(RV_REG_FP, store_offset - 4, RV_REG_SP), ctx);
-	emit(rv_lw(RV_REG_S1, store_offset - 8, RV_REG_SP), ctx);
-	emit(rv_lw(RV_REG_S2, store_offset - 12, RV_REG_SP), ctx);
-	emit(rv_lw(RV_REG_S3, store_offset - 16, RV_REG_SP), ctx);
-	emit(rv_lw(RV_REG_S4, store_offset - 20, RV_REG_SP), ctx);
-	emit(rv_lw(RV_REG_S5, store_offset - 24, RV_REG_SP), ctx);
-	emit(rv_lw(RV_REG_S6, store_offset - 28, RV_REG_SP), ctx);
-	emit(rv_lw(RV_REG_S7, store_offset - 32, RV_REG_SP), ctx);
+	emit(rv_lw(RV_REG_RA, stack_adjust - 4, RV_REG_SP), ctx);
+	emit(rv_lw(RV_REG_FP, stack_adjust - 8, RV_REG_SP), ctx);
+	emit(rv_lw(RV_REG_S1, stack_adjust - 12, RV_REG_SP), ctx);
+	emit(rv_lw(RV_REG_S2, stack_adjust - 16, RV_REG_SP), ctx);
+	emit(rv_lw(RV_REG_S3, stack_adjust - 20, RV_REG_SP), ctx);
+	emit(rv_lw(RV_REG_S4, stack_adjust - 24, RV_REG_SP), ctx);
+	emit(rv_lw(RV_REG_S5, stack_adjust - 28, RV_REG_SP), ctx);
+	emit(rv_lw(RV_REG_S6, stack_adjust - 32, RV_REG_SP), ctx);
+	emit(rv_lw(RV_REG_S7, stack_adjust - 36, RV_REG_SP), ctx);
 
 	emit(rv_addi(RV_REG_SP, RV_REG_SP, stack_adjust), ctx);
 
@@ -1260,17 +1289,20 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
 
 void bpf_jit_build_prologue(struct rv_jit_context *ctx)
 {
-	/* Make space to save 9 registers: ra, fp, s1--s7. */
-	int stack_adjust = 9 * sizeof(u32), store_offset, bpf_stack_adjust;
 	const s8 *fp = bpf2rv32[BPF_REG_FP];
 	const s8 *r1 = bpf2rv32[BPF_REG_1];
-
-	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
+	int stack_adjust = 0;
+	int bpf_stack_adjust =
+		round_up(ctx->prog->aux->stack_depth, STACK_ALIGN);
+
+	/* Make space for callee-saved registers. */
+	stack_adjust += NR_SAVED_REGISTERS * sizeof(u32);
+	/* Make space for BPF registers on stack. */
+	stack_adjust += BPF_JIT_SCRATCH_REGS * sizeof(u32);
+	/* Make space for BPF stack. */
 	stack_adjust += bpf_stack_adjust;
-
-	store_offset = stack_adjust - 4;
-
-	stack_adjust += 4 * BPF_JIT_SCRATCH_REGS;
+	/* Round up for stack alignment. */
+	stack_adjust = round_up(stack_adjust, STACK_ALIGN);
 
 	/*
 	 * The first instruction sets the tail-call-counter (TCC) register.
@@ -1281,24 +1313,24 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx)
 	emit(rv_addi(RV_REG_SP, RV_REG_SP, -stack_adjust), ctx);
 
 	/* Save callee-save registers. */
-	emit(rv_sw(RV_REG_SP, store_offset - 0, RV_REG_RA), ctx);
-	emit(rv_sw(RV_REG_SP, store_offset - 4, RV_REG_FP), ctx);
-	emit(rv_sw(RV_REG_SP, store_offset - 8, RV_REG_S1), ctx);
-	emit(rv_sw(RV_REG_SP, store_offset - 12, RV_REG_S2), ctx);
-	emit(rv_sw(RV_REG_SP, store_offset - 16, RV_REG_S3), ctx);
-	emit(rv_sw(RV_REG_SP, store_offset - 20, RV_REG_S4), ctx);
-	emit(rv_sw(RV_REG_SP, store_offset - 24, RV_REG_S5), ctx);
-	emit(rv_sw(RV_REG_SP, store_offset - 28, RV_REG_S6), ctx);
-	emit(rv_sw(RV_REG_SP, store_offset - 32, RV_REG_S7), ctx);
+	emit(rv_sw(RV_REG_SP, stack_adjust - 4, RV_REG_RA), ctx);
+	emit(rv_sw(RV_REG_SP, stack_adjust - 8, RV_REG_FP), ctx);
+	emit(rv_sw(RV_REG_SP, stack_adjust - 12, RV_REG_S1), ctx);
+	emit(rv_sw(RV_REG_SP, stack_adjust - 16, RV_REG_S2), ctx);
+	emit(rv_sw(RV_REG_SP, stack_adjust - 20, RV_REG_S3), ctx);
+	emit(rv_sw(RV_REG_SP, stack_adjust - 24, RV_REG_S4), ctx);
+	emit(rv_sw(RV_REG_SP, stack_adjust - 28, RV_REG_S5), ctx);
+	emit(rv_sw(RV_REG_SP, stack_adjust - 32, RV_REG_S6), ctx);
+	emit(rv_sw(RV_REG_SP, stack_adjust - 36, RV_REG_S7), ctx);
 
 	/* Set fp: used as the base address for stacked BPF registers. */
 	emit(rv_addi(RV_REG_FP, RV_REG_SP, stack_adjust), ctx);
 
-	/* Set up BPF stack pointer. */
+	/* Set up BPF frame pointer. */
 	emit(rv_addi(lo(fp), RV_REG_SP, bpf_stack_adjust), ctx);
 	emit(rv_addi(hi(fp), RV_REG_ZERO, 0), ctx);
 
-	/* Set up context pointer. */
+	/* Set up BPF context pointer. */
 	emit(rv_addi(lo(r1), RV_REG_A0, 0), ctx);
 	emit(rv_addi(hi(r1), RV_REG_ZERO, 0), ctx);