Commit 779bcdd4 authored by Jason Wang's avatar Jason Wang Committed by Jakub Kicinski
Browse files

vhost: rewind next_avail_head while discarding descriptors



When discarding descriptors with IN_ORDER, we should rewind
next_avail_head otherwise it would run out of sync with
last_avail_idx. This would cause driver to report
"id X is not a head".

Fixing this by returning the number of descriptors that is used for
each buffer via vhost_get_vq_desc_n() so caller can use the value
while discarding descriptors.

Fixes: 67a873df ("vhost: basic in order support")
Cc: stable@vger.kernel.org
Signed-off-by: default avatarJason Wang <jasowang@redhat.com>
Acked-by: default avatarMichael S. Tsirkin <mst@redhat.com>
Link: https://patch.msgid.link/20251120022950.10117-1-jasowang@redhat.com


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parent 0ebc27a4
Loading
Loading
Loading
Loading
+32 −21
Original line number Diff line number Diff line
@@ -592,14 +592,15 @@ static void vhost_net_busy_poll(struct vhost_net *net,
static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
				    struct vhost_net_virtqueue *tnvq,
				    unsigned int *out_num, unsigned int *in_num,
				    struct msghdr *msghdr, bool *busyloop_intr)
				    struct msghdr *msghdr, bool *busyloop_intr,
				    unsigned int *ndesc)
{
	struct vhost_net_virtqueue *rnvq = &net->vqs[VHOST_NET_VQ_RX];
	struct vhost_virtqueue *rvq = &rnvq->vq;
	struct vhost_virtqueue *tvq = &tnvq->vq;

	int r = vhost_get_vq_desc(tvq, tvq->iov, ARRAY_SIZE(tvq->iov),
				  out_num, in_num, NULL, NULL);
	int r = vhost_get_vq_desc_n(tvq, tvq->iov, ARRAY_SIZE(tvq->iov),
				    out_num, in_num, NULL, NULL, ndesc);

	if (r == tvq->num && tvq->busyloop_timeout) {
		/* Flush batched packets first */
@@ -610,8 +611,8 @@ static int vhost_net_tx_get_vq_desc(struct vhost_net *net,

		vhost_net_busy_poll(net, rvq, tvq, busyloop_intr, false);

		r = vhost_get_vq_desc(tvq, tvq->iov, ARRAY_SIZE(tvq->iov),
				      out_num, in_num, NULL, NULL);
		r = vhost_get_vq_desc_n(tvq, tvq->iov, ARRAY_SIZE(tvq->iov),
					out_num, in_num, NULL, NULL, ndesc);
	}

	return r;
@@ -642,12 +643,14 @@ static int get_tx_bufs(struct vhost_net *net,
		       struct vhost_net_virtqueue *nvq,
		       struct msghdr *msg,
		       unsigned int *out, unsigned int *in,
		       size_t *len, bool *busyloop_intr)
		       size_t *len, bool *busyloop_intr,
		       unsigned int *ndesc)
{
	struct vhost_virtqueue *vq = &nvq->vq;
	int ret;

	ret = vhost_net_tx_get_vq_desc(net, nvq, out, in, msg, busyloop_intr);
	ret = vhost_net_tx_get_vq_desc(net, nvq, out, in, msg,
				       busyloop_intr, ndesc);

	if (ret < 0 || ret == vq->num)
		return ret;
@@ -766,6 +769,7 @@ static void handle_tx_copy(struct vhost_net *net, struct socket *sock)
	int sent_pkts = 0;
	bool sock_can_batch = (sock->sk->sk_sndbuf == INT_MAX);
	bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER);
	unsigned int ndesc = 0;

	do {
		bool busyloop_intr = false;
@@ -774,7 +778,7 @@ static void handle_tx_copy(struct vhost_net *net, struct socket *sock)
			vhost_tx_batch(net, nvq, sock, &msg);

		head = get_tx_bufs(net, nvq, &msg, &out, &in, &len,
				   &busyloop_intr);
				   &busyloop_intr, &ndesc);
		/* On error, stop handling until the next kick. */
		if (unlikely(head < 0))
			break;
@@ -806,7 +810,7 @@ static void handle_tx_copy(struct vhost_net *net, struct socket *sock)
				goto done;
			} else if (unlikely(err != -ENOSPC)) {
				vhost_tx_batch(net, nvq, sock, &msg);
				vhost_discard_vq_desc(vq, 1);
				vhost_discard_vq_desc(vq, 1, ndesc);
				vhost_net_enable_vq(net, vq);
				break;
			}
