Commit 55e8757c authored by Paolo Abeni's avatar Paolo Abeni
Browse files

Merge branch 'net-mctp-improved-bind-handling'



Matt Johnston says:

====================
net: mctp: Improved bind handling

This series improves a couple of aspects of MCTP bind() handling.

MCTP wasn't checking whether the same MCTP type was bound by multiple
sockets. That would result in messages being received by an arbitrary
socket, which isn't useful behaviour. Instead it makes more sense to
have the duplicate binds fail, the same as other network protocols.
An exception is made for more-specific binds to particular MCTP
addresses.

It is also useful to be able to limit a bind to only receive incoming
request messages (MCTP TO bit set) from a specific peer+type, so that
individual processes can communicate with separate MCTP peers. One
example is a PLDM firmware update requester, which will initiate
communication with a device, and then the device will connect back to the
requester process.

These limited binds are implemented by a connect() call on the socket
prior to bind. connect() isn't used in the general case for MCTP, since
a plain send() wouldn't provide the required MCTP tag argument for
addressing.

Signed-off-by: default avatarMatt Johnston <matt@codeconstruct.com.au>
====================

Link: https://patch.msgid.link/20250710-mctp-bind-v4-0-8ec2f6460c56@codeconstruct.com.au


Signed-off-by: default avatarPaolo Abeni <pabeni@redhat.com>
parents a8594c95 e6d8e7db
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -69,7 +69,10 @@ struct mctp_sock {

	/* bind() params */
	unsigned int	bind_net;
	mctp_eid_t	bind_addr;
	mctp_eid_t	bind_local_addr;
	mctp_eid_t	bind_peer_addr;
	unsigned int	bind_peer_net;
	bool		bind_peer_set;
	__u8		bind_type;

	/* sendmsg()/recvmsg() uses struct sockaddr_mctp_ext */
+16 −4
Original line number Diff line number Diff line
@@ -6,19 +6,25 @@
#ifndef __NETNS_MCTP_H__
#define __NETNS_MCTP_H__

#include <linux/hash.h>
#include <linux/hashtable.h>
#include <linux/mutex.h>
#include <linux/types.h>

#define MCTP_BINDS_BITS 7

struct netns_mctp {
	/* Only updated under RTNL, entries freed via RCU */
	struct list_head routes;

	/* Bound sockets: list of sockets bound by type.
	 * This list is updated from non-atomic contexts (under bind_lock),
	 * and read (under rcu) in packet rx
	/* Bound sockets: hash table of sockets, keyed by
	 * (type, src_eid, dest_eid).
	 * Specific src_eid/dest_eid entries also have an entry for
	 * MCTP_ADDR_ANY. This list is updated from non-atomic contexts
	 * (under bind_lock), and read (under rcu) in packet rx.
	 */
	struct mutex bind_lock;
	struct hlist_head binds;
	DECLARE_HASHTABLE(binds, MCTP_BINDS_BITS);

	/* tag allocations. This list is read and updated from atomic contexts,
	 * but elements are free()ed after a RCU grace-period
@@ -34,4 +40,10 @@ struct netns_mctp {
	struct list_head neighbours;
};

static inline u32 mctp_bind_hash(u8 type, u8 local_addr, u8 peer_addr)
{
	return hash_32(type | (u32)local_addr << 8 | (u32)peer_addr << 16,
		       MCTP_BINDS_BITS);
}

#endif /* __NETNS_MCTP_H__ */
+139 −9
Original line number Diff line number Diff line
@@ -53,6 +53,7 @@ static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen)
{
	struct sock *sk = sock->sk;
	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
	struct net *net = sock_net(&msk->sk);
	struct sockaddr_mctp *smctp;
	int rc;

@@ -73,14 +74,48 @@ static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen)

	lock_sock(sk);

	/* TODO: allow rebind */
	if (sk_hashed(sk)) {
		rc = -EADDRINUSE;
		goto out_release;
	}

	msk->bind_local_addr = smctp->smctp_addr.s_addr;

	/* MCTP_NET_ANY with a specific EID is resolved to the default net
	 * at bind() time.
	 * For bind_addr=MCTP_ADDR_ANY it is handled specially at route
	 * lookup time.
	 */
	if (smctp->smctp_network == MCTP_NET_ANY &&
	    msk->bind_local_addr != MCTP_ADDR_ANY) {
		msk->bind_net = mctp_default_net(net);
	} else {
		msk->bind_net = smctp->smctp_network;
	msk->bind_addr = smctp->smctp_addr.s_addr;
	msk->bind_type = smctp->smctp_type & 0x7f; /* ignore the IC bit */
	}

	/* ignore the IC bit */
	smctp->smctp_type &= 0x7f;

	if (msk->bind_peer_set) {
		if (msk->bind_type != smctp->smctp_type) {
			/* Prior connect() had a different type */
			rc = -EINVAL;
			goto out_release;
		}

		if (msk->bind_net == MCTP_NET_ANY) {
			/* Restrict to the network passed to connect() */
			msk->bind_net = msk->bind_peer_net;
		}

		if (msk->bind_net != msk->bind_peer_net) {
			/* connect() had a different net to bind() */
			rc = -EINVAL;
			goto out_release;
		}
	} else {
		msk->bind_type = smctp->smctp_type;
	}

	rc = sk->sk_prot->hash(sk);

@@ -90,6 +125,67 @@ static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen)
	return rc;
}

