Commit 566771af authored by Linus Torvalds's avatar Linus Torvalds
Browse files

Merge tag 'v6.18-rc2-smb-server-fixes' of git://git.samba.org/ksmbd

Pull smb server fixes from Steve French:
 "smbdirect (RDMA) fixes in order avoid potential submission queue
  overflows:

   - free transport teardown fix

   - credit related fixes (five server related, one client related)"

* tag 'v6.18-rc2-smb-server-fixes' of git://git.samba.org/ksmbd:
  smb: server: let free_transport() wait for SMBDIRECT_SOCKET_DISCONNECTED
  smb: client: make use of smbdirect_socket.send_io.lcredits.*
  smb: server: make use of smbdirect_socket.send_io.lcredits.*
  smb: server: simplify sibling_list handling in smb_direct_flush_send_list/send_done
  smb: server: smb_direct_disconnect_rdma_connection() already wakes all waiters on error
  smb: smbdirect: introduce smbdirect_socket.send_io.lcredits.*
  smb: server: allocate enough space for RW WRs and ib_drain_qp()
parents 53abe3e1 dd6940f5
Loading
Loading
Loading
Loading
+42 −25
Original line number Diff line number Diff line
@@ -172,6 +172,7 @@ static void smbd_disconnect_wake_up_all(struct smbdirect_socket *sc)
	 * in order to notice the broken connection.
	 */
	wake_up_all(&sc->status_wait);
	wake_up_all(&sc->send_io.lcredits.wait_queue);
	wake_up_all(&sc->send_io.credits.wait_queue);
	wake_up_all(&sc->send_io.pending.dec_wait_queue);
	wake_up_all(&sc->send_io.pending.zero_wait_queue);
@@ -495,6 +496,7 @@ static void send_done(struct ib_cq *cq, struct ib_wc *wc)
	struct smbdirect_send_io *request =
		container_of(wc->wr_cqe, struct smbdirect_send_io, cqe);
	struct smbdirect_socket *sc = request->socket;
	int lcredits = 0;

	log_rdma_send(INFO, "smbdirect_send_io 0x%p completed wc->status=%s\n",
		request, ib_wc_status_msg(wc->status));
@@ -504,22 +506,24 @@ static void send_done(struct ib_cq *cq, struct ib_wc *wc)
			request->sge[i].addr,
			request->sge[i].length,
			DMA_TO_DEVICE);
	mempool_free(request, sc->send_io.mem.pool);
	lcredits += 1;

	if (wc->status != IB_WC_SUCCESS || wc->opcode != IB_WC_SEND) {
		if (wc->status != IB_WC_WR_FLUSH_ERR)
			log_rdma_send(ERR, "wc->status=%s wc->opcode=%d\n",
				ib_wc_status_msg(wc->status), wc->opcode);
		mempool_free(request, sc->send_io.mem.pool);
		smbd_disconnect_rdma_connection(sc);
		return;
	}

	atomic_add(lcredits, &sc->send_io.lcredits.count);
	wake_up(&sc->send_io.lcredits.wait_queue);

	if (atomic_dec_and_test(&sc->send_io.pending.count))
		wake_up(&sc->send_io.pending.zero_wait_queue);

	wake_up(&sc->send_io.pending.dec_wait_queue);

	mempool_free(request, sc->send_io.mem.pool);
}

