Commit e6e14d58 authored by Kuniyuki Iwashima's avatar Kuniyuki Iwashima Committed by Jakub Kicinski
Browse files

ipv6: mcast: Don't hold RTNL for MCAST_ socket options.



In ip6_mc_source() and ip6_mc_msfilter(), per-socket mld data is
protected by lock_sock() and inet6_dev->mc_lock is also held for
some per-interface functions.

ip6_mc_find_dev_rtnl() only depends on RTNL.  If we want to remove
it, we need to check inet6_dev->dead under mc_lock to close the race
with addrconf_ifdown(), as mentioned earlier.

Let's do that and drop RTNL for the rest of MCAST_ socket options.

Note that ip6_mc_msfilter() has unnecessary lock dances and they
are integrated into one to avoid the last-minute error and simplify
the error handling.

Signed-off-by: default avatarKuniyuki Iwashima <kuniyu@google.com>
Reviewed-by: default avatarEric Dumazet <edumazet@google.com>
Link: https://patch.msgid.link/20250702230210.3115355-10-kuni1840@gmail.com


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parent 1e589db3
Loading
Loading
Loading
Loading
+0 −5
Original line number Diff line number Diff line
@@ -123,11 +123,6 @@ static bool setsockopt_needs_rtnl(int optname)
	case IPV6_ADDRFORM:
	case IPV6_JOIN_ANYCAST:
	case IPV6_LEAVE_ANYCAST:
	case MCAST_JOIN_SOURCE_GROUP:
	case MCAST_LEAVE_SOURCE_GROUP:
	case MCAST_BLOCK_SOURCE:
	case MCAST_UNBLOCK_SOURCE:
	case MCAST_MSFILTER:
		return true;
	}
	return false;
+45 −29
Original line number Diff line number Diff line
@@ -302,31 +302,36 @@ int ipv6_sock_mc_drop(struct sock *sk, int ifindex, const struct in6_addr *addr)
}
EXPORT_SYMBOL(ipv6_sock_mc_drop);

