Commit 4961d8f4 authored by Alexei Starovoitov's avatar Alexei Starovoitov
Browse files

Merge branch 'bpf-arm64-simplify-jited-prologue-epilogue'



Xu Kuohai says:

====================
bpf, arm64: Simplify jited prologue/epilogue

From: Xu Kuohai <xukuohai@huawei.com>

The arm64 jit blindly saves/restores all callee-saved registers, making
the jited result looks a bit too compliated. For example, for an empty
prog, the jited result is:

   0:   bti jc
   4:   mov     x9, lr
   8:   nop
   c:   paciasp
  10:   stp     fp, lr, [sp, #-16]!
  14:   mov     fp, sp
  18:   stp     x19, x20, [sp, #-16]!
  1c:   stp     x21, x22, [sp, #-16]!
  20:   stp     x26, x25, [sp, #-16]!
  24:   mov     x26, #0
  28:   stp     x26, x25, [sp, #-16]!
  2c:   mov     x26, sp
  30:   stp     x27, x28, [sp, #-16]!
  34:   mov     x25, sp
  38:   bti j 		// tailcall target
  3c:   sub     sp, sp, #0
  40:   mov     x7, #0
  44:   add     sp, sp, #0
  48:   ldp     x27, x28, [sp], #16
  4c:   ldp     x26, x25, [sp], #16
  50:   ldp     x26, x25, [sp], #16
  54:   ldp     x21, x22, [sp], #16
  58:   ldp     x19, x20, [sp], #16
  5c:   ldp     fp, lr, [sp], #16
  60:   mov     x0, x7
  64:   autiasp
  68:   ret

Clearly, there is no need to save/restore unused callee-saved registers.
This patch does this change, making the jited image to only save/restore
the callee-saved registers it uses.

Now the jited result of empty prog is:

   0:   bti jc
   4:   mov     x9, lr
   8:   nop
   c:   paciasp
  10:   stp     fp, lr, [sp, #-16]!
  14:   mov     fp, sp
  18:   stp     xzr, x26, [sp, #-16]!
  1c:   mov     x26, sp
  20:   bti j		// tailcall target
  24:   mov     x7, #0
  28:   ldp     xzr, x26, [sp], #16
  2c:   ldp     fp, lr, [sp], #16
  30:   mov     x0, x7
  34:   autiasp
  38:   ret
====================

Acked-by: default avatarPuranjay Mohan <puranjay@kernel.org>
Link: https://lore.kernel.org/r/20240826071624.350108-1-xukuohai@huaweicloud.com


Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parents d205d4af 5d4fa9ec
Loading
Loading
Loading
Loading
+192 −202
Original line number Diff line number Diff line
@@ -28,7 +28,6 @@
#define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
#define TCCNT_PTR (MAX_BPF_JIT_REG + 2)
#define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
#define FP_BOTTOM (MAX_BPF_JIT_REG + 4)
#define ARENA_VM_START (MAX_BPF_JIT_REG + 5)

#define check_imm(bits, imm) do {				\
@@ -67,7 +66,6 @@ static const int bpf2a64[] = {
	[TCCNT_PTR] = A64_R(26),
	/* temporary register for blinding constants */
	[BPF_REG_AX] = A64_R(9),
	[FP_BOTTOM] = A64_R(27),
	/* callee saved register for kern_vm_start address */
	[ARENA_VM_START] = A64_R(28),
};
@@ -78,11 +76,14 @@ struct jit_ctx {
	int epilogue_offset;
	int *offset;
	int exentry_idx;
	int nr_used_callee_reg;
	u8 used_callee_reg[8]; /* r6~r9, fp, arena_vm_start */
	__le32 *image;
	__le32 *ro_image;
	u32 stack_size;
	int fpb_offset;
	u64 user_vm_start;
	u64 arena_vm_start;
	bool fp_used;
};

struct bpf_plt {
@@ -273,41 +274,141 @@ static bool is_lsi_offset(int offset, int scale)
	return true;
}

/* generated prologue:
/* generated main prog prologue:
 *      bti c // if CONFIG_ARM64_BTI_KERNEL
 *      mov x9, lr
 *      nop  // POKE_OFFSET
 *      paciasp // if CONFIG_ARM64_PTR_AUTH_KERNEL
 *      stp x29, lr, [sp, #-16]!
 *      mov x29, sp
 *      stp x19, x20, [sp, #-16]!
 *      stp x21, x22, [sp, #-16]!
 *      stp x26, x25, [sp, #-16]!
 *      stp x26, x25, [sp, #-16]!
 *      stp x27, x28, [sp, #-16]!
 *      mov x25, sp
 *      mov tcc, #0
 *      stp xzr, x26, [sp, #-16]!
 *      mov x26, sp
 *      // PROLOGUE_OFFSET
 *	// save callee-saved registers
 */

static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx)
{
	const struct bpf_prog *prog = ctx->prog;
	const bool is_main_prog = !bpf_is_subprog(prog);
	const bool is_main_prog = !bpf_is_subprog(ctx->prog);
	const u8 ptr = bpf2a64[TCCNT_PTR];
	const u8 fp = bpf2a64[BPF_REG_FP];
	const u8 tcc = ptr;

	emit(A64_PUSH(ptr, fp, A64_SP), ctx);
	if (is_main_prog) {
		/* Initialize tail_call_cnt. */
		emit(A64_MOVZ(1, tcc, 0, 0), ctx);
		emit(A64_PUSH(tcc, fp, A64_SP), ctx);
		emit(A64_PUSH(A64_ZR, ptr, A64_SP), ctx);
		emit(A64_MOV(1, ptr, A64_SP), ctx);
	} else
		emit(A64_PUSH(ptr, ptr, A64_SP), ctx);
}

static void find_used_callee_regs(struct jit_ctx *ctx)
{
	int i;
	const struct bpf_prog *prog = ctx->prog;
	const struct bpf_insn *insn = &prog->insnsi[0];
	int reg_used = 0;

	for (i = 0; i < prog->len; i++, insn++) {
		if (insn->dst_reg == BPF_REG_6 || insn->src_reg == BPF_REG_6)
			reg_used |= 1;

		if (insn->dst_reg == BPF_REG_7 || insn->src_reg == BPF_REG_7)
			reg_used |= 2;

		if (insn->dst_reg == BPF_REG_8 || insn->src_reg == BPF_REG_8)
			reg_used |= 4;

		if (insn->dst_reg == BPF_REG_9 || insn->src_reg == BPF_REG_9)
			reg_used |= 8;

		if (insn->dst_reg == BPF_REG_FP || insn->src_reg == BPF_REG_FP) {
			ctx->fp_used = true;
			reg_used |= 16;
		}
	}

	i = 0;
	if (reg_used & 1)
		ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_6];

	if (reg_used & 2)
		ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_7];

	if (reg_used & 4)
		ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_8];

	if (reg_used & 8)
		ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_9];

	if (reg_used & 16)
		ctx->used_callee_reg[i++] = bpf2a64[BPF_REG_FP];

	if (ctx->arena_vm_start)
		ctx->used_callee_reg[i++] = bpf2a64[ARENA_VM_START];

	ctx->nr_used_callee_reg = i;
}

/* Save callee-saved registers */
static void push_callee_regs(struct jit_ctx *ctx)
{
	int reg1, reg2, i;

	/*
	 * Program acting as exception boundary should save all ARM64
	 * Callee-saved registers as the exception callback needs to recover
	 * all ARM64 Callee-saved registers in its epilogue.
	 */
	if (ctx->prog->aux->exception_boundary) {
		emit(A64_PUSH(A64_R(19), A64_R(20), A64_SP), ctx);
		emit(A64_PUSH(A64_R(21), A64_R(22), A64_SP), ctx);
		emit(A64_PUSH(A64_R(23), A64_R(24), A64_SP), ctx);
		emit(A64_PUSH(A64_R(25), A64_R(26), A64_SP), ctx);
		emit(A64_PUSH(A64_R(27), A64_R(28), A64_SP), ctx);
	} else {
		emit(A64_PUSH(ptr, fp, A64_SP), ctx);
		emit(A64_NOP, ctx);
		emit(A64_NOP, ctx);
		find_used_callee_regs(ctx);
		for (i = 0; i + 1 < ctx->nr_used_callee_reg; i += 2) {
			reg1 = ctx->used_callee_reg[i];
			reg2 = ctx->used_callee_reg[i + 1];
			emit(A64_PUSH(reg1, reg2, A64_SP), ctx);
		}
		if (i < ctx->nr_used_callee_reg) {
			reg1 = ctx->used_callee_reg[i];
			/* keep SP 16-byte aligned */
			emit(A64_PUSH(reg1, A64_ZR, A64_SP), ctx);
		}
	}
}

/* Restore callee-saved registers */
static void pop_callee_regs(struct jit_ctx *ctx)
{
	struct bpf_prog_aux *aux = ctx->prog->aux;
	int reg1, reg2, i;

	/*
	 * Program acting as exception boundary pushes R23 and R24 in addition
	 * to BPF callee-saved registers. Exception callback uses the boundary
	 * program's stack frame, so recover these extra registers in the above
	 * two cases.
	 */
	if (aux->exception_boundary || aux->exception_cb) {
		emit(A64_POP(A64_R(27), A64_R(28), A64_SP), ctx);
		emit(A64_POP(A64_R(25), A64_R(26), A64_SP), ctx);
		emit(A64_POP(A64_R(23), A64_R(24), A64_SP), ctx);
		emit(A64_POP(A64_R(21), A64_R(22), A64_SP), ctx);
		emit(A64_POP(A64_R(19), A64_R(20), A64_SP), ctx);
	} else {
		i = ctx->nr_used_callee_reg - 1;
		if (ctx->nr_used_callee_reg % 2 != 0) {
			reg1 = ctx->used_callee_reg[i];
			emit(A64_POP(reg1, A64_ZR, A64_SP), ctx);
			i--;
		}
		while (i > 0) {
			reg1 = ctx->used_callee_reg[i - 1];
			reg2 = ctx->used_callee_reg[i];
			emit(A64_POP(reg1, reg2, A64_SP), ctx);
			i -= 2;
		}
	}
}

@@ -318,19 +419,13 @@ static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx)
#define POKE_OFFSET (BTI_INSNS + 1)

/* Tail call offset to jump into */
#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 10)
#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 4)

static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
			  bool is_exception_cb, u64 arena_vm_start)
static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
{
	const struct bpf_prog *prog = ctx->prog;
	const bool is_main_prog = !bpf_is_subprog(prog);
	const u8 r6 = bpf2a64[BPF_REG_6];
	const u8 r7 = bpf2a64[BPF_REG_7];
	const u8 r8 = bpf2a64[BPF_REG_8];
	const u8 r9 = bpf2a64[BPF_REG_9];
	const u8 fp = bpf2a64[BPF_REG_FP];
	const u8 fpb = bpf2a64[FP_BOTTOM];
	const u8 arena_vm_base = bpf2a64[ARENA_VM_START];
	const int idx0 = ctx->idx;
	int cur_offset;
@@ -369,19 +464,28 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
	emit(A64_MOV(1, A64_R(9), A64_LR), ctx);
	emit(A64_NOP, ctx);

	if (!is_exception_cb) {
	if (!prog->aux->exception_cb) {
		/* Sign lr */
		if (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL))
			emit(A64_PACIASP, ctx);

		/* Save FP and LR registers to stay align with ARM64 AAPCS */
		emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
		emit(A64_MOV(1, A64_FP, A64_SP), ctx);

		/* Save callee-saved registers */
		emit(A64_PUSH(r6, r7, A64_SP), ctx);
		emit(A64_PUSH(r8, r9, A64_SP), ctx);
		prepare_bpf_tail_call_cnt(ctx);
		emit(A64_PUSH(fpb, A64_R(28), A64_SP), ctx);

		if (!ebpf_from_cbpf && is_main_prog) {
			cur_offset = ctx->idx - idx0;
			if (cur_offset != PROLOGUE_OFFSET) {
				pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
						cur_offset, PROLOGUE_OFFSET);
				return -1;
			}
			/* BTI landing pad for the tail call, done with a BR */
			emit_bti(A64_BTI_J, ctx);
		}
		push_callee_regs(ctx);
	} else {
		/*
		 * Exception callback receives FP of Main Program as third
@@ -398,50 +502,23 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
		emit(A64_SUB_I(1, A64_SP, A64_FP, 96), ctx);
	}

	if (ctx->fp_used)
		/* Set up BPF prog stack base register */
		emit(A64_MOV(1, fp, A64_SP), ctx);

	if (!ebpf_from_cbpf && is_main_prog) {
		cur_offset = ctx->idx - idx0;
		if (cur_offset != PROLOGUE_OFFSET) {
			pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
				    cur_offset, PROLOGUE_OFFSET);
			return -1;
		}

		/* BTI landing pad for the tail call, done with a BR */
		emit_bti(A64_BTI_J, ctx);
	}

	/*
	 * Program acting as exception boundary should save all ARM64
	 * Callee-saved registers as the exception callback needs to recover
	 * all ARM64 Callee-saved registers in its epilogue.
	 */
	if (prog->aux->exception_boundary) {
		/*
		 * As we are pushing two more registers, BPF_FP should be moved
		 * 16 bytes
		 */
		emit(A64_SUB_I(1, fp, fp, 16), ctx);
		emit(A64_PUSH(A64_R(23), A64_R(24), A64_SP), ctx);
	}

	emit(A64_SUB_I(1, fpb, fp, ctx->fpb_offset), ctx);

	/* Stack must be multiples of 16B */
	ctx->stack_size = round_up(prog->aux->stack_depth, 16);

	/* Set up function call stack */
	if (ctx->stack_size)
		emit(A64_SUB_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);

	if (arena_vm_start)
		emit_a64_mov_i64(arena_vm_base, arena_vm_start, ctx);
	if (ctx->arena_vm_start)
		emit_a64_mov_i64(arena_vm_base, ctx->arena_vm_start, ctx);

	return 0;
}

