Commit 3c6b97a9 authored by Paolo Abeni's avatar Paolo Abeni
Browse files

Merge branch 'inet-frags-fully-use-rcu'

Eric Dumazet says:

====================
inet: frags: fully use RCU

While inet reassembly uses RCU, it is acquiring/releasing
a refcount on struct inet_frag_queue in fast path,
for no good reason.

This was mentioned in one patch changelog seven years ago :/

This series is removing these refcount changes, by extending
RCU sections.
====================

Link: https://patch.msgid.link/20250312082250.1803501-1-edumazet@google.com


Signed-off-by: default avatarPaolo Abeni <pabeni@redhat.com>
parents 24faa63b ca0359df
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -137,7 +137,7 @@ static inline void fqdir_pre_exit(struct fqdir *fqdir)
}
void fqdir_exit(struct fqdir *fqdir);

void inet_frag_kill(struct inet_frag_queue *q);
void inet_frag_kill(struct inet_frag_queue *q, int *refs);
void inet_frag_destroy(struct inet_frag_queue *q);
struct inet_frag_queue *inet_frag_find(struct fqdir *fqdir, void *key);

@@ -145,9 +145,9 @@ struct inet_frag_queue *inet_frag_find(struct fqdir *fqdir, void *key);
unsigned int inet_frag_rbtree_purge(struct rb_root *root,
				    enum skb_drop_reason reason);

static inline void inet_frag_put(struct inet_frag_queue *q)
static inline void inet_frag_putn(struct inet_frag_queue *q, int refs)
{
	if (refcount_dec_and_test(&q->refcnt))
	if (refs && refcount_sub_and_test(refs, &q->refcnt))
		inet_frag_destroy(q);
}

+3 −2
Original line number Diff line number Diff line
@@ -66,6 +66,7 @@ ip6frag_expire_frag_queue(struct net *net, struct frag_queue *fq)
{
	struct net_device *dev = NULL;
	struct sk_buff *head;
	int refs = 1;

	rcu_read_lock();
	/* Paired with the WRITE_ONCE() in fqdir_pre_exit(). */
@@ -77,7 +78,7 @@ ip6frag_expire_frag_queue(struct net *net, struct frag_queue *fq)
		goto out;

	fq->q.flags |= INET_FRAG_DROP;
	inet_frag_kill(&fq->q);
	inet_frag_kill(&fq->q, &refs);

	dev = dev_get_by_index_rcu(net, fq->iif);
	if (!dev)
@@ -109,7 +110,7 @@ ip6frag_expire_frag_queue(struct net *net, struct frag_queue *fq)
	spin_unlock(&fq->q.lock);
out_rcu_unlock:
	rcu_read_unlock();
	inet_frag_put(&fq->q);
	inet_frag_putn(&fq->q, refs);
}

/* Check if the upper layer header is truncated in the first fragment. */
+17 −10
Original line number Diff line number Diff line
@@ -31,7 +31,8 @@ static const char lowpan_frags_cache_name[] = "lowpan-frags";
static struct inet_frags lowpan_frags;

static int lowpan_frag_reasm(struct lowpan_frag_queue *fq, struct sk_buff *skb,
			     struct sk_buff *prev,  struct net_device *ldev);
			     struct sk_buff *prev, struct net_device *ldev,
			     int *refs);

