Commit e833eb25 authored by Kuniyuki Iwashima's avatar Kuniyuki Iwashima Committed by Jakub Kicinski
Browse files

mpls: Protect net->mpls.platform_label with a per-netns mutex.



MPLS (re)uses RTNL to protect net->mpls.platform_label,
but the lock does not need to be RTNL at all.

Let's protect net->mpls.platform_label with a dedicated
per-netns mutex.

Signed-off-by: default avatarKuniyuki Iwashima <kuniyu@google.com>
Reviewed-by: default avatarGuillaume Nault <gnault@redhat.com>
Link: https://patch.msgid.link/20251029173344.2934622-13-kuniyu@google.com


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parent fb2b77b9
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ struct netns_mpls {
	int default_ttl;
	size_t platform_labels;
	struct mpls_route __rcu * __rcu *platform_label;
	struct mutex platform_mutex;

	struct ctl_table_header *ctl;
};
+36 −19
Original line number Diff line number Diff line
@@ -79,8 +79,8 @@ static struct mpls_route *mpls_route_input(struct net *net, unsigned int index)
{
	struct mpls_route __rcu **platform_label;

	platform_label = rtnl_dereference(net->mpls.platform_label);
	return rtnl_dereference(platform_label[index]);
	platform_label = mpls_dereference(net, net->mpls.platform_label);
	return mpls_dereference(net, platform_label[index]);
}

static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned int index)
@@ -578,10 +578,8 @@ static void mpls_route_update(struct net *net, unsigned index,
	struct mpls_route __rcu **platform_label;
	struct mpls_route *rt;

	ASSERT_RTNL();

	platform_label = rtnl_dereference(net->mpls.platform_label);
	rt = rtnl_dereference(platform_label[index]);
	platform_label = mpls_dereference(net, net->mpls.platform_label);
	rt = mpls_dereference(net, platform_label[index]);
	rcu_assign_pointer(platform_label[index], new);

	mpls_notify_route(net, index, rt, new, info);
@@ -1472,8 +1470,6 @@ static struct mpls_dev *mpls_add_dev(struct net_device *dev)
	int err = -ENOMEM;
	int i;

	ASSERT_RTNL();

	mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
	if (!mdev)
		return ERR_PTR(err);
@@ -1633,6 +1629,8 @@ static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
	unsigned int flags;
	int err;

	mutex_lock(&net->mpls.platform_mutex);

	if (event == NETDEV_REGISTER) {
		mdev = mpls_add_dev(dev);
		if (IS_ERR(mdev)) {
@@ -1695,9 +1693,11 @@ static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
	}

out:
	mutex_unlock(&net->mpls.platform_mutex);
	return NOTIFY_OK;

err:
	mutex_unlock(&net->mpls.platform_mutex);
	return notifier_from_errno(err);
}

@@ -1973,6 +1973,7 @@ static int rtm_to_route_config(struct sk_buff *skb,
static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh,
			     struct netlink_ext_ack *extack)
{
	struct net *net = sock_net(skb->sk);
	struct mpls_route_config *cfg;
	int err;

@@ -1984,7 +1985,9 @@ static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh,
	if (err < 0)
		goto out;

	mutex_lock(&net->mpls.platform_mutex);
	err = mpls_route_del(cfg, extack);
	mutex_unlock(&net->mpls.platform_mutex);
out:
	kfree(cfg);

@@ -1995,6 +1998,7 @@ static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh,
static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh,
			     struct netlink_ext_ack *extack)
{
	struct net *net = sock_net(skb->sk);
	struct mpls_route_config *cfg;
	int err;

@@ -2006,7 +2010,9 @@ static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh,
	if (err < 0)
		goto out;

	mutex_lock(&net->mpls.platform_mutex);
	err = mpls_route_add(cfg, extack);
	mutex_unlock(&net->mpls.platform_mutex);
out:
	kfree(cfg);

@@ -2407,6 +2413,8 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
	u8 n_labels;
	int err;

	mutex_lock(&net->mpls.platform_mutex);

	err = mpls_valid_getroute_req(in_skb, in_nlh, tb, extack);
	if (err < 0)
		goto errout;
@@ -2450,7 +2458,8 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
			goto errout_free;
		}

		return rtnl_unicast(skb, net, portid);
		err = rtnl_unicast(skb, net, portid);
		goto errout;
	}

	if (tb[RTA_NEWDST]) {
@@ -2542,12 +2551,14 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,

	err = rtnl_unicast(skb, net, portid);
errout:
	mutex_unlock(&net->mpls.platform_mutex);
	return err;

nla_put_failure:
	nlmsg_cancel(skb, nlh);
	err = -EMSGSIZE;
errout_free:
	mutex_unlock(&net->mpls.platform_mutex);
	kfree_skb(skb);
	return err;
}
@@ -2603,9 +2614,10 @@ static int resize_platform_label_table(struct net *net, size_t limit)
		       lo->addr_len);
	}

	rtnl_lock();
	mutex_lock(&net->mpls.platform_mutex);

	/* Remember the original table */
	old = rtnl_dereference(net->mpls.platform_label);
	old = mpls_dereference(net, net->mpls.platform_label);
	old_limit = net->mpls.platform_labels;

	/* Free any labels beyond the new table */
