Commit 4d3a453b authored by Ilya Leoshkevich's avatar Ilya Leoshkevich Committed by Daniel Borkmann
Browse files

s390/bpf: Support BPF_PROBE_MEM32



BPF_PROBE_MEM32 is a new mode for LDX, ST and STX instructions. The JIT
is supposed to add the start address of the kernel arena mapping to the
%dst register, and use a probing variant of the respective memory
access.

Reuse the existing probing infrastructure for that. Put the arena
address into the literal pool, load it into %r1 and use that as an
index register. Do not clear any registers in ex_handler_bpf() for
failing ST and STX instructions.

Signed-off-by: default avatarIlya Leoshkevich <iii@linux.ibm.com>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Link: https://lore.kernel.org/bpf/20240701234304.14336-7-iii@linux.ibm.com
parent a1c04bcc
Loading
Loading
Loading
Loading
+110 −27
Original line number Diff line number Diff line
@@ -53,6 +53,7 @@ struct bpf_jit {
	int excnt;		/* Number of exception table entries */
	int prologue_plt_ret;	/* Return address for prologue hotpatch PLT */
	int prologue_plt;	/* Start of prologue hotpatch PLT */
	int kern_arena;		/* Pool offset of kernel arena address */
};

#define SEEN_MEM	BIT(0)		/* use mem[] for temporary storage */
@@ -670,6 +671,7 @@ static void bpf_jit_epilogue(struct bpf_jit *jit, u32 stack_depth)
bool ex_handler_bpf(const struct exception_table_entry *x, struct pt_regs *regs)
{
	regs->psw.addr = extable_fixup(x);
	if (x->data != -1)
		regs->gprs[x->data] = 0;
	return true;
}
@@ -681,6 +683,7 @@ struct bpf_jit_probe {
	int prg;	/* JITed instruction offset */
	int nop_prg;	/* JITed nop offset */
	int reg;	/* Register to clear on exception */
	int arena_reg;	/* Register to use for arena addressing */
};

static void bpf_jit_probe_init(struct bpf_jit_probe *probe)
@@ -688,6 +691,7 @@ static void bpf_jit_probe_init(struct bpf_jit_probe *probe)
	probe->prg = -1;
	probe->nop_prg = -1;
	probe->reg = -1;
	probe->arena_reg = REG_0;
}

/*
@@ -708,13 +712,31 @@ static void bpf_jit_probe_load_pre(struct bpf_jit *jit, struct bpf_insn *insn,
				   struct bpf_jit_probe *probe)
{
	if (BPF_MODE(insn->code) != BPF_PROBE_MEM &&
	    BPF_MODE(insn->code) != BPF_PROBE_MEMSX)
	    BPF_MODE(insn->code) != BPF_PROBE_MEMSX &&
	    BPF_MODE(insn->code) != BPF_PROBE_MEM32)
		return;

	if (BPF_MODE(insn->code) == BPF_PROBE_MEM32) {
		/* lgrl %r1,kern_arena */
		EMIT6_PCREL_RILB(0xc4080000, REG_W1, jit->kern_arena);
		probe->arena_reg = REG_W1;
	}
	probe->prg = jit->prg;
	probe->reg = reg2hex[insn->dst_reg];
}

static void bpf_jit_probe_store_pre(struct bpf_jit *jit, struct bpf_insn *insn,
				    struct bpf_jit_probe *probe)
{
	if (BPF_MODE(insn->code) != BPF_PROBE_MEM32)
		return;

	/* lgrl %r1,kern_arena */
	EMIT6_PCREL_RILB(0xc4080000, REG_W1, jit->kern_arena);
	probe->arena_reg = REG_W1;
	probe->prg = jit->prg;
}