/* Used to set a specific peer prior to bind. Not used for outbound
 * connections (Tag Owner set) since MCTP is a datagram protocol.
 */
static int mctp_connect(struct socket *sock, struct sockaddr *addr,
			int addrlen, int flags)
{
	struct sock *sk = sock->sk;
	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
	struct net *net = sock_net(&msk->sk);
	struct sockaddr_mctp *smctp;
	int rc;

	if (addrlen != sizeof(*smctp))
		return -EINVAL;

	if (addr->sa_family != AF_MCTP)
		return -EAFNOSUPPORT;

	/* It's a valid sockaddr for MCTP, cast and do protocol checks */
	smctp = (struct sockaddr_mctp *)addr;

	if (!mctp_sockaddr_is_ok(smctp))
		return -EINVAL;

	/* Can't bind by tag */
	if (smctp->smctp_tag)
		return -EINVAL;

	/* IC bit must be unset */
	if (smctp->smctp_type & 0x80)
		return -EINVAL;

	lock_sock(sk);

	if (sk_hashed(sk)) {
		/* bind() already */
		rc = -EADDRINUSE;
		goto out_release;
	}

	if (msk->bind_peer_set) {
		/* connect() already */
		rc = -EADDRINUSE;
		goto out_release;
	}

	msk->bind_peer_set = true;
	msk->bind_peer_addr = smctp->smctp_addr.s_addr;
	msk->bind_type = smctp->smctp_type;
	if (smctp->smctp_network == MCTP_NET_ANY)
		msk->bind_peer_net = mctp_default_net(net);
	else
		msk->bind_peer_net = smctp->smctp_network;

	rc = 0;

out_release:
	release_sock(sk);
	return rc;
}

