Commit 633a6e01 authored by Puranjay Mohan's avatar Puranjay Mohan Committed by Daniel Borkmann
Browse files

bpf, riscv: Implement PROBE_MEM32 pseudo instructions



Add support for [LDX | STX | ST], PROBE_MEM32, [B | H | W | DW]
instructions. They are similar to PROBE_MEM instructions with the
following differences:

- PROBE_MEM32 supports store.
- PROBE_MEM32 relies on the verifier to clear upper 32-bit of the
  src/dst register
- PROBE_MEM32 adds 64-bit kern_vm_start address (which is stored in S7
  in the prologue). Due to bpf_arena constructions such S7 + reg +
  off16 access is guaranteed to be within arena virtual range, so no
  address check at run-time.
- S11 is a free callee-saved register, so it is used to store kern_vm_start
- PROBE_MEM32 allows STX and ST. If they fault the store is a nop. When
  LDX faults the destination register is zeroed.

To support these on riscv, we do tmp = S7 + src/dst reg and then use
tmp2 as the new src/dst register. This allows us to reuse most of the
code for normal [LDX | STX | ST].

Signed-off-by: default avatarPuranjay Mohan <puranjay12@gmail.com>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Tested-by: default avatarBjörn Töpel <bjorn@rivosinc.com>
Tested-by: default avatarPu Lehui <pulehui@huawei.com>
Reviewed-by: default avatarPu Lehui <pulehui@huawei.com>
Acked-by: default avatarBjörn Töpel <bjorn@kernel.org>
Link: https://lore.kernel.org/bpf/20240404114203.105970-2-puranjay12@gmail.com
parent af682b76
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -81,6 +81,7 @@ struct rv_jit_context {
	int nexentries;
	unsigned long flags;
	int stack_size;
	u64 arena_vm_start;
};

/* Convert from ninsns to bytes. */
+187 −2
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@

#define RV_REG_TCC RV_REG_A6
#define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
#define RV_REG_ARENA RV_REG_S7 /* For storing arena_vm_start */

static const int regmap[] = {
	[BPF_REG_0] =	RV_REG_A5,
@@ -255,6 +256,10 @@ static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
		emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
		store_offset -= 8;
	}
	if (ctx->arena_vm_start) {
		emit_ld(RV_REG_ARENA, store_offset, RV_REG_SP, ctx);
		store_offset -= 8;
	}

	emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
	/* Set return value. */
