Commit 471025c9 authored by Linus Torvalds's avatar Linus Torvalds
Browse files

Merge tag 'v6.17rc-part2-ksmbd-server-fixes' of git://git.samba.org/ksmbd

Pull smb server fixes from Steve French:

 - Fix limiting repeated connections from same IP

 - Fix for extracting shortname when name begins with a dot

 - Four smbdirect fixes:
     - three fixes to the receive path: potential unmap bug, potential
       resource leaks and stale connections, and also potential use
       after free race
     - cleanup to remove unneeded queue

* tag 'v6.17rc-part2-ksmbd-server-fixes' of git://git.samba.org/ksmbd:
  smb: server: Fix extension string in ksmbd_extract_shortname()
  ksmbd: limit repeated connections from clients with the same IP
  smb: server: let recv_done() avoid touching data_transfer after cleanup/move
  smb: server: let recv_done() consistently call put_recvmsg/smb_direct_disconnect_rdma_connection
  smb: server: make sure we call ib_dma_unmap_single() only if we called ib_dma_map_single already
  smb: server: remove separate empty_recvmsg_queue
parents 37816488 8e7d178d
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -46,6 +46,7 @@ struct ksmbd_conn {
	struct mutex			srv_mutex;
	int				status;
	unsigned int			cli_cap;
	__be32				inet_addr;
	char				*request_buf;
	struct ksmbd_transport		*transport;
	struct nls_table		*local_nls;
+1 −1
Original line number Diff line number Diff line
@@ -515,7 +515,7 @@ int ksmbd_extract_shortname(struct ksmbd_conn *conn, const char *longname,

	p = strrchr(longname, '.');
	if (p == longname) { /*name starts with a dot*/
		strscpy(extension, "___", strlen("___"));
		strscpy(extension, "___", sizeof(extension));
	} else {
		if (p) {
			p++;
+35 −62
Original line number Diff line number Diff line
@@ -129,9 +129,6 @@ struct smb_direct_transport {
	spinlock_t		recvmsg_queue_lock;
	struct list_head	recvmsg_queue;

	spinlock_t		empty_recvmsg_queue_lock;
	struct list_head	empty_recvmsg_queue;

	int			send_credit_target;
	atomic_t		send_credits;
	spinlock_t		lock_new_recv_credits;
@@ -268,40 +265,19 @@ smb_direct_recvmsg *get_free_recvmsg(struct smb_direct_transport *t)
static void put_recvmsg(struct smb_direct_transport *t,
			struct smb_direct_recvmsg *recvmsg)
{
	ib_dma_unmap_single(t->cm_id->device, recvmsg->sge.addr,
			    recvmsg->sge.length, DMA_FROM_DEVICE);
	if (likely(recvmsg->sge.length != 0)) {
		ib_dma_unmap_single(t->cm_id->device,
				    recvmsg->sge.addr,
				    recvmsg->sge.length,
				    DMA_FROM_DEVICE);
		recvmsg->sge.length = 0;
	}

	spin_lock(&t->recvmsg_queue_lock);
	list_add(&recvmsg->list, &t->recvmsg_queue);
	spin_unlock(&t->recvmsg_queue_lock);
}

static struct
smb_direct_recvmsg *get_empty_recvmsg(struct smb_direct_transport *t)
{
	struct smb_direct_recvmsg *recvmsg = NULL;

	spin_lock(&t->empty_recvmsg_queue_lock);
	if (!list_empty(&t->empty_recvmsg_queue)) {
		recvmsg = list_first_entry(&t->empty_recvmsg_queue,
					   struct smb_direct_recvmsg, list);
		list_del(&recvmsg->list);
	}
	spin_unlock(&t->empty_recvmsg_queue_lock);
	return recvmsg;
}

static void put_empty_recvmsg(struct smb_direct_transport *t,
			      struct smb_direct_recvmsg *recvmsg)
{
	ib_dma_unmap_single(t->cm_id->device, recvmsg->sge.addr,
			    recvmsg->sge.length, DMA_FROM_DEVICE);

	spin_lock(&t->empty_recvmsg_queue_lock);
	list_add_tail(&recvmsg->list, &t->empty_recvmsg_queue);
	spin_unlock(&t->empty_recvmsg_queue_lock);
}

static void enqueue_reassembly(struct smb_direct_transport *t,
			       struct smb_direct_recvmsg *recvmsg,
			       int data_length)
@@ -386,9 +362,6 @@ static struct smb_direct_transport *alloc_transport(struct rdma_cm_id *cm_id)
	spin_lock_init(&t->recvmsg_queue_lock);
	INIT_LIST_HEAD(&t->recvmsg_queue);

	spin_lock_init(&t->empty_recvmsg_queue_lock);
	INIT_LIST_HEAD(&t->empty_recvmsg_queue);

	init_waitqueue_head(&t->wait_send_pending);
	atomic_set(&t->send_pending, 0);

@@ -548,13 +521,13 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
	t = recvmsg->transport;

	if (wc->status != IB_WC_SUCCESS || wc->opcode != IB_WC_RECV) {
		put_recvmsg(t, recvmsg);
		if (wc->status != IB_WC_WR_FLUSH_ERR) {
			pr_err("Recv error. status='%s (%d)' opcode=%d\n",
			       ib_wc_status_msg(wc->status), wc->status,
			       wc->opcode);
			smb_direct_disconnect_rdma_connection(t);
		}
		put_empty_recvmsg(t, recvmsg);
		return;
	}

@@ -568,7 +541,8 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
	switch (recvmsg->type) {
	case SMB_DIRECT_MSG_NEGOTIATE_REQ:
		if (wc->byte_len < sizeof(struct smb_direct_negotiate_req)) {
			put_empty_recvmsg(t, recvmsg);
			put_recvmsg(t, recvmsg);
			smb_direct_disconnect_rdma_connection(t);
			return;
		}
		t->negotiation_requested = true;
@@ -576,7 +550,7 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
		t->status = SMB_DIRECT_CS_CONNECTED;
		enqueue_reassembly(t, recvmsg, 0);
		wake_up_interruptible(&t->wait_status);
		break;
		return;
	case SMB_DIRECT_MSG_DATA_TRANSFER: {
		struct smb_direct_data_transfer *data_transfer =
			(struct smb_direct_data_transfer *)recvmsg->packet;
@@ -585,7 +559,8 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc)

		if (wc->byte_len <
		    offsetof(struct smb_direct_data_transfer, padding)) {
			put_empty_recvmsg(t, recvmsg);
			put_recvmsg(t, recvmsg);
			smb_direct_disconnect_rdma_connection(t);
			return;
		}

@@ -593,7 +568,8 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
		if (data_length) {
			if (wc->byte_len < sizeof(struct smb_direct_data_transfer) +
			    (u64)data_length) {
				put_empty_recvmsg(t, recvmsg);
				put_recvmsg(t, recvmsg);
				smb_direct_disconnect_rdma_connection(t);
				return;
			}

@@ -605,16 +581,11 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
			else
				t->full_packet_received = true;

			enqueue_reassembly(t, recvmsg, (int)data_length);
			wake_up_interruptible(&t->wait_reassembly_queue);

			spin_lock(&t->receive_credit_lock);
			receive_credits = --(t->recv_credits);
			avail_recvmsg_count = t->count_avail_recvmsg;
			spin_unlock(&t->receive_credit_lock);
		} else {
			put_empty_recvmsg(t, recvmsg);

			spin_lock(&t->receive_credit_lock);
			receive_credits = --(t->recv_credits);
			avail_recvmsg_count = ++(t->count_avail_recvmsg);
@@ -636,11 +607,23 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
		if (is_receive_credit_post_required(receive_credits, avail_recvmsg_count))
			mod_delayed_work(smb_direct_wq,
					 &t->post_recv_credits_work, 0);
		break;

		if (data_length) {
			enqueue_reassembly(t, recvmsg, (int)data_length);
			wake_up_interruptible(&t->wait_reassembly_queue);
		} else
			put_recvmsg(t, recvmsg);

		return;
	}
	default:
		break;
	}

	/*
	 * This is an internal error!
	 */
	WARN_ON_ONCE(recvmsg->type != SMB_DIRECT_MSG_DATA_TRANSFER);
	put_recvmsg(t, recvmsg);
	smb_direct_disconnect_rdma_connection(t);
}

static int smb_direct_post_recv(struct smb_direct_transport *t,
@@ -670,6 +653,7 @@ static int smb_direct_post_recv(struct smb_direct_transport *t,
		ib_dma_unmap_single(t->cm_id->device,
				    recvmsg->sge.addr, recvmsg->sge.length,
				    DMA_FROM_DEVICE);
		recvmsg->sge.length = 0;
		smb_direct_disconnect_rdma_connection(t);
		return ret;
	}
@@ -811,7 +795,6 @@ static void smb_direct_post_recv_credits(struct work_struct *work)
	struct smb_direct_recvmsg *recvmsg;
	int receive_credits, credits = 0;
	int ret;
	int use_free = 1;

	spin_lock(&t->receive_credit_lock);
	receive_credits = t->recv_credits;
@@ -819,18 +802,9 @@ static void smb_direct_post_recv_credits(struct work_struct *work)

	if (receive_credits < t->recv_credit_target) {
		while (true) {
			if (use_free)
			recvmsg = get_free_recvmsg(t);
			else
				recvmsg = get_empty_recvmsg(t);
			if (!recvmsg) {
				if (use_free) {
					use_free = 0;
					continue;
				} else {
			if (!recvmsg)
				break;
				}
			}

			recvmsg->type = SMB_DIRECT_MSG_DATA_TRANSFER;
			recvmsg->first_segment = false;
@@ -1806,8 +1780,6 @@ static void smb_direct_destroy_pools(struct smb_direct_transport *t)

	while ((recvmsg = get_free_recvmsg(t)))
		mempool_free(recvmsg, t->recvmsg_mempool);
	while ((recvmsg = get_empty_recvmsg(t)))
		mempool_free(recvmsg, t->recvmsg_mempool);

	mempool_destroy(t->recvmsg_mempool);
	t->recvmsg_mempool = NULL;
@@ -1863,6 +1835,7 @@ static int smb_direct_create_pools(struct smb_direct_transport *t)
		if (!recvmsg)
			goto err;
		recvmsg->transport = t;
		recvmsg->sge.length = 0;
		list_add(&recvmsg->list, &t->recvmsg_queue);
	}
	t->count_avail_recvmsg = t->recv_credit_max;
+17 −0
Original line number Diff line number Diff line
@@ -85,6 +85,7 @@ static struct tcp_transport *alloc_transport(struct socket *client_sk)
		return NULL;
	}

	conn->inet_addr = inet_sk(client_sk->sk)->inet_daddr;
	conn->transport = KSMBD_TRANS(t);
	KSMBD_TRANS(t)->conn = conn;
	KSMBD_TRANS(t)->ops = &ksmbd_tcp_transport_ops;
@@ -228,6 +229,8 @@ static int ksmbd_kthread_fn(void *p)
{
	struct socket *client_sk = NULL;
	struct interface *iface = (struct interface *)p;
	struct inet_sock *csk_inet;
	struct ksmbd_conn *conn;
	int ret;

	while (!kthread_should_stop()) {
@@ -246,6 +249,20 @@ static int ksmbd_kthread_fn(void *p)
			continue;
		}

		/*
		 * Limits repeated connections from clients with the same IP.
		 */
		csk_inet = inet_sk(client_sk->sk);
		down_read(&conn_list_lock);
		list_for_each_entry(conn, &conn_list, conns_list)
			if (csk_inet->inet_daddr == conn->inet_addr) {
				ret = -EAGAIN;
				break;
			}
		up_read(&conn_list_lock);
		if (ret == -EAGAIN)
			continue;

		if (server_conf.max_connections &&
		    atomic_inc_return(&active_num_conn) >= server_conf.max_connections) {
			pr_info_ratelimited("Limit the maximum number of connections(%u)\n",