Commit ab6559bd authored by Jens Axboe's avatar Jens Axboe
Browse files

io_uring/kbuf: introduce struct io_br_sel

Rather than return addresses directly from buffer selection, add a
struct around it. No functional changes in this patch, it's in
preparation for storing more buffer related information locally, rather
than in struct io_kiocb.

Link: https://lore.kernel.org/r/20250821020750.598432-7-axboe@kernel.dk


Signed-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent 1b5add75
Loading
Loading
Loading
Loading
+13 −13
Original line number Diff line number Diff line
@@ -151,18 +151,18 @@ static int io_provided_buffers_select(struct io_kiocb *req, size_t *len,
	return 1;
}

static void __user *io_ring_buffer_select(struct io_kiocb *req, size_t *len,
static struct io_br_sel io_ring_buffer_select(struct io_kiocb *req, size_t *len,
					      struct io_buffer_list *bl,
					      unsigned int issue_flags)
{
	struct io_uring_buf_ring *br = bl->buf_ring;
	__u16 tail, head = bl->head;
	struct io_br_sel sel = { };
	struct io_uring_buf *buf;
	void __user *ret;

	tail = smp_load_acquire(&br->tail);
	if (unlikely(tail == head))
		return NULL;
		return sel;

	if (head + 1 == tail)
		req->flags |= REQ_F_BL_EMPTY;
@@ -173,7 +173,7 @@ static void __user *io_ring_buffer_select(struct io_kiocb *req, size_t *len,
	req->flags |= REQ_F_BUFFER_RING | REQ_F_BUFFERS_COMMIT;
	req->buf_list = bl;
	req->buf_index = buf->bid;
	ret = u64_to_user_ptr(buf->addr);
	sel.addr = u64_to_user_ptr(buf->addr);

	if (issue_flags & IO_URING_F_UNLOCKED || !io_file_can_poll(req)) {
		/*
@@ -189,27 +189,27 @@ static void __user *io_ring_buffer_select(struct io_kiocb *req, size_t *len,
		io_kbuf_commit(req, bl, *len, 1);
		req->buf_list = NULL;
	}
	return ret;
	return sel;
}

void __user *io_buffer_select(struct io_kiocb *req, size_t *len,
struct io_br_sel io_buffer_select(struct io_kiocb *req, size_t *len,
				  unsigned buf_group, unsigned int issue_flags)
{
	struct io_ring_ctx *ctx = req->ctx;
	struct io_br_sel sel = { };
	struct io_buffer_list *bl;
	void __user *ret = NULL;

	io_ring_submit_lock(req->ctx, issue_flags);

	bl = io_buffer_get_list(ctx, buf_group);
	if (likely(bl)) {
		if (bl->flags & IOBL_BUF_RING)
			ret = io_ring_buffer_select(req, len, bl, issue_flags);
			sel = io_ring_buffer_select(req, len, bl, issue_flags);
		else
			ret = io_provided_buffer_select(req, len, bl);
			sel.addr = io_provided_buffer_select(req, len, bl);
	}
	io_ring_submit_unlock(req->ctx, issue_flags);
	return ret;
	return sel;
}

/* cap it at a reasonable 256, will be one page even for 4K */
+17 −2
Original line number Diff line number Diff line
@@ -62,7 +62,22 @@ struct buf_sel_arg {
	unsigned short partial_map;
};

void __user *io_buffer_select(struct io_kiocb *req, size_t *len,
/*
 * Return value from io_buffer_list selection. Just returns the error or
 * user address for now, will be extended to return the buffer list in the
 * future.
 */
struct io_br_sel {
	/*
	 * Some selection parts return the user address, others return an error.
	 */
	union {
		void __user *addr;
		ssize_t val;
	};
};

struct io_br_sel io_buffer_select(struct io_kiocb *req, size_t *len,
				  unsigned buf_group, unsigned int issue_flags);
int io_buffers_select(struct io_kiocb *req, struct buf_sel_arg *arg,
		      unsigned int issue_flags);
+9 −9
Original line number Diff line number Diff line
@@ -1035,22 +1035,22 @@ int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)

retry_multishot:
	if (io_do_buffer_select(req)) {
		void __user *buf;
		struct io_br_sel sel;
		size_t len = sr->len;

		buf = io_buffer_select(req, &len, sr->buf_group, issue_flags);
		if (!buf)
		sel = io_buffer_select(req, &len, sr->buf_group, issue_flags);
		if (!sel.addr)
			return -ENOBUFS;

		if (req->flags & REQ_F_APOLL_MULTISHOT) {
			ret = io_recvmsg_prep_multishot(kmsg, sr, &buf, &len);
			ret = io_recvmsg_prep_multishot(kmsg, sr, &sel.addr, &len);
			if (ret) {
				io_kbuf_recycle(req, req->buf_list, issue_flags);
				return ret;
			}
		}

		iov_iter_ubuf(&kmsg->msg.msg_iter, ITER_DEST, buf, len);
		iov_iter_ubuf(&kmsg->msg.msg_iter, ITER_DEST, sel.addr, len);
	}

	kmsg->msg.msg_get_inq = 1;
@@ -1153,13 +1153,13 @@ static int io_recv_buf_select(struct io_kiocb *req, struct io_async_msghdr *kmsg
		iov_iter_init(&kmsg->msg.msg_iter, ITER_DEST, arg.iovs, ret,
				arg.out_len);
	} else {
		void __user *buf;
		struct io_br_sel sel;

		*len = sr->len;
		buf = io_buffer_select(req, len, sr->buf_group, issue_flags);
		if (!buf)
		sel = io_buffer_select(req, len, sr->buf_group, issue_flags);
		if (!sel.addr)
			return -ENOBUFS;
		sr->buf = buf;
		sr->buf = sel.addr;
		sr->len = *len;
map_ubuf:
		ret = import_ubuf(ITER_DEST, sr->buf, sr->len,
+20 −14
Original line number Diff line number Diff line
@@ -107,34 +107,35 @@ static int io_import_vec(int ddir, struct io_kiocb *req,
}

static int __io_import_rw_buffer(int ddir, struct io_kiocb *req,
			     struct io_async_rw *io,
				 struct io_async_rw *io, struct io_br_sel *sel,
				 unsigned int issue_flags)
{
	const struct io_issue_def *def = &io_issue_defs[req->opcode];
	struct io_rw *rw = io_kiocb_to_cmd(req, struct io_rw);
	void __user *buf = u64_to_user_ptr(rw->addr);
	size_t sqe_len = rw->len;

	sel->addr = u64_to_user_ptr(rw->addr);
	if (def->vectored && !(req->flags & REQ_F_BUFFER_SELECT))
		return io_import_vec(ddir, req, io, buf, sqe_len);
		return io_import_vec(ddir, req, io, sel->addr, sqe_len);

	if (io_do_buffer_select(req)) {
		buf = io_buffer_select(req, &sqe_len, io->buf_group, issue_flags);
		if (!buf)
		*sel = io_buffer_select(req, &sqe_len, io->buf_group, issue_flags);
		if (!sel->addr)
			return -ENOBUFS;
		rw->addr = (unsigned long) buf;
		rw->addr = (unsigned long) sel->addr;
		rw->len = sqe_len;
	}
	return import_ubuf(ddir, buf, sqe_len, &io->iter);
	return import_ubuf(ddir, sel->addr, sqe_len, &io->iter);
}

static inline int io_import_rw_buffer(int rw, struct io_kiocb *req,
				      struct io_async_rw *io,
				      struct io_br_sel *sel,
				      unsigned int issue_flags)
{
	int ret;

	ret = __io_import_rw_buffer(rw, req, io, issue_flags);
	ret = __io_import_rw_buffer(rw, req, io, sel, issue_flags);
	if (unlikely(ret < 0))
		return ret;

@@ -306,10 +307,12 @@ static int __io_prep_rw(struct io_kiocb *req, const struct io_uring_sqe *sqe,

static int io_rw_do_import(struct io_kiocb *req, int ddir)
{
	struct io_br_sel sel = { };

	if (io_do_buffer_select(req))
		return 0;

	return io_import_rw_buffer(ddir, req, req->async_data, 0);
	return io_import_rw_buffer(ddir, req, req->async_data, &sel, 0);
}

static int io_prep_rw(struct io_kiocb *req, const struct io_uring_sqe *sqe,
@@ -899,7 +902,8 @@ static int io_rw_init_file(struct io_kiocb *req, fmode_t mode, int rw_type)
	return 0;
}

static int __io_read(struct io_kiocb *req, unsigned int issue_flags)
static int __io_read(struct io_kiocb *req, struct io_br_sel *sel,
		     unsigned int issue_flags)
{
	bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
	struct io_rw *rw = io_kiocb_to_cmd(req, struct io_rw);
@@ -913,7 +917,7 @@ static int __io_read(struct io_kiocb *req, unsigned int issue_flags)
		if (unlikely(ret))
			return ret;
	} else if (io_do_buffer_select(req)) {
		ret = io_import_rw_buffer(ITER_DEST, req, io, issue_flags);
		ret = io_import_rw_buffer(ITER_DEST, req, io, sel, issue_flags);
		if (unlikely(ret < 0))
			return ret;
	}
@@ -1015,9 +1019,10 @@ static int __io_read(struct io_kiocb *req, unsigned int issue_flags)

int io_read(struct io_kiocb *req, unsigned int issue_flags)
{
	struct io_br_sel sel = { };
	int ret;

	ret = __io_read(req, issue_flags);
	ret = __io_read(req, &sel, issue_flags);
	if (ret >= 0)
		return kiocb_done(req, ret, issue_flags);

@@ -1027,6 +1032,7 @@ int io_read(struct io_kiocb *req, unsigned int issue_flags)
int io_read_mshot(struct io_kiocb *req, unsigned int issue_flags)
{
	struct io_rw *rw = io_kiocb_to_cmd(req, struct io_rw);
	struct io_br_sel sel = { };
	unsigned int cflags = 0;
	int ret;

@@ -1038,7 +1044,7 @@ int io_read_mshot(struct io_kiocb *req, unsigned int issue_flags)

	/* make it sync, multishot doesn't support async execution */
	rw->kiocb.ki_complete = NULL;
	ret = __io_read(req, issue_flags);
	ret = __io_read(req, &sel, issue_flags);

	/*
	 * If we get -EAGAIN, recycle our buffer and just let normal poll