static int out_offset = -1; /* initialized on the first pass of build_body() */
static int emit_bpf_tail_call(struct jit_ctx *ctx)
{
	/* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
@@ -452,10 +529,10 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
	const u8 prg = bpf2a64[TMP_REG_2];
	const u8 tcc = bpf2a64[TMP_REG_3];
	const u8 ptr = bpf2a64[TCCNT_PTR];
	const int idx0 = ctx->idx;
#define cur_offset (ctx->idx - idx0)
#define jmp_offset (out_offset - (cur_offset))
	size_t off;
	__le32 *branch1 = NULL;
	__le32 *branch2 = NULL;
	__le32 *branch3 = NULL;

	/* if (index >= array->map.max_entries)
	 *     goto out;
@@ -465,17 +542,20 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
	emit(A64_LDR32(tmp, r2, tmp), ctx);
	emit(A64_MOV(0, r3, r3), ctx);
	emit(A64_CMP(0, r3, tmp), ctx);
	emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
	branch1 = ctx->image + ctx->idx;
	emit(A64_NOP, ctx);

	/*
	 * if ((*tail_call_cnt_ptr) >= MAX_TAIL_CALL_CNT)
	 *     goto out;
	 * (*tail_call_cnt_ptr)++;
	 */
	emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
	emit(A64_LDR64I(tcc, ptr, 0), ctx);
	emit(A64_CMP(1, tcc, tmp), ctx);
	emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
	branch2 = ctx->image + ctx->idx;
	emit(A64_NOP, ctx);

	/* (*tail_call_cnt_ptr)++; */
	emit(A64_ADD_I(1, tcc, tcc, 1), ctx);

	/* prog = array->ptrs[index];
@@ -487,30 +567,37 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
	emit(A64_ADD(1, tmp, r2, tmp), ctx);
	emit(A64_LSL(1, prg, r3, 3), ctx);
	emit(A64_LDR64(prg, tmp, prg), ctx);
	emit(A64_CBZ(1, prg, jmp_offset), ctx);
	branch3 = ctx->image + ctx->idx;
	emit(A64_NOP, ctx);

	/* Update tail_call_cnt if the slot is populated. */
	emit(A64_STR64I(tcc, ptr, 0), ctx);

	/* restore SP */
	if (ctx->stack_size)
		emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);

	pop_callee_regs(ctx);

	/* goto *(prog->bpf_func + prologue_offset); */
	off = offsetof(struct bpf_prog, bpf_func);
	emit_a64_mov_i64(tmp, off, ctx);
	emit(A64_LDR64(tmp, prg, tmp), ctx);
	emit(A64_ADD_I(1, tmp, tmp, sizeof(u32) * PROLOGUE_OFFSET), ctx);
	emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
	emit(A64_BR(tmp), ctx);

	/* out: */
	if (out_offset == -1)
		out_offset = cur_offset;
	if (cur_offset != out_offset) {
		pr_err_once("tail_call out_offset = %d, expected %d!\n",
			    cur_offset, out_offset);
		return -1;
	if (ctx->image) {
		off = &ctx->image[ctx->idx] - branch1;
		*branch1 = cpu_to_le32(A64_B_(A64_COND_CS, off));

		off = &ctx->image[ctx->idx] - branch2;
		*branch2 = cpu_to_le32(A64_B_(A64_COND_CS, off));

		off = &ctx->image[ctx->idx] - branch3;
		*branch3 = cpu_to_le32(A64_CBZ(1, prg, off));
	}

	return 0;
#undef cur_offset
#undef jmp_offset
}

