Commit f03baece authored by Jens Axboe's avatar Jens Axboe
Browse files

io_uring: move cancelations to be io_uring_task based



Right now the task_struct pointer is used as the key to match a task,
but in preparation for some io_kiocb changes, move it to using struct
io_uring_task instead. No functional changes intended in this patch.

Signed-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent 6f94cbc2
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -141,7 +141,7 @@ int io_futex_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd,
	return -ENOENT;
}

bool io_futex_remove_all(struct io_ring_ctx *ctx, struct task_struct *task,
bool io_futex_remove_all(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
			 bool cancel_all)
{
	struct hlist_node *tmp;
@@ -151,7 +151,7 @@ bool io_futex_remove_all(struct io_ring_ctx *ctx, struct task_struct *task,
	lockdep_assert_held(&ctx->uring_lock);

	hlist_for_each_entry_safe(req, tmp, &ctx->futex_list, hash_node) {
		if (!io_match_task_safe(req, task, cancel_all))
		if (!io_match_task_safe(req, tctx, cancel_all))
			continue;
		hlist_del_init(&req->hash_node);
		__io_futex_cancel(ctx, req);
+2 −2
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ int io_futex_wake(struct io_kiocb *req, unsigned int issue_flags);
#if defined(CONFIG_FUTEX)
int io_futex_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd,
		    unsigned int issue_flags);
bool io_futex_remove_all(struct io_ring_ctx *ctx, struct task_struct *task,
bool io_futex_remove_all(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
			 bool cancel_all);
bool io_futex_cache_init(struct io_ring_ctx *ctx);
void io_futex_cache_free(struct io_ring_ctx *ctx);
@@ -23,7 +23,7 @@ static inline int io_futex_cancel(struct io_ring_ctx *ctx,
	return 0;
}
static inline bool io_futex_remove_all(struct io_ring_ctx *ctx,
				       struct task_struct *task, bool cancel_all)
				       struct io_uring_task *tctx, bool cancel_all)
{
	return false;
}
+21 −21
Original line number Diff line number Diff line
@@ -142,7 +142,7 @@ struct io_defer_entry {
#define IO_CQ_WAKE_FORCE	(IO_CQ_WAKE_INIT >> 1)

static bool io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
					 struct task_struct *task,
					 struct io_uring_task *tctx,
					 bool cancel_all);

static void io_queue_sqe(struct io_kiocb *req);
@@ -201,12 +201,12 @@ static bool io_match_linked(struct io_kiocb *head)
 * As io_match_task() but protected against racing with linked timeouts.
 * User must not hold timeout_lock.
 */
bool io_match_task_safe(struct io_kiocb *head, struct task_struct *task,
bool io_match_task_safe(struct io_kiocb *head, struct io_uring_task *tctx,
			bool cancel_all)
{
	bool matched;

	if (task && head->task != task)
	if (tctx && head->task->io_uring != tctx)
		return false;
	if (cancel_all)
		return true;
@@ -2987,7 +2987,7 @@ static int io_uring_release(struct inode *inode, struct file *file)
}

struct io_task_cancel {
	struct task_struct *task;
	struct io_uring_task *tctx;
	bool all;
};

@@ -2996,11 +2996,11 @@ static bool io_cancel_task_cb(struct io_wq_work *work, void *data)
	struct io_kiocb *req = container_of(work, struct io_kiocb, work);
	struct io_task_cancel *cancel = data;

	return io_match_task_safe(req, cancel->task, cancel->all);
	return io_match_task_safe(req, cancel->tctx, cancel->all);
}

static __cold bool io_cancel_defer_files(struct io_ring_ctx *ctx,
					 struct task_struct *task,
					 struct io_uring_task *tctx,
					 bool cancel_all)
{
	struct io_defer_entry *de;
@@ -3008,7 +3008,7 @@ static __cold bool io_cancel_defer_files(struct io_ring_ctx *ctx,

	spin_lock(&ctx->completion_lock);
	list_for_each_entry_reverse(de, &ctx->defer_list, list) {
		if (io_match_task_safe(de->req, task, cancel_all)) {
		if (io_match_task_safe(de->req, tctx, cancel_all)) {
			list_cut_position(&list, &ctx->defer_list, &de->list);
			break;
		}
@@ -3051,11 +3051,10 @@ static __cold bool io_uring_try_cancel_iowq(struct io_ring_ctx *ctx)
}

static __cold bool io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
						struct task_struct *task,
						struct io_uring_task *tctx,
						bool cancel_all)
{
	struct io_task_cancel cancel = { .task = task, .all = cancel_all, };
	struct io_uring_task *tctx = task ? task->io_uring : NULL;
	struct io_task_cancel cancel = { .tctx = tctx, .all = cancel_all, };
	enum io_wq_cancel cret;
	bool ret = false;

@@ -3069,9 +3068,9 @@ static __cold bool io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
	if (!ctx->rings)
		return false;

	if (!task) {
	if (!tctx) {
		ret |= io_uring_try_cancel_iowq(ctx);
	} else if (tctx && tctx->io_wq) {
	} else if (tctx->io_wq) {
		/*
		 * Cancels requests of all rings, not only @ctx, but
		 * it's fine as the task is in exit/exec.
@@ -3094,15 +3093,15 @@ static __cold bool io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
	if ((ctx->flags & IORING_SETUP_DEFER_TASKRUN) &&
	    io_allowed_defer_tw_run(ctx))
		ret |= io_run_local_work(ctx, INT_MAX) > 0;
	ret |= io_cancel_defer_files(ctx, task, cancel_all);
	ret |= io_cancel_defer_files(ctx, tctx, cancel_all);
	mutex_lock(&ctx->uring_lock);
	ret |= io_poll_remove_all(ctx, task, cancel_all);
	ret |= io_waitid_remove_all(ctx, task, cancel_all);
	ret |= io_futex_remove_all(ctx, task, cancel_all);
	ret |= io_uring_try_cancel_uring_cmd(ctx, task, cancel_all);
	ret |= io_poll_remove_all(ctx, tctx, cancel_all);
	ret |= io_waitid_remove_all(ctx, tctx, cancel_all);
	ret |= io_futex_remove_all(ctx, tctx, cancel_all);
	ret |= io_uring_try_cancel_uring_cmd(ctx, tctx, cancel_all);
	mutex_unlock(&ctx->uring_lock);
	ret |= io_kill_timeouts(ctx, task, cancel_all);
	if (task)
	ret |= io_kill_timeouts(ctx, tctx, cancel_all);
	if (tctx)
		ret |= io_run_task_work() > 0;
	else
		ret |= flush_delayed_work(&ctx->fallback_work);
@@ -3155,12 +3154,13 @@ __cold void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd)
				if (node->ctx->sq_data)
					continue;
				loop |= io_uring_try_cancel_requests(node->ctx,
							current, cancel_all);
							current->io_uring,
							cancel_all);
			}
		} else {
			list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
				loop |= io_uring_try_cancel_requests(ctx,
								     current,
								     current->io_uring,
								     cancel_all);
		}

+1 −1
Original line number Diff line number Diff line
@@ -115,7 +115,7 @@ void io_queue_next(struct io_kiocb *req);
void io_task_refs_refill(struct io_uring_task *tctx);
bool __io_alloc_req_refill(struct io_ring_ctx *ctx);

bool io_match_task_safe(struct io_kiocb *head, struct task_struct *task,
bool io_match_task_safe(struct io_kiocb *head, struct io_uring_task *tctx,
			bool cancel_all);

void io_activate_pollwq(struct io_ring_ctx *ctx);
+2 −2
Original line number Diff line number Diff line
@@ -714,7 +714,7 @@ int io_arm_poll_handler(struct io_kiocb *req, unsigned issue_flags)
/*
 * Returns true if we found and killed one or more poll requests
 */
__cold bool io_poll_remove_all(struct io_ring_ctx *ctx, struct task_struct *tsk,
__cold bool io_poll_remove_all(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
			       bool cancel_all)
{
	unsigned nr_buckets = 1U << ctx->cancel_table.hash_bits;
@@ -729,7 +729,7 @@ __cold bool io_poll_remove_all(struct io_ring_ctx *ctx, struct task_struct *tsk,
		struct io_hash_bucket *hb = &ctx->cancel_table.hbs[i];

		hlist_for_each_entry_safe(req, tmp, &hb->list, hash_node) {
			if (io_match_task_safe(req, tsk, cancel_all)) {
			if (io_match_task_safe(req, tctx, cancel_all)) {
				hlist_del_init(&req->hash_node);
				io_poll_cancel_req(req);
				found = true;
Loading