Commit c5e9f6a9 authored by Jens Axboe's avatar Jens Axboe
Browse files

io_uring: unify getting ctx from passed in file descriptor



io_uring_enter() and io_uring_register() end up having duplicated code
for getting a ctx from a passed in file descriptor, for either a
registered ring descriptor or a normal file descriptor. Move the
io_uring_register_get_file() into io_uring.c and name it a bit more
generically, and use it from both callsites rather than have that logic
and handling duplicated.

Signed-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent b4d893d6
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -181,7 +181,7 @@ static int bpf_io_reg(void *kdata, struct bpf_link *link)
	struct file *file;
	int ret = -EBUSY;

	file = io_uring_register_get_file(ops->ring_fd, false);
	file = io_uring_ctx_get_file(ops->ring_fd, false);
	if (IS_ERR(file))
		return PTR_ERR(file);
	ctx = file->private_data;
+36 −21
Original line number Diff line number Diff line
@@ -2543,39 +2543,54 @@ static int io_get_ext_arg(struct io_ring_ctx *ctx, unsigned flags,
#endif
}

SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
		u32, min_complete, u32, flags, const void __user *, argp,
		size_t, argsz)
/*
 * Given an 'fd' value, return the ctx associated with if. If 'registered' is
 * true, then the registered index is used. Otherwise, the normal fd table.
 * Caller must call fput() on the returned file if it isn't a registered file,
 * unless it's an ERR_PTR.
 */
struct file *io_uring_ctx_get_file(unsigned int fd, bool registered)
{
	struct io_ring_ctx *ctx;
	struct file *file;
	long ret;

	if (unlikely(flags & ~IORING_ENTER_FLAGS))
		return -EINVAL;

	if (registered) {
		/*
		 * Ring fd has been registered via IORING_REGISTER_RING_FDS, we
		 * need only dereference our task private array to find it.
		 */
	if (flags & IORING_ENTER_REGISTERED_RING) {
		struct io_uring_task *tctx = current->io_uring;

		if (unlikely(!tctx || fd >= IO_RINGFD_REG_MAX))
			return -EINVAL;
			return ERR_PTR(-EINVAL);
		fd = array_index_nospec(fd, IO_RINGFD_REG_MAX);
		file = tctx->registered_rings[fd];
		if (unlikely(!file))
			return -EBADF;
	} else {
		file = fget(fd);
	}

	if (unlikely(!file))
			return -EBADF;
		ret = -EOPNOTSUPP;
		if (unlikely(!io_is_uring_fops(file)))
			goto out;
		return ERR_PTR(-EBADF);
	if (io_is_uring_fops(file))
		return file;
	fput(file);
	return ERR_PTR(-EOPNOTSUPP);
}


SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
		u32, min_complete, u32, flags, const void __user *, argp,
		size_t, argsz)
{
	struct io_ring_ctx *ctx;
	struct file *file;
	long ret;

	if (unlikely(flags & ~IORING_ENTER_FLAGS))
		return -EINVAL;

	file = io_uring_ctx_get_file(fd, flags & IORING_ENTER_REGISTERED_RING);
	if (IS_ERR(file))
		return PTR_ERR(file);
	ctx = file->private_data;
	ret = -EBADFD;
	/*
+1 −0
Original line number Diff line number Diff line
@@ -173,6 +173,7 @@ void io_req_track_inflight(struct io_kiocb *req);
struct file *io_file_get_normal(struct io_kiocb *req, int fd);
struct file *io_file_get_fixed(struct io_kiocb *req, int fd,
			       unsigned issue_flags);
struct file *io_uring_ctx_get_file(unsigned int fd, bool registered);

void io_req_task_queue(struct io_kiocb *req);
void io_req_task_complete(struct io_tw_req tw_req, io_tw_token_t tw);
+1 −34
Original line number Diff line number Diff line
@@ -938,39 +938,6 @@ static int __io_uring_register(struct io_ring_ctx *ctx, unsigned opcode,
	return ret;
}

/*
 * Given an 'fd' value, return the ctx associated with if. If 'registered' is
 * true, then the registered index is used. Otherwise, the normal fd table.
 * Caller must call fput() on the returned file if it isn't a registered file,
 * unless it's an ERR_PTR.
 */
struct file *io_uring_register_get_file(unsigned int fd, bool registered)
{
	struct file *file;

	if (registered) {
		/*
		 * Ring fd has been registered via IORING_REGISTER_RING_FDS, we
		 * need only dereference our task private array to find it.
		 */
		struct io_uring_task *tctx = current->io_uring;

		if (unlikely(!tctx || fd >= IO_RINGFD_REG_MAX))
			return ERR_PTR(-EINVAL);
		fd = array_index_nospec(fd, IO_RINGFD_REG_MAX);
		file = tctx->registered_rings[fd];
	} else {
		file = fget(fd);
	}

	if (unlikely(!file))
		return ERR_PTR(-EBADF);
	if (io_is_uring_fops(file))
		return file;
	fput(file);
	return ERR_PTR(-EOPNOTSUPP);
}

static int io_uring_register_send_msg_ring(void __user *arg, unsigned int nr_args)
{
	struct io_uring_sqe sqe;
@@ -1025,7 +992,7 @@ SYSCALL_DEFINE4(io_uring_register, unsigned int, fd, unsigned int, opcode,
	if (fd == -1)
		return io_uring_register_blind(opcode, arg, nr_args);

	file = io_uring_register_get_file(fd, use_registered_ring);
	file = io_uring_ctx_get_file(fd, use_registered_ring);
	if (IS_ERR(file))
		return PTR_ERR(file);
	ctx = file->private_data;
+0 −1
Original line number Diff line number Diff line
@@ -4,6 +4,5 @@

int io_eventfd_unregister(struct io_ring_ctx *ctx);
int io_unregister_personality(struct io_ring_ctx *ctx, unsigned id);
struct file *io_uring_register_get_file(unsigned int fd, bool registered);

#endif
Loading