static struct inet6_dev *ip6_mc_find_dev_rtnl(struct net *net,
static struct inet6_dev *ip6_mc_find_dev(struct net *net,
					 const struct in6_addr *group,
					 int ifindex)
{
	struct net_device *dev = NULL;
	struct inet6_dev *idev = NULL;
	struct inet6_dev *idev;

	if (ifindex == 0) {
		struct rt6_info *rt = rt6_lookup(net, group, NULL, 0, NULL, 0);
		struct rt6_info *rt;

		rcu_read_lock();
		rt = rt6_lookup(net, group, NULL, 0, NULL, 0);
		if (rt) {
			dev = rt->dst.dev;
			dev = dst_dev(&rt->dst);
			dev_hold(dev);
			ip6_rt_put(rt);
		}
		rcu_read_unlock();
	} else {
		dev = __dev_get_by_index(net, ifindex);
		dev = dev_get_by_index(net, ifindex);
	}

	if (!dev)
		return NULL;
	idev = __in6_dev_get(dev);

	idev = in6_dev_get(dev);
	dev_put(dev);

	if (!idev)
		return NULL;
	if (idev->dead)
		return NULL;

	return idev;
}

@@ -356,14 +361,14 @@ void ipv6_sock_mc_close(struct sock *sk)
int ip6_mc_source(int add, int omode, struct sock *sk,
		  struct group_source_req *pgsr)
{
	struct ipv6_pinfo *inet6 = inet6_sk(sk);
	struct in6_addr *source, *group;
	struct net *net = sock_net(sk);
	struct ipv6_mc_socklist *pmc;
	struct inet6_dev *idev;
	struct ipv6_pinfo *inet6 = inet6_sk(sk);
	struct ip6_sf_socklist *psl;
	struct net *net = sock_net(sk);
	int i, j, rv;
	struct inet6_dev *idev;
	int leavegroup = 0;
	int i, j, rv;
	int err;

	source = &((struct sockaddr_in6 *)&pgsr->gsr_source)->sin6_addr;
@@ -372,13 +377,19 @@ int ip6_mc_source(int add, int omode, struct sock *sk,
	if (!ipv6_addr_is_multicast(group))
		return -EINVAL;

	idev = ip6_mc_find_dev_rtnl(net, group, pgsr->gsr_interface);
	idev = ip6_mc_find_dev(net, group, pgsr->gsr_interface);
	if (!idev)
		return -ENODEV;

	mutex_lock(&idev->mc_lock);

	if (idev->dead) {
		err = -ENODEV;
		goto done;
	}

	err = -EADDRNOTAVAIL;

	mutex_lock(&idev->mc_lock);
	for_each_pmc_socklock(inet6, sk, pmc) {
		if (pgsr->gsr_interface && pmc->ifindex != pgsr->gsr_interface)
			continue;
@@ -475,6 +486,7 @@ int ip6_mc_source(int add, int omode, struct sock *sk,
	ip6_mc_add_src(idev, group, omode, 1, source, 1);
done:
	mutex_unlock(&idev->mc_lock);
	in6_dev_put(idev);
	if (leavegroup)
		err = ipv6_sock_mc_drop(sk, pgsr->gsr_interface, group);
	return err;
@@ -483,12 +495,12 @@ int ip6_mc_source(int add, int omode, struct sock *sk,
int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf,
		    struct sockaddr_storage *list)
{
	const struct in6_addr *group;
	struct ipv6_mc_socklist *pmc;
	struct inet6_dev *idev;
	struct ipv6_pinfo *inet6 = inet6_sk(sk);
	struct ip6_sf_socklist *newpsl, *psl;
	struct net *net = sock_net(sk);
	const struct in6_addr *group;
	struct ipv6_mc_socklist *pmc;
	struct inet6_dev *idev;
	int leavegroup = 0;
	int i, err;

@@ -500,10 +512,17 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf,
	    gsf->gf_fmode != MCAST_EXCLUDE)
		return -EINVAL;

	idev = ip6_mc_find_dev_rtnl(net, group, gsf->gf_interface);
	idev = ip6_mc_find_dev(net, group, gsf->gf_interface);
	if (!idev)
		return -ENODEV;

	mutex_lock(&idev->mc_lock);

	if (idev->dead) {
		err = -ENODEV;
		goto done;
	}

	err = 0;

	if (gsf->gf_fmode == MCAST_INCLUDE && gsf->gf_numsrc == 0) {
@@ -536,24 +555,19 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf,
			psin6 = (struct sockaddr_in6 *)list;
			newpsl->sl_addr[i] = psin6->sin6_addr;
		}
		mutex_lock(&idev->mc_lock);

		err = ip6_mc_add_src(idev, group, gsf->gf_fmode,
				     newpsl->sl_count, newpsl->sl_addr, 0);
		if (err) {
			mutex_unlock(&idev->mc_lock);
			sock_kfree_s(sk, newpsl, struct_size(newpsl, sl_addr,
							     newpsl->sl_max));
			goto done;
		}
		mutex_unlock(&idev->mc_lock);
	} else {
		newpsl = NULL;
		mutex_lock(&idev->mc_lock);
		ip6_mc_add_src(idev, group, gsf->gf_fmode, 0, NULL, 0);
		mutex_unlock(&idev->mc_lock);
	}

	mutex_lock(&idev->mc_lock);
	psl = sock_dereference(pmc->sflist, sk);
	if (psl) {
		ip6_mc_del_src(idev, group, pmc->sfmode,
@@ -563,12 +577,14 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf,
	} else {
		ip6_mc_del_src(idev, group, pmc->sfmode, 0, NULL, 0);
	}

	rcu_assign_pointer(pmc->sflist, newpsl);
	mutex_unlock(&idev->mc_lock);
	kfree_rcu(psl, rcu);
	pmc->sfmode = gsf->gf_fmode;
	err = 0;
done:
	mutex_unlock(&idev->mc_lock);
	in6_dev_put(idev);
	if (leavegroup)
		err = ipv6_sock_mc_drop(sk, gsf->gf_interface, group);
	return err;