@@ -548,6 +553,7 @@ static void emit_atomic(u8 rd, u8 rs, s16 off, s32 imm, bool is64,

#define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
#define BPF_FIXUP_REG_MASK      GENMASK(31, 27)
#define REG_DONT_CLEAR_MARKER	0	/* RV_REG_ZERO unused in pt_regmap */

bool ex_handler_bpf(const struct exception_table_entry *ex,
		    struct pt_regs *regs)
@@ -555,6 +561,7 @@ bool ex_handler_bpf(const struct exception_table_entry *ex,
	off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
	int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);

	if (regs_offset != REG_DONT_CLEAR_MARKER)
		*(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
	regs->epc = (unsigned long)&ex->fixup - offset;

@@ -572,7 +579,8 @@ static int add_exception_handler(const struct bpf_insn *insn,
	off_t fixup_offset;

	if (!ctx->insns || !ctx->ro_insns || !ctx->prog->aux->extable ||
	    (BPF_MODE(insn->code) != BPF_PROBE_MEM && BPF_MODE(insn->code) != BPF_PROBE_MEMSX))
	    (BPF_MODE(insn->code) != BPF_PROBE_MEM && BPF_MODE(insn->code) != BPF_PROBE_MEMSX &&
	     BPF_MODE(insn->code) != BPF_PROBE_MEM32))
		return 0;

	if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
@@ -1539,6 +1547,11 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
	case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
	case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
	case BPF_LDX | BPF_PROBE_MEMSX | BPF_W:
	/* LDX | PROBE_MEM32: dst = *(unsigned size *)(src + RV_REG_ARENA + off) */
	case BPF_LDX | BPF_PROBE_MEM32 | BPF_B:
	case BPF_LDX | BPF_PROBE_MEM32 | BPF_H:
	case BPF_LDX | BPF_PROBE_MEM32 | BPF_W:
	case BPF_LDX | BPF_PROBE_MEM32 | BPF_DW:
	{
		int insn_len, insns_start;
		bool sign_ext;
@@ -1546,6 +1559,11 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
		sign_ext = BPF_MODE(insn->code) == BPF_MEMSX ||
			   BPF_MODE(insn->code) == BPF_PROBE_MEMSX;

		if (BPF_MODE(insn->code) == BPF_PROBE_MEM32) {
			emit_add(RV_REG_T2, rs, RV_REG_ARENA, ctx);
			rs = RV_REG_T2;
		}

		switch (BPF_SIZE(code)) {
		case BPF_B:
			if (is_12b_int(off)) {
@@ -1682,6 +1700,86 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
		emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
		break;

	case BPF_ST | BPF_PROBE_MEM32 | BPF_B:
	case BPF_ST | BPF_PROBE_MEM32 | BPF_H:
	case BPF_ST | BPF_PROBE_MEM32 | BPF_W:
	case BPF_ST | BPF_PROBE_MEM32 | BPF_DW:
	{
		int insn_len, insns_start;

		emit_add(RV_REG_T3, rd, RV_REG_ARENA, ctx);
		rd = RV_REG_T3;

		/* Load imm to a register then store it */
		emit_imm(RV_REG_T1, imm, ctx);

		switch (BPF_SIZE(code)) {
		case BPF_B:
			if (is_12b_int(off)) {
				insns_start = ctx->ninsns;
				emit(rv_sb(rd, off, RV_REG_T1), ctx);
				insn_len = ctx->ninsns - insns_start;
				break;
			}

			emit_imm(RV_REG_T2, off, ctx);
			emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
			insns_start = ctx->ninsns;
			emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
			insn_len = ctx->ninsns - insns_start;
			break;
		case BPF_H:
			if (is_12b_int(off)) {
				insns_start = ctx->ninsns;
				emit(rv_sh(rd, off, RV_REG_T1), ctx);
				insn_len = ctx->ninsns - insns_start;
				break;
			}

			emit_imm(RV_REG_T2, off, ctx);
			emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
			insns_start = ctx->ninsns;
			emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
			insn_len = ctx->ninsns - insns_start;
			break;
		case BPF_W:
			if (is_12b_int(off)) {
				insns_start = ctx->ninsns;
				emit_sw(rd, off, RV_REG_T1, ctx);
				insn_len = ctx->ninsns - insns_start;
				break;
			}

			emit_imm(RV_REG_T2, off, ctx);
			emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
			insns_start = ctx->ninsns;
			emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
			insn_len = ctx->ninsns - insns_start;
			break;
		case BPF_DW:
			if (is_12b_int(off)) {
				insns_start = ctx->ninsns;
				emit_sd(rd, off, RV_REG_T1, ctx);
				insn_len = ctx->ninsns - insns_start;
				break;
			}

			emit_imm(RV_REG_T2, off, ctx);
			emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
			insns_start = ctx->ninsns;
			emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
			insn_len = ctx->ninsns - insns_start;
			break;
		}

		ret = add_exception_handler(insn, ctx, REG_DONT_CLEAR_MARKER,
					    insn_len);
		if (ret)
			return ret;

		break;
	}

	/* STX: *(size *)(dst + off) = src */
	case BPF_STX | BPF_MEM | BPF_B:
		if (is_12b_int(off)) {
@@ -1728,6 +1826,84 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
		emit_atomic(rd, rs, off, imm,
			    BPF_SIZE(code) == BPF_DW, ctx);
		break;

	case BPF_STX | BPF_PROBE_MEM32 | BPF_B:
	case BPF_STX | BPF_PROBE_MEM32 | BPF_H:
	case BPF_STX | BPF_PROBE_MEM32 | BPF_W:
	case BPF_STX | BPF_PROBE_MEM32 | BPF_DW:
	{
		int insn_len, insns_start;

		emit_add(RV_REG_T2, rd, RV_REG_ARENA, ctx);
		rd = RV_REG_T2;

		switch (BPF_SIZE(code)) {
		case BPF_B:
			if (is_12b_int(off)) {
				insns_start = ctx->ninsns;
				emit(rv_sb(rd, off, rs), ctx);
				insn_len = ctx->ninsns - insns_start;
				break;
			}

			emit_imm(RV_REG_T1, off, ctx);
			emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
			insns_start = ctx->ninsns;
			emit(rv_sb(RV_REG_T1, 0, rs), ctx);
			insn_len = ctx->ninsns - insns_start;
			break;
		case BPF_H:
			if (is_12b_int(off)) {
				insns_start = ctx->ninsns;
				emit(rv_sh(rd, off, rs), ctx);
				insn_len = ctx->ninsns - insns_start;
				break;
			}

			emit_imm(RV_REG_T1, off, ctx);
			emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
			insns_start = ctx->ninsns;
			emit(rv_sh(RV_REG_T1, 0, rs), ctx);
			insn_len = ctx->ninsns - insns_start;
			break;
		case BPF_W:
			if (is_12b_int(off)) {
				insns_start = ctx->ninsns;
				emit_sw(rd, off, rs, ctx);
				insn_len = ctx->ninsns - insns_start;
				break;
			}

			emit_imm(RV_REG_T1, off, ctx);
			emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
			insns_start = ctx->ninsns;
			emit_sw(RV_REG_T1, 0, rs, ctx);
			insn_len = ctx->ninsns - insns_start;
			break;
		case BPF_DW:
			if (is_12b_int(off)) {
				insns_start = ctx->ninsns;
				emit_sd(rd, off, rs, ctx);
				insn_len = ctx->ninsns - insns_start;
				break;
			}

			emit_imm(RV_REG_T1, off, ctx);
			emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
			insns_start = ctx->ninsns;
			emit_sd(RV_REG_T1, 0, rs, ctx);
			insn_len = ctx->ninsns - insns_start;
			break;
		}

		ret = add_exception_handler(insn, ctx, REG_DONT_CLEAR_MARKER,
					    insn_len);
		if (ret)
			return ret;

		break;
	}

	default:
		pr_err("bpf-jit: unknown opcode %02x\n", code);
		return -EINVAL;
@@ -1759,6 +1935,8 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog)
		stack_adjust += 8;
	if (seen_reg(RV_REG_S6, ctx))
		stack_adjust += 8;
	if (ctx->arena_vm_start)
		stack_adjust += 8;

	stack_adjust = round_up(stack_adjust, 16);
	stack_adjust += bpf_stack_adjust;
@@ -1810,6 +1988,10 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog)
		emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
		store_offset -= 8;
	}
	if (ctx->arena_vm_start) {
		emit_sd(RV_REG_SP, store_offset, RV_REG_ARENA, ctx);
		store_offset -= 8;
	}

	emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);

@@ -1823,6 +2005,9 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog)
		emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);

	ctx->stack_size = stack_adjust;

	if (ctx->arena_vm_start)
		emit_imm(RV_REG_ARENA, ctx->arena_vm_start, ctx);
}

void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
+1 −0
Original line number Diff line number Diff line
@@ -80,6 +80,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
		goto skip_init_ctx;
	}

	ctx->arena_vm_start = bpf_arena_get_kern_vm_start(prog->aux->arena);
	ctx->prog = prog;
	ctx->offset = kcalloc(prog->len, sizeof(int), GFP_KERNEL);
	if (!ctx->offset) {