static void lowpan_frag_init(struct inet_frag_queue *q, const void *a)
{
@@ -45,6 +46,7 @@ static void lowpan_frag_expire(struct timer_list *t)
{
	struct inet_frag_queue *frag = from_timer(frag, t, timer);
	struct frag_queue *fq;
	int refs = 1;

	fq = container_of(frag, struct frag_queue, q);

@@ -53,10 +55,10 @@ static void lowpan_frag_expire(struct timer_list *t)
	if (fq->q.flags & INET_FRAG_COMPLETE)
		goto out;

	inet_frag_kill(&fq->q);
	inet_frag_kill(&fq->q, &refs);
out:
	spin_unlock(&fq->q.lock);
	inet_frag_put(&fq->q);
	inet_frag_putn(&fq->q, refs);
}

static inline struct lowpan_frag_queue *
@@ -82,7 +84,8 @@ fq_find(struct net *net, const struct lowpan_802154_cb *cb,
}

static int lowpan_frag_queue(struct lowpan_frag_queue *fq,
			     struct sk_buff *skb, u8 frag_type)
			     struct sk_buff *skb, u8 frag_type,
			     int *refs)
{
	struct sk_buff *prev_tail;
	struct net_device *ldev;
@@ -143,7 +146,7 @@ static int lowpan_frag_queue(struct lowpan_frag_queue *fq,
		unsigned long orefdst = skb->_skb_refdst;

		skb->_skb_refdst = 0UL;
		res = lowpan_frag_reasm(fq, skb, prev_tail, ldev);
		res = lowpan_frag_reasm(fq, skb, prev_tail, ldev, refs);
		skb->_skb_refdst = orefdst;
		return res;
	}
@@ -162,11 +165,12 @@ static int lowpan_frag_queue(struct lowpan_frag_queue *fq,
 *	the last and the first frames arrived and all the bits are here.
 */
static int lowpan_frag_reasm(struct lowpan_frag_queue *fq, struct sk_buff *skb,
			     struct sk_buff *prev_tail, struct net_device *ldev)
			     struct sk_buff *prev_tail, struct net_device *ldev,
			     int *refs)
{
	void *reasm_data;

	inet_frag_kill(&fq->q);
	inet_frag_kill(&fq->q, refs);

	reasm_data = inet_frag_reasm_prepare(&fq->q, skb, prev_tail);
	if (!reasm_data)
@@ -300,17 +304,20 @@ int lowpan_frag_rcv(struct sk_buff *skb, u8 frag_type)
		goto err;
	}

	rcu_read_lock();
	fq = fq_find(net, cb, &hdr.source, &hdr.dest);
	if (fq != NULL) {
		int ret;
		int ret, refs = 0;

		spin_lock(&fq->q.lock);
		ret = lowpan_frag_queue(fq, skb, frag_type);
		ret = lowpan_frag_queue(fq, skb, frag_type, &refs);
		spin_unlock(&fq->q.lock);

		inet_frag_put(&fq->q);
		rcu_read_unlock();
		inet_frag_putn(&fq->q, refs);
		return ret;
	}
	rcu_read_unlock();

err:
	kfree_skb(skb);
+15 −16
Original line number Diff line number Diff line
@@ -145,8 +145,7 @@ static void inet_frags_free_cb(void *ptr, void *arg)
	}
	spin_unlock_bh(&fq->lock);

	if (refcount_sub_and_test(count, &fq->refcnt))
		inet_frag_destroy(fq);
	inet_frag_putn(fq, count);
}

static LLIST_HEAD(fqdir_free_list);
@@ -226,10 +225,10 @@ void fqdir_exit(struct fqdir *fqdir)
}
EXPORT_SYMBOL(fqdir_exit);

