Commit 6895e1d7 authored by Alexei Starovoitov's avatar Alexei Starovoitov
Browse files

Merge branch 'bpf-fix-u32-s32-bounds-when-ranges-cross-min-max-boundary'

Eduard Zingerman says:

====================
bpf: Fix u32/s32 bounds when ranges cross min/max boundary

Cover the following cases in range refinement logic for 32-bit ranges:
- s32 range crosses U32_MAX/0 boundary, positive part of the s32 range
  overlaps with u32 range.
- s32 range crosses U32_MAX/0 boundary, negative part of the s32 range
  overlaps with u32 range.

These cases are already handled for 64-bit range refinement.

Without the fix the test in patch 2 is rejected by the verifier.
The test was reduced from sched-ext program.

Changelog:
- v2 -> v3:
  - Reverted da653de2 (Paul)
  - Removed !BPF_F_TEST_REG_INVARIANTS flag from
    crossing_32_bit_signed_boundary_2() (Paul)
- v1 -> v2:
  - Extended commit message and comments (Emil)
  - Targeting 'bpf' tree instead of bpf-next (Alexei)

v1: https://lore.kernel.org/bpf/9a23fbacdc6d33ec8fcb3f6988395b5129f75369.camel@gmail.com/T
v2: https://lore.kernel.org/bpf/20260305-bpf-32-bit-range-overflow-v2-0-7169206a3041@gmail.com/
---
====================

Link: https://patch.msgid.link/20260306-bpf-32-bit-range-overflow-v3-0-f7f67e060a6b@gmail.com


Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parents 56145d23 d87c9305
Loading
Loading
Loading
Loading
+24 −0
Original line number Diff line number Diff line
@@ -2511,6 +2511,30 @@ static void __reg32_deduce_bounds(struct bpf_reg_state *reg)
	if ((u32)reg->s32_min_value <= (u32)reg->s32_max_value) {
		reg->u32_min_value = max_t(u32, reg->s32_min_value, reg->u32_min_value);
		reg->u32_max_value = min_t(u32, reg->s32_max_value, reg->u32_max_value);
	} else {
		if (reg->u32_max_value < (u32)reg->s32_min_value) {
			/* See __reg64_deduce_bounds() for detailed explanation.
			 * Refine ranges in the following situation:
			 *
			 * 0                                                   U32_MAX
			 * |  [xxxxxxxxxxxxxx u32 range xxxxxxxxxxxxxx]              |
			 * |----------------------------|----------------------------|
			 * |xxxxx s32 range xxxxxxxxx]                       [xxxxxxx|
			 * 0                     S32_MAX S32_MIN                    -1
			 */
			reg->s32_min_value = (s32)reg->u32_min_value;
			reg->u32_max_value = min_t(u32, reg->u32_max_value, reg->s32_max_value);
		} else if ((u32)reg->s32_max_value < reg->u32_min_value) {
			/*
			 * 0                                                   U32_MAX
			 * |              [xxxxxxxxxxxxxx u32 range xxxxxxxxxxxxxx]  |
			 * |----------------------------|----------------------------|
			 * |xxxxxxxxx]                       [xxxxxxxxxxxx s32 range |
			 * 0                     S32_MAX S32_MIN                    -1
			 */
			reg->s32_max_value = (s32)reg->u32_max_value;
			reg->u32_min_value = max_t(u32, reg->u32_min_value, reg->s32_min_value);
		}
	}
}
+58 −18
Original line number Diff line number Diff line
@@ -422,15 +422,69 @@ static bool is_valid_range(enum num_t t, struct range x)
	}
}

static struct range range_improve(enum num_t t, struct range old, struct range new)
static struct range range_intersection(enum num_t t, struct range old, struct range new)
{
	return range(t, max_t(t, old.a, new.a), min_t(t, old.b, new.b));
}

/*
 * Result is precise when 'x' and 'y' overlap or form a continuous range,
 * result is an over-approximation if 'x' and 'y' do not overlap.
 */
static struct range range_union(enum num_t t, struct range x, struct range y)
{
	if (!is_valid_range(t, x))
		return y;
	if (!is_valid_range(t, y))
		return x;
	return range(t, min_t(t, x.a, y.a), max_t(t, x.b, y.b));
}

/*
 * This function attempts to improve x range intersecting it with y.
 * range_cast(... to_t ...) looses precision for ranges that pass to_t
 * min/max boundaries. To avoid such precision loses this function
 * splits both x and y into halves corresponding to non-overflowing
 * sub-ranges: [0, smin] and [smax, -1].
 * Final result is computed as follows:
 *
 *   ((x ∩ [0, smax]) ∩ (y ∩ [0, smax])) ∪
 *   ((x ∩ [smin,-1]) ∩ (y ∩ [smin,-1]))
 *
 * Precision might still be lost if final union is not a continuous range.
 */
static struct range range_refine_in_halves(enum num_t x_t, struct range x,
					   enum num_t y_t, struct range y)
{
	struct range x_pos, x_neg, y_pos, y_neg, r_pos, r_neg;
	u64 smax, smin, neg_one;

