Commit 1df7dad4 authored by Nandakumar Edamana's avatar Nandakumar Edamana Committed by Andrii Nakryiko
Browse files

bpf: Improve the general precision of tnum_mul



Drop the value-mask decomposition technique and adopt straightforward
long-multiplication with a twist: when LSB(a) is uncertain, find the
two partial products (for LSB(a) = known 0 and LSB(a) = known 1) and
take a union.

Experiment shows that applying this technique in long multiplication
improves the precision in a significant number of cases (at the cost
of losing precision in a relatively lower number of cases).

Signed-off-by: default avatarNandakumar Edamana <nandakumar@nandakumar.co.in>
Signed-off-by: default avatarAndrii Nakryiko <andrii@kernel.org>
Tested-by: default avatarHarishankar Vishwanathan <harishankar.vishwanathan@gmail.com>
Reviewed-by: default avatarHarishankar Vishwanathan <harishankar.vishwanathan@gmail.com>
Acked-by: default avatarEduard Zingerman <eddyz87@gmail.com>
Link: https://lore.kernel.org/bpf/20250826034524.2159515-1-nandakumar@nandakumar.co.in
parent 2465bb83
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -57,6 +57,9 @@ bool tnum_overlap(struct tnum a, struct tnum b);
/* Return a tnum representing numbers satisfying both @a and @b */
struct tnum tnum_intersect(struct tnum a, struct tnum b);

/* Returns a tnum representing numbers satisfying either @a or @b */
struct tnum tnum_union(struct tnum t1, struct tnum t2);

/* Return @a with all but the lowest @size bytes cleared */
struct tnum tnum_cast(struct tnum a, u8 size);

+42 −13
Original line number Diff line number Diff line
@@ -116,31 +116,47 @@ struct tnum tnum_xor(struct tnum a, struct tnum b)
	return TNUM(v & ~mu, mu);
}

/* Generate partial products by multiplying each bit in the multiplier (tnum a)
 * with the multiplicand (tnum b), and add the partial products after
 * appropriately bit-shifting them. Instead of directly performing tnum addition
 * on the generated partial products, equivalenty, decompose each partial
 * product into two tnums, consisting of the value-sum (acc_v) and the
 * mask-sum (acc_m) and then perform tnum addition on them. The following paper
 * explains the algorithm in more detail: https://arxiv.org/abs/2105.05398.
/* Perform long multiplication, iterating through the bits in a using rshift:
 * - if LSB(a) is a known 0, keep current accumulator
 * - if LSB(a) is a known 1, add b to current accumulator
 * - if LSB(a) is unknown, take a union of the above cases.
 *
 * For example:
 *
 *               acc_0:        acc_1:
 *
 *     11 *  ->      11 *  ->      11 *  -> union(0011, 1001) == x0x1
 *     x1            01            11
 * ------        ------        ------
 *     11            11            11
 *    xx            00            11
 * ------        ------        ------
 *   ????          0011          1001
 */
struct tnum tnum_mul(struct tnum a, struct tnum b)
{
	u64 acc_v = a.value * b.value;
	struct tnum acc_m = TNUM(0, 0);
	struct tnum acc = TNUM(0, 0);

	while (a.value || a.mask) {
		/* LSB of tnum a is a certain 1 */
		if (a.value & 1)
			acc_m = tnum_add(acc_m, TNUM(0, b.mask));
			acc = tnum_add(acc, b);
		/* LSB of tnum a is uncertain */
		else if (a.mask & 1)
			acc_m = tnum_add(acc_m, TNUM(0, b.value | b.mask));
		else if (a.mask & 1) {
			/* acc = tnum_union(acc_0, acc_1), where acc_0 and
			 * acc_1 are partial accumulators for cases
			 * LSB(a) = certain 0 and LSB(a) = certain 1.
			 * acc_0 = acc + 0 * b = acc.
			 * acc_1 = acc + 1 * b = tnum_add(acc, b).
			 */

			acc = tnum_union(acc, tnum_add(acc, b));
		}
		/* Note: no case for LSB is certain 0 */
		a = tnum_rshift(a, 1);
		b = tnum_lshift(b, 1);
	}
	return tnum_add(TNUM(acc_v, 0), acc_m);
	return acc;
}

bool tnum_overlap(struct tnum a, struct tnum b)
@@ -163,6 +179,19 @@ struct tnum tnum_intersect(struct tnum a, struct tnum b)
	return TNUM(v & ~mu, mu);
}

/* Returns a tnum with the uncertainty from both a and b, and in addition, new
 * uncertainty at any position that a and b disagree. This represents a
 * superset of the union of the concrete sets of both a and b. Despite the
 * overapproximation, it is optimal.
 */
struct tnum tnum_union(struct tnum a, struct tnum b)
{
	u64 v = a.value & b.value;
	u64 mu = (a.value ^ b.value) | a.mask | b.mask;

	return TNUM(v & ~mu, mu);
}

struct tnum tnum_cast(struct tnum a, u8 size)
{
	a.value &= (1ULL << (size * 8)) - 1;