Commit b3b6babf authored by Kuniyuki Iwashima's avatar Kuniyuki Iwashima Committed by Jakub Kicinski
Browse files

ipmr: Free mr_table after RCU grace period.



With CONFIG_IP_MROUTE_MULTIPLE_TABLES=n, ipmr_fib_lookup()
does not check if net->ipv4.mrt is NULL.

Since default_device_exit_batch() is called after ->exit_rtnl(),
a device could receive IGMP packets and access net->ipv4.mrt
during/after ipmr_rules_exit_rtnl().

If ipmr_rules_exit_rtnl() had already cleared it and freed the
memory, the access would trigger null-ptr-deref or use-after-free.

Let's fix it by using RCU helper and free mrt after RCU grace
period.

In addition, check_net(net) is added to mroute_clean_tables()
and ipmr_cache_unresolved() to synchronise via mfc_unres_lock.
This prevents ipmr_cache_unresolved() from putting skb into
c->_c.mfc_un.unres.unresolved after mroute_clean_tables()
purges it.

For the same reason, timer_shutdown_sync() is moved after
mroute_clean_tables().

Since rhltable_destroy() holds mutex internally, rcu_work is
used, and it is placed as the first member because rcu_head
must be placed within <4K offset.  mr_table is alraedy 3864
bytes without rcu_work.

Note that IP6MR is not yet converted to ->exit_rtnl(), so this
change is not needed for now but will be.