static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
{
	DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
@@ -533,7 +629,7 @@ static const struct proto_ops mctp_dgram_ops = {
	.family		= PF_MCTP,
	.release	= mctp_release,
	.bind		= mctp_bind,
	.connect	= sock_no_connect,
	.connect	= mctp_connect,
	.socketpair	= sock_no_socketpair,
	.accept		= sock_no_accept,
	.getname	= sock_no_getname,
@@ -600,6 +696,7 @@ static int mctp_sk_init(struct sock *sk)

	INIT_HLIST_HEAD(&msk->keys);
	timer_setup(&msk->key_expiry, mctp_sk_expire_keys, 0);
	msk->bind_peer_set = false;
	return 0;
}

@@ -611,15 +708,48 @@ static void mctp_sk_close(struct sock *sk, long timeout)
static int mctp_sk_hash(struct sock *sk)
{
	struct net *net = sock_net(sk);
	struct sock *existing;
	struct mctp_sock *msk;
	mctp_eid_t remote;
	u32 hash;
	int rc;

	msk = container_of(sk, struct mctp_sock, sk);

	if (msk->bind_peer_set)
		remote = msk->bind_peer_addr;
	else
		remote = MCTP_ADDR_ANY;
	hash = mctp_bind_hash(msk->bind_type, msk->bind_local_addr, remote);

	mutex_lock(&net->mctp.bind_lock);

	/* Prevent duplicate binds. */
	sk_for_each(existing, &net->mctp.binds[hash]) {
		struct mctp_sock *mex =
			container_of(existing, struct mctp_sock, sk);

		bool same_peer = (mex->bind_peer_set && msk->bind_peer_set &&
				  mex->bind_peer_addr == msk->bind_peer_addr) ||
				 (!mex->bind_peer_set && !msk->bind_peer_set);

		if (mex->bind_type == msk->bind_type &&
		    mex->bind_local_addr == msk->bind_local_addr && same_peer &&
		    mex->bind_net == msk->bind_net) {
			rc = -EADDRINUSE;
			goto out;
		}
	}

	/* Bind lookup runs under RCU, remain live during that. */
	sock_set_flag(sk, SOCK_RCU_FREE);

	mutex_lock(&net->mctp.bind_lock);
	sk_add_node_rcu(sk, &net->mctp.binds);
	mutex_unlock(&net->mctp.bind_lock);
	sk_add_node_rcu(sk, &net->mctp.binds[hash]);
	rc = 0;

	return 0;
out:
	mutex_unlock(&net->mctp.bind_lock);
	return rc;
}

static void mctp_sk_unhash(struct sock *sk)
+65 −14
Original line number Diff line number Diff line
@@ -40,33 +40,36 @@ static int mctp_dst_discard(struct mctp_dst *dst, struct sk_buff *skb)
	return 0;
}

static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
static struct mctp_sock *mctp_lookup_bind_details(struct net *net,
						  struct sk_buff *skb,
						  u8 type, u8 dest,
						  u8 src, bool allow_net_any)
{
	struct mctp_skb_cb *cb = mctp_cb(skb);
	struct mctp_hdr *mh;
	struct sock *sk;
	u8 type;

	WARN_ON(!rcu_read_lock_held());

	/* TODO: look up in skb->cb? */
	mh = mctp_hdr(skb);
	u8 hash;

	if (!skb_headlen(skb))
		return NULL;
	WARN_ON_ONCE(!rcu_read_lock_held());

	type = (*(u8 *)skb->data) & 0x7f;
	hash = mctp_bind_hash(type, dest, src);

	sk_for_each_rcu(sk, &net->mctp.binds) {
	sk_for_each_rcu(sk, &net->mctp.binds[hash]) {
		struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);

		if (!allow_net_any && msk->bind_net == MCTP_NET_ANY)
			continue;

		if (msk->bind_net != MCTP_NET_ANY && msk->bind_net != cb->net)
			continue;

		if (msk->bind_type != type)
			continue;

		if (!mctp_address_matches(msk->bind_addr, mh->dest))
		if (msk->bind_peer_set &&
		    !mctp_address_matches(msk->bind_peer_addr, src))
			continue;

		if (!mctp_address_matches(msk->bind_local_addr, dest))
			continue;

		return msk;
@@ -75,6 +78,54 @@ static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
	return NULL;
}

static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
{
	struct mctp_sock *msk;
	struct mctp_hdr *mh;
	u8 type;

	/* TODO: look up in skb->cb? */
	mh = mctp_hdr(skb);

	if (!skb_headlen(skb))
		return NULL;

	type = (*(u8 *)skb->data) & 0x7f;

	/* Look for binds in order of widening scope. A given destination or
	 * source address also implies matching on a particular network.
	 *
	 * - Matching destination and source
	 * - Matching destination
	 * - Matching source
	 * - Matching network, any address
	 * - Any network or address
	 */

	msk = mctp_lookup_bind_details(net, skb, type, mh->dest, mh->src,
				       false);
	if (msk)
		return msk;
	msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY, mh->src,
				       false);
	if (msk)
		return msk;
	msk = mctp_lookup_bind_details(net, skb, type, mh->dest, MCTP_ADDR_ANY,
				       false);
	if (msk)
		return msk;
	msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY,
				       MCTP_ADDR_ANY, false);
	if (msk)
		return msk;
	msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY,
				       MCTP_ADDR_ANY, true);
	if (msk)
		return msk;

	return NULL;
}

