Commit 592a93fe authored by Linus Torvalds's avatar Linus Torvalds
Browse files

Merge tag '6.17-rc6-ksmbd-fixes' of git://git.samba.org/ksmbd

Pull smb server fixes from Steve French:

 - Two fixes for remaining_data_length and offset checks in receive path

 - Don't go over max SGEs which caused smbdirect send to fail (and
   trigger disconnect)

* tag '6.17-rc6-ksmbd-fixes' of git://git.samba.org/ksmbd:
  ksmbd: smbdirect: verify remaining_data_length respects max_fragmented_recv_size
  ksmbd: smbdirect: validate data_offset and data_length field of smb_direct_data_transfer
  smb: server: let smb_direct_writev() respect SMB_DIRECT_MAX_SEND_SGES
parents 992d4e48 e1868ba3
Loading
Loading
Loading
Loading
+125 −58
Original line number Diff line number Diff line
@@ -554,7 +554,7 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
	case SMB_DIRECT_MSG_DATA_TRANSFER: {
		struct smb_direct_data_transfer *data_transfer =
			(struct smb_direct_data_transfer *)recvmsg->packet;
		unsigned int data_length;
		u32 remaining_data_length, data_offset, data_length;
		int avail_recvmsg_count, receive_credits;

		if (wc->byte_len <
@@ -564,15 +564,25 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
			return;
		}

		remaining_data_length = le32_to_cpu(data_transfer->remaining_data_length);
		data_length = le32_to_cpu(data_transfer->data_length);
		if (data_length) {
			if (wc->byte_len < sizeof(struct smb_direct_data_transfer) +
			    (u64)data_length) {
		data_offset = le32_to_cpu(data_transfer->data_offset);
		if (wc->byte_len < data_offset ||
		    wc->byte_len < (u64)data_offset + data_length) {
			put_recvmsg(t, recvmsg);
			smb_direct_disconnect_rdma_connection(t);
			return;
		}
		if (remaining_data_length > t->max_fragmented_recv_size ||
		    data_length > t->max_fragmented_recv_size ||
		    (u64)remaining_data_length + (u64)data_length >
		    (u64)t->max_fragmented_recv_size) {
			put_recvmsg(t, recvmsg);
			smb_direct_disconnect_rdma_connection(t);
			return;
		}

		if (data_length) {
			if (t->full_packet_received)
				recvmsg->first_segment = true;

@@ -1209,78 +1219,130 @@ static int smb_direct_writev(struct ksmbd_transport *t,
			     bool need_invalidate, unsigned int remote_key)
{
	struct smb_direct_transport *st = smb_trans_direct_transfort(t);
	int remaining_data_length;
	int start, i, j;
	int max_iov_size = st->max_send_size -
	size_t remaining_data_length;
	size_t iov_idx;
	size_t iov_ofs;
	size_t max_iov_size = st->max_send_size -
			sizeof(struct smb_direct_data_transfer);
	int ret;
	struct kvec vec;
	struct smb_direct_send_ctx send_ctx;
	int error = 0;

	if (st->status != SMB_DIRECT_CS_CONNECTED)
		return -ENOTCONN;

	//FIXME: skip RFC1002 header..
	if (WARN_ON_ONCE(niovs <= 1 || iov[0].iov_len != 4))
		return -EINVAL;
	buflen -= 4;
	iov_idx = 1;
	iov_ofs = 0;

	remaining_data_length = buflen;
	ksmbd_debug(RDMA, "Sending smb (RDMA): smb_len=%u\n", buflen);

	smb_direct_send_ctx_init(st, &send_ctx, need_invalidate, remote_key);
	start = i = 1;
	buflen = 0;
	while (true) {
		buflen += iov[i].iov_len;
		if (buflen > max_iov_size) {
			if (i > start) {
				remaining_data_length -=
					(buflen - iov[i].iov_len);
				ret = smb_direct_post_send_data(st, &send_ctx,
								&iov[start], i - start,
								remaining_data_length);
				if (ret)
	while (remaining_data_length) {
		struct kvec vecs[SMB_DIRECT_MAX_SEND_SGES - 1]; /* minus smbdirect hdr */
		size_t possible_bytes = max_iov_size;
		size_t possible_vecs;
		size_t bytes = 0;
		size_t nvecs = 0;

		/*
		 * For the last message remaining_data_length should be
		 * have been 0 already!
		 */
		if (WARN_ON_ONCE(iov_idx >= niovs)) {
			error = -EINVAL;
			goto done;
			} else {
				/* iov[start] is too big, break it */
				int nvec  = (buflen + max_iov_size - 1) /
						max_iov_size;

				for (j = 0; j < nvec; j++) {
					vec.iov_base =
						(char *)iov[start].iov_base +
						j * max_iov_size;
					vec.iov_len =
						min_t(int, max_iov_size,
						      buflen - max_iov_size * j);
					remaining_data_length -= vec.iov_len;
					ret = smb_direct_post_send_data(st, &send_ctx, &vec, 1,
									remaining_data_length);
					if (ret)
		}

		/*
		 * We have 2 factors which limit the arguments we pass
		 * to smb_direct_post_send_data():
		 *
		 * 1. The number of supported sges for the send,
		 *    while one is reserved for the smbdirect header.
		 *    And we currently need one SGE per page.
		 * 2. The number of negotiated payload bytes per send.
		 */
		possible_vecs = min_t(size_t, ARRAY_SIZE(vecs), niovs - iov_idx);

		while (iov_idx < niovs && possible_vecs && possible_bytes) {
			struct kvec *v = &vecs[nvecs];
			int page_count;

			v->iov_base = ((u8 *)iov[iov_idx].iov_base) + iov_ofs;
			v->iov_len = min_t(size_t,
					   iov[iov_idx].iov_len - iov_ofs,
					   possible_bytes);
			page_count = get_buf_page_count(v->iov_base, v->iov_len);
			if (page_count > possible_vecs) {
				/*
				 * If the number of pages in the buffer
				 * is to much (because we currently require
				 * one SGE per page), we need to limit the
				 * length.
				 *
				 * We know possible_vecs is at least 1,
				 * so we always keep the first page.
				 *
				 * We need to calculate the number extra
				 * pages (epages) we can also keep.
				 *
				 * We calculate the number of bytes in the
				 * first page (fplen), this should never be
				 * larger than v->iov_len because page_count is
				 * at least 2, but adding a limitation feels
				 * better.
				 *
				 * Then we calculate the number of bytes (elen)
				 * we can keep for the extra pages.
				 */
				size_t epages = possible_vecs - 1;
				size_t fpofs = offset_in_page(v->iov_base);
				size_t fplen = min_t(size_t, PAGE_SIZE - fpofs, v->iov_len);
				size_t elen = min_t(size_t, v->iov_len - fplen, epages*PAGE_SIZE);

				v->iov_len = fplen + elen;
				page_count = get_buf_page_count(v->iov_base, v->iov_len);
				if (WARN_ON_ONCE(page_count > possible_vecs)) {
					/*
					 * Something went wrong in the above
					 * logic...
					 */
					error = -EINVAL;
					goto done;
				}
				i++;
				if (i == niovs)
					break;
			}
			start = i;
			buflen = 0;
		} else {
			i++;
			if (i == niovs) {
				/* send out all remaining vecs */
				remaining_data_length -= buflen;
			possible_vecs -= page_count;
			nvecs += 1;
			possible_bytes -= v->iov_len;
			bytes += v->iov_len;

			iov_ofs += v->iov_len;
			if (iov_ofs >= iov[iov_idx].iov_len) {
				iov_idx += 1;
				iov_ofs = 0;
			}
		}

		remaining_data_length -= bytes;

		ret = smb_direct_post_send_data(st, &send_ctx,
								&iov[start], i - start,
						vecs, nvecs,
						remaining_data_length);
				if (ret)
		if (unlikely(ret)) {
			error = ret;
			goto done;
				break;
			}
		}
	}

done:
	ret = smb_direct_flush_send_list(st, &send_ctx, true);
	if (unlikely(!ret && error))
		ret = error;

	/*
	 * As an optimization, we don't wait for individual I/O to finish
@@ -1744,6 +1806,11 @@ static int smb_direct_init_params(struct smb_direct_transport *t,
		return -EINVAL;
	}

	if (device->attrs.max_send_sge < SMB_DIRECT_MAX_SEND_SGES) {
		pr_err("warning: device max_send_sge = %d too small\n",
		       device->attrs.max_send_sge);
		return -EINVAL;
	}
	if (device->attrs.max_recv_sge < SMB_DIRECT_MAX_RECV_SGES) {
		pr_err("warning: device max_recv_sge = %d too small\n",
		       device->attrs.max_recv_sge);
@@ -1767,7 +1834,7 @@ static int smb_direct_init_params(struct smb_direct_transport *t,

	cap->max_send_wr = max_send_wrs;
	cap->max_recv_wr = t->recv_credit_max;
	cap->max_send_sge = max_sge_per_wr;
	cap->max_send_sge = SMB_DIRECT_MAX_SEND_SGES;
	cap->max_recv_sge = SMB_DIRECT_MAX_RECV_SGES;
	cap->max_inline_data = 0;
	cap->max_rdma_ctxs = t->max_rw_credits;