diff --git a/arch/x86/include/asm/nospec-branch.h b/arch/x86/include/asm/nospec-branch.h
index 66094a0473a89..249f1c769f21f 100644
--- a/arch/x86/include/asm/nospec-branch.h
+++ b/arch/x86/include/asm/nospec-branch.h
@@ -195,4 +195,41 @@ static inline void vmexit_fill_RSB(void)
 }
 
 #endif /* __ASSEMBLY__ */
+
+/*
+ * Below is used in the eBPF JIT compiler and emits the byte sequence
+ * for the following assembly:
+ *
+ * With retpolines configured:
+ *
+ *    callq do_rop
+ *  spec_trap:
+ *    pause
+ *    lfence
+ *    jmp spec_trap
+ *  do_rop:
+ *    mov %rax,(%rsp)
+ *    retq
+ *
+ * Without retpolines configured:
+ *
+ *    jmp *%rax
+ */
+#ifdef CONFIG_RETPOLINE
+# define RETPOLINE_RAX_BPF_JIT_SIZE	17
+# define RETPOLINE_RAX_BPF_JIT()				\
+	EMIT1_off32(0xE8, 7);	 /* callq do_rop */		\
+	/* spec_trap: */					\
+	EMIT2(0xF3, 0x90);       /* pause */			\
+	EMIT3(0x0F, 0xAE, 0xE8); /* lfence */			\
+	EMIT2(0xEB, 0xF9);       /* jmp spec_trap */		\
+	/* do_rop: */						\
+	EMIT4(0x48, 0x89, 0x04, 0x24); /* mov %rax,(%rsp) */	\
+	EMIT1(0xC3);             /* retq */
+#else
+# define RETPOLINE_RAX_BPF_JIT_SIZE	2
+# define RETPOLINE_RAX_BPF_JIT()				\
+	EMIT2(0xFF, 0xE0);	 /* jmp *%rax */
+#endif
+
 #endif /* _ASM_X86_NOSPEC_BRANCH_H_ */
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index 33c42b8267918..a889211e21c5a 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -12,6 +12,7 @@
 #include <linux/filter.h>
 #include <linux/if_vlan.h>
 #include <asm/cacheflush.h>
+#include <asm/nospec-branch.h>
 #include <linux/bpf.h>
 
 int bpf_jit_enable __read_mostly;
@@ -269,7 +270,7 @@ static void emit_bpf_tail_call(u8 **pprog)
 	EMIT2(0x89, 0xD2);                        /* mov edx, edx */
 	EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
 	      offsetof(struct bpf_array, map.max_entries));
-#define OFFSET1 43 /* number of bytes to jump */
+#define OFFSET1 (41 + RETPOLINE_RAX_BPF_JIT_SIZE) /* number of bytes to jump */
 	EMIT2(X86_JBE, OFFSET1);                  /* jbe out */
 	label1 = cnt;
 
@@ -278,7 +279,7 @@ static void emit_bpf_tail_call(u8 **pprog)
 	 */
 	EMIT2_off32(0x8B, 0x85, -STACKSIZE + 36); /* mov eax, dword ptr [rbp - 516] */
 	EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
-#define OFFSET2 32
+#define OFFSET2 (30 + RETPOLINE_RAX_BPF_JIT_SIZE)
 	EMIT2(X86_JA, OFFSET2);                   /* ja out */
 	label2 = cnt;
 	EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
@@ -292,7 +293,7 @@ static void emit_bpf_tail_call(u8 **pprog)
 	 *   goto out;
 	 */
 	EMIT3(0x48, 0x85, 0xC0);		  /* test rax,rax */
-#define OFFSET3 10
+#define OFFSET3 (8 + RETPOLINE_RAX_BPF_JIT_SIZE)
 	EMIT2(X86_JE, OFFSET3);                   /* je out */
 	label3 = cnt;
 
@@ -305,7 +306,7 @@ static void emit_bpf_tail_call(u8 **pprog)
 	 * rdi == ctx (1st arg)
 	 * rax == prog->bpf_func + prologue_size
 	 */
-	EMIT2(0xFF, 0xE0);                        /* jmp rax */
+	RETPOLINE_RAX_BPF_JIT();
 
 	/* out: */
 	BUILD_BUG_ON(cnt - label1 != OFFSET1);