Commit e5bdd6a8 authored by Eduard Zingerman's avatar Eduard Zingerman Committed by Alexei Starovoitov
Browse files

selftests/bpf: validate jit behaviour for tail calls



A program calling sub-program which does a tail call.
The idea is to verify instructions generated by jit for tail calls:
- in program and sub-program prologues;
- for subprogram call instruction;
- for tail call itself.

Signed-off-by: default avatarEduard Zingerman <eddyz87@gmail.com>
Link: https://lore.kernel.org/r/20240820102357.3372779-9-eddyz87@gmail.com


Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parent 7d743e4c
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -75,6 +75,7 @@
#include "verifier_stack_ptr.skel.h"
#include "verifier_subprog_precision.skel.h"
#include "verifier_subreg.skel.h"
#include "verifier_tailcall_jit.skel.h"
#include "verifier_typedef.skel.h"
#include "verifier_uninit.skel.h"
#include "verifier_unpriv.skel.h"
@@ -198,6 +199,7 @@ void test_verifier_spin_lock(void) { RUN(verifier_spin_lock); }
void test_verifier_stack_ptr(void)            { RUN(verifier_stack_ptr); }
void test_verifier_subprog_precision(void)    { RUN(verifier_subprog_precision); }
void test_verifier_subreg(void)               { RUN(verifier_subreg); }
void test_verifier_tailcall_jit(void)         { RUN(verifier_tailcall_jit); }
void test_verifier_typedef(void)              { RUN(verifier_typedef); }
void test_verifier_uninit(void)               { RUN(verifier_uninit); }
void test_verifier_unpriv(void)               { RUN(verifier_unpriv); }
+105 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: GPL-2.0
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#include "bpf_misc.h"

int main(void);

struct {
	__uint(type, BPF_MAP_TYPE_PROG_ARRAY);
	__uint(max_entries, 1);
	__uint(key_size, sizeof(__u32));
	__array(values, void (void));
} jmp_table SEC(".maps") = {
	.values = {
		[0] = (void *) &main,
	},
};

__noinline __auxiliary
static __naked int sub(void)
{
	asm volatile (
	"r2 = %[jmp_table] ll;"
	"r3 = 0;"
	"call 12;"
	"exit;"
	:
	: __imm_addr(jmp_table)
	: __clobber_all);
}

__success
__arch_x86_64
/* program entry for main(), regular function prologue */
__jited("	endbr64")
__jited("	nopl	(%rax,%rax)")
__jited("	xorq	%rax, %rax")
__jited("	pushq	%rbp")
__jited("	movq	%rsp, %rbp")
/* tail call prologue for program:
 * - establish memory location for tail call counter at &rbp[-8];
 * - spill tail_call_cnt_ptr at &rbp[-16];
 * - expect tail call counter to be passed in rax;
 * - for entry program rax is a raw counter, value < 33;
 * - for tail called program rax is tail_call_cnt_ptr (value > 33).
 */
__jited("	endbr64")
__jited("	cmpq	$0x21, %rax")
__jited("	ja	L0")
__jited("	pushq	%rax")
__jited("	movq	%rsp, %rax")
__jited("	jmp	L1")
__jited("L0:	pushq	%rax")			/* rbp[-8]  = rax         */
__jited("L1:	pushq	%rax")			/* rbp[-16] = rax         */
/* on subprogram call restore rax to be tail_call_cnt_ptr from rbp[-16]
 * (cause original rax might be clobbered by this point)
 */
__jited("	movq	-0x10(%rbp), %rax")
__jited("	callq	0x{{.*}}")		/* call to sub()          */
__jited("	xorl	%eax, %eax")
__jited("	leave")
__jited("	retq")
__jited("...")
/* subprogram entry for sub(), regular function prologue */
__jited("	endbr64")
__jited("	nopl	(%rax,%rax)")
__jited("	nopl	(%rax)")
__jited("	pushq	%rbp")
__jited("	movq	%rsp, %rbp")
/* tail call prologue for subprogram address of tail call counter
 * stored at rbp[-16].
 */
__jited("	endbr64")
__jited("	pushq	%rax")			/* rbp[-8]  = rax          */
__jited("	pushq	%rax")			/* rbp[-16] = rax          */
__jited("	movabsq	${{.*}}, %rsi")		/* r2 = &jmp_table         */
__jited("	xorl	%edx, %edx")		/* r3 = 0                  */
/* bpf_tail_call implementation:
 * - load tail_call_cnt_ptr from rbp[-16];
 * - if *tail_call_cnt_ptr < 33, increment it and jump to target;
 * - otherwise do nothing.
 */
__jited("	movq	-0x10(%rbp), %rax")
__jited("	cmpq	$0x21, (%rax)")
__jited("	jae	L0")
__jited("	nopl	(%rax,%rax)")
__jited("	addq	$0x1, (%rax)")		/* *tail_call_cnt_ptr += 1 */
__jited("	popq	%rax")
__jited("	popq	%rax")
__jited("	jmp	{{.*}}")		/* jump to tail call tgt   */
__jited("L0:	leave")
__jited("	retq")
SEC("tc")
__naked int main(void)
{
	asm volatile (
	"call %[sub];"
	"r0 = 0;"
	"exit;"
	:
	: __imm(sub)
	: __clobber_all);
}

char __license[] SEC("license") = "GPL";