Commit ffc41ce5 authored by Alexei Starovoitov's avatar Alexei Starovoitov
Browse files

Merge branch 'support-passing-bpf-iterator-to-kfuncs'

Andrii Nakryiko says:

====================
Support passing BPF iterator to kfuncs

Add support for passing BPF iterator state to any kfunc. Such kfunc has to
declare such argument with valid `struct bpf_iter_<type> *` type and should
use "__iter" suffix in argument name, following the established suffix-based
convention. We add a simple test/demo iterator getter in bpf_testmod.
====================

Link: https://lore.kernel.org/r/20240808232230.2848712-1-andrii@kernel.org


Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parents 01ac89d0 b0cd726f
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -580,6 +580,7 @@ bool btf_is_prog_ctx_type(struct bpf_verifier_log *log, const struct btf *btf,
int get_kern_ctx_btf_id(struct bpf_verifier_log *log, enum bpf_prog_type prog_type);
bool btf_types_are_same(const struct btf *btf1, u32 id1,
			const struct btf *btf2, u32 id2);
int btf_check_iter_arg(struct btf *btf, const struct btf_type *func, int arg_idx);
#else
static inline const struct btf_type *btf_type_by_id(const struct btf *btf,
						    u32 type_id)
@@ -654,6 +655,10 @@ static inline bool btf_types_are_same(const struct btf *btf1, u32 id1,
{
	return false;
}
static inline int btf_check_iter_arg(struct btf *btf, const struct btf_type *func, int arg_idx)
{
	return -EOPNOTSUPP;
}
#endif

static inline bool btf_type_is_struct_ptr(struct btf *btf, const struct btf_type *t)
+36 −14
Original line number Diff line number Diff line
@@ -8047,15 +8047,44 @@ BTF_ID_LIST_GLOBAL(btf_tracing_ids, MAX_BTF_TRACING_TYPE)
BTF_TRACING_TYPE_xxx
#undef BTF_TRACING_TYPE

/* Validate well-formedness of iter argument type.
 * On success, return positive BTF ID of iter state's STRUCT type.
 * On error, negative error is returned.
 */
int btf_check_iter_arg(struct btf *btf, const struct btf_type *func, int arg_idx)
{
	const struct btf_param *arg;
	const struct btf_type *t;
	const char *name;
	int btf_id;

	if (btf_type_vlen(func) <= arg_idx)
		return -EINVAL;

	arg = &btf_params(func)[arg_idx];
	t = btf_type_skip_modifiers(btf, arg->type, NULL);
	if (!t || !btf_type_is_ptr(t))
		return -EINVAL;
	t = btf_type_skip_modifiers(btf, t->type, &btf_id);
	if (!t || !__btf_type_is_struct(t))
		return -EINVAL;

	name = btf_name_by_offset(btf, t->name_off);
	if (!name || strncmp(name, ITER_PREFIX, sizeof(ITER_PREFIX) - 1))
		return -EINVAL;

	return btf_id;
}

static int btf_check_iter_kfuncs(struct btf *btf, const char *func_name,
				 const struct btf_type *func, u32 func_flags)
{
	u32 flags = func_flags & (KF_ITER_NEW | KF_ITER_NEXT | KF_ITER_DESTROY);
	const char *name, *sfx, *iter_name;
	const struct btf_param *arg;
	const char *sfx, *iter_name;
	const struct btf_type *t;
	char exp_name[128];
	u32 nr_args;
	int btf_id;

	/* exactly one of KF_ITER_{NEW,NEXT,DESTROY} can be set */
	if (!flags || (flags & (flags - 1)))
@@ -8066,28 +8095,21 @@ static int btf_check_iter_kfuncs(struct btf *btf, const char *func_name,
	if (nr_args < 1)
		return -EINVAL;

	arg = &btf_params(func)[0];
	t = btf_type_skip_modifiers(btf, arg->type, NULL);
	if (!t || !btf_type_is_ptr(t))
		return -EINVAL;
	t = btf_type_skip_modifiers(btf, t->type, NULL);
	if (!t || !__btf_type_is_struct(t))
		return -EINVAL;

	name = btf_name_by_offset(btf, t->name_off);
	if (!name || strncmp(name, ITER_PREFIX, sizeof(ITER_PREFIX) - 1))
		return -EINVAL;
	btf_id = btf_check_iter_arg(btf, func, 0);
	if (btf_id < 0)
		return btf_id;

	/* sizeof(struct bpf_iter_<type>) should be a multiple of 8 to
	 * fit nicely in stack slots
	 */
	t = btf_type_by_id(btf, btf_id);
	if (t->size == 0 || (t->size % 8))
		return -EINVAL;

	/* validate bpf_iter_<type>_{new,next,destroy}(struct bpf_iter_<type> *)
	 * naming pattern
	 */
	iter_name = name + sizeof(ITER_PREFIX) - 1;
	iter_name = btf_name_by_offset(btf, t->name_off) + sizeof(ITER_PREFIX) - 1;
	if (flags & KF_ITER_NEW)
		sfx = "new";
	else if (flags & KF_ITER_NEXT)
+24 −11
Original line number Diff line number Diff line
@@ -7970,12 +7970,17 @@ static bool is_iter_destroy_kfunc(struct bpf_kfunc_call_arg_meta *meta)
	return meta->kfunc_flags & KF_ITER_DESTROY;
}
static bool is_kfunc_arg_iter(struct bpf_kfunc_call_arg_meta *meta, int arg)
static bool is_kfunc_arg_iter(struct bpf_kfunc_call_arg_meta *meta, int arg_idx,
			      const struct btf_param *arg)
{
	/* btf_check_iter_kfuncs() guarantees that first argument of any iter
	 * kfunc is iter state pointer
	 */
	return arg == 0 && is_iter_kfunc(meta);
	if (is_iter_kfunc(meta))
		return arg_idx == 0;
	/* iter passed as an argument to a generic kfunc */
	return btf_param_match_suffix(meta->btf, arg, "__iter");
}
static int process_iter_arg(struct bpf_verifier_env *env, int regno, int insn_idx,
@@ -7983,14 +7988,20 @@ static int process_iter_arg(struct bpf_verifier_env *env, int regno, int insn_id
{
	struct bpf_reg_state *regs = cur_regs(env), *reg = &regs[regno];
	const struct btf_type *t;
	const struct btf_param *arg;
	int spi, err, i, nr_slots;
	u32 btf_id;
	int spi, err, i, nr_slots, btf_id;
	/* btf_check_iter_kfuncs() ensures we don't need to validate anything here */
	arg = &btf_params(meta->func_proto)[0];
	t = btf_type_skip_modifiers(meta->btf, arg->type, NULL);	/* PTR */
	t = btf_type_skip_modifiers(meta->btf, t->type, &btf_id);	/* STRUCT */
	/* For iter_{new,next,destroy} functions, btf_check_iter_kfuncs()
	 * ensures struct convention, so we wouldn't need to do any BTF
	 * validation here. But given iter state can be passed as a parameter
	 * to any kfunc, if arg has "__iter" suffix, we need to be a bit more
	 * conservative here.
	 */
	btf_id = btf_check_iter_arg(meta->btf, meta->func_proto, regno - 1);
	if (btf_id < 0) {
		verbose(env, "expected valid iter pointer as arg #%d\n", regno);
		return -EINVAL;
	}
	t = btf_type_by_id(meta->btf, btf_id);
	nr_slots = t->size / BPF_REG_SIZE;
	if (is_iter_new_kfunc(meta)) {
@@ -8012,7 +8023,9 @@ static int process_iter_arg(struct bpf_verifier_env *env, int regno, int insn_id
		if (err)
			return err;
	} else {
		/* iter_next() or iter_destroy() expect initialized iter state*/
		/* iter_next() or iter_destroy(), as well as any kfunc
		 * accepting iter argument, expect initialized iter state
		 */
		err = is_iter_reg_valid_init(env, reg, meta->btf, btf_id, nr_slots);
		switch (err) {
		case 0:
@@ -11382,7 +11395,7 @@ get_kfunc_ptr_arg_type(struct bpf_verifier_env *env,
	if (is_kfunc_arg_dynptr(meta->btf, &args[argno]))
		return KF_ARG_PTR_TO_DYNPTR;
	if (is_kfunc_arg_iter(meta, argno))
	if (is_kfunc_arg_iter(meta, argno, &args[argno]))
		return KF_ARG_PTR_TO_ITER;
	if (is_kfunc_arg_list_head(meta->btf, &args[argno]))
+12 −4
Original line number Diff line number Diff line
@@ -141,13 +141,12 @@ bpf_testmod_test_mod_kfunc(int i)

__bpf_kfunc int bpf_iter_testmod_seq_new(struct bpf_iter_testmod_seq *it, s64 value, int cnt)
{
	if (cnt < 0) {
		it->cnt = 0;
	it->cnt = cnt;

	if (cnt < 0)
		return -EINVAL;
	}

	it->value = value;
	it->cnt = cnt;

	return 0;
}
@@ -162,6 +161,14 @@ __bpf_kfunc s64 *bpf_iter_testmod_seq_next(struct bpf_iter_testmod_seq* it)
	return &it->value;
}

__bpf_kfunc s64 bpf_iter_testmod_seq_value(int val, struct bpf_iter_testmod_seq* it__iter)
{
	if (it__iter->cnt < 0)
		return 0;

	return val + it__iter->value;
}

__bpf_kfunc void bpf_iter_testmod_seq_destroy(struct bpf_iter_testmod_seq *it)
{
	it->cnt = 0;
@@ -531,6 +538,7 @@ BTF_KFUNCS_START(bpf_testmod_common_kfunc_ids)
BTF_ID_FLAGS(func, bpf_iter_testmod_seq_new, KF_ITER_NEW)
BTF_ID_FLAGS(func, bpf_iter_testmod_seq_next, KF_ITER_NEXT | KF_RET_NULL)
BTF_ID_FLAGS(func, bpf_iter_testmod_seq_destroy, KF_ITER_DESTROY)
BTF_ID_FLAGS(func, bpf_iter_testmod_seq_value)
BTF_ID_FLAGS(func, bpf_kfunc_common_test)
BTF_ID_FLAGS(func, bpf_kfunc_dynptr_test)
BTF_ID_FLAGS(func, bpf_testmod_ctx_create, KF_ACQUIRE | KF_RET_NULL)
+50 −0
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@ struct bpf_iter_testmod_seq {

extern int bpf_iter_testmod_seq_new(struct bpf_iter_testmod_seq *it, s64 value, int cnt) __ksym;
extern s64 *bpf_iter_testmod_seq_next(struct bpf_iter_testmod_seq *it) __ksym;
extern s64 bpf_iter_testmod_seq_value(int blah, struct bpf_iter_testmod_seq *it) __ksym;
extern void bpf_iter_testmod_seq_destroy(struct bpf_iter_testmod_seq *it) __ksym;

const volatile __s64 exp_empty = 0 + 1;
@@ -76,4 +77,53 @@ int testmod_seq_truncated(const void *ctx)
	return 0;
}

SEC("?raw_tp")
__failure
__msg("expected an initialized iter_testmod_seq as arg #2")
int testmod_seq_getter_before_bad(const void *ctx)
{
	struct bpf_iter_testmod_seq it;

	return bpf_iter_testmod_seq_value(0, &it);
}

SEC("?raw_tp")
__failure
__msg("expected an initialized iter_testmod_seq as arg #2")
int testmod_seq_getter_after_bad(const void *ctx)
{
	struct bpf_iter_testmod_seq it;
	s64 sum = 0, *v;

	bpf_iter_testmod_seq_new(&it, 100, 100);

	while ((v = bpf_iter_testmod_seq_next(&it))) {
		sum += *v;
	}

	bpf_iter_testmod_seq_destroy(&it);

	return sum + bpf_iter_testmod_seq_value(0, &it);
}

SEC("?socket")
__success __retval(1000000)
int testmod_seq_getter_good(const void *ctx)
{
	struct bpf_iter_testmod_seq it;
	s64 sum = 0, *v;

	bpf_iter_testmod_seq_new(&it, 100, 100);

	while ((v = bpf_iter_testmod_seq_next(&it))) {
		sum += *v;
	}

	sum *= bpf_iter_testmod_seq_value(0, &it);

	bpf_iter_testmod_seq_destroy(&it);

	return sum;
}

char _license[] SEC("license") = "GPL";