static int bpf_jit_probe_post(struct bpf_jit *jit, struct bpf_prog *fp,
			      struct bpf_jit_probe *probe)
{
@@ -1384,51 +1406,99 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
	 * BPF_ST(X)
	 */
	case BPF_STX | BPF_MEM | BPF_B: /* *(u8 *)(dst + off) = src_reg */
		/* stcy %src,off(%dst) */
		EMIT6_DISP_LH(0xe3000000, 0x0072, src_reg, dst_reg, REG_0, off);
	case BPF_STX | BPF_PROBE_MEM32 | BPF_B:
		bpf_jit_probe_store_pre(jit, insn, &probe);
		/* stcy %src,off(%dst,%arena) */
		EMIT6_DISP_LH(0xe3000000, 0x0072, src_reg, dst_reg,
			      probe.arena_reg, off);
		err = bpf_jit_probe_post(jit, fp, &probe);
		if (err < 0)
			return err;
		jit->seen |= SEEN_MEM;
		break;
	case BPF_STX | BPF_MEM | BPF_H: /* (u16 *)(dst + off) = src */
		/* sthy %src,off(%dst) */
		EMIT6_DISP_LH(0xe3000000, 0x0070, src_reg, dst_reg, REG_0, off);
	case BPF_STX | BPF_PROBE_MEM32 | BPF_H:
		bpf_jit_probe_store_pre(jit, insn, &probe);
		/* sthy %src,off(%dst,%arena) */
		EMIT6_DISP_LH(0xe3000000, 0x0070, src_reg, dst_reg,
			      probe.arena_reg, off);
		err = bpf_jit_probe_post(jit, fp, &probe);
		if (err < 0)
			return err;
		jit->seen |= SEEN_MEM;
		break;
	case BPF_STX | BPF_MEM | BPF_W: /* *(u32 *)(dst + off) = src */
		/* sty %src,off(%dst) */
		EMIT6_DISP_LH(0xe3000000, 0x0050, src_reg, dst_reg, REG_0, off);
	case BPF_STX | BPF_PROBE_MEM32 | BPF_W:
		bpf_jit_probe_store_pre(jit, insn, &probe);
		/* sty %src,off(%dst,%arena) */
		EMIT6_DISP_LH(0xe3000000, 0x0050, src_reg, dst_reg,
			      probe.arena_reg, off);
		err = bpf_jit_probe_post(jit, fp, &probe);
		if (err < 0)
			return err;
		jit->seen |= SEEN_MEM;
		break;
	case BPF_STX | BPF_MEM | BPF_DW: /* (u64 *)(dst + off) = src */
		/* stg %src,off(%dst) */
		EMIT6_DISP_LH(0xe3000000, 0x0024, src_reg, dst_reg, REG_0, off);
	case BPF_STX | BPF_PROBE_MEM32 | BPF_DW:
		bpf_jit_probe_store_pre(jit, insn, &probe);
		/* stg %src,off(%dst,%arena) */
		EMIT6_DISP_LH(0xe3000000, 0x0024, src_reg, dst_reg,
			      probe.arena_reg, off);
		err = bpf_jit_probe_post(jit, fp, &probe);
		if (err < 0)
			return err;
		jit->seen |= SEEN_MEM;
		break;
	case BPF_ST | BPF_MEM | BPF_B: /* *(u8 *)(dst + off) = imm */
	case BPF_ST | BPF_PROBE_MEM32 | BPF_B:
		/* lhi %w0,imm */
		EMIT4_IMM(0xa7080000, REG_W0, (u8) imm);
		/* stcy %w0,off(dst) */
		EMIT6_DISP_LH(0xe3000000, 0x0072, REG_W0, dst_reg, REG_0, off);
		bpf_jit_probe_store_pre(jit, insn, &probe);
		/* stcy %w0,off(%dst,%arena) */
		EMIT6_DISP_LH(0xe3000000, 0x0072, REG_W0, dst_reg,
			      probe.arena_reg, off);
		err = bpf_jit_probe_post(jit, fp, &probe);
		if (err < 0)
			return err;
		jit->seen |= SEEN_MEM;
		break;
	case BPF_ST | BPF_MEM | BPF_H: /* (u16 *)(dst + off) = imm */
	case BPF_ST | BPF_PROBE_MEM32 | BPF_H:
		/* lhi %w0,imm */
		EMIT4_IMM(0xa7080000, REG_W0, (u16) imm);
		/* sthy %w0,off(dst) */
		EMIT6_DISP_LH(0xe3000000, 0x0070, REG_W0, dst_reg, REG_0, off);
		bpf_jit_probe_store_pre(jit, insn, &probe);
		/* sthy %w0,off(%dst,%arena) */
		EMIT6_DISP_LH(0xe3000000, 0x0070, REG_W0, dst_reg,
			      probe.arena_reg, off);
		err = bpf_jit_probe_post(jit, fp, &probe);
		if (err < 0)
			return err;
		jit->seen |= SEEN_MEM;
		break;
	case BPF_ST | BPF_MEM | BPF_W: /* *(u32 *)(dst + off) = imm */
	case BPF_ST | BPF_PROBE_MEM32 | BPF_W:
		/* llilf %w0,imm  */
		EMIT6_IMM(0xc00f0000, REG_W0, (u32) imm);
		/* sty %w0,off(%dst) */
		EMIT6_DISP_LH(0xe3000000, 0x0050, REG_W0, dst_reg, REG_0, off);
		bpf_jit_probe_store_pre(jit, insn, &probe);
		/* sty %w0,off(%dst,%arena) */
		EMIT6_DISP_LH(0xe3000000, 0x0050, REG_W0, dst_reg,
			      probe.arena_reg, off);
		err = bpf_jit_probe_post(jit, fp, &probe);
		if (err < 0)
			return err;
		jit->seen |= SEEN_MEM;
		break;
	case BPF_ST | BPF_MEM | BPF_DW: /* *(u64 *)(dst + off) = imm */
	case BPF_ST | BPF_PROBE_MEM32 | BPF_DW:
		/* lgfi %w0,imm */
		EMIT6_IMM(0xc0010000, REG_W0, imm);
		/* stg %w0,off(%dst) */
		EMIT6_DISP_LH(0xe3000000, 0x0024, REG_W0, dst_reg, REG_0, off);
		bpf_jit_probe_store_pre(jit, insn, &probe);
		/* stg %w0,off(%dst,%arena) */
		EMIT6_DISP_LH(0xe3000000, 0x0024, REG_W0, dst_reg,
			      probe.arena_reg, off);
		err = bpf_jit_probe_post(jit, fp, &probe);
		if (err < 0)
			return err;
		jit->seen |= SEEN_MEM;
		break;
	/*
@@ -1506,9 +1576,11 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
	 */
	case BPF_LDX | BPF_MEM | BPF_B: /* dst = *(u8 *)(ul) (src + off) */
	case BPF_LDX | BPF_PROBE_MEM | BPF_B:
	case BPF_LDX | BPF_PROBE_MEM32 | BPF_B:
		bpf_jit_probe_load_pre(jit, insn, &probe);
		/* llgc %dst,0(off,%src) */
		EMIT6_DISP_LH(0xe3000000, 0x0090, dst_reg, src_reg, REG_0, off);
		/* llgc %dst,off(%src,%arena) */
		EMIT6_DISP_LH(0xe3000000, 0x0090, dst_reg, src_reg,
			      probe.arena_reg, off);
		err = bpf_jit_probe_post(jit, fp, &probe);
		if (err < 0)
			return err;
@@ -1519,7 +1591,7 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
	case BPF_LDX | BPF_MEMSX | BPF_B: /* dst = *(s8 *)(ul) (src + off) */
	case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
		bpf_jit_probe_load_pre(jit, insn, &probe);
		/* lgb %dst,0(off,%src) */
		/* lgb %dst,off(%src) */
		EMIT6_DISP_LH(0xe3000000, 0x0077, dst_reg, src_reg, REG_0, off);
		err = bpf_jit_probe_post(jit, fp, &probe);
		if (err < 0)
@@ -1528,9 +1600,11 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
		break;
	case BPF_LDX | BPF_MEM | BPF_H: /* dst = *(u16 *)(ul) (src + off) */
	case BPF_LDX | BPF_PROBE_MEM | BPF_H:
	case BPF_LDX | BPF_PROBE_MEM32 | BPF_H:
		bpf_jit_probe_load_pre(jit, insn, &probe);
		/* llgh %dst,0(off,%src) */
		EMIT6_DISP_LH(0xe3000000, 0x0091, dst_reg, src_reg, REG_0, off);
		/* llgh %dst,off(%src,%arena) */
		EMIT6_DISP_LH(0xe3000000, 0x0091, dst_reg, src_reg,
			      probe.arena_reg, off);
		err = bpf_jit_probe_post(jit, fp, &probe);
		if (err < 0)
			return err;
@@ -1541,7 +1615,7 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
	case BPF_LDX | BPF_MEMSX | BPF_H: /* dst = *(s16 *)(ul) (src + off) */
	case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
		bpf_jit_probe_load_pre(jit, insn, &probe);
		/* lgh %dst,0(off,%src) */
		/* lgh %dst,off(%src) */
		EMIT6_DISP_LH(0xe3000000, 0x0015, dst_reg, src_reg, REG_0, off);
		err = bpf_jit_probe_post(jit, fp, &probe);
		if (err < 0)
@@ -1550,10 +1624,12 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
		break;
	case BPF_LDX | BPF_MEM | BPF_W: /* dst = *(u32 *)(ul) (src + off) */
	case BPF_LDX | BPF_PROBE_MEM | BPF_W:
	case BPF_LDX | BPF_PROBE_MEM32 | BPF_W:
		bpf_jit_probe_load_pre(jit, insn, &probe);
		/* llgf %dst,off(%src) */
		jit->seen |= SEEN_MEM;
		EMIT6_DISP_LH(0xe3000000, 0x0016, dst_reg, src_reg, REG_0, off);
		EMIT6_DISP_LH(0xe3000000, 0x0016, dst_reg, src_reg,
			      probe.arena_reg, off);
		err = bpf_jit_probe_post(jit, fp, &probe);
		if (err < 0)
			return err;
@@ -1572,10 +1648,12 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
		break;
	case BPF_LDX | BPF_MEM | BPF_DW: /* dst = *(u64 *)(ul) (src + off) */
	case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
	case BPF_LDX | BPF_PROBE_MEM32 | BPF_DW:
		bpf_jit_probe_load_pre(jit, insn, &probe);
		/* lg %dst,0(off,%src) */
		/* lg %dst,off(%src,%arena) */
		jit->seen |= SEEN_MEM;
		EMIT6_DISP_LH(0xe3000000, 0x0004, dst_reg, src_reg, REG_0, off);
		EMIT6_DISP_LH(0xe3000000, 0x0004, dst_reg, src_reg,
			      probe.arena_reg, off);
		err = bpf_jit_probe_post(jit, fp, &probe);
		if (err < 0)
			return err;
@@ -1988,12 +2066,17 @@ static int bpf_jit_prog(struct bpf_jit *jit, struct bpf_prog *fp,
			bool extra_pass, u32 stack_depth)
{
	int i, insn_count, lit32_size, lit64_size;
	u64 kern_arena;

	jit->lit32 = jit->lit32_start;
	jit->lit64 = jit->lit64_start;
	jit->prg = 0;
	jit->excnt = 0;

	kern_arena = bpf_arena_get_kern_vm_start(fp->aux->arena);
	if (kern_arena)
		jit->kern_arena = _EMIT_CONST_U64(kern_arena);

	bpf_jit_prologue(jit, fp, stack_depth);
	if (bpf_set_addr(jit, 0) < 0)
		return -1;