void inet_frag_kill(struct inet_frag_queue *fq)
void inet_frag_kill(struct inet_frag_queue *fq, int *refs)
{
	if (del_timer(&fq->timer))
		refcount_dec(&fq->refcnt);
		(*refs)++;

	if (!(fq->flags & INET_FRAG_COMPLETE)) {
		struct fqdir *fqdir = fq->fqdir;
@@ -244,7 +243,7 @@ void inet_frag_kill(struct inet_frag_queue *fq)
		if (!READ_ONCE(fqdir->dead)) {
			rhashtable_remove_fast(&fqdir->rhashtable, &fq->node,
					       fqdir->f->rhash_params);
			refcount_dec(&fq->refcnt);
			(*refs)++;
		} else {
			fq->flags |= INET_FRAG_HASH_DEAD;
		}
@@ -328,7 +327,8 @@ static struct inet_frag_queue *inet_frag_alloc(struct fqdir *fqdir,

	timer_setup(&q->timer, f->frag_expire, 0);
	spin_lock_init(&q->lock);
	refcount_set(&q->refcnt, 3);
	/* One reference for the timer, one for the hash table. */
	refcount_set(&q->refcnt, 2);

	return q;
}
@@ -350,15 +350,20 @@ static struct inet_frag_queue *inet_frag_create(struct fqdir *fqdir,
	*prev = rhashtable_lookup_get_insert_key(&fqdir->rhashtable, &q->key,
						 &q->node, f->rhash_params);
	if (*prev) {
		/* We could not insert in the hash table,
		 * we need to cancel what inet_frag_alloc()
		 * anticipated.
		 */
		int refs = 1;

		q->flags |= INET_FRAG_COMPLETE;
		inet_frag_kill(q);
		inet_frag_destroy(q);
		inet_frag_kill(q, &refs);
		inet_frag_putn(q, refs);
		return NULL;
	}
	return q;
}

/* TODO : call from rcu_read_lock() and no longer use refcount_inc_not_zero() */
struct inet_frag_queue *inet_frag_find(struct fqdir *fqdir, void *key)
{
	/* This pairs with WRITE_ONCE() in fqdir_pre_exit(). */
@@ -368,17 +373,11 @@ struct inet_frag_queue *inet_frag_find(struct fqdir *fqdir, void *key)
	if (!high_thresh || frag_mem_limit(fqdir) > high_thresh)
		return NULL;

	rcu_read_lock();

	prev = rhashtable_lookup(&fqdir->rhashtable, key, fqdir->f->rhash_params);
	if (!prev)
		fq = inet_frag_create(fqdir, key, &prev);
	if (!IS_ERR_OR_NULL(prev)) {
	if (!IS_ERR_OR_NULL(prev))
		fq = prev;
		if (!refcount_inc_not_zero(&fq->refcnt))
			fq = NULL;
	}
	rcu_read_unlock();
	return fq;
}
EXPORT_SYMBOL(inet_frag_find);
+19 −29
Original line number Diff line number Diff line
@@ -76,7 +76,8 @@ static u8 ip4_frag_ecn(u8 tos)
static struct inet_frags ip4_frags;

static int ip_frag_reasm(struct ipq *qp, struct sk_buff *skb,
			 struct sk_buff *prev_tail, struct net_device *dev);
			 struct sk_buff *prev_tail, struct net_device *dev,
			 int *refs);


static void ip4_frag_init(struct inet_frag_queue *q, const void *a)
@@ -107,22 +108,6 @@ static void ip4_frag_free(struct inet_frag_queue *q)
		inet_putpeer(qp->peer);
}


/* Destruction primitives. */

static void ipq_put(struct ipq *ipq)
{
	inet_frag_put(&ipq->q);
}

/* Kill ipq entry. It is not destroyed immediately,
 * because caller (and someone more) holds reference count.
 */
static void ipq_kill(struct ipq *ipq)
{
	inet_frag_kill(&ipq->q);
}

static bool frag_expire_skip_icmp(u32 user)
{
	return user == IP_DEFRAG_AF_PACKET ||
@@ -143,6 +128,7 @@ static void ip_expire(struct timer_list *t)
	struct sk_buff *head = NULL;
	struct net *net;
	struct ipq *qp;
	int refs = 1;

	qp = container_of(frag, struct ipq, q);
	net = qp->q.fqdir->net;
@@ -159,7 +145,7 @@ static void ip_expire(struct timer_list *t)
		goto out;

	qp->q.flags |= INET_FRAG_DROP;
	ipq_kill(qp);
	inet_frag_kill(&qp->q, &refs);
	__IP_INC_STATS(net, IPSTATS_MIB_REASMFAILS);
	__IP_INC_STATS(net, IPSTATS_MIB_REASMTIMEOUT);

@@ -202,7 +188,7 @@ static void ip_expire(struct timer_list *t)
out_rcu_unlock:
	rcu_read_unlock();
	kfree_skb_reason(head, reason);
	ipq_put(qp);
	inet_frag_putn(&qp->q, refs);
}