/* A note on the key allocations.
 *
 * struct net->mctp.keys contains our set of currently-allocated keys for
@@ -1671,7 +1722,7 @@ static int __net_init mctp_routes_net_init(struct net *net)
	struct netns_mctp *ns = &net->mctp;

	INIT_LIST_HEAD(&ns->routes);
	INIT_HLIST_HEAD(&ns->binds);
	hash_init(ns->binds);
	mutex_init(&ns->bind_lock);
	INIT_HLIST_HEAD(&ns->keys);
	spin_lock_init(&ns->keys_lock);
+190 −4
Original line number Diff line number Diff line
@@ -1164,8 +1164,6 @@ static void mctp_test_route_extaddr_input(struct kunit *test)
	rc = mctp_dst_input(&dst, skb);
	KUNIT_ASSERT_EQ(test, rc, 0);

	mctp_test_dst_release(&dst, &tpq);

	skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc);
	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb2);
	KUNIT_ASSERT_EQ(test, skb2->len, len);
@@ -1179,8 +1177,8 @@ static void mctp_test_route_extaddr_input(struct kunit *test)
	KUNIT_EXPECT_EQ(test, cb2->halen, sizeof(haddr));
	KUNIT_EXPECT_MEMEQ(test, cb2->haddr, haddr, sizeof(haddr));

	skb_free_datagram(sock->sk, skb2);
	mctp_test_destroy_dev(dev);
	kfree_skb(skb2);
	__mctp_route_test_fini(test, dev, &dst, &tpq, sock);
}

static void mctp_test_route_gw_lookup(struct kunit *test)
@@ -1410,6 +1408,193 @@ static void mctp_test_route_gw_output(struct kunit *test)
	kfree_skb(skb);
}

struct mctp_bind_lookup_test {
	/* header of incoming message */
	struct mctp_hdr hdr;
	u8 ty;
	/* mctp network of incoming interface (smctp_network) */
	unsigned int net;

	/* expected socket, matches .name in lookup_binds, NULL for dropped */
	const char *expect;
};

/* Single-packet TO-set message */
#define LK(src, dst) RX_HDR(1, (src), (dst), FL_S | FL_E | FL_TO)

/* Input message test cases for bind lookup tests.
 *
 * 10 and 11 are local EIDs.
 * 20 and 21 are remote EIDs.
 */
