Commit 325d1ba3 authored by Alexei Starovoitov's avatar Alexei Starovoitov
Browse files

Merge branch 'bpf-fix-precision-backtracking-bug-with-linked-registers'

Eduard Zingerman says:

====================
bpf: Fix precision backtracking bug with linked registers

Emil Tsalapatis reported a verifier bug hit by the scx_lavd sched_ext
scheduler. The essential part of the verifier log looks as follows:

  436: ...
  // checkpoint hit for 438: (1d) if r7 == r8 goto ...
  frame 3: propagating r2,r7,r8
  frame 2: propagating r6
  mark_precise: frame3: last_idx ...
  mark_precise: frame3: regs=r2,r7,r8 stack= before 436: ...
  mark_precise: frame3: regs=r2,r7 stack= before 435: ...
  mark_precise: frame3: regs=r2,r7 stack= before 434: (85) call bpf_trace_vprintk#177
  verifier bug: backtracking call unexpected regs 84

The log complains that registers r2 and r7 are tracked as precise
while processing the bpf_trace_vprintk() call in precision backtracking.
This can't be right, as r2 is reset by the call and there is nothing
to backtrack it to. The precision propagation is triggered when
a checkpoint is hit at instruction 438, r2 is dead at that instruction.

This happens because of the following sequence of events:
- Instruction 438 is first reached with registers r2 and r7 having
  the same id via a path that does not call bpf_trace_vprintk():
  - Checkpoint is created at 438.
  - The jump at 438 is predicted, hence r7 and registers linked to it
    (r2) are propagated as precise, marking r2 and r7 precise in the
    checkpoint.
- Instruction 438 is reached a second time with r2 undefined and via
  a path that calls bpf_trace_vprintk():
  - Checkpoint is hit.
  - propagate_precision() picks registers r2 and r7 and propagates
    precision marks for those up to the helper call.

The root cause is the fact that states_equal() and
propagate_precision() assume that the precision flag can't be set for a
dead register (as computed by compute_live_registers()).
However, this is not the case when linked registers are at play.
Fix this by accounting for live register flags in
collect_linked_regs().
---
====================