/* Find the correct entry in the "incomplete datagrams" queue for
@@ -278,7 +264,7 @@ static int ip_frag_reinit(struct ipq *qp)
}

/* Add new segment to existing queue. */
static int ip_frag_queue(struct ipq *qp, struct sk_buff *skb)
static int ip_frag_queue(struct ipq *qp, struct sk_buff *skb, int *refs)
{
	struct net *net = qp->q.fqdir->net;
	int ihl, end, flags, offset;
@@ -298,7 +284,7 @@ static int ip_frag_queue(struct ipq *qp, struct sk_buff *skb)
	if (!(IPCB(skb)->flags & IPSKB_FRAG_COMPLETE) &&
	    unlikely(ip_frag_too_far(qp)) &&
	    unlikely(err = ip_frag_reinit(qp))) {
		ipq_kill(qp);
		inet_frag_kill(&qp->q, refs);
		goto err;
	}

@@ -382,10 +368,10 @@ static int ip_frag_queue(struct ipq *qp, struct sk_buff *skb)
		unsigned long orefdst = skb->_skb_refdst;

		skb->_skb_refdst = 0UL;
		err = ip_frag_reasm(qp, skb, prev_tail, dev);
		err = ip_frag_reasm(qp, skb, prev_tail, dev, refs);
		skb->_skb_refdst = orefdst;
		if (err)
			inet_frag_kill(&qp->q);
			inet_frag_kill(&qp->q, refs);
		return err;
	}

@@ -402,7 +388,7 @@ static int ip_frag_queue(struct ipq *qp, struct sk_buff *skb)
	err = -EINVAL;
	__IP_INC_STATS(net, IPSTATS_MIB_REASM_OVERLAPS);
discard_qp:
	inet_frag_kill(&qp->q);
	inet_frag_kill(&qp->q, refs);
	__IP_INC_STATS(net, IPSTATS_MIB_REASMFAILS);
err:
	kfree_skb_reason(skb, reason);
@@ -416,7 +402,8 @@ static bool ip_frag_coalesce_ok(const struct ipq *qp)

/* Build a new IP datagram from all its fragments. */
static int ip_frag_reasm(struct ipq *qp, struct sk_buff *skb,
			 struct sk_buff *prev_tail, struct net_device *dev)
			 struct sk_buff *prev_tail, struct net_device *dev,
			 int *refs)
{
	struct net *net = qp->q.fqdir->net;
	struct iphdr *iph;
@@ -424,7 +411,7 @@ static int ip_frag_reasm(struct ipq *qp, struct sk_buff *skb,
	int len, err;
	u8 ecn;

	ipq_kill(qp);
	inet_frag_kill(&qp->q, refs);

	ecn = ip_frag_ecn_table[qp->ecn];
	if (unlikely(ecn == 0xff)) {
@@ -496,18 +483,21 @@ int ip_defrag(struct net *net, struct sk_buff *skb, u32 user)
	__IP_INC_STATS(net, IPSTATS_MIB_REASMREQDS);

	/* Lookup (or create) queue header */
	rcu_read_lock();
	qp = ip_find(net, ip_hdr(skb), user, vif);
	if (qp) {
		int ret;
		int ret, refs = 0;

		spin_lock(&qp->q.lock);

		ret = ip_frag_queue(qp, skb);
		ret = ip_frag_queue(qp, skb, &refs);

		spin_unlock(&qp->q.lock);
		ipq_put(qp);
		rcu_read_unlock();
		inet_frag_putn(&qp->q, refs);
		return ret;
	}
	rcu_read_unlock();

	__IP_INC_STATS(net, IPSTATS_MIB_REASMFAILS);
	kfree_skb(skb);
Loading