static const struct mctp_bind_lookup_test mctp_bind_lookup_tests[] = {
	/* both local-eid and remote-eid binds, remote eid is preferenced */
	{ .hdr = LK(20, 10),  .ty = 1, .net = 1, .expect = "remote20" },

	{ .hdr = LK(20, 255), .ty = 1, .net = 1, .expect = "remote20" },
	{ .hdr = LK(20, 0),   .ty = 1, .net = 1, .expect = "remote20" },
	{ .hdr = LK(0, 255),  .ty = 1, .net = 1, .expect = "any" },
	{ .hdr = LK(0, 11),   .ty = 1, .net = 1, .expect = "any" },
	{ .hdr = LK(0, 0),    .ty = 1, .net = 1, .expect = "any" },
	{ .hdr = LK(0, 10),   .ty = 1, .net = 1, .expect = "local10" },
	{ .hdr = LK(21, 10),  .ty = 1, .net = 1, .expect = "local10" },
	{ .hdr = LK(21, 11),  .ty = 1, .net = 1, .expect = "remote21local11" },

	/* both src and dest set to eid=99. unusual, but accepted
	 * by MCTP stack currently.
	 */
	{ .hdr = LK(99, 99),  .ty = 1, .net = 1, .expect = "any" },

	/* unbound smctp_type */
	{ .hdr = LK(20, 10),  .ty = 3, .net = 1, .expect = NULL },

	/* smctp_network tests */

	{ .hdr = LK(0, 0),    .ty = 1, .net = 7, .expect = "any" },
	{ .hdr = LK(21, 10),  .ty = 1, .net = 2, .expect = "any" },

	/* remote EID 20 matches, but MCTP_NET_ANY in "remote20" resolved
	 * to net=1, so lookup doesn't match "remote20"
	 */
	{ .hdr = LK(20, 10),  .ty = 1, .net = 3, .expect = "any" },

	{ .hdr = LK(21, 10),  .ty = 1, .net = 3, .expect = "remote21net3" },
	{ .hdr = LK(21, 10),  .ty = 1, .net = 4, .expect = "remote21net4" },
	{ .hdr = LK(21, 10),  .ty = 1, .net = 5, .expect = "remote21net5" },

	{ .hdr = LK(21, 10),  .ty = 1, .net = 5, .expect = "remote21net5" },

	{ .hdr = LK(99, 10),  .ty = 1, .net = 8, .expect = "local10net8" },

	{ .hdr = LK(99, 10),  .ty = 1, .net = 9, .expect = "anynet9" },
	{ .hdr = LK(0, 0),    .ty = 1, .net = 9, .expect = "anynet9" },
	{ .hdr = LK(99, 99),  .ty = 1, .net = 9, .expect = "anynet9" },
	{ .hdr = LK(20, 10),  .ty = 1, .net = 9, .expect = "anynet9" },
};

/* Binds to create during the lookup tests */
static const struct mctp_test_bind_setup lookup_binds[] = {
	/* any address and net, type 1 */
	{ .name = "any", .bind_addr = MCTP_ADDR_ANY,
		.bind_net = MCTP_NET_ANY, .bind_type = 1, },
	/* local eid 10, net 1 (resolved from MCTP_NET_ANY) */
	{ .name = "local10", .bind_addr = 10,
		.bind_net = MCTP_NET_ANY, .bind_type = 1, },
	/* local eid 10, net 8 */
	{ .name = "local10net8", .bind_addr = 10,
		.bind_net = 8, .bind_type = 1, },
	/* any EID, net 9 */
	{ .name = "anynet9", .bind_addr = MCTP_ADDR_ANY,
		.bind_net = 9, .bind_type = 1, },

	/* remote eid 20, net 1, any local eid */
	{ .name = "remote20", .bind_addr = MCTP_ADDR_ANY,
		.bind_net = MCTP_NET_ANY, .bind_type = 1,
		.have_peer = true, .peer_addr = 20, .peer_net = MCTP_NET_ANY, },

	/* remote eid 20, net 1, local eid 11 */
	{ .name = "remote21local11", .bind_addr = 11,
		.bind_net = MCTP_NET_ANY, .bind_type = 1,
		.have_peer = true, .peer_addr = 21, .peer_net = MCTP_NET_ANY, },

	/* remote eid 21, specific net=3 for connect() */
	{ .name = "remote21net3", .bind_addr = MCTP_ADDR_ANY,
		.bind_net = MCTP_NET_ANY, .bind_type = 1,
		.have_peer = true, .peer_addr = 21, .peer_net = 3, },

	/* remote eid 21, net 4 for bind, specific net=4 for connect() */
	{ .name = "remote21net4", .bind_addr = MCTP_ADDR_ANY,
		.bind_net = 4, .bind_type = 1,
		.have_peer = true, .peer_addr = 21, .peer_net = 4, },

	/* remote eid 21, net 5 for bind, specific net=5 for connect() */
	{ .name = "remote21net5", .bind_addr = MCTP_ADDR_ANY,
		.bind_net = 5, .bind_type = 1,
		.have_peer = true, .peer_addr = 21, .peer_net = 5, },
};

