Commit 32c563d1 authored by Alexei Starovoitov's avatar Alexei Starovoitov
Browse files

Merge branch 'bpf-riscv64-support-load-acquire-and-store-release-instructions'

Peilin Ye says:

====================
bpf, riscv64: Support load-acquire and store-release instructions

Hi all!

Patchset [1] introduced BPF load-acquire (BPF_LOAD_ACQ) and
store-release (BPF_STORE_REL) instructions, and added x86-64 and arm64
JIT compiler support.  As a follow-up, this v2 patchset supports
load-acquire and store-release instructions for the riscv64 JIT
compiler, and introduces some related selftests/ changes.

Specifically:

 * PATCH 1 makes insn_def_regno() handle load-acquires properly for
   bpf_jit_needs_zext() (true for riscv64) architectures
 * PATCH 2, 3 from Andrea Parri add the actual support to the riscv64
   JIT compiler
 * PATCH 4 optimizes code emission by skipping redundant zext
   instructions inserted by the verifier
 * PATCH 5, 6 and 7 are minor selftest/ improvements
 * PATCH 8 enables (non-arena) load-acquire/store-release selftests for
   riscv64

v1: https://lore.kernel.org/bpf/cover.1745970908.git.yepeilin@google.com/
Changes since v1:

 * add Acked-by:, Reviewed-by: and Tested-by: tags from Lehui and Björn
 * simplify code logic in PATCH 1 (Lehui)
 * in PATCH 3, avoid changing 'return 0;' to 'return ret;' at the end of
   bpf_jit_emit_insn() (Lehui)

Please refer to individual patches for details.  Thanks!

[1] https://lore.kernel.org/all/cover.1741049567.git.yepeilin@google.com/
====================

Link: https://patch.msgid.link/cover.1746588351.git.yepeilin@google.com


Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parents b69d4413 d3131466
Loading
Loading
Loading
Loading
+15 −0
Original line number Diff line number Diff line
@@ -608,6 +608,21 @@ static inline u32 rv_fence(u8 pred, u8 succ)
	return rv_i_insn(imm11_0, 0, 0, 0, 0xf);
}

static inline void emit_fence_r_rw(struct rv_jit_context *ctx)
{
	emit(rv_fence(0x2, 0x3), ctx);
}

static inline void emit_fence_rw_w(struct rv_jit_context *ctx)
{
	emit(rv_fence(0x3, 0x1), ctx);
}

static inline void emit_fence_rw_rw(struct rv_jit_context *ctx)
{
	emit(rv_fence(0x3, 0x3), ctx);
}

static inline u32 rv_nop(void)
{
	return rv_i_insn(0, 0, 0, 0, 0x13);
+227 −105
Original line number Diff line number Diff line
@@ -473,11 +473,212 @@ static inline void emit_kcfi(u32 hash, struct rv_jit_context *ctx)
		emit(hash, ctx);
}

