Commit 73e40539 authored by Kuniyuki Iwashima's avatar Kuniyuki Iwashima Committed by Jakub Kicinski
Browse files

mpls: Add mpls_route_input().



mpls_route_input_rcu() is called from mpls_forward() and
mpls_getroute().

The former is under RCU, and the latter is under RTNL, so
mpls_route_input_rcu() uses rcu_dereference_rtnl().

Let's use rcu_dereference() in mpls_route_input_rcu() and
add an RTNL variant for mpls_getroute().

Later, we will remove rtnl_dereference() there.

Signed-off-by: default avatarKuniyuki Iwashima <kuniyu@google.com>
Reviewed-by: default avatarGuillaume Nault <gnault@redhat.com>
Link: https://patch.msgid.link/20251029173344.2934622-9-kuniyu@google.com


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parent 1fb462de
Loading
Loading
Loading
Loading
+18 −10
Original line number Diff line number Diff line
@@ -75,16 +75,23 @@ static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
		       struct nlmsghdr *nlh, struct net *net, u32 portid,
		       unsigned int nlm_flags);

static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned index)
static struct mpls_route *mpls_route_input(struct net *net, unsigned int index)
{
	struct mpls_route *rt = NULL;
	struct mpls_route __rcu **platform_label;

	if (index < net->mpls.platform_labels) {
		struct mpls_route __rcu **platform_label =
			rcu_dereference_rtnl(net->mpls.platform_label);
		rt = rcu_dereference_rtnl(platform_label[index]);
	platform_label = rtnl_dereference(net->mpls.platform_label);
	return rtnl_dereference(platform_label[index]);
}
	return rt;

static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned int index)
{
	struct mpls_route __rcu **platform_label;

	if (index >= net->mpls.platform_labels)
		return NULL;

	platform_label = rcu_dereference(net->mpls.platform_label);
	return rcu_dereference(platform_label[index]);
}

bool mpls_output_possible(const struct net_device *dev)
@@ -2373,12 +2380,12 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
	u32 portid = NETLINK_CB(in_skb).portid;
	u32 in_label = LABEL_NOT_SPECIFIED;
	struct nlattr *tb[RTA_MAX + 1];
	struct mpls_route *rt = NULL;
	u32 labels[MAX_NEW_LABELS];
	struct mpls_shim_hdr *hdr;
	unsigned int hdr_size = 0;
	const struct mpls_nh *nh;
	struct net_device *dev;
	struct mpls_route *rt;
	struct rtmsg *rtm, *r;
	struct nlmsghdr *nlh;
	struct sk_buff *skb;
@@ -2406,7 +2413,8 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
		}
	}

	rt = mpls_route_input_rcu(net, in_label);
	if (in_label < net->mpls.platform_labels)
		rt = mpls_route_input(net, in_label);
	if (!rt) {
		err = -ENETUNREACH;
		goto errout;