Link: https://patch.msgid.link/20260306-linked-regs-and-propagate-precision-v1-0-18e859be570d@gmail.com


Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parents 6895e1d7 223ffb6a
Loading
Loading
Loading
Loading
+10 −3
Original line number Diff line number Diff line
@@ -17359,17 +17359,24 @@ static void __collect_linked_regs(struct linked_regs *reg_set, struct bpf_reg_st
 * in verifier state, save R in linked_regs if R->id == id.
 * If there are too many Rs sharing same id, reset id for leftover Rs.
 */
static void collect_linked_regs(struct bpf_verifier_state *vstate, u32 id,
static void collect_linked_regs(struct bpf_verifier_env *env,
				struct bpf_verifier_state *vstate,
				u32 id,
				struct linked_regs *linked_regs)
{
	struct bpf_insn_aux_data *aux = env->insn_aux_data;
	struct bpf_func_state *func;
	struct bpf_reg_state *reg;
	u16 live_regs;
	int i, j;
	id = id & ~BPF_ADD_CONST;
	for (i = vstate->curframe; i >= 0; i--) {
		live_regs = aux[frame_insn_idx(vstate, i)].live_regs_before;
		func = vstate->frame[i];
		for (j = 0; j < BPF_REG_FP; j++) {
			if (!(live_regs & BIT(j)))
				continue;
			reg = &func->regs[j];
			__collect_linked_regs(linked_regs, reg, id, i, j, true);
		}
@@ -17584,9 +17591,9 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
	 * if parent state is created.
	 */
	if (BPF_SRC(insn->code) == BPF_X && src_reg->type == SCALAR_VALUE && src_reg->id)
		collect_linked_regs(this_branch, src_reg->id, &linked_regs);
		collect_linked_regs(env, this_branch, src_reg->id, &linked_regs);
	if (dst_reg->type == SCALAR_VALUE && dst_reg->id)
		collect_linked_regs(this_branch, dst_reg->id, &linked_regs);
		collect_linked_regs(env, this_branch, dst_reg->id, &linked_regs);
	if (linked_regs.cnt > 1) {
		err = push_jmp_history(env, this_branch, 0, linked_regs_pack(&linked_regs));
		if (err)
+17 −17
Original line number Diff line number Diff line
@@ -18,43 +18,43 @@
		return *(u64 *)num;					\
	}

__msg(": R0=0xffffffff80000000")
__msg("R{{.}}=0xffffffff80000000")
check_assert(s64, ==, eq_int_min, INT_MIN);
__msg(": R0=0x7fffffff")
__msg("R{{.}}=0x7fffffff")
check_assert(s64, ==, eq_int_max, INT_MAX);
__msg(": R0=0")
__msg("R{{.}}=0")
check_assert(s64, ==, eq_zero, 0);
__msg(": R0=0x8000000000000000 R1=0x8000000000000000")
__msg("R{{.}}=0x8000000000000000")
check_assert(s64, ==, eq_llong_min, LLONG_MIN);
__msg(": R0=0x7fffffffffffffff R1=0x7fffffffffffffff")
__msg("R{{.}}=0x7fffffffffffffff")
check_assert(s64, ==, eq_llong_max, LLONG_MAX);

__msg(": R0=scalar(id=1,smax=0x7ffffffe)")
__msg("R{{.}}=scalar(id=1,smax=0x7ffffffe)")
check_assert(s64, <, lt_pos, INT_MAX);
__msg(": R0=scalar(id=1,smax=-1,umin=0x8000000000000000,var_off=(0x8000000000000000; 0x7fffffffffffffff))")
__msg("R{{.}}=scalar(id=1,smax=-1,umin=0x8000000000000000,var_off=(0x8000000000000000; 0x7fffffffffffffff))")
check_assert(s64, <, lt_zero, 0);
__msg(": R0=scalar(id=1,smax=0xffffffff7fffffff")
__msg("R{{.}}=scalar(id=1,smax=0xffffffff7fffffff")
check_assert(s64, <, lt_neg, INT_MIN);

__msg(": R0=scalar(id=1,smax=0x7fffffff)")
__msg("R{{.}}=scalar(id=1,smax=0x7fffffff)")
check_assert(s64, <=, le_pos, INT_MAX);
__msg(": R0=scalar(id=1,smax=0)")
__msg("R{{.}}=scalar(id=1,smax=0)")
check_assert(s64, <=, le_zero, 0);
__msg(": R0=scalar(id=1,smax=0xffffffff80000000")
__msg("R{{.}}=scalar(id=1,smax=0xffffffff80000000")
check_assert(s64, <=, le_neg, INT_MIN);

__msg(": R0=scalar(id=1,smin=umin=0x80000000,umax=0x7fffffffffffffff,var_off=(0x0; 0x7fffffffffffffff))")
__msg("R{{.}}=scalar(id=1,smin=umin=0x80000000,umax=0x7fffffffffffffff,var_off=(0x0; 0x7fffffffffffffff))")
check_assert(s64, >, gt_pos, INT_MAX);
__msg(": R0=scalar(id=1,smin=umin=1,umax=0x7fffffffffffffff,var_off=(0x0; 0x7fffffffffffffff))")
__msg("R{{.}}=scalar(id=1,smin=umin=1,umax=0x7fffffffffffffff,var_off=(0x0; 0x7fffffffffffffff))")
check_assert(s64, >, gt_zero, 0);
__msg(": R0=scalar(id=1,smin=0xffffffff80000001")
__msg("R{{.}}=scalar(id=1,smin=0xffffffff80000001")
check_assert(s64, >, gt_neg, INT_MIN);

__msg(": R0=scalar(id=1,smin=umin=0x7fffffff,umax=0x7fffffffffffffff,var_off=(0x0; 0x7fffffffffffffff))")
__msg("R{{.}}=scalar(id=1,smin=umin=0x7fffffff,umax=0x7fffffffffffffff,var_off=(0x0; 0x7fffffffffffffff))")
check_assert(s64, >=, ge_pos, INT_MAX);
__msg(": R0=scalar(id=1,smin=0,umax=0x7fffffffffffffff,var_off=(0x0; 0x7fffffffffffffff))")
__msg("R{{.}}=scalar(id=1,smin=0,umax=0x7fffffffffffffff,var_off=(0x0; 0x7fffffffffffffff))")
check_assert(s64, >=, ge_zero, 0);
__msg(": R0=scalar(id=1,smin=0xffffffff80000000")
__msg("R{{.}}=scalar(id=1,smin=0xffffffff80000000")
check_assert(s64, >=, ge_neg, INT_MIN);

SEC("?tc")
+64 −0
Original line number Diff line number Diff line
@@ -363,4 +363,68 @@ void alu32_negative_offset(void)
	__sink(path[0]);
}

void dummy_calls(void)
{
	bpf_iter_num_new(0, 0, 0);
	bpf_iter_num_next(0);
	bpf_iter_num_destroy(0);
}

SEC("socket")
__success
__flag(BPF_F_TEST_STATE_FREQ)
int spurious_precision_marks(void *ctx)
{
	struct bpf_iter_num iter;

	asm volatile(
		"r1 = %[iter];"
		"r2 = 0;"
		"r3 = 10;"
		"call %[bpf_iter_num_new];"
	"1:"
		"r1 = %[iter];"
		"call %[bpf_iter_num_next];"
		"if r0 == 0 goto 4f;"
		"r7 = *(u32 *)(r0 + 0);"
		"r8 = *(u32 *)(r0 + 0);"
		/* This jump can't be predicted and does not change r7 or r8 state. */
		"if r7 > r8 goto 2f;"
		/* Branch explored first ties r2 and r7 as having the same id. */
		"r2 = r7;"
		"goto 3f;"
	"2:"
		/* Branch explored second does not tie r2 and r7 but has a function call. */
		"call %[bpf_get_prandom_u32];"
	"3:"
		/*
		 * A checkpoint.
		 * When first branch is explored, this would inject linked registers
		 * r2 and r7 into the jump history.
		 * When second branch is explored, this would be a cache hit point,
		 * triggering propagate_precision().
		 */
		"if r7 <= 42 goto +0;"
		/*
		 * Mark r7 as precise using an if condition that is always true.
		 * When reached via the second branch, this triggered a bug in the backtrack_insn()
		 * because r2 (tied to r7) was propagated as precise to a call.
		 */
		"if r7 <= 0xffffFFFF goto +0;"
		"goto 1b;"
	"4:"
		"r1 = %[iter];"
		"call %[bpf_iter_num_destroy];"
		:
		: __imm_ptr(iter),
		  __imm(bpf_iter_num_new),
		  __imm(bpf_iter_num_next),
		  __imm(bpf_iter_num_destroy),
		  __imm(bpf_get_prandom_u32)
		: __clobber_common, "r7", "r8"
	);

	return 0;
}

char _license[] SEC("license") = "GPL";
+42 −14
Original line number Diff line number Diff line
@@ -40,6 +40,9 @@ __naked void linked_regs_bpf_k(void)
	 */
	"r3 = r10;"
	"r3 += r0;"
	/* Mark r1 and r2 as alive. */
	"r1 = r1;"
	"r2 = r2;"
	"r0 = 0;"
	"exit;"
	:
@@ -73,6 +76,9 @@ __naked void linked_regs_bpf_x_src(void)
	 */
	"r4 = r10;"
	"r4 += r0;"
	/* Mark r1 and r2 as alive. */
	"r1 = r1;"
	"r2 = r2;"
	"r0 = 0;"
	"exit;"
	:
@@ -106,6 +112,10 @@ __naked void linked_regs_bpf_x_dst(void)
	 */
	"r4 = r10;"
	"r4 += r3;"
	/* Mark r1 and r2 as alive. */
	"r0 = r0;"
	"r1 = r1;"
	"r2 = r2;"
	"r0 = 0;"
	"exit;"
	:
@@ -143,6 +153,9 @@ __naked void linked_regs_broken_link(void)
	 */
	"r3 = r10;"
	"r3 += r0;"
	/* Mark r1 and r2 as alive. */
	"r1 = r1;"
	"r2 = r2;"
	"r0 = 0;"
	"exit;"
	:
@@ -156,16 +169,16 @@ __naked void linked_regs_broken_link(void)
 */
SEC("socket")
__success __log_level(2)
__msg("12: (0f) r2 += r1")
__msg("17: (0f) r2 += r1")
/* Current state */
__msg("frame2: last_idx 12 first_idx 11 subseq_idx -1 ")
__msg("frame2: regs=r1 stack= before 11: (bf) r2 = r10")
__msg("frame2: last_idx 17 first_idx 14 subseq_idx -1 ")
__msg("frame2: regs=r1 stack= before 16: (bf) r2 = r10")
__msg("frame2: parent state regs=r1 stack=")
__msg("frame1: parent state regs= stack=")
__msg("frame0: parent state regs= stack=")
/* Parent state */
__msg("frame2: last_idx 10 first_idx 10 subseq_idx 11 ")
__msg("frame2: regs=r1 stack= before 10: (25) if r1 > 0x7 goto pc+0")
__msg("frame2: last_idx 13 first_idx 13 subseq_idx 14 ")
__msg("frame2: regs=r1 stack= before 13: (25) if r1 > 0x7 goto pc+0")
__msg("frame2: parent state regs=r1 stack=")
/* frame1.r{6,7} are marked because mark_precise_scalar_ids()
 * looks for all registers with frame2.r1.id in the current state
@@ -173,20 +186,20 @@ __msg("frame2: parent state regs=r1 stack=")
__msg("frame1: parent state regs=r6,r7 stack=")
__msg("frame0: parent state regs=r6 stack=")
/* Parent state */
__msg("frame2: last_idx 8 first_idx 8 subseq_idx 10")
__msg("frame2: regs=r1 stack= before 8: (85) call pc+1")
__msg("frame2: last_idx 9 first_idx 9 subseq_idx 13")
__msg("frame2: regs=r1 stack= before 9: (85) call pc+3")
/* frame1.r1 is marked because of backtracking of call instruction */
__msg("frame1: parent state regs=r1,r6,r7 stack=")
__msg("frame0: parent state regs=r6 stack=")
/* Parent state */
__msg("frame1: last_idx 7 first_idx 6 subseq_idx 8")
__msg("frame1: regs=r1,r6,r7 stack= before 7: (bf) r7 = r1")
__msg("frame1: regs=r1,r6 stack= before 6: (bf) r6 = r1")
__msg("frame1: last_idx 8 first_idx 7 subseq_idx 9")
__msg("frame1: regs=r1,r6,r7 stack= before 8: (bf) r7 = r1")
__msg("frame1: regs=r1,r6 stack= before 7: (bf) r6 = r1")
__msg("frame1: parent state regs=r1 stack=")
__msg("frame0: parent state regs=r6 stack=")
/* Parent state */
__msg("frame1: last_idx 4 first_idx 4 subseq_idx 6")
__msg("frame1: regs=r1 stack= before 4: (85) call pc+1")
__msg("frame1: last_idx 4 first_idx 4 subseq_idx 7")
__msg("frame1: regs=r1 stack= before 4: (85) call pc+2")
__msg("frame0: parent state regs=r1,r6 stack=")
/* Parent state */
__msg("frame0: last_idx 3 first_idx 1 subseq_idx 4")
@@ -204,6 +217,7 @@ __naked void precision_many_frames(void)
	"r1 = r0;"
	"r6 = r0;"
	"call precision_many_frames__foo;"
	"r6 = r6;" /* mark r6 as live */
	"exit;"
	:
	: __imm(bpf_ktime_get_ns)
@@ -220,6 +234,8 @@ void precision_many_frames__foo(void)
	"r6 = r1;"
	"r7 = r1;"
	"call precision_many_frames__bar;"
	"r6 = r6;" /* mark r6 as live */
	"r7 = r7;" /* mark r7 as live */
	"exit"
	::: __clobber_all);
}
@@ -229,6 +245,8 @@ void precision_many_frames__bar(void)
{
	asm volatile (
	"if r1 > 7 goto +0;"
	"r6 = 0;" /* mark r6 as live */
	"r7 = 0;" /* mark r7 as live */
	/* force r1 to be precise, this eventually marks:
	 * - bar frame r1
	 * - foo frame r{1,6,7}
@@ -340,6 +358,8 @@ __naked void precision_two_ids(void)
	"r3 += r7;"
	/* force r9 to be precise, this also marks r8 */
	"r3 += r9;"
	"r6 = r6;" /* mark r6 as live */
	"r8 = r8;" /* mark r8 as live */
	"exit;"
	:
	: __imm(bpf_ktime_get_ns)
@@ -353,7 +373,7 @@ __flag(BPF_F_TEST_STATE_FREQ)
 * collect_linked_regs() can't tie more than 6 registers for a single insn.
 */
__msg("8: (25) if r0 > 0x7 goto pc+0         ; R0=scalar(id=1")
__msg("9: (bf) r6 = r6                       ; R6=scalar(id=2")
__msg("14: (bf) r6 = r6                      ; R6=scalar(id=2")
/* check that r{0-5} are marked precise after 'if' */
__msg("frame0: regs=r0 stack= before 8: (25) if r0 > 0x7 goto pc+0")
__msg("frame0: parent state regs=r0,r1,r2,r3,r4,r5 stack=:")
@@ -372,6 +392,12 @@ __naked void linked_regs_too_many_regs(void)
	"r6 = r0;"
	/* propagate range for r{0-6} */
	"if r0 > 7 goto +0;"
	/* keep r{1-5} live */
	"r1 = r1;"
	"r2 = r2;"
	"r3 = r3;"
	"r4 = r4;"
	"r5 = r5;"
	/* make r6 appear in the log */
	"r6 = r6;"
	/* force r0 to be precise,
@@ -517,7 +543,7 @@ __naked void check_ids_in_regsafe_2(void)
	"*(u64*)(r10 - 8) = r1;"
	/* r9 = pointer to stack */
	"r9 = r10;"
	"r9 += -8;"
	"r9 += -16;"
	/* r8 = ktime_get_ns() */
	"call %[bpf_ktime_get_ns];"
	"r8 = r0;"
@@ -538,6 +564,8 @@ __naked void check_ids_in_regsafe_2(void)
	"if r7 > 4 goto l2_%=;"
	/* Access memory at r9[r6] */
	"r9 += r6;"
	"r9 += r7;"
	"r9 += r8;"
	"r0 = *(u8*)(r9 + 0);"
"l2_%=:"
	"r0 = 0;"
+4 −4
Original line number Diff line number Diff line
@@ -44,9 +44,9 @@
	mark_precise: frame0: regs=r2 stack= before 23\
	mark_precise: frame0: regs=r2 stack= before 22\
	mark_precise: frame0: regs=r2 stack= before 20\
	mark_precise: frame0: parent state regs=r2,r9 stack=:\
	mark_precise: frame0: parent state regs=r2 stack=:\
	mark_precise: frame0: last_idx 19 first_idx 10\
	mark_precise: frame0: regs=r2,r9 stack= before 19\
	mark_precise: frame0: regs=r2 stack= before 19\
	mark_precise: frame0: regs=r9 stack= before 18\
	mark_precise: frame0: regs=r8,r9 stack= before 17\
	mark_precise: frame0: regs=r0,r9 stack= before 15\
@@ -107,9 +107,9 @@
	mark_precise: frame0: parent state regs=r2 stack=:\
	mark_precise: frame0: last_idx 20 first_idx 20\
	mark_precise: frame0: regs=r2 stack= before 20\
	mark_precise: frame0: parent state regs=r2,r9 stack=:\
	mark_precise: frame0: parent state regs=r2 stack=:\
	mark_precise: frame0: last_idx 19 first_idx 17\
	mark_precise: frame0: regs=r2,r9 stack= before 19\
	mark_precise: frame0: regs=r2 stack= before 19\
	mark_precise: frame0: regs=r9 stack= before 18\
	mark_precise: frame0: regs=r8,r9 stack= before 17\
	mark_precise: frame0: parent state regs= stack=:",