Fixes: b22b0186 ("ipmr: Convert ipmr_net_exit_batch() to ->exit_rtnl().")
Signed-off-by: default avatarKuniyuki Iwashima <kuniyu@google.com>
Link: https://patch.msgid.link/20260423053456.4097409-1-kuniyu@google.com


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parent 5b0c911b
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -226,6 +226,7 @@ struct mr_table_ops {

/**
 * struct mr_table - a multicast routing table
 * @work: used for table destruction
 * @list: entry within a list of multicast routing tables
 * @net: net where this table belongs
 * @ops: protocol specific operations
@@ -243,6 +244,7 @@ struct mr_table_ops {
 * @mroute_reg_vif_num: PIM-device vif index
 */
struct mr_table {
	struct rcu_work		work;
	struct list_head	list;
	possible_net_t		net;
	struct mr_table_ops	ops;
@@ -274,6 +276,7 @@ void vif_device_init(struct vif_device *v,
		     unsigned short flags,
		     unsigned short get_iflink_mask);

void mr_table_free(struct mr_table *mrt);
struct mr_table *
mr_table_alloc(struct net *net, u32 id,
	       struct mr_table_ops *ops,
+58 −50
Original line number Diff line number Diff line
@@ -151,16 +151,6 @@ static struct mr_table *__ipmr_get_table(struct net *net, u32 id)
	return NULL;
}

static struct mr_table *ipmr_get_table(struct net *net, u32 id)
{
	struct mr_table *mrt;

	rcu_read_lock();
	mrt = __ipmr_get_table(net, id);
	rcu_read_unlock();
	return mrt;
}

static int ipmr_fib_lookup(struct net *net, struct flowi4 *flp4,
			   struct mr_table **mrt)
{
@@ -293,7 +283,7 @@ static void __net_exit ipmr_rules_exit_rtnl(struct net *net,
	struct mr_table *mrt, *next;

	list_for_each_entry_safe(mrt, next, &net->ipv4.mr_tables, list) {
		list_del(&mrt->list);
		list_del_rcu(&mrt->list);
		ipmr_free_table(mrt, dev_kill_list);
	}
}
@@ -315,28 +305,30 @@ bool ipmr_rule_default(const struct fib_rule *rule)
}
EXPORT_SYMBOL(ipmr_rule_default);
#else
#define ipmr_for_each_table(mrt, net) \
	for (mrt = net->ipv4.mrt; mrt; mrt = NULL)

static struct mr_table *ipmr_mr_table_iter(struct net *net,
					   struct mr_table *mrt)
{
	if (!mrt)
		return net->ipv4.mrt;
		return rcu_dereference(net->ipv4.mrt);
	return NULL;
}

static struct mr_table *ipmr_get_table(struct net *net, u32 id)
static struct mr_table *__ipmr_get_table(struct net *net, u32 id)
{
	return net->ipv4.mrt;
	return rcu_dereference_check(net->ipv4.mrt,
				     lockdep_rtnl_is_held() ||
				     !rcu_access_pointer(net->ipv4.mrt));
}

#define __ipmr_get_table ipmr_get_table
#define ipmr_for_each_table(mrt, net)				\
	for (mrt = __ipmr_get_table(net, 0); mrt; mrt = NULL)

static int ipmr_fib_lookup(struct net *net, struct flowi4 *flp4,
			   struct mr_table **mrt)
{
	*mrt = net->ipv4.mrt;
	*mrt = rcu_dereference(net->ipv4.mrt);
	if (!*mrt)
		return -EAGAIN;
	return 0;
}

@@ -347,7 +339,8 @@ static int __net_init ipmr_rules_init(struct net *net)
	mrt = ipmr_new_table(net, RT_TABLE_DEFAULT);
	if (IS_ERR(mrt))
		return PTR_ERR(mrt);
	net->ipv4.mrt = mrt;

	rcu_assign_pointer(net->ipv4.mrt, mrt);
	return 0;
}

@@ -358,9 +351,10 @@ static void __net_exit ipmr_rules_exit(struct net *net)
static void __net_exit ipmr_rules_exit_rtnl(struct net *net,
					    struct list_head *dev_kill_list)
{
	ipmr_free_table(net->ipv4.mrt, dev_kill_list);
	struct mr_table *mrt = rcu_dereference_protected(net->ipv4.mrt, 1);

	net->ipv4.mrt = NULL;
	RCU_INIT_POINTER(net->ipv4.mrt, NULL);
	ipmr_free_table(mrt, dev_kill_list);
}

static int ipmr_rules_dump(struct net *net, struct notifier_block *nb,
@@ -381,6 +375,17 @@ bool ipmr_rule_default(const struct fib_rule *rule)
EXPORT_SYMBOL(ipmr_rule_default);
#endif

static struct mr_table *ipmr_get_table(struct net *net, u32 id)
{
	struct mr_table *mrt;

	rcu_read_lock();
	mrt = __ipmr_get_table(net, id);
	rcu_read_unlock();

	return mrt;
}

static inline int ipmr_hash_cmp(struct rhashtable_compare_arg *arg,
				const void *ptr)
{
@@ -441,12 +446,11 @@ static void ipmr_free_table(struct mr_table *mrt, struct list_head *dev_kill_lis

	WARN_ON_ONCE(!mr_can_free_table(net));

	timer_shutdown_sync(&mrt->ipmr_expire_timer);
	mroute_clean_tables(mrt, MRT_FLUSH_VIFS | MRT_FLUSH_VIFS_STATIC |
			    MRT_FLUSH_MFC | MRT_FLUSH_MFC_STATIC,
			    &ipmr_dev_kill_list);
	rhltable_destroy(&mrt->mfc_hash);
	kfree(mrt);
	timer_shutdown_sync(&mrt->ipmr_expire_timer);
	mr_table_free(mrt);

	WARN_ON_ONCE(!net_initialized(net) && !list_empty(&ipmr_dev_kill_list));
	list_splice(&ipmr_dev_kill_list, dev_kill_list);
@@ -1135,12 +1139,19 @@ static int ipmr_cache_report(const struct mr_table *mrt,
static int ipmr_cache_unresolved(struct mr_table *mrt, vifi_t vifi,
				 struct sk_buff *skb, struct net_device *dev)
{
	struct net *net = read_pnet(&mrt->net);
	const struct iphdr *iph = ip_hdr(skb);
	struct mfc_cache *c;
	struct mfc_cache *c = NULL;
	bool found = false;
	int err;

	spin_lock_bh(&mfc_unres_lock);

	if (!check_net(net)) {
		err = -EINVAL;
		goto err;
	}

	list_for_each_entry(c, &mrt->mfc_unres_queue, _c.list) {
		if (c->mfc_mcastgrp == iph->daddr &&
		    c->mfc_origin == iph->saddr) {
@@ -1153,10 +1164,8 @@ static int ipmr_cache_unresolved(struct mr_table *mrt, vifi_t vifi,
		/* Create a new entry if allowable */
		c = ipmr_cache_alloc_unres();
		if (!c) {
			spin_unlock_bh(&mfc_unres_lock);

			kfree_skb(skb);
			return -ENOBUFS;
			err = -ENOBUFS;
			goto err;
		}

		/* Fill in the new cache entry */
@@ -1166,17 +1175,8 @@ static int ipmr_cache_unresolved(struct mr_table *mrt, vifi_t vifi,

		/* Reflect first query at mrouted. */
		err = ipmr_cache_report(mrt, skb, vifi, IGMPMSG_NOCACHE);

		if (err < 0) {
			/* If the report failed throw the cache entry
			   out - Brad Parker
			 */
			spin_unlock_bh(&mfc_unres_lock);

			ipmr_cache_free(c);
			kfree_skb(skb);
			return err;
		}
		if (err < 0)
			goto err;

		atomic_inc(&mrt->cache_resolve_queue_len);
		list_add(&c->_c.list, &mrt->mfc_unres_queue);
@@ -1189,18 +1189,26 @@ static int ipmr_cache_unresolved(struct mr_table *mrt, vifi_t vifi,

	/* See if we can append the packet */
	if (c->_c.mfc_un.unres.unresolved.qlen > 3) {
		kfree_skb(skb);
		c = NULL;
		err = -ENOBUFS;
	} else {
		goto err;
	}

	if (dev) {
		skb->dev = dev;
		skb->skb_iif = dev->ifindex;
	}

	skb_queue_tail(&c->_c.mfc_un.unres.unresolved, skb);
		err = 0;
	}

	spin_unlock_bh(&mfc_unres_lock);
	return 0;

err:
	spin_unlock_bh(&mfc_unres_lock);
	if (c)
		ipmr_cache_free(c);
	kfree_skb(skb);
	return err;
}

@@ -1346,7 +1354,7 @@ static void mroute_clean_tables(struct mr_table *mrt, int flags,
	}

	if (flags & MRT_FLUSH_MFC) {
		if (atomic_read(&mrt->cache_resolve_queue_len) != 0) {
		if (atomic_read(&mrt->cache_resolve_queue_len) != 0 || !check_net(net)) {
			spin_lock_bh(&mfc_unres_lock);
			list_for_each_entry_safe(c, tmp, &mrt->mfc_unres_queue, list) {
				list_del(&c->list);
+16 −0
Original line number Diff line number Diff line
@@ -28,6 +28,20 @@ void vif_device_init(struct vif_device *v,
		v->link = dev->ifindex;
}

static void __mr_free_table(struct work_struct *work)
{
	struct mr_table *mrt = container_of(to_rcu_work(work),
					    struct mr_table, work);

	rhltable_destroy(&mrt->mfc_hash);
	kfree(mrt);
}

void mr_table_free(struct mr_table *mrt)
{
	queue_rcu_work(system_unbound_wq, &mrt->work);
}

struct mr_table *
mr_table_alloc(struct net *net, u32 id,
	       struct mr_table_ops *ops,
@@ -50,6 +64,8 @@ mr_table_alloc(struct net *net, u32 id,
		kfree(mrt);
		return ERR_PTR(err);
	}

	INIT_RCU_WORK(&mrt->work, __mr_free_table);
	INIT_LIST_HEAD(&mrt->mfc_cache_list);
	INIT_LIST_HEAD(&mrt->mfc_unres_queue);