static void emit_atomic(u8 rd, u8 rs, s16 off, s32 imm, bool is64,
static int emit_load_8(bool sign_ext, u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx)
{
	int insns_start;

	if (is_12b_int(off)) {
		insns_start = ctx->ninsns;
		if (sign_ext)
			emit(rv_lb(rd, off, rs), ctx);
		else
			emit(rv_lbu(rd, off, rs), ctx);
		return ctx->ninsns - insns_start;
	}

	emit_imm(RV_REG_T1, off, ctx);
	emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
	insns_start = ctx->ninsns;
	if (sign_ext)
		emit(rv_lb(rd, 0, RV_REG_T1), ctx);
	else
		emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
	return ctx->ninsns - insns_start;
}

static int emit_load_16(bool sign_ext, u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx)
{
	int insns_start;

	if (is_12b_int(off)) {
		insns_start = ctx->ninsns;
		if (sign_ext)
			emit(rv_lh(rd, off, rs), ctx);
		else
			emit(rv_lhu(rd, off, rs), ctx);
		return ctx->ninsns - insns_start;
	}

	emit_imm(RV_REG_T1, off, ctx);
	emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
	insns_start = ctx->ninsns;
	if (sign_ext)
		emit(rv_lh(rd, 0, RV_REG_T1), ctx);
	else
		emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
	return ctx->ninsns - insns_start;
}

static int emit_load_32(bool sign_ext, u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx)
{
	int insns_start;

	if (is_12b_int(off)) {
		insns_start = ctx->ninsns;
		if (sign_ext)
			emit(rv_lw(rd, off, rs), ctx);
		else
			emit(rv_lwu(rd, off, rs), ctx);
		return ctx->ninsns - insns_start;
	}

	emit_imm(RV_REG_T1, off, ctx);
	emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
	insns_start = ctx->ninsns;
	if (sign_ext)
		emit(rv_lw(rd, 0, RV_REG_T1), ctx);
	else
		emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
	return ctx->ninsns - insns_start;
}

static int emit_load_64(bool sign_ext, u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx)
{
	int insns_start;

	if (is_12b_int(off)) {
		insns_start = ctx->ninsns;
		emit_ld(rd, off, rs, ctx);
		return ctx->ninsns - insns_start;
	}

	emit_imm(RV_REG_T1, off, ctx);
	emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
	insns_start = ctx->ninsns;
	emit_ld(rd, 0, RV_REG_T1, ctx);
	return ctx->ninsns - insns_start;
}

static void emit_store_8(u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx)
{
	if (is_12b_int(off)) {
		emit(rv_sb(rd, off, rs), ctx);
		return;
	}

	emit_imm(RV_REG_T1, off, ctx);
	emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
	emit(rv_sb(RV_REG_T1, 0, rs), ctx);
}

static void emit_store_16(u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx)
{
	if (is_12b_int(off)) {
		emit(rv_sh(rd, off, rs), ctx);
		return;
	}

	emit_imm(RV_REG_T1, off, ctx);
	emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
	emit(rv_sh(RV_REG_T1, 0, rs), ctx);
}

static void emit_store_32(u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx)
{
	if (is_12b_int(off)) {
		emit_sw(rd, off, rs, ctx);
		return;
	}

	emit_imm(RV_REG_T1, off, ctx);
	emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
	emit_sw(RV_REG_T1, 0, rs, ctx);
}

static void emit_store_64(u8 rd, s32 off, u8 rs, struct rv_jit_context *ctx)
{
	if (is_12b_int(off)) {
		emit_sd(rd, off, rs, ctx);
		return;
	}

	emit_imm(RV_REG_T1, off, ctx);
	emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
	emit_sd(RV_REG_T1, 0, rs, ctx);
}

static int emit_atomic_ld_st(u8 rd, u8 rs, const struct bpf_insn *insn,
			     struct rv_jit_context *ctx)
{
	u8 r0;
	u8 code = insn->code;
	s32 imm = insn->imm;
	s16 off = insn->off;

	switch (imm) {
	/* dst_reg = load_acquire(src_reg + off16) */
	case BPF_LOAD_ACQ:
		switch (BPF_SIZE(code)) {
		case BPF_B:
			emit_load_8(false, rd, off, rs, ctx);
			break;
		case BPF_H:
			emit_load_16(false, rd, off, rs, ctx);
			break;
		case BPF_W:
			emit_load_32(false, rd, off, rs, ctx);
			break;
		case BPF_DW:
			emit_load_64(false, rd, off, rs, ctx);
			break;
		}
		emit_fence_r_rw(ctx);

		/* If our next insn is a redundant zext, return 1 to tell
		 * build_body() to skip it.
		 */
		if (BPF_SIZE(code) != BPF_DW && insn_is_zext(&insn[1]))
			return 1;
		break;
	/* store_release(dst_reg + off16, src_reg) */
	case BPF_STORE_REL:
		emit_fence_rw_w(ctx);
		switch (BPF_SIZE(code)) {
		case BPF_B:
			emit_store_8(rd, off, rs, ctx);
			break;
		case BPF_H:
			emit_store_16(rd, off, rs, ctx);
			break;
		case BPF_W:
			emit_store_32(rd, off, rs, ctx);
			break;
		case BPF_DW:
			emit_store_64(rd, off, rs, ctx);
			break;
		}
		break;
	default:
		pr_err_once("bpf-jit: invalid atomic load/store opcode %02x\n", imm);
		return -EINVAL;
	}

	return 0;
}

static int emit_atomic_rmw(u8 rd, u8 rs, const struct bpf_insn *insn,
			   struct rv_jit_context *ctx)
{
	u8 r0, code = insn->code;
	s16 off = insn->off;
	s32 imm = insn->imm;
	int jmp_offset;
	bool is64;

	if (BPF_SIZE(code) != BPF_W && BPF_SIZE(code) != BPF_DW) {
		pr_err_once("bpf-jit: 1- and 2-byte RMW atomics are not supported\n");
		return -EINVAL;
	}
	is64 = BPF_SIZE(code) == BPF_DW;

	if (off) {
		if (is_12b_int(off)) {
@@ -554,9 +755,14 @@ static void emit_atomic(u8 rd, u8 rs, s16 off, s32 imm, bool is64,
		     rv_sc_w(RV_REG_T3, rs, rd, 0, 1), ctx);
		jmp_offset = ninsns_rvoff(-6);
		emit(rv_bne(RV_REG_T3, 0, jmp_offset >> 1), ctx);
		emit(rv_fence(0x3, 0x3), ctx);
		emit_fence_rw_rw(ctx);
		break;
	default:
		pr_err_once("bpf-jit: invalid atomic RMW opcode %02x\n", imm);
		return -EINVAL;
	}

	return 0;
}

#define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
@@ -1650,8 +1856,8 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
	case BPF_LDX | BPF_PROBE_MEM32 | BPF_W:
	case BPF_LDX | BPF_PROBE_MEM32 | BPF_DW:
	{
		int insn_len, insns_start;
		bool sign_ext;
		int insn_len;

		sign_ext = BPF_MODE(insn->code) == BPF_MEMSX ||
			   BPF_MODE(insn->code) == BPF_PROBE_MEMSX;
@@ -1663,78 +1869,16 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,

		switch (BPF_SIZE(code)) {
		case BPF_B:
			if (is_12b_int(off)) {
				insns_start = ctx->ninsns;
				if (sign_ext)
					emit(rv_lb(rd, off, rs), ctx);
				else
					emit(rv_lbu(rd, off, rs), ctx);
				insn_len = ctx->ninsns - insns_start;
				break;
			}

			emit_imm(RV_REG_T1, off, ctx);
			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
			insns_start = ctx->ninsns;
			if (sign_ext)
				emit(rv_lb(rd, 0, RV_REG_T1), ctx);
			else
				emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
			insn_len = ctx->ninsns - insns_start;
			insn_len = emit_load_8(sign_ext, rd, off, rs, ctx);
			break;
		case BPF_H:
			if (is_12b_int(off)) {
				insns_start = ctx->ninsns;
				if (sign_ext)
					emit(rv_lh(rd, off, rs), ctx);
				else
					emit(rv_lhu(rd, off, rs), ctx);
				insn_len = ctx->ninsns - insns_start;
				break;
			}

			emit_imm(RV_REG_T1, off, ctx);
			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
			insns_start = ctx->ninsns;
			if (sign_ext)
				emit(rv_lh(rd, 0, RV_REG_T1), ctx);
			else
				emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
			insn_len = ctx->ninsns - insns_start;
			insn_len = emit_load_16(sign_ext, rd, off, rs, ctx);
			break;
		case BPF_W:
			if (is_12b_int(off)) {
				insns_start = ctx->ninsns;
				if (sign_ext)
					emit(rv_lw(rd, off, rs), ctx);
				else
					emit(rv_lwu(rd, off, rs), ctx);
				insn_len = ctx->ninsns - insns_start;
				break;
			}

			emit_imm(RV_REG_T1, off, ctx);
			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
			insns_start = ctx->ninsns;
			if (sign_ext)
				emit(rv_lw(rd, 0, RV_REG_T1), ctx);
			else
				emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
			insn_len = ctx->ninsns - insns_start;
			insn_len = emit_load_32(sign_ext, rd, off, rs, ctx);
			break;
		case BPF_DW:
			if (is_12b_int(off)) {
				insns_start = ctx->ninsns;
				emit_ld(rd, off, rs, ctx);
				insn_len = ctx->ninsns - insns_start;
				break;
			}

			emit_imm(RV_REG_T1, off, ctx);
			emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
			insns_start = ctx->ninsns;
			emit_ld(rd, 0, RV_REG_T1, ctx);
			insn_len = ctx->ninsns - insns_start;
			insn_len = emit_load_64(sign_ext, rd, off, rs, ctx);
			break;
		}

@@ -1879,49 +2023,27 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,

	/* STX: *(size *)(dst + off) = src */
	case BPF_STX | BPF_MEM | BPF_B:
		if (is_12b_int(off)) {
			emit(rv_sb(rd, off, rs), ctx);
			break;
		}

		emit_imm(RV_REG_T1, off, ctx);
		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
		emit(rv_sb(RV_REG_T1, 0, rs), ctx);
		emit_store_8(rd, off, rs, ctx);
		break;
	case BPF_STX | BPF_MEM | BPF_H:
		if (is_12b_int(off)) {
			emit(rv_sh(rd, off, rs), ctx);
			break;
		}

		emit_imm(RV_REG_T1, off, ctx);
		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
		emit(rv_sh(RV_REG_T1, 0, rs), ctx);
		emit_store_16(rd, off, rs, ctx);
		break;
	case BPF_STX | BPF_MEM | BPF_W:
		if (is_12b_int(off)) {
			emit_sw(rd, off, rs, ctx);
			break;
		}

		emit_imm(RV_REG_T1, off, ctx);
		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
		emit_sw(RV_REG_T1, 0, rs, ctx);
		emit_store_32(rd, off, rs, ctx);
		break;
	case BPF_STX | BPF_MEM | BPF_DW:
		if (is_12b_int(off)) {
			emit_sd(rd, off, rs, ctx);
			break;
		}

		emit_imm(RV_REG_T1, off, ctx);
		emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
		emit_sd(RV_REG_T1, 0, rs, ctx);
		emit_store_64(rd, off, rs, ctx);
		break;
	case BPF_STX | BPF_ATOMIC | BPF_B:
	case BPF_STX | BPF_ATOMIC | BPF_H:
	case BPF_STX | BPF_ATOMIC | BPF_W:
	case BPF_STX | BPF_ATOMIC | BPF_DW:
		emit_atomic(rd, rs, off, imm,
			    BPF_SIZE(code) == BPF_DW, ctx);
		if (bpf_atomic_is_load_store(insn))
			ret = emit_atomic_ld_st(rd, rs, insn, ctx);
		else
			ret = emit_atomic_rmw(rd, rs, insn, ctx);
		if (ret)
			return ret;
		break;

	case BPF_STX | BPF_PROBE_MEM32 | BPF_B:
+1 −2
Original line number Diff line number Diff line
@@ -26,9 +26,8 @@ static int build_body(struct rv_jit_context *ctx, bool extra_pass, int *offset)
		int ret;

		ret = bpf_jit_emit_insn(insn, ctx, extra_pass);
		/* BPF_LD | BPF_IMM | BPF_DW: skip the next instruction. */
		if (ret > 0)
			i++;
			i++; /* skip the next instruction */
		if (offset)
			offset[i] = ctx->ninsns;
		if (ret < 0)
+6 −6
Original line number Diff line number Diff line
@@ -3649,16 +3649,16 @@ static int insn_def_regno(const struct bpf_insn *insn)
	case BPF_ST:
		return -1;
	case BPF_STX:
		if ((BPF_MODE(insn->code) == BPF_ATOMIC ||
		     BPF_MODE(insn->code) == BPF_PROBE_ATOMIC) &&
		    (insn->imm & BPF_FETCH)) {
		if (BPF_MODE(insn->code) == BPF_ATOMIC ||
		    BPF_MODE(insn->code) == BPF_PROBE_ATOMIC) {
			if (insn->imm == BPF_CMPXCHG)
				return BPF_REG_0;
			else
			else if (insn->imm == BPF_LOAD_ACQ)
				return insn->dst_reg;
			else if (insn->imm & BPF_FETCH)
				return insn->src_reg;
		} else {
			return -1;
		}
		return -1;
	default:
		return insn->dst_reg;
	}
+3 −2
Original line number Diff line number Diff line
@@ -226,7 +226,8 @@
#endif

#if __clang_major__ >= 18 && defined(ENABLE_ATOMICS_TESTS) &&		\
	(defined(__TARGET_ARCH_arm64) || defined(__TARGET_ARCH_x86))
	(defined(__TARGET_ARCH_arm64) || defined(__TARGET_ARCH_x86) ||	\
	 (defined(__TARGET_ARCH_riscv) && __riscv_xlen == 64))
#define CAN_USE_LOAD_ACQ_STORE_REL
#endif

Loading