static void mctp_bind_lookup_desc(const struct mctp_bind_lookup_test *t,
				  char *desc)
{
	snprintf(desc, KUNIT_PARAM_DESC_SIZE,
		 "{src %d dst %d ty %d net %d expect %s}",
		 t->hdr.src, t->hdr.dest, t->ty, t->net, t->expect);
}

KUNIT_ARRAY_PARAM(mctp_bind_lookup, mctp_bind_lookup_tests,
		  mctp_bind_lookup_desc);

static void mctp_test_bind_lookup(struct kunit *test)
{
	const struct mctp_bind_lookup_test *rx;
	struct socket *socks[ARRAY_SIZE(lookup_binds)];
	struct sk_buff *skb_pkt = NULL, *skb_sock = NULL;
	struct socket *sock_ty0, *sock_expect = NULL;
	struct mctp_test_pktqueue tpq;
	struct mctp_test_dev *dev;
	struct mctp_dst dst;
	int rc;

	rx = test->param_value;

	__mctp_route_test_init(test, &dev, &dst, &tpq, &sock_ty0, rx->net);
	/* Create all binds */
	for (size_t i = 0; i < ARRAY_SIZE(lookup_binds); i++) {
		mctp_test_bind_run(test, &lookup_binds[i],
				   &rc, &socks[i]);
		KUNIT_ASSERT_EQ(test, rc, 0);

		/* Record the expected receive socket */
		if (rx->expect &&
		    strcmp(rx->expect, lookup_binds[i].name) == 0) {
			KUNIT_ASSERT_NULL(test, sock_expect);
			sock_expect = socks[i];
		}
	}
	KUNIT_ASSERT_EQ(test, !!sock_expect, !!rx->expect);

	/* Create test message */
	skb_pkt = mctp_test_create_skb_data(&rx->hdr, &rx->ty);
	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb_pkt);
	mctp_test_skb_set_dev(skb_pkt, dev);
	mctp_test_pktqueue_init(&tpq);

	rc = mctp_dst_input(&dst, skb_pkt);
	if (rx->expect) {
		/* Test the message is received on the expected socket */
		KUNIT_EXPECT_EQ(test, rc, 0);
		skb_sock = skb_recv_datagram(sock_expect->sk,
					     MSG_DONTWAIT, &rc);
		if (!skb_sock) {
			/* Find which socket received it instead */
			for (size_t i = 0; i < ARRAY_SIZE(lookup_binds); i++) {
				skb_sock = skb_recv_datagram(socks[i]->sk,
							     MSG_DONTWAIT, &rc);
				if (skb_sock) {
					KUNIT_FAIL(test,
						   "received on incorrect socket '%s', expect '%s'",
						   lookup_binds[i].name,
						   rx->expect);
					goto cleanup;
				}
			}
			KUNIT_FAIL(test, "no message received");
		}
	} else {
		KUNIT_EXPECT_NE(test, rc, 0);
	}

cleanup:
	kfree_skb(skb_sock);
	kfree_skb(skb_pkt);

	/* Drop all binds */
	for (size_t i = 0; i < ARRAY_SIZE(lookup_binds); i++)
		sock_release(socks[i]);

	__mctp_route_test_fini(test, dev, &dst, &tpq, sock_ty0);
}

static struct kunit_case mctp_test_cases[] = {
	KUNIT_CASE_PARAM(mctp_test_fragment, mctp_frag_gen_params),
	KUNIT_CASE_PARAM(mctp_test_rx_input, mctp_rx_input_gen_params),
@@ -1431,6 +1616,7 @@ static struct kunit_case mctp_test_cases[] = {
	KUNIT_CASE(mctp_test_route_gw_loop),
	KUNIT_CASE_PARAM(mctp_test_route_gw_mtu, mctp_route_gw_mtu_gen_params),
	KUNIT_CASE(mctp_test_route_gw_output),
	KUNIT_CASE_PARAM(mctp_test_bind_lookup, mctp_bind_lookup_gen_params),
	{}
};

Loading