Commit da4ab5dc authored by Alexei Starovoitov's avatar Alexei Starovoitov
Browse files

Merge branch 'bpf-recognize-special-arithmetic-shift-in-the-verifier'

Puranjay Mohan says:

====================
bpf: Recognize special arithmetic shift in the verifier

v3: https://lore.kernel.org/all/20260103022310.935686-1-puranjay@kernel.org/
Changes in v3->v4:
- Fork verifier state while processing BPF_OR when src_reg has [-1,0]
  range and 2nd operand is a constant. This is to detect the following pattern:
	i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
- Add selftests for above.
- Remove __description("s>>=63") (Eduard in another patchset)

v2: https://lore.kernel.org/bpf/20251115022611.64898-1-alexei.starovoitov@gmail.com/
Changes in v2->v3:
- fork verifier state while processing BPF_AND when src_reg has [-1,0]
  range and 2nd operand is a constant.

v1->v2:
Use __mark_reg32_known() or __mark_reg_known() for zero too.
Add comment to selftest.

v1:
https://lore.kernel.org/bpf/20251114031039.63852-1-alexei.starovoitov@gmail.com/
====================

Link: https://patch.msgid.link/20260112201424.816836-1-puranjay@kernel.org


Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parents 1fffe1f4 91603353
Loading
Loading
Loading
Loading
+39 −0
Original line number Diff line number Diff line
@@ -15491,6 +15491,35 @@ static bool is_safe_to_compute_dst_reg_range(struct bpf_insn *insn,
	}
}
static int maybe_fork_scalars(struct bpf_verifier_env *env, struct bpf_insn *insn,
			      struct bpf_reg_state *dst_reg)
{
	struct bpf_verifier_state *branch;
	struct bpf_reg_state *regs;
	bool alu32;
	if (dst_reg->smin_value == -1 && dst_reg->smax_value == 0)
		alu32 = false;
	else if (dst_reg->s32_min_value == -1 && dst_reg->s32_max_value == 0)
		alu32 = true;
	else
		return 0;
	branch = push_stack(env, env->insn_idx + 1, env->insn_idx, false);
	if (IS_ERR(branch))
		return PTR_ERR(branch);
	regs = branch->frame[branch->curframe]->regs;
	if (alu32) {
		__mark_reg32_known(&regs[insn->dst_reg], 0);
		__mark_reg32_known(dst_reg, -1ull);
	} else {
		__mark_reg_known(&regs[insn->dst_reg], 0);
		__mark_reg_known(dst_reg, -1ull);
	}
	return 0;
}
/* WARNING: This function does calculations on 64-bit values, but the actual
 * execution may occur on 32-bit values. Therefore, things like bitshifts
 * need extra checks in the 32-bit case.
@@ -15553,11 +15582,21 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
		scalar_min_max_mul(dst_reg, &src_reg);
		break;
	case BPF_AND:
		if (tnum_is_const(src_reg.var_off)) {
			ret = maybe_fork_scalars(env, insn, dst_reg);
			if (ret)
				return ret;
		}
		dst_reg->var_off = tnum_and(dst_reg->var_off, src_reg.var_off);
		scalar32_min_max_and(dst_reg, &src_reg);
		scalar_min_max_and(dst_reg, &src_reg);
		break;
	case BPF_OR:
		if (tnum_is_const(src_reg.var_off)) {
			ret = maybe_fork_scalars(env, insn, dst_reg);
			if (ret)
				return ret;
		}
		dst_reg->var_off = tnum_or(dst_reg->var_off, src_reg.var_off);
		scalar32_min_max_or(dst_reg, &src_reg);
		scalar_min_max_or(dst_reg, &src_reg);
+85 −0
Original line number Diff line number Diff line
@@ -738,4 +738,89 @@ __naked void ldx_w_zero_extend_check(void)
	: __clobber_all);
}

SEC("socket")
__success __success_unpriv __retval(0)
__naked void arsh_31_and(void)
{
	/* Below is what LLVM generates in cilium's bpf_wiregard.o */
	asm volatile ("					\
	call %[bpf_get_prandom_u32];			\
	w2 = w0;					\
	w2 s>>= 31;					\
	w2 &= -134; /* w2 becomes 0 or -134 */		\
	if w2 s> -1 goto +2;				\
	/* Branch always taken because w2 = -134 */	\
	if w2 != -136 goto +1;				\
	w0 /= 0;					\
	w0 = 0;						\
	exit;						\
"	:
	: __imm(bpf_get_prandom_u32)
	: __clobber_all);
}

SEC("socket")
__success __success_unpriv __retval(0)
__naked void arsh_63_and(void)
{
	/* Copy of arsh_31 with s/w/r/ */
	asm volatile ("					\
	call %[bpf_get_prandom_u32];			\
	r2 = r0;					\
	r2 <<= 32;					\
	r2 s>>= 63;					\
	r2 &= -134;					\
	if r2 s> -1 goto +2;				\
	/* Branch always taken because w2 = -134 */	\
	if r2 != -136 goto +1;				\
	r0 /= 0;					\
	r0 = 0;						\
	exit;						\
"	:
	: __imm(bpf_get_prandom_u32)
	: __clobber_all);
}

SEC("socket")
__success __success_unpriv __retval(0)
__naked void arsh_31_or(void)
{
	asm volatile ("					\
	call %[bpf_get_prandom_u32];			\
	w2 = w0;					\
	w2 s>>= 31;					\
	w2 |= 134; /* w2 becomes -1 or 134 */		\
	if w2 s> -1 goto +2;				\
	/* Branch always taken because w2 = -1 */	\
	if w2 == -1 goto +1;				\
	w0 /= 0;					\
	w0 = 0;						\
	exit;						\
"	:
	: __imm(bpf_get_prandom_u32)
	: __clobber_all);
}

SEC("socket")
__success __success_unpriv __retval(0)
__naked void arsh_63_or(void)
{
	/* Copy of arsh_31 with s/w/r/ */
	asm volatile ("					\
	call %[bpf_get_prandom_u32];			\
	r2 = r0;					\
	r2 <<= 32;					\
	r2 s>>= 63;					\
	r2 |= 134; /* r2 becomes -1 or 134 */		\
	if r2 s> -1 goto +2;				\
	/* Branch always taken because w2 = -1 */	\
	if r2 == -1 goto +1;				\
	r0 /= 0;					\
	r0 = 0;						\
	exit;						\
"	:
	: __imm(bpf_get_prandom_u32)
	: __clobber_all);
}

char _license[] SEC("license") = "GPL";