@@ -829,7 +833,7 @@ static void handle_tx_copy(struct vhost_net *net, struct socket *sock)
		err = sock->ops->sendmsg(sock, &msg, len);
		if (unlikely(err < 0)) {
			if (err == -EAGAIN || err == -ENOMEM || err == -ENOBUFS) {
				vhost_discard_vq_desc(vq, 1);
				vhost_discard_vq_desc(vq, 1, ndesc);
				vhost_net_enable_vq(net, vq);
				break;
			}
@@ -868,6 +872,7 @@ static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock)
	int err;
	struct vhost_net_ubuf_ref *ubufs;
	struct ubuf_info_msgzc *ubuf;
	unsigned int ndesc = 0;
	bool zcopy_used;
	int sent_pkts = 0;

@@ -879,7 +884,7 @@ static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock)

		busyloop_intr = false;
		head = get_tx_bufs(net, nvq, &msg, &out, &in, &len,
				   &busyloop_intr);
				   &busyloop_intr, &ndesc);
		/* On error, stop handling until the next kick. */
		if (unlikely(head < 0))
			break;
@@ -941,7 +946,7 @@ static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock)
					vq->heads[ubuf->desc].len = VHOST_DMA_DONE_LEN;
			}
			if (retry) {
				vhost_discard_vq_desc(vq, 1);
				vhost_discard_vq_desc(vq, 1, ndesc);
				vhost_net_enable_vq(net, vq);
				break;
			}
@@ -1045,11 +1050,12 @@ static int get_rx_bufs(struct vhost_net_virtqueue *nvq,
		       unsigned *iovcount,
		       struct vhost_log *log,
		       unsigned *log_num,
		       unsigned int quota)
		       unsigned int quota,
		       unsigned int *ndesc)
{
	struct vhost_virtqueue *vq = &nvq->vq;
	bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER);
	unsigned int out, in;
	unsigned int out, in, desc_num, n = 0;
	int seg = 0;
	int headcount = 0;
	unsigned d;
@@ -1064,9 +1070,9 @@ static int get_rx_bufs(struct vhost_net_virtqueue *nvq,
			r = -ENOBUFS;
			goto err;
		}
		r = vhost_get_vq_desc(vq, vq->iov + seg,
		r = vhost_get_vq_desc_n(vq, vq->iov + seg,
					ARRAY_SIZE(vq->iov) - seg, &out,
				      &in, log, log_num);
					&in, log, log_num, &desc_num);
		if (unlikely(r < 0))
			goto err;

@@ -1093,6 +1099,7 @@ static int get_rx_bufs(struct vhost_net_virtqueue *nvq,
		++headcount;
		datalen -= len;
		seg += in;
		n += desc_num;
	}

	*iovcount = seg;
@@ -1113,9 +1120,11 @@ static int get_rx_bufs(struct vhost_net_virtqueue *nvq,
		nheads[0] = headcount;
	}

	*ndesc = n;

	return headcount;
err:
	vhost_discard_vq_desc(vq, headcount);
	vhost_discard_vq_desc(vq, headcount, n);
	return r;
}

@@ -1151,6 +1160,7 @@ static void handle_rx(struct vhost_net *net)
	struct iov_iter fixup;
	__virtio16 num_buffers;
	int recv_pkts = 0;
	unsigned int ndesc;

	mutex_lock_nested(&vq->mutex, VHOST_NET_VQ_RX);
	sock = vhost_vq_get_backend(vq);
@@ -1182,7 +1192,8 @@ static void handle_rx(struct vhost_net *net)
		headcount = get_rx_bufs(nvq, vq->heads + count,
					vq->nheads + count,
					vhost_len, &in, vq_log, &log,
					likely(mergeable) ? UIO_MAXIOV : 1);
					likely(mergeable) ? UIO_MAXIOV : 1,
					&ndesc);
		/* On error, stop handling until the next kick. */
		if (unlikely(headcount < 0))
			goto out;
@@ -1228,7 +1239,7 @@ static void handle_rx(struct vhost_net *net)
		if (unlikely(err != sock_len)) {
			pr_debug("Discarded rx packet: "
				 " len %d, expected %zd\n", err, sock_len);
			vhost_discard_vq_desc(vq, headcount);
			vhost_discard_vq_desc(vq, headcount, ndesc);
			continue;
		}
		/* Supply virtio_net_hdr if VHOST_NET_F_VIRTIO_NET_HDR */
@@ -1252,7 +1263,7 @@ static void handle_rx(struct vhost_net *net)
		    copy_to_iter(&num_buffers, sizeof num_buffers,
				 &fixup) != sizeof num_buffers) {
			vq_err(vq, "Failed num_buffers write");
			vhost_discard_vq_desc(vq, headcount);
			vhost_discard_vq_desc(vq, headcount, ndesc);
			goto out;
		}
		nvq->done_idx += headcount;