	if (t_is_32(x_t)) {
		smax = (u64)(u32)S32_MAX;
		smin = (u64)(u32)S32_MIN;
		neg_one = (u64)(u32)(s32)(-1);
	} else {
		smax = (u64)S64_MAX;
		smin = (u64)S64_MIN;
		neg_one = U64_MAX;
	}
	x_pos = range_intersection(x_t, x, range(x_t, 0, smax));
	x_neg = range_intersection(x_t, x, range(x_t, smin, neg_one));
	y_pos = range_intersection(y_t, y, range(x_t, 0, smax));
	y_neg = range_intersection(y_t, y, range(y_t, smin, neg_one));
	r_pos = range_intersection(x_t, x_pos, range_cast(y_t, x_t, y_pos));
	r_neg = range_intersection(x_t, x_neg, range_cast(y_t, x_t, y_neg));
	return range_union(x_t, r_pos, r_neg);

}

static struct range range_refine(enum num_t x_t, struct range x, enum num_t y_t, struct range y)
{
	struct range y_cast;

	if (t_is_32(x_t) == t_is_32(y_t))
		x = range_refine_in_halves(x_t, x, y_t, y);

	y_cast = range_cast(y_t, x_t, y);

	/* If we know that
@@ -444,7 +498,7 @@ static struct range range_refine(enum num_t x_t, struct range x, enum num_t y_t,
	 */
	if (x_t == S64 && y_t == S32 && y_cast.a <= S32_MAX  && y_cast.b <= S32_MAX &&
	    (s64)x.a >= S32_MIN && (s64)x.b <= S32_MAX)
		return range_improve(x_t, x, y_cast);
		return range_intersection(x_t, x, y_cast);

	/* the case when new range knowledge, *y*, is a 32-bit subregister
	 * range, while previous range knowledge, *x*, is a full register
@@ -462,25 +516,11 @@ static struct range range_refine(enum num_t x_t, struct range x, enum num_t y_t,
		x_swap = range(x_t, swap_low32(x.a, y_cast.a), swap_low32(x.b, y_cast.b));
		if (!is_valid_range(x_t, x_swap))
			return x;
		return range_improve(x_t, x, x_swap);
	}

	if (!t_is_32(x_t) && !t_is_32(y_t) && x_t != y_t) {
		if (x_t == S64 && x.a > x.b) {
			if (x.b < y.a && x.a <= y.b)
				return range(x_t, x.a, y.b);
			if (x.a > y.b && x.b >= y.a)
				return range(x_t, y.a, x.b);
		} else if (x_t == U64 && y.a > y.b) {
			if (y.b < x.a && y.a <= x.b)
				return range(x_t, y.a, x.b);
			if (y.a > x.b && y.b >= x.a)
				return range(x_t, x.a, y.b);
		}
		return range_intersection(x_t, x, x_swap);
	}

	/* otherwise, plain range cast and intersection works */
	return range_improve(x_t, x, y_cast);
	return range_intersection(x_t, x, y_cast);
}

/* =======================
+38 −1
Original line number Diff line number Diff line
@@ -1148,7 +1148,7 @@ l0_%=: r0 = 0; \
SEC("xdp")
__description("bound check with JMP32_JSLT for crossing 32-bit signed boundary")
__success __retval(0)
__flag(!BPF_F_TEST_REG_INVARIANTS) /* known invariants violation */
__flag(BPF_F_TEST_REG_INVARIANTS)
__naked void crossing_32_bit_signed_boundary_2(void)
{
	asm volatile ("					\
@@ -2000,4 +2000,41 @@ __naked void bounds_refinement_multiple_overlaps(void *ctx)
	: __clobber_all);
}

SEC("socket")
__success
__flag(BPF_F_TEST_REG_INVARIANTS)
__naked void signed_unsigned_intersection32_case1(void *ctx)
{
	asm volatile("									\
	call %[bpf_get_prandom_u32];							\
	w0 &= 0xffffffff;								\
	if w0 < 0x3 goto 1f;		/* on fall-through u32 range [3..U32_MAX]  */	\
	if w0 s> 0x1 goto 1f;		/* on fall-through s32 range [S32_MIN..1]  */	\
	if w0 s< 0x0 goto 1f;		/* range can be narrowed to  [S32_MIN..-1] */	\
	r10 = 0;			/* thus predicting the jump. */			\
1:	exit;										\
"	:
	: __imm(bpf_get_prandom_u32)
	: __clobber_all);
}

SEC("socket")
__success
__flag(BPF_F_TEST_REG_INVARIANTS)
__naked void signed_unsigned_intersection32_case2(void *ctx)
{
	asm volatile("									\
	call %[bpf_get_prandom_u32];							\
	w0 &= 0xffffffff;								\
	if w0 > 0x80000003 goto 1f;	/* on fall-through u32 range [0..S32_MIN+3] */	\
	if w0 s< -3 goto 1f;		/* on fall-through s32 range [-3..S32_MAX] */	\
	if w0 s> 5 goto 1f;		/* on fall-through s32 range [-3..5] */		\
	if w0 <= 5 goto 1f;		/* range can be narrowed to  [0..5] */		\
	r10 = 0;			/* thus predicting the jump */			\
1:	exit;										\
"	:
	: __imm(bpf_get_prandom_u32)
	: __clobber_all);
}

char _license[] SEC("license") = "GPL";