Commit 1767bb2d authored by Kuniyuki Iwashima's avatar Kuniyuki Iwashima Committed by Jakub Kicinski
Browse files

ipv6: mcast: Don't hold RTNL for IPV6_ADD_MEMBERSHIP and MCAST_JOIN_GROUP.



In __ipv6_sock_mc_join(), per-socket mld data is protected by lock_sock(),
and only __dev_get_by_index() requires RTNL.

Let's use dev_get_by_index() and drop RTNL for IPV6_ADD_MEMBERSHIP and
MCAST_JOIN_GROUP.

Note that we must call rt6_lookup() and dev_hold() under RCU.

If rt6_lookup() returns an entry from the exception table, dst_dev_put()
could change rt->dev.dst to loopback concurrently, and the original device
could lose the refcount before dev_hold() and unblock device registration.

dst_dev_put() is called from NETDEV_UNREGISTER and synchronize_net() follows
it, so as long as rt6_lookup() and dev_hold() are called within the same
RCU critical section, the dev is alive.

Even if the race happens, they are synchronised by idev->dead and mcast
addresses are cleaned up.

For the racy access to rt->dst.dev, we use dst_dev().

Signed-off-by: default avatarKuniyuki Iwashima <kuniyu@google.com>
Reviewed-by: default avatarEric Dumazet <edumazet@google.com>
Link: https://patch.msgid.link/20250702230210.3115355-7-kuni1840@gmail.com


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parent e01b193e
Loading
Loading
Loading
Loading
+0 −2
Original line number Diff line number Diff line
@@ -121,11 +121,9 @@ static bool setsockopt_needs_rtnl(int optname)
{
	switch (optname) {
	case IPV6_ADDRFORM:
	case IPV6_ADD_MEMBERSHIP:
	case IPV6_DROP_MEMBERSHIP:
	case IPV6_JOIN_ANYCAST:
	case IPV6_LEAVE_ANYCAST:
	case MCAST_JOIN_GROUP:
	case MCAST_LEAVE_GROUP:
	case MCAST_JOIN_SOURCE_GROUP:
	case MCAST_LEAVE_SOURCE_GROUP:
+13 −11
Original line number Diff line number Diff line
@@ -175,14 +175,12 @@ static int unsolicited_report_interval(struct inet6_dev *idev)
static int __ipv6_sock_mc_join(struct sock *sk, int ifindex,
			       const struct in6_addr *addr, unsigned int mode)
{
	struct net_device *dev = NULL;
	struct ipv6_mc_socklist *mc_lst;
	struct ipv6_pinfo *np = inet6_sk(sk);
	struct ipv6_mc_socklist *mc_lst;
	struct net *net = sock_net(sk);
	struct net_device *dev = NULL;
	int err;

	ASSERT_RTNL();

	if (!ipv6_addr_is_multicast(addr))
		return -EINVAL;

@@ -202,13 +200,18 @@ static int __ipv6_sock_mc_join(struct sock *sk, int ifindex,

	if (ifindex == 0) {
		struct rt6_info *rt;

		rcu_read_lock();
		rt = rt6_lookup(net, addr, NULL, 0, NULL, 0);
		if (rt) {
			dev = rt->dst.dev;
			dev = dst_dev(&rt->dst);
			dev_hold(dev);
			ip6_rt_put(rt);
		}
	} else
		dev = __dev_get_by_index(net, ifindex);
		rcu_read_unlock();
	} else {
		dev = dev_get_by_index(net, ifindex);
	}

	if (!dev) {
		sock_kfree_s(sk, mc_lst, sizeof(*mc_lst));
@@ -219,12 +222,11 @@ static int __ipv6_sock_mc_join(struct sock *sk, int ifindex,
	mc_lst->sfmode = mode;
	RCU_INIT_POINTER(mc_lst->sflist, NULL);

	/*
	 *	now add/increase the group membership on the device
	 */

	/* now add/increase the group membership on the device */
	err = __ipv6_dev_mc_inc(dev, addr, mode);

	dev_put(dev);

	if (err) {
		sock_kfree_s(sk, mc_lst, sizeof(*mc_lst));
		return err;