+62 −14
Original line number Diff line number Diff line
@@ -2792,18 +2792,34 @@ static int get_indirect(struct vhost_virtqueue *vq,
	return 0;
}

/* This looks in the virtqueue and for the first available buffer, and converts
 * it to an iovec for convenient access.  Since descriptors consist of some
 * number of output then some number of input descriptors, it's actually two
 * iovecs, but we pack them into one and note how many of each there were.
/**
 * vhost_get_vq_desc_n - Fetch the next available descriptor chain and build iovecs
 * @vq: target virtqueue
 * @iov: array that receives the scatter/gather segments
 * @iov_size: capacity of @iov in elements
 * @out_num: the number of output segments
 * @in_num: the number of input segments
 * @log: optional array to record addr/len for each writable segment; NULL if unused
 * @log_num: optional output; number of entries written to @log when provided
 * @ndesc: optional output; number of descriptors consumed from the available ring
 *         (useful for rollback via vhost_discard_vq_desc)
 *
 * This function returns the descriptor number found, or vq->num (which is
 * never a valid descriptor number) if none was found.  A negative code is
 * returned on error. */
int vhost_get_vq_desc(struct vhost_virtqueue *vq,
 * Extracts one available descriptor chain from @vq and translates guest addresses
 * into host iovecs.
 *
 * On success, advances @vq->last_avail_idx by 1 and @vq->next_avail_head by the
 * number of descriptors consumed (also stored via @ndesc when non-NULL).
 *
 * Return:
 * - head index in [0, @vq->num) on success;
 * - @vq->num if no descriptor is currently available;
 * - negative errno on failure
 */
int vhost_get_vq_desc_n(struct vhost_virtqueue *vq,
			struct iovec iov[], unsigned int iov_size,
			unsigned int *out_num, unsigned int *in_num,
		      struct vhost_log *log, unsigned int *log_num)
			struct vhost_log *log, unsigned int *log_num,
			unsigned int *ndesc)
{
	bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER);
	struct vring_desc desc;
@@ -2921,17 +2937,49 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
	vq->last_avail_idx++;
	vq->next_avail_head += c;

	if (ndesc)
		*ndesc = c;

	/* Assume notifications from guest are disabled at this point,
	 * if they aren't we would need to update avail_event index. */
	BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY));
	return head;
}
EXPORT_SYMBOL_GPL(vhost_get_vq_desc_n);

/* This looks in the virtqueue and for the first available buffer, and converts
 * it to an iovec for convenient access.  Since descriptors consist of some
 * number of output then some number of input descriptors, it's actually two
 * iovecs, but we pack them into one and note how many of each there were.
 *
 * This function returns the descriptor number found, or vq->num (which is
 * never a valid descriptor number) if none was found.  A negative code is
 * returned on error.
 */
int vhost_get_vq_desc(struct vhost_virtqueue *vq,
		      struct iovec iov[], unsigned int iov_size,
		      unsigned int *out_num, unsigned int *in_num,
		      struct vhost_log *log, unsigned int *log_num)
{
	return vhost_get_vq_desc_n(vq, iov, iov_size, out_num, in_num,
				   log, log_num, NULL);
}
EXPORT_SYMBOL_GPL(vhost_get_vq_desc);

/* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */
void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n)
/**
 * vhost_discard_vq_desc - Reverse the effect of vhost_get_vq_desc_n()
 * @vq: target virtqueue
 * @nbufs: number of buffers to roll back
 * @ndesc: number of descriptors to roll back
 *
 * Rewinds the internal consumer cursors after a failed attempt to use buffers
 * returned by vhost_get_vq_desc_n().
 */
void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int nbufs,
			   unsigned int ndesc)
{
	vq->last_avail_idx -= n;
	vq->next_avail_head -= ndesc;
	vq->last_avail_idx -= nbufs;
}
EXPORT_SYMBOL_GPL(vhost_discard_vq_desc);

+9 −1
Original line number Diff line number Diff line
@@ -230,7 +230,15 @@ int vhost_get_vq_desc(struct vhost_virtqueue *,
		      struct iovec iov[], unsigned int iov_size,
		      unsigned int *out_num, unsigned int *in_num,
		      struct vhost_log *log, unsigned int *log_num);
void vhost_discard_vq_desc(struct vhost_virtqueue *, int n);

int vhost_get_vq_desc_n(struct vhost_virtqueue *vq,
			struct iovec iov[], unsigned int iov_size,
			unsigned int *out_num, unsigned int *in_num,
			struct vhost_log *log, unsigned int *log_num,
			unsigned int *ndesc);

void vhost_discard_vq_desc(struct vhost_virtqueue *, int nbuf,
			   unsigned int ndesc);

bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work);
bool vhost_vq_has_work(struct vhost_virtqueue *vq);