Commit c538f400 authored by Keith Busch's avatar Keith Busch Committed by Jens Axboe
Browse files

io_uring: consistently use rcu semantics with sqpoll thread



The sqpoll thread is dereferenced with rcu read protection in one place,
so it needs to be annotated as an __rcu type, and should consistently
use rcu helpers for access and assignment to make sparse happy.

Since most of the accesses occur under the sqd->lock, we can use
rcu_dereference_protected() without declaring an rcu read section.
Provide a simple helper to get the thread from a locked context.

Fixes: ac0b8b32 ("io_uring: fix use-after-free of sq->thread in __io_uring_show_fdinfo()")
Signed-off-by: default avatarKeith Busch <kbusch@kernel.org>
Link: https://lore.kernel.org/r/20250611205343.1821117-1-kbusch@meta.com


[axboe: fold in fix for register.c]
Signed-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent ac0b8b32
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -2906,7 +2906,7 @@ static __cold void io_ring_exit_work(struct work_struct *work)
			struct task_struct *tsk;

			io_sq_thread_park(sqd);
			tsk = sqd->thread;
			tsk = sqpoll_task_locked(sqd);
			if (tsk && tsk->io_uring && tsk->io_uring->io_wq)
				io_wq_cancel_cb(tsk->io_uring->io_wq,
						io_cancel_ctx_cb, ctx, true);
@@ -3142,7 +3142,7 @@ __cold void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd)
	s64 inflight;
	DEFINE_WAIT(wait);

	WARN_ON_ONCE(sqd && sqd->thread != current);
	WARN_ON_ONCE(sqd && sqpoll_task_locked(sqd) != current);

	if (!current->io_uring)
		return;
+5 −2
Original line number Diff line number Diff line
@@ -273,6 +273,8 @@ static __cold int io_register_iowq_max_workers(struct io_ring_ctx *ctx,
	if (ctx->flags & IORING_SETUP_SQPOLL) {
		sqd = ctx->sq_data;
		if (sqd) {
			struct task_struct *tsk;

			/*
			 * Observe the correct sqd->lock -> ctx->uring_lock
			 * ordering. Fine to drop uring_lock here, we hold
@@ -282,8 +284,9 @@ static __cold int io_register_iowq_max_workers(struct io_ring_ctx *ctx,
			mutex_unlock(&ctx->uring_lock);
			mutex_lock(&sqd->lock);
			mutex_lock(&ctx->uring_lock);
			if (sqd->thread)
				tctx = sqd->thread->io_uring;
			tsk = sqpoll_task_locked(sqd);
			if (tsk)
				tctx = tsk->io_uring;
		}
	} else {
		tctx = current->io_uring;
+24 −10
Original line number Diff line number Diff line
@@ -30,7 +30,7 @@ enum {
void io_sq_thread_unpark(struct io_sq_data *sqd)
	__releases(&sqd->lock)
{
	WARN_ON_ONCE(sqd->thread == current);
	WARN_ON_ONCE(sqpoll_task_locked(sqd) == current);

	/*
	 * Do the dance but not conditional clear_bit() because it'd race with
@@ -46,24 +46,32 @@ void io_sq_thread_unpark(struct io_sq_data *sqd)
void io_sq_thread_park(struct io_sq_data *sqd)
	__acquires(&sqd->lock)
{
	WARN_ON_ONCE(data_race(sqd->thread) == current);
	struct task_struct *tsk;

	atomic_inc(&sqd->park_pending);
	set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
	mutex_lock(&sqd->lock);
	if (sqd->thread)
		wake_up_process(sqd->thread);

	tsk = sqpoll_task_locked(sqd);
	if (tsk) {
		WARN_ON_ONCE(tsk == current);
		wake_up_process(tsk);
	}
}

void io_sq_thread_stop(struct io_sq_data *sqd)
{
	WARN_ON_ONCE(sqd->thread == current);
	struct task_struct *tsk;

	WARN_ON_ONCE(test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state));

	set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
	mutex_lock(&sqd->lock);
	if (sqd->thread)
		wake_up_process(sqd->thread);
	tsk = sqpoll_task_locked(sqd);
	if (tsk) {
		WARN_ON_ONCE(tsk == current);
		wake_up_process(tsk);
	}
	mutex_unlock(&sqd->lock);
	wait_for_completion(&sqd->exited);
}
@@ -486,7 +494,10 @@ __cold int io_sq_offload_create(struct io_ring_ctx *ctx,
			goto err_sqpoll;
		}

		sqd->thread = tsk;
		mutex_lock(&sqd->lock);
		rcu_assign_pointer(sqd->thread, tsk);
		mutex_unlock(&sqd->lock);

		task_to_put = get_task_struct(tsk);
		ret = io_uring_alloc_task_context(tsk, ctx);
		wake_up_new_task(tsk);
@@ -514,10 +525,13 @@ __cold int io_sqpoll_wq_cpu_affinity(struct io_ring_ctx *ctx,
	int ret = -EINVAL;

	if (sqd) {
		struct task_struct *tsk;

		io_sq_thread_park(sqd);
		/* Don't set affinity for a dying thread */
		if (sqd->thread)
			ret = io_wq_cpu_affinity(sqd->thread->io_uring, mask);
		tsk = sqpoll_task_locked(sqd);
		if (tsk)
			ret = io_wq_cpu_affinity(tsk->io_uring, mask);
		io_sq_thread_unpark(sqd);
	}

+7 −1
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ struct io_sq_data {
	/* ctx's that are using this sqd */
	struct list_head	ctx_list;

	struct task_struct	*thread;
	struct task_struct __rcu *thread;
	struct wait_queue_head	wait;

	unsigned		sq_thread_idle;
@@ -29,3 +29,9 @@ void io_sq_thread_unpark(struct io_sq_data *sqd);
void io_put_sq_data(struct io_sq_data *sqd);
void io_sqpoll_wait_sq(struct io_ring_ctx *ctx);
int io_sqpoll_wq_cpu_affinity(struct io_ring_ctx *ctx, cpumask_var_t mask);

static inline struct task_struct *sqpoll_task_locked(struct io_sq_data *sqd)
{
	return rcu_dereference_protected(sqd->thread,
					 lockdep_is_held(&sqd->lock));
}