static void dump_smbdirect_negotiate_resp(struct smbdirect_negotiate_resp *resp)
@@ -567,6 +571,7 @@ static bool process_negotiation_response(
		log_rdma_event(ERR, "error: credits_granted==0\n");
		return false;
	}
	atomic_set(&sc->send_io.lcredits.count, sp->send_credit_target);
	atomic_set(&sc->send_io.credits.count, le16_to_cpu(packet->credits_granted));

	if (le32_to_cpu(packet->preferred_send_size) > sp->max_recv_size) {
@@ -1114,6 +1119,24 @@ static int smbd_post_send_iter(struct smbdirect_socket *sc,
	struct smbdirect_data_transfer *packet;
	int new_credits = 0;

wait_lcredit:
	/* Wait for local send credits */
	rc = wait_event_interruptible(sc->send_io.lcredits.wait_queue,
		atomic_read(&sc->send_io.lcredits.count) > 0 ||
		sc->status != SMBDIRECT_SOCKET_CONNECTED);
	if (rc)
		goto err_wait_lcredit;

	if (sc->status != SMBDIRECT_SOCKET_CONNECTED) {
		log_outgoing(ERR, "disconnected not sending on wait_credit\n");
		rc = -EAGAIN;
		goto err_wait_lcredit;
	}
	if (unlikely(atomic_dec_return(&sc->send_io.lcredits.count) < 0)) {
		atomic_inc(&sc->send_io.lcredits.count);
		goto wait_lcredit;
	}

wait_credit:
	/* Wait for send credits. A SMBD packet needs one credit */
	rc = wait_event_interruptible(sc->send_io.credits.wait_queue,
@@ -1132,23 +1155,6 @@ static int smbd_post_send_iter(struct smbdirect_socket *sc,
		goto wait_credit;
	}

wait_send_queue:
	wait_event(sc->send_io.pending.dec_wait_queue,
		atomic_read(&sc->send_io.pending.count) < sp->send_credit_target ||
		sc->status != SMBDIRECT_SOCKET_CONNECTED);

	if (sc->status != SMBDIRECT_SOCKET_CONNECTED) {
		log_outgoing(ERR, "disconnected not sending on wait_send_queue\n");
		rc = -EAGAIN;
		goto err_wait_send_queue;
	}

	if (unlikely(atomic_inc_return(&sc->send_io.pending.count) >
				sp->send_credit_target)) {
		atomic_dec(&sc->send_io.pending.count);
		goto wait_send_queue;
	}

	request = mempool_alloc(sc->send_io.mem.pool, GFP_KERNEL);
	if (!request) {
		rc = -ENOMEM;
@@ -1229,10 +1235,21 @@ static int smbd_post_send_iter(struct smbdirect_socket *sc,
		     le32_to_cpu(packet->data_length),
		     le32_to_cpu(packet->remaining_data_length));

	/*
	 * Now that we got a local and a remote credit
	 * we add us as pending
	 */
	atomic_inc(&sc->send_io.pending.count);

	rc = smbd_post_send(sc, request);
	if (!rc)
		return 0;

	if (atomic_dec_and_test(&sc->send_io.pending.count))
		wake_up(&sc->send_io.pending.zero_wait_queue);

	wake_up(&sc->send_io.pending.dec_wait_queue);

err_dma:
	for (i = 0; i < request->num_sge; i++)
		if (request->sge[i].addr)
@@ -1246,14 +1263,14 @@ static int smbd_post_send_iter(struct smbdirect_socket *sc,
	atomic_sub(new_credits, &sc->recv_io.credits.count);

err_alloc:
	if (atomic_dec_and_test(&sc->send_io.pending.count))
		wake_up(&sc->send_io.pending.zero_wait_queue);

err_wait_send_queue:
	/* roll back send credits and pending */
	atomic_inc(&sc->send_io.credits.count);
	wake_up(&sc->send_io.credits.wait_queue);

err_wait_credit:
	atomic_inc(&sc->send_io.lcredits.count);
	wake_up(&sc->send_io.lcredits.wait_queue);

err_wait_lcredit:
	return rc;
}

+12 −1
Original line number Diff line number Diff line
@@ -142,7 +142,15 @@ struct smbdirect_socket {
		} mem;

		/*
		 * The credit state for the send side
		 * The local credit state for ib_post_send()
		 */
		struct {
			atomic_t count;
			wait_queue_head_t wait_queue;
		} lcredits;

		/*
		 * The remote credit state for the send side
		 */
		struct {
			atomic_t count;
@@ -337,6 +345,9 @@ static __always_inline void smbdirect_socket_init(struct smbdirect_socket *sc)
	INIT_DELAYED_WORK(&sc->idle.timer_work, __smbdirect_socket_disabled_work);
	disable_delayed_work_sync(&sc->idle.timer_work);

	atomic_set(&sc->send_io.lcredits.count, 0);
	init_waitqueue_head(&sc->send_io.lcredits.wait_queue);

	atomic_set(&sc->send_io.credits.count, 0);
	init_waitqueue_head(&sc->send_io.credits.wait_queue);

+219 −125
Original line number Diff line number Diff line
@@ -219,6 +219,7 @@ static void smb_direct_disconnect_wake_up_all(struct smbdirect_socket *sc)
	 * in order to notice the broken connection.
	 */
	wake_up_all(&sc->status_wait);
	wake_up_all(&sc->send_io.lcredits.wait_queue);
	wake_up_all(&sc->send_io.credits.wait_queue);
	wake_up_all(&sc->send_io.pending.zero_wait_queue);
	wake_up_all(&sc->recv_io.reassembly.wait_queue);
@@ -450,11 +451,10 @@ static void free_transport(struct smb_direct_transport *t)
	struct smbdirect_recv_io *recvmsg;

	disable_work_sync(&sc->disconnect_work);
	if (sc->status < SMBDIRECT_SOCKET_DISCONNECTING) {
	if (sc->status < SMBDIRECT_SOCKET_DISCONNECTING)
		smb_direct_disconnect_rdma_work(&sc->disconnect_work);
		wait_event_interruptible(sc->status_wait,
					 sc->status == SMBDIRECT_SOCKET_DISCONNECTED);
	}
	if (sc->status < SMBDIRECT_SOCKET_DISCONNECTED)
		wait_event(sc->status_wait, sc->status == SMBDIRECT_SOCKET_DISCONNECTED);

	/*
	 * Wake up all waiters in all wait queues
@@ -471,7 +471,6 @@ static void free_transport(struct smb_direct_transport *t)

	if (sc->ib.qp) {
		ib_drain_qp(sc->ib.qp);
		ib_mr_pool_destroy(sc->ib.qp, &sc->ib.qp->rdma_mrs);
		sc->ib.qp = NULL;
		rdma_destroy_qp(sc->rdma.cm_id);
	}
@@ -524,6 +523,12 @@ static void smb_direct_free_sendmsg(struct smbdirect_socket *sc,
{
	int i;

	/*
	 * The list needs to be empty!
	 * The caller should take care of it.
	 */
	WARN_ON_ONCE(!list_empty(&msg->sibling_list));

	if (msg->num_sge > 0) {
		ib_dma_unmap_single(sc->ib.dev,
				    msg->sge[0].addr, msg->sge[0].length,
@@ -909,9 +914,9 @@ static void smb_direct_post_recv_credits(struct work_struct *work)

static void send_done(struct ib_cq *cq, struct ib_wc *wc)
{
	struct smbdirect_send_io *sendmsg, *sibling;
	struct smbdirect_send_io *sendmsg, *sibling, *next;
	struct smbdirect_socket *sc;
	struct list_head *pos, *prev, *end;
	int lcredits = 0;

	sendmsg = container_of(wc->wr_cqe, struct smbdirect_send_io, cqe);
	sc = sendmsg->socket;
@@ -920,27 +925,31 @@ static void send_done(struct ib_cq *cq, struct ib_wc *wc)
		    ib_wc_status_msg(wc->status), wc->status,
		    wc->opcode);

	/*
	 * Free possible siblings and then the main send_io
	 */
	list_for_each_entry_safe(sibling, next, &sendmsg->sibling_list, sibling_list) {
		list_del_init(&sibling->sibling_list);
		smb_direct_free_sendmsg(sc, sibling);
		lcredits += 1;
	}
	/* Note this frees wc->wr_cqe, but not wc */
	smb_direct_free_sendmsg(sc, sendmsg);
	lcredits += 1;

	if (wc->status != IB_WC_SUCCESS || wc->opcode != IB_WC_SEND) {
		pr_err("Send error. status='%s (%d)', opcode=%d\n",
		       ib_wc_status_msg(wc->status), wc->status,
		       wc->opcode);
		smb_direct_disconnect_rdma_connection(sc);
		return;
	}

	atomic_add(lcredits, &sc->send_io.lcredits.count);
	wake_up(&sc->send_io.lcredits.wait_queue);

	if (atomic_dec_and_test(&sc->send_io.pending.count))
		wake_up(&sc->send_io.pending.zero_wait_queue);

	/* iterate and free the list of messages in reverse. the list's head
	 * is invalid.
	 */
	for (pos = &sendmsg->sibling_list, prev = pos->prev, end = sendmsg->sibling_list.next;
	     prev != end; pos = prev, prev = prev->prev) {
		sibling = container_of(pos, struct smbdirect_send_io, sibling_list);
		smb_direct_free_sendmsg(sc, sibling);
	}

	sibling = container_of(pos, struct smbdirect_send_io, sibling_list);
	smb_direct_free_sendmsg(sc, sibling);
}

static int manage_credits_prior_sending(struct smbdirect_socket *sc)
@@ -988,8 +997,6 @@ static int smb_direct_post_send(struct smbdirect_socket *sc,
	ret = ib_post_send(sc->ib.qp, wr, NULL);
	if (ret) {
		pr_err("failed to post send: %d\n", ret);
		if (atomic_dec_and_test(&sc->send_io.pending.count))
			wake_up(&sc->send_io.pending.zero_wait_queue);
		smb_direct_disconnect_rdma_connection(sc);
	}
	return ret;
@@ -1032,19 +1039,29 @@ static int smb_direct_flush_send_list(struct smbdirect_socket *sc,
	last->wr.send_flags = IB_SEND_SIGNALED;
	last->wr.wr_cqe = &last->cqe;

	/*
	 * Remove last from send_ctx->msg_list
	 * and splice the rest of send_ctx->msg_list
	 * to last->sibling_list.
	 *
	 * send_ctx->msg_list is a valid empty list
	 * at the end.
	 */
	list_del_init(&last->sibling_list);
	list_splice_tail_init(&send_ctx->msg_list, &last->sibling_list);
	send_ctx->wr_cnt = 0;

	ret = smb_direct_post_send(sc, &first->wr);
	if (!ret) {
		smb_direct_send_ctx_init(send_ctx,
					 send_ctx->need_invalidate_rkey,
					 send_ctx->remote_key);
	} else {
		atomic_add(send_ctx->wr_cnt, &sc->send_io.credits.count);
		wake_up(&sc->send_io.credits.wait_queue);
		list_for_each_entry_safe(first, last, &send_ctx->msg_list,
					 sibling_list) {
			smb_direct_free_sendmsg(sc, first);
	if (ret) {
		struct smbdirect_send_io *sibling, *next;

		list_for_each_entry_safe(sibling, next, &last->sibling_list, sibling_list) {
			list_del_init(&sibling->sibling_list);
			smb_direct_free_sendmsg(sc, sibling);
		}
		smb_direct_free_sendmsg(sc, last);
	}

	return ret;
}

@@ -1070,6 +1087,23 @@ static int wait_for_credits(struct smbdirect_socket *sc,
	} while (true);
}

static int wait_for_send_lcredit(struct smbdirect_socket *sc,
				 struct smbdirect_send_batch *send_ctx)
{
	if (send_ctx && (atomic_read(&sc->send_io.lcredits.count) <= 1)) {
		int ret;

		ret = smb_direct_flush_send_list(sc, send_ctx, false);
		if (ret)
			return ret;
	}

	return wait_for_credits(sc,
				&sc->send_io.lcredits.wait_queue,
				&sc->send_io.lcredits.count,
				1);
}

static int wait_for_send_credits(struct smbdirect_socket *sc,
				 struct smbdirect_send_batch *send_ctx)
{
@@ -1257,9 +1291,13 @@ static int smb_direct_post_send_data(struct smbdirect_socket *sc,
	int data_length;
	struct scatterlist sg[SMBDIRECT_SEND_IO_MAX_SGE - 1];

	ret = wait_for_send_lcredit(sc, send_ctx);
	if (ret)
		goto lcredit_failed;

	ret = wait_for_send_credits(sc, send_ctx);
	if (ret)
		return ret;
		goto credit_failed;

	data_length = 0;
	for (i = 0; i < niov; i++)
@@ -1267,10 +1305,8 @@ static int smb_direct_post_send_data(struct smbdirect_socket *sc,

	ret = smb_direct_create_header(sc, data_length, remaining_data_length,
				       &msg);
	if (ret) {
		atomic_inc(&sc->send_io.credits.count);
		return ret;
	}
	if (ret)
		goto header_failed;

	for (i = 0; i < niov; i++) {
		struct ib_sge *sge;
@@ -1308,7 +1344,11 @@ static int smb_direct_post_send_data(struct smbdirect_socket *sc,
	return 0;
err:
	smb_direct_free_sendmsg(sc, msg);
header_failed:
	atomic_inc(&sc->send_io.credits.count);
credit_failed:
	atomic_inc(&sc->send_io.lcredits.count);
lcredit_failed:
	return ret;
}

@@ -1871,20 +1911,11 @@ static int smb_direct_prepare_negotiation(struct smbdirect_socket *sc)
	return ret;
}

static unsigned int smb_direct_get_max_fr_pages(struct smbdirect_socket *sc)
{
	return min_t(unsigned int,
		     sc->ib.dev->attrs.max_fast_reg_page_list_len,
		     256);
}

static int smb_direct_init_params(struct smbdirect_socket *sc,
				  struct ib_qp_cap *cap)
static int smb_direct_init_params(struct smbdirect_socket *sc)
{
	struct smbdirect_socket_parameters *sp = &sc->parameters;
	struct ib_device *device = sc->ib.dev;
	int max_send_sges, max_rw_wrs, max_send_wrs;
	unsigned int max_sge_per_wr, wrs_per_credit;
	int max_send_sges;
	unsigned int maxpages;

	/* need 3 more sge. because a SMB_DIRECT header, SMB2 header,
	 * SMB2 response could be mapped.
@@ -1895,67 +1926,20 @@ static int smb_direct_init_params(struct smbdirect_socket *sc,
		return -EINVAL;
	}

	/* Calculate the number of work requests for RDMA R/W.
	 * The maximum number of pages which can be registered
	 * with one Memory region can be transferred with one
	 * R/W credit. And at least 4 work requests for each credit
	 * are needed for MR registration, RDMA R/W, local & remote
	 * MR invalidation.
	 */
	sc->rw_io.credits.num_pages = smb_direct_get_max_fr_pages(sc);
	sc->rw_io.credits.max = DIV_ROUND_UP(sp->max_read_write_size,
					 (sc->rw_io.credits.num_pages - 1) *
					 PAGE_SIZE);

	max_sge_per_wr = min_t(unsigned int, device->attrs.max_send_sge,
			       device->attrs.max_sge_rd);
	max_sge_per_wr = max_t(unsigned int, max_sge_per_wr,
			       max_send_sges);
	wrs_per_credit = max_t(unsigned int, 4,
			       DIV_ROUND_UP(sc->rw_io.credits.num_pages,
					    max_sge_per_wr) + 1);
	max_rw_wrs = sc->rw_io.credits.max * wrs_per_credit;

	max_send_wrs = sp->send_credit_target + max_rw_wrs;
	if (max_send_wrs > device->attrs.max_cqe ||
	    max_send_wrs > device->attrs.max_qp_wr) {
		pr_err("consider lowering send_credit_target = %d\n",
		       sp->send_credit_target);
		pr_err("Possible CQE overrun, device reporting max_cqe %d max_qp_wr %d\n",
		       device->attrs.max_cqe, device->attrs.max_qp_wr);
		return -EINVAL;
	}
	atomic_set(&sc->send_io.lcredits.count, sp->send_credit_target);

	if (sp->recv_credit_max > device->attrs.max_cqe ||
	    sp->recv_credit_max > device->attrs.max_qp_wr) {
		pr_err("consider lowering receive_credit_max = %d\n",
		       sp->recv_credit_max);
		pr_err("Possible CQE overrun, device reporting max_cpe %d max_qp_wr %d\n",
		       device->attrs.max_cqe, device->attrs.max_qp_wr);
		return -EINVAL;
	}

	if (device->attrs.max_send_sge < SMBDIRECT_SEND_IO_MAX_SGE) {
		pr_err("warning: device max_send_sge = %d too small\n",
		       device->attrs.max_send_sge);
		return -EINVAL;
	}
	if (device->attrs.max_recv_sge < SMBDIRECT_RECV_IO_MAX_SGE) {
		pr_err("warning: device max_recv_sge = %d too small\n",
		       device->attrs.max_recv_sge);
		return -EINVAL;
	}
	maxpages = DIV_ROUND_UP(sp->max_read_write_size, PAGE_SIZE);
	sc->rw_io.credits.max = rdma_rw_mr_factor(sc->ib.dev,
						  sc->rdma.cm_id->port_num,
						  maxpages);
	sc->rw_io.credits.num_pages = DIV_ROUND_UP(maxpages, sc->rw_io.credits.max);
	/* add one extra in order to handle unaligned pages */
	sc->rw_io.credits.max += 1;

	sc->recv_io.credits.target = 1;

	atomic_set(&sc->rw_io.credits.count, sc->rw_io.credits.max);

	cap->max_send_wr = max_send_wrs;
	cap->max_recv_wr = sp->recv_credit_max;
	cap->max_send_sge = SMBDIRECT_SEND_IO_MAX_SGE;
	cap->max_recv_sge = SMBDIRECT_RECV_IO_MAX_SGE;
	cap->max_inline_data = 0;
	cap->max_rdma_ctxs = sc->rw_io.credits.max;
	return 0;
}

@@ -2029,13 +2013,129 @@ static int smb_direct_create_pools(struct smbdirect_socket *sc)
	return -ENOMEM;
}

static int smb_direct_create_qpair(struct smbdirect_socket *sc,
				   struct ib_qp_cap *cap)
static u32 smb_direct_rdma_rw_send_wrs(struct ib_device *dev, const struct ib_qp_init_attr *attr)
{
	/*
	 * This could be split out of rdma_rw_init_qp()
	 * and be a helper function next to rdma_rw_mr_factor()
	 *
	 * We can't check unlikely(rdma_rw_force_mr) here,
	 * but that is most likely 0 anyway.
	 */
	u32 factor;

	WARN_ON_ONCE(attr->port_num == 0);

	/*
	 * Each context needs at least one RDMA READ or WRITE WR.
	 *
	 * For some hardware we might need more, eventually we should ask the
	 * HCA driver for a multiplier here.
	 */
	factor = 1;

	/*
	 * If the device needs MRs to perform RDMA READ or WRITE operations,
	 * we'll need two additional MRs for the registrations and the
	 * invalidation.
	 */
	if (rdma_protocol_iwarp(dev, attr->port_num) || dev->attrs.max_sgl_rd)
		factor += 2;	/* inv + reg */

	return factor * attr->cap.max_rdma_ctxs;
}

static int smb_direct_create_qpair(struct smbdirect_socket *sc)
{
	struct smbdirect_socket_parameters *sp = &sc->parameters;
	int ret;
	struct ib_qp_cap qp_cap;
	struct ib_qp_init_attr qp_attr;
	int pages_per_rw;
	u32 max_send_wr;
	u32 rdma_send_wr;

	/*
	 * Note that {rdma,ib}_create_qp() will call
	 * rdma_rw_init_qp() if cap->max_rdma_ctxs is not 0.
	 * It will adjust cap->max_send_wr to the required
	 * number of additional WRs for the RDMA RW operations.
	 * It will cap cap->max_send_wr to the device limit.
	 *
	 * +1 for ib_drain_qp
	 */
	qp_cap.max_send_wr = sp->send_credit_target + 1;
	qp_cap.max_recv_wr = sp->recv_credit_max + 1;
	qp_cap.max_send_sge = SMBDIRECT_SEND_IO_MAX_SGE;
	qp_cap.max_recv_sge = SMBDIRECT_RECV_IO_MAX_SGE;
	qp_cap.max_inline_data = 0;
	qp_cap.max_rdma_ctxs = sc->rw_io.credits.max;

	/*
	 * Find out the number of max_send_wr
	 * after rdma_rw_init_qp() adjusted it.
	 *
	 * We only do it on a temporary variable,
	 * as rdma_create_qp() will trigger
	 * rdma_rw_init_qp() again.
	 */
	memset(&qp_attr, 0, sizeof(qp_attr));
	qp_attr.cap = qp_cap;
	qp_attr.port_num = sc->rdma.cm_id->port_num;
	rdma_send_wr = smb_direct_rdma_rw_send_wrs(sc->ib.dev, &qp_attr);
	max_send_wr = qp_cap.max_send_wr + rdma_send_wr;

	if (qp_cap.max_send_wr > sc->ib.dev->attrs.max_cqe ||
	    qp_cap.max_send_wr > sc->ib.dev->attrs.max_qp_wr) {
		pr_err("Possible CQE overrun: max_send_wr %d\n",
		       qp_cap.max_send_wr);
		pr_err("device %.*s reporting max_cqe %d max_qp_wr %d\n",
		       IB_DEVICE_NAME_MAX,
		       sc->ib.dev->name,
		       sc->ib.dev->attrs.max_cqe,
		       sc->ib.dev->attrs.max_qp_wr);
		pr_err("consider lowering send_credit_target = %d\n",
		       sp->send_credit_target);
		return -EINVAL;
	}

	if (qp_cap.max_rdma_ctxs &&
	    (max_send_wr >= sc->ib.dev->attrs.max_cqe ||
	     max_send_wr >= sc->ib.dev->attrs.max_qp_wr)) {
		pr_err("Possible CQE overrun: rdma_send_wr %d + max_send_wr %d = %d\n",
		       rdma_send_wr, qp_cap.max_send_wr, max_send_wr);
		pr_err("device %.*s reporting max_cqe %d max_qp_wr %d\n",
		       IB_DEVICE_NAME_MAX,
		       sc->ib.dev->name,
		       sc->ib.dev->attrs.max_cqe,
		       sc->ib.dev->attrs.max_qp_wr);
		pr_err("consider lowering send_credit_target = %d, max_rdma_ctxs = %d\n",
		       sp->send_credit_target, qp_cap.max_rdma_ctxs);
		return -EINVAL;
	}

	if (qp_cap.max_recv_wr > sc->ib.dev->attrs.max_cqe ||
	    qp_cap.max_recv_wr > sc->ib.dev->attrs.max_qp_wr) {
		pr_err("Possible CQE overrun: max_recv_wr %d\n",
		       qp_cap.max_recv_wr);
		pr_err("device %.*s reporting max_cqe %d max_qp_wr %d\n",
		       IB_DEVICE_NAME_MAX,
		       sc->ib.dev->name,
		       sc->ib.dev->attrs.max_cqe,
		       sc->ib.dev->attrs.max_qp_wr);
		pr_err("consider lowering receive_credit_max = %d\n",
		       sp->recv_credit_max);
		return -EINVAL;
	}

	if (qp_cap.max_send_sge > sc->ib.dev->attrs.max_send_sge ||
	    qp_cap.max_recv_sge > sc->ib.dev->attrs.max_recv_sge) {
		pr_err("device %.*s max_send_sge/max_recv_sge = %d/%d too small\n",
		       IB_DEVICE_NAME_MAX,
		       sc->ib.dev->name,
		       sc->ib.dev->attrs.max_send_sge,
		       sc->ib.dev->attrs.max_recv_sge);
		return -EINVAL;
	}

	sc->ib.pd = ib_alloc_pd(sc->ib.dev, 0);
	if (IS_ERR(sc->ib.pd)) {
@@ -2046,8 +2146,7 @@ static int smb_direct_create_qpair(struct smbdirect_socket *sc,
	}

	sc->ib.send_cq = ib_alloc_cq_any(sc->ib.dev, sc,
					 sp->send_credit_target +
					 cap->max_rdma_ctxs,
					 max_send_wr,
					 IB_POLL_WORKQUEUE);
	if (IS_ERR(sc->ib.send_cq)) {
		pr_err("Can't create RDMA send CQ\n");
@@ -2057,7 +2156,7 @@ static int smb_direct_create_qpair(struct smbdirect_socket *sc,
	}

	sc->ib.recv_cq = ib_alloc_cq_any(sc->ib.dev, sc,
					 sp->recv_credit_max,
					 qp_cap.max_recv_wr,
					 IB_POLL_WORKQUEUE);
	if (IS_ERR(sc->ib.recv_cq)) {
		pr_err("Can't create RDMA recv CQ\n");
@@ -2066,10 +2165,18 @@ static int smb_direct_create_qpair(struct smbdirect_socket *sc,
		goto err;
	}

	/*
	 * We reset completely here!
	 * As the above use was just temporary
	 * to calc max_send_wr and rdma_send_wr.
	 *
	 * rdma_create_qp() will trigger rdma_rw_init_qp()
	 * again if max_rdma_ctxs is not 0.
	 */
	memset(&qp_attr, 0, sizeof(qp_attr));
	qp_attr.event_handler = smb_direct_qpair_handler;
	qp_attr.qp_context = sc;
	qp_attr.cap = *cap;
	qp_attr.cap = qp_cap;
	qp_attr.sq_sig_type = IB_SIGNAL_REQ_WR;
	qp_attr.qp_type = IB_QPT_RC;
	qp_attr.send_cq = sc->ib.send_cq;
@@ -2085,18 +2192,6 @@ static int smb_direct_create_qpair(struct smbdirect_socket *sc,
	sc->ib.qp = sc->rdma.cm_id->qp;
	sc->rdma.cm_id->event_handler = smb_direct_cm_handler;

	pages_per_rw = DIV_ROUND_UP(sp->max_read_write_size, PAGE_SIZE) + 1;
	if (pages_per_rw > sc->ib.dev->attrs.max_sgl_rd) {
		ret = ib_mr_pool_init(sc->ib.qp, &sc->ib.qp->rdma_mrs,
				      sc->rw_io.credits.max, IB_MR_TYPE_MEM_REG,
				      sc->rw_io.credits.num_pages, 0);
		if (ret) {
			pr_err("failed to init mr pool count %zu pages %zu\n",
			       sc->rw_io.credits.max, sc->rw_io.credits.num_pages);
			goto err;
		}
	}

	return 0;
err:
	if (sc->ib.qp) {
@@ -2183,10 +2278,9 @@ static int smb_direct_prepare(struct ksmbd_transport *t)

static int smb_direct_connect(struct smbdirect_socket *sc)
{
	struct ib_qp_cap qp_cap;
	int ret;

	ret = smb_direct_init_params(sc, &qp_cap);
	ret = smb_direct_init_params(sc);
	if (ret) {
		pr_err("Can't configure RDMA parameters\n");
		return ret;
@@ -2198,7 +2292,7 @@ static int smb_direct_connect(struct smbdirect_socket *sc)
		return ret;
	}

	ret = smb_direct_create_qpair(sc, &qp_cap);
	ret = smb_direct_create_qpair(sc);
	if (ret) {
		pr_err("Can't accept RDMA client: %d\n", ret);
		return ret;