Commit 1aeed732 authored by Matt Johnston's avatar Matt Johnston Committed by Paolo Abeni
Browse files

net: mctp: Use hashtable for binds



Ensure that a specific EID (remote or local) bind will match in
preference to a MCTP_ADDR_ANY bind.

This adds infrastructure for binding a socket to receive messages from a
specific remote peer address, a future commit will expose an API for
this.

Signed-off-by: default avatarMatt Johnston <matt@codeconstruct.com.au>
Link: https://patch.msgid.link/20250710-mctp-bind-v4-5-8ec2f6460c56@codeconstruct.com.au


Signed-off-by: default avatarPaolo Abeni <pabeni@redhat.com>
parent 4ec4b7fc
Loading
Loading
Loading
Loading
+16 −4
Original line number Diff line number Diff line
@@ -6,19 +6,25 @@
#ifndef __NETNS_MCTP_H__
#define __NETNS_MCTP_H__

#include <linux/hash.h>
#include <linux/hashtable.h>
#include <linux/mutex.h>
#include <linux/types.h>

#define MCTP_BINDS_BITS 7

struct netns_mctp {
	/* Only updated under RTNL, entries freed via RCU */
	struct list_head routes;

	/* Bound sockets: list of sockets bound by type.
	 * This list is updated from non-atomic contexts (under bind_lock),
	 * and read (under rcu) in packet rx
	/* Bound sockets: hash table of sockets, keyed by
	 * (type, src_eid, dest_eid).
	 * Specific src_eid/dest_eid entries also have an entry for
	 * MCTP_ADDR_ANY. This list is updated from non-atomic contexts
	 * (under bind_lock), and read (under rcu) in packet rx.
	 */
	struct mutex bind_lock;
	struct hlist_head binds;
	DECLARE_HASHTABLE(binds, MCTP_BINDS_BITS);

	/* tag allocations. This list is read and updated from atomic contexts,
	 * but elements are free()ed after a RCU grace-period
@@ -34,4 +40,10 @@ struct netns_mctp {
	struct list_head neighbours;
};

static inline u32 mctp_bind_hash(u8 type, u8 local_addr, u8 peer_addr)
{
	return hash_32(type | (u32)local_addr << 8 | (u32)peer_addr << 16,
		       MCTP_BINDS_BITS);
}

#endif /* __NETNS_MCTP_H__ */
+7 −4
Original line number Diff line number Diff line
@@ -626,17 +626,17 @@ static int mctp_sk_hash(struct sock *sk)
	struct net *net = sock_net(sk);
	struct sock *existing;
	struct mctp_sock *msk;
	u32 hash;
	int rc;

	msk = container_of(sk, struct mctp_sock, sk);

	/* Bind lookup runs under RCU, remain live during that. */
	sock_set_flag(sk, SOCK_RCU_FREE);
	hash = mctp_bind_hash(msk->bind_type, msk->bind_addr, MCTP_ADDR_ANY);

	mutex_lock(&net->mctp.bind_lock);

	/* Prevent duplicate binds. */
	sk_for_each(existing, &net->mctp.binds) {
	sk_for_each(existing, &net->mctp.binds[hash]) {
		struct mctp_sock *mex =
			container_of(existing, struct mctp_sock, sk);

@@ -648,7 +648,10 @@ static int mctp_sk_hash(struct sock *sk)
		}
	}

	sk_add_node_rcu(sk, &net->mctp.binds);
	/* Bind lookup runs under RCU, remain live during that. */
	sock_set_flag(sk, SOCK_RCU_FREE);

	sk_add_node_rcu(sk, &net->mctp.binds[hash]);
	rc = 0;

out:
+61 −14
Original line number Diff line number Diff line
@@ -40,33 +40,32 @@ static int mctp_dst_discard(struct mctp_dst *dst, struct sk_buff *skb)
	return 0;
}

static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
static struct mctp_sock *mctp_lookup_bind_details(struct net *net,
						  struct sk_buff *skb,
						  u8 type, u8 dest,
						  u8 src, bool allow_net_any)
{
	struct mctp_skb_cb *cb = mctp_cb(skb);
	struct mctp_hdr *mh;
	struct sock *sk;
	u8 type;

	WARN_ON(!rcu_read_lock_held());

	/* TODO: look up in skb->cb? */
	mh = mctp_hdr(skb);
	u8 hash;

	if (!skb_headlen(skb))
		return NULL;
	WARN_ON_ONCE(!rcu_read_lock_held());

	type = (*(u8 *)skb->data) & 0x7f;
	hash = mctp_bind_hash(type, dest, src);

	sk_for_each_rcu(sk, &net->mctp.binds) {
	sk_for_each_rcu(sk, &net->mctp.binds[hash]) {
		struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);

		if (!allow_net_any && msk->bind_net == MCTP_NET_ANY)
			continue;

		if (msk->bind_net != MCTP_NET_ANY && msk->bind_net != cb->net)
			continue;

		if (msk->bind_type != type)
			continue;

		if (!mctp_address_matches(msk->bind_addr, mh->dest))
		if (!mctp_address_matches(msk->bind_addr, dest))
			continue;

		return msk;
@@ -75,6 +74,54 @@ static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
	return NULL;
}

static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
{
	struct mctp_sock *msk;
	struct mctp_hdr *mh;
	u8 type;

	/* TODO: look up in skb->cb? */
	mh = mctp_hdr(skb);

	if (!skb_headlen(skb))
		return NULL;

	type = (*(u8 *)skb->data) & 0x7f;

	/* Look for binds in order of widening scope. A given destination or
	 * source address also implies matching on a particular network.
	 *
	 * - Matching destination and source
	 * - Matching destination
	 * - Matching source
	 * - Matching network, any address
	 * - Any network or address
	 */

	msk = mctp_lookup_bind_details(net, skb, type, mh->dest, mh->src,
				       false);
	if (msk)
		return msk;
	msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY, mh->src,
				       false);
	if (msk)
		return msk;
	msk = mctp_lookup_bind_details(net, skb, type, mh->dest, MCTP_ADDR_ANY,
				       false);
	if (msk)
		return msk;
	msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY,
				       MCTP_ADDR_ANY, false);
	if (msk)
		return msk;
	msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY,
				       MCTP_ADDR_ANY, true);
	if (msk)
		return msk;

	return NULL;
}

/* A note on the key allocations.
 *
 * struct net->mctp.keys contains our set of currently-allocated keys for
@@ -1671,7 +1718,7 @@ static int __net_init mctp_routes_net_init(struct net *net)
	struct netns_mctp *ns = &net->mctp;

	INIT_LIST_HEAD(&ns->routes);
	INIT_HLIST_HEAD(&ns->binds);
	hash_init(ns->binds);
	mutex_init(&ns->bind_lock);
	INIT_HLIST_HEAD(&ns->keys);
	spin_lock_init(&ns->keys_lock);