#ifdef CONFIG_ARM64_LSE_ATOMICS
@@ -736,38 +823,18 @@ static void build_plt(struct jit_ctx *ctx)
		plt->target = (u64)&dummy_tramp;
}

static void build_epilogue(struct jit_ctx *ctx, bool is_exception_cb)
static void build_epilogue(struct jit_ctx *ctx)
{
	const u8 r0 = bpf2a64[BPF_REG_0];
	const u8 r6 = bpf2a64[BPF_REG_6];
	const u8 r7 = bpf2a64[BPF_REG_7];
	const u8 r8 = bpf2a64[BPF_REG_8];
	const u8 r9 = bpf2a64[BPF_REG_9];
	const u8 fp = bpf2a64[BPF_REG_FP];
	const u8 ptr = bpf2a64[TCCNT_PTR];
	const u8 fpb = bpf2a64[FP_BOTTOM];

	/* We're done with BPF stack */
	if (ctx->stack_size)
		emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);

	/*
	 * Program acting as exception boundary pushes R23 and R24 in addition
	 * to BPF callee-saved registers. Exception callback uses the boundary
	 * program's stack frame, so recover these extra registers in the above
	 * two cases.
	 */
	if (ctx->prog->aux->exception_boundary || is_exception_cb)
		emit(A64_POP(A64_R(23), A64_R(24), A64_SP), ctx);

	/* Restore x27 and x28 */
	emit(A64_POP(fpb, A64_R(28), A64_SP), ctx);
	/* Restore fs (x25) and x26 */
	emit(A64_POP(ptr, fp, A64_SP), ctx);
	emit(A64_POP(ptr, fp, A64_SP), ctx);
	pop_callee_regs(ctx);

	/* Restore callee-saved register */
	emit(A64_POP(r8, r9, A64_SP), ctx);
	emit(A64_POP(r6, r7, A64_SP), ctx);
	emit(A64_POP(A64_ZR, ptr, A64_SP), ctx);

	/* Restore FP/LR registers */
	emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
@@ -887,7 +954,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
	const u8 tmp = bpf2a64[TMP_REG_1];
	const u8 tmp2 = bpf2a64[TMP_REG_2];
	const u8 fp = bpf2a64[BPF_REG_FP];
	const u8 fpb = bpf2a64[FP_BOTTOM];
	const u8 arena_vm_base = bpf2a64[ARENA_VM_START];
	const s16 off = insn->off;
	const s32 imm = insn->imm;
@@ -1339,9 +1405,9 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
			emit(A64_ADD(1, tmp2, src, arena_vm_base), ctx);
			src = tmp2;
		}
		if (ctx->fpb_offset > 0 && src == fp && BPF_MODE(insn->code) != BPF_PROBE_MEM32) {
			src_adj = fpb;
			off_adj = off + ctx->fpb_offset;
		if (src == fp) {
			src_adj = A64_SP;
			off_adj = off + ctx->stack_size;
		} else {
			src_adj = src;
			off_adj = off;
@@ -1432,9 +1498,9 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
			emit(A64_ADD(1, tmp2, dst, arena_vm_base), ctx);
			dst = tmp2;
		}
		if (ctx->fpb_offset > 0 && dst == fp && BPF_MODE(insn->code) != BPF_PROBE_MEM32) {
			dst_adj = fpb;
			off_adj = off + ctx->fpb_offset;
		if (dst == fp) {
			dst_adj = A64_SP;
			off_adj = off + ctx->stack_size;
		} else {
			dst_adj = dst;
			off_adj = off;
@@ -1494,9 +1560,9 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
			emit(A64_ADD(1, tmp2, dst, arena_vm_base), ctx);
			dst = tmp2;
		}
		if (ctx->fpb_offset > 0 && dst == fp && BPF_MODE(insn->code) != BPF_PROBE_MEM32) {
			dst_adj = fpb;
			off_adj = off + ctx->fpb_offset;
		if (dst == fp) {
			dst_adj = A64_SP;
			off_adj = off + ctx->stack_size;
		} else {
			dst_adj = dst;
			off_adj = off;
@@ -1565,79 +1631,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
	return 0;
}

/*
 * Return 0 if FP may change at runtime, otherwise find the minimum negative
 * offset to FP, converts it to positive number, and align down to 8 bytes.
 */
static int find_fpb_offset(struct bpf_prog *prog)
{
	int i;
	int offset = 0;

	for (i = 0; i < prog->len; i++) {
		const struct bpf_insn *insn = &prog->insnsi[i];
		const u8 class = BPF_CLASS(insn->code);
		const u8 mode = BPF_MODE(insn->code);
		const u8 src = insn->src_reg;
		const u8 dst = insn->dst_reg;
		const s32 imm = insn->imm;
		const s16 off = insn->off;

		switch (class) {
		case BPF_STX:
		case BPF_ST:
			/* fp holds atomic operation result */
			if (class == BPF_STX && mode == BPF_ATOMIC &&
			    ((imm == BPF_XCHG ||
			      imm == (BPF_FETCH | BPF_ADD) ||
			      imm == (BPF_FETCH | BPF_AND) ||
			      imm == (BPF_FETCH | BPF_XOR) ||
			      imm == (BPF_FETCH | BPF_OR)) &&
			     src == BPF_REG_FP))
				return 0;

			if (mode == BPF_MEM && dst == BPF_REG_FP &&
			    off < offset)
				offset = insn->off;
			break;

		case BPF_JMP32:
		case BPF_JMP:
			break;

		case BPF_LDX:
		case BPF_LD:
			/* fp holds load result */
			if (dst == BPF_REG_FP)
				return 0;

			if (class == BPF_LDX && mode == BPF_MEM &&
			    src == BPF_REG_FP && off < offset)
				offset = off;
			break;

		case BPF_ALU:
		case BPF_ALU64:
		default:
			/* fp holds ALU result */
			if (dst == BPF_REG_FP)
				return 0;
		}
	}

	if (offset < 0) {
		/*
		 * safely be converted to a positive 'int', since insn->off
		 * is 's16'
		 */
		offset = -offset;
		/* align down to 8 bytes */
		offset = ALIGN_DOWN(offset, 8);
	}

	return offset;
}

static int build_body(struct jit_ctx *ctx, bool extra_pass)
{
	const struct bpf_prog *prog = ctx->prog;
@@ -1726,7 +1719,6 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
	bool tmp_blinded = false;
	bool extra_pass = false;
	struct jit_ctx ctx;
	u64 arena_vm_start;
	u8 *image_ptr;
	u8 *ro_image_ptr;

@@ -1744,7 +1736,6 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
		prog = tmp;
	}

	arena_vm_start = bpf_arena_get_kern_vm_start(prog->aux->arena);
	jit_data = prog->aux->jit_data;
	if (!jit_data) {
		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
@@ -1774,8 +1765,8 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
		goto out_off;
	}

	ctx.fpb_offset = find_fpb_offset(prog);
	ctx.user_vm_start = bpf_arena_get_user_vm_start(prog->aux->arena);
	ctx.arena_vm_start = bpf_arena_get_kern_vm_start(prog->aux->arena);

	/*
	 * 1. Initial fake pass to compute ctx->idx and ctx->offset.
@@ -1783,8 +1774,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
	 * BPF line info needs ctx->offset[i] to be the offset of
	 * instruction[i] in jited image, so build prologue first.
	 */
	if (build_prologue(&ctx, was_classic, prog->aux->exception_cb,
			   arena_vm_start)) {
	if (build_prologue(&ctx, was_classic)) {
		prog = orig_prog;
		goto out_off;
	}
@@ -1795,7 +1785,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
	}

	ctx.epilogue_offset = ctx.idx;
	build_epilogue(&ctx, prog->aux->exception_cb);
	build_epilogue(&ctx);
	build_plt(&ctx);

	extable_align = __alignof__(struct exception_table_entry);
@@ -1832,14 +1822,14 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
	ctx.idx = 0;
	ctx.exentry_idx = 0;

	build_prologue(&ctx, was_classic, prog->aux->exception_cb, arena_vm_start);
	build_prologue(&ctx, was_classic);

	if (build_body(&ctx, extra_pass)) {
		prog = orig_prog;
		goto out_free_hdr;
	}

	build_epilogue(&ctx, prog->aux->exception_cb);
	build_epilogue(&ctx);
	build_plt(&ctx);

	/* 3. Extra pass to validate JITed code. */