@@ -2636,7 +2648,7 @@ static int resize_platform_label_table(struct net *net, size_t limit)
	net->mpls.platform_labels = limit;
	rcu_assign_pointer(net->mpls.platform_label, labels);

	rtnl_unlock();
	mutex_unlock(&net->mpls.platform_mutex);

	mpls_rt_free(rt2);
	mpls_rt_free(rt0);
@@ -2709,12 +2721,13 @@ static const struct ctl_table mpls_table[] = {
	},
};

static int mpls_net_init(struct net *net)
static __net_init int mpls_net_init(struct net *net)
{
	size_t table_size = ARRAY_SIZE(mpls_table);
	struct ctl_table *table;
	int i;

	mutex_init(&net->mpls.platform_mutex);
	net->mpls.platform_labels = 0;
	net->mpls.platform_label = NULL;
	net->mpls.ip_ttl_propagate = 1;
@@ -2740,7 +2753,7 @@ static int mpls_net_init(struct net *net)
	return 0;
}

static void mpls_net_exit(struct net *net)
static __net_exit void mpls_net_exit(struct net *net)
{
	struct mpls_route __rcu **platform_label;
	size_t platform_labels;
@@ -2760,16 +2773,20 @@ static void mpls_net_exit(struct net *net)
	 * As such no additional rcu synchronization is necessary when
	 * freeing the platform_label table.
	 */
	rtnl_lock();
	platform_label = rtnl_dereference(net->mpls.platform_label);
	mutex_lock(&net->mpls.platform_mutex);

	platform_label = mpls_dereference(net, net->mpls.platform_label);
	platform_labels = net->mpls.platform_labels;

	for (index = 0; index < platform_labels; index++) {
		struct mpls_route *rt = rtnl_dereference(platform_label[index]);
		RCU_INIT_POINTER(platform_label[index], NULL);
		struct mpls_route *rt;

		rt = mpls_dereference(net, platform_label[index]);
		mpls_notify_route(net, index, rt, NULL, NULL);
		mpls_rt_free(rt);
	}
	rtnl_unlock();

	mutex_unlock(&net->mpls.platform_mutex);

	kvfree(platform_label);
}
+6 −1
Original line number Diff line number Diff line
@@ -185,6 +185,11 @@ static inline struct mpls_entry_decoded mpls_entry_decode(struct mpls_shim_hdr *
	return result;
}

#define mpls_dereference(net, p)					\
	rcu_dereference_protected(					\
		(p),							\
		lockdep_is_held(&(net)->mpls.platform_mutex))

static inline struct mpls_dev *mpls_dev_rcu(const struct net_device *dev)
{
	return rcu_dereference(dev->mpls_ptr);
@@ -193,7 +198,7 @@ static inline struct mpls_dev *mpls_dev_rcu(const struct net_device *dev)
static inline struct mpls_dev *mpls_dev_get(const struct net *net,
					    const struct net_device *dev)
{
	return rcu_dereference_rtnl(dev->mpls_ptr);
	return mpls_dereference(net, dev->mpls_ptr);
}

int nla_put_labels(struct sk_buff *skb, int attrtype,  u8 labels,