Commit e218ae40 authored by Jakub Kicinski's avatar Jakub Kicinski
Browse files

Merge branch 'net-fix-uaf-of-sk_dst_get-sk-dev'

Kuniyuki Iwashima says:

====================
net: Fix UAF of sk_dst_get(sk)->dev.

syzbot caught use-after-free of sk_dst_get(sk)->dev,
which was not fetched under RCU nor RTNL. [0]

Patch 1 ~ 5, 7 fix UAF in smc, tcp, ktls, mptcp
Patch 6 fixes dst ref leak in mptcp

[0]: https://lore.kernel.org/68c237c7.050a0220.3c6139.0036.GAE@google.com

v1: https://lore.kernel.org/20250911030620.1284754-1-kuniyu@google.com
====================

Link: https://patch.msgid.link/20250916214758.650211-1-kuniyu@google.com


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parents 6b957c0a 893c49a7
Loading
Loading
Loading
Loading
+7 −2
Original line number Diff line number Diff line
@@ -501,10 +501,15 @@ void mptcp_active_enable(struct sock *sk)
	struct mptcp_pernet *pernet = mptcp_get_pernet(sock_net(sk));

	if (atomic_read(&pernet->active_disable_times)) {
		struct dst_entry *dst = sk_dst_get(sk);
		struct net_device *dev;
		struct dst_entry *dst;

		if (dst && dst->dev && (dst->dev->flags & IFF_LOOPBACK))
		rcu_read_lock();
		dst = __sk_dst_get(sk);
		dev = dst ? dst_dev_rcu(dst) : NULL;
		if (dev && (dev->flags & IFF_LOOPBACK))
			atomic_set(&pernet->active_disable_times, 0);
		rcu_read_unlock();
	}
}

+35 −32
Original line number Diff line number Diff line
@@ -509,10 +509,10 @@ static bool smc_clc_msg_hdr_valid(struct smc_clc_msg_hdr *clcm, bool check_trl)
}

/* find ipv4 addr on device and get the prefix len, fill CLC proposal msg */
static int smc_clc_prfx_set4_rcu(struct dst_entry *dst, __be32 ipv4,
static int smc_clc_prfx_set4_rcu(struct net_device *dev, __be32 ipv4,
				 struct smc_clc_msg_proposal_prefix *prop)
{
	struct in_device *in_dev = __in_dev_get_rcu(dst->dev);
	struct in_device *in_dev = __in_dev_get_rcu(dev);
	const struct in_ifaddr *ifa;

	if (!in_dev)
@@ -530,12 +530,12 @@ static int smc_clc_prfx_set4_rcu(struct dst_entry *dst, __be32 ipv4,
}

/* fill CLC proposal msg with ipv6 prefixes from device */
static int smc_clc_prfx_set6_rcu(struct dst_entry *dst,
static int smc_clc_prfx_set6_rcu(struct net_device *dev,
				 struct smc_clc_msg_proposal_prefix *prop,
				 struct smc_clc_ipv6_prefix *ipv6_prfx)
{
#if IS_ENABLED(CONFIG_IPV6)
	struct inet6_dev *in6_dev = __in6_dev_get(dst->dev);
	struct inet6_dev *in6_dev = __in6_dev_get(dev);
	struct inet6_ifaddr *ifa;
	int cnt = 0;

@@ -564,41 +564,44 @@ static int smc_clc_prfx_set(struct socket *clcsock,
			    struct smc_clc_msg_proposal_prefix *prop,
			    struct smc_clc_ipv6_prefix *ipv6_prfx)
{
	struct dst_entry *dst = sk_dst_get(clcsock->sk);
	struct sockaddr_storage addrs;
	struct sockaddr_in6 *addr6;
	struct sockaddr_in *addr;
	struct net_device *dev;
	struct dst_entry *dst;
	int rc = -ENOENT;

	if (!dst) {
		rc = -ENOTCONN;
		goto out;
	}
	if (!dst->dev) {
		rc = -ENODEV;
		goto out_rel;
	}
	/* get address to which the internal TCP socket is bound */
	if (kernel_getsockname(clcsock, (struct sockaddr *)&addrs) < 0)
		goto out_rel;
		goto out;

	/* analyze IP specific data of net_device belonging to TCP socket */
	addr6 = (struct sockaddr_in6 *)&addrs;

	rcu_read_lock();

	dst = __sk_dst_get(clcsock->sk);
	dev = dst ? dst_dev_rcu(dst) : NULL;
	if (!dev) {
		rc = -ENODEV;
		goto out_unlock;
	}

	if (addrs.ss_family == PF_INET) {
		/* IPv4 */
		addr = (struct sockaddr_in *)&addrs;
		rc = smc_clc_prfx_set4_rcu(dst, addr->sin_addr.s_addr, prop);
		rc = smc_clc_prfx_set4_rcu(dev, addr->sin_addr.s_addr, prop);
	} else if (ipv6_addr_v4mapped(&addr6->sin6_addr)) {
		/* mapped IPv4 address - peer is IPv4 only */
		rc = smc_clc_prfx_set4_rcu(dst, addr6->sin6_addr.s6_addr32[3],
		rc = smc_clc_prfx_set4_rcu(dev, addr6->sin6_addr.s6_addr32[3],
					   prop);
	} else {
		/* IPv6 */
		rc = smc_clc_prfx_set6_rcu(dst, prop, ipv6_prfx);
		rc = smc_clc_prfx_set6_rcu(dev, prop, ipv6_prfx);
	}

out_unlock:
	rcu_read_unlock();
out_rel:
	dst_release(dst);
out:
	return rc;
}
@@ -654,26 +657,26 @@ static int smc_clc_prfx_match6_rcu(struct net_device *dev,
int smc_clc_prfx_match(struct socket *clcsock,
		       struct smc_clc_msg_proposal_prefix *prop)
{
	struct dst_entry *dst = sk_dst_get(clcsock->sk);
	struct net_device *dev;
	struct dst_entry *dst;
	int rc;

	if (!dst) {
		rc = -ENOTCONN;
		goto out;
	}
	if (!dst->dev) {
	rcu_read_lock();

	dst = __sk_dst_get(clcsock->sk);
	dev = dst ? dst_dev_rcu(dst) : NULL;
	if (!dev) {
		rc = -ENODEV;
		goto out_rel;
		goto out;
	}
	rcu_read_lock();

	if (!prop->ipv6_prefixes_cnt)
		rc = smc_clc_prfx_match4_rcu(dst->dev, prop);
		rc = smc_clc_prfx_match4_rcu(dev, prop);
	else
		rc = smc_clc_prfx_match6_rcu(dst->dev, prop);
	rcu_read_unlock();
out_rel:
	dst_release(dst);
		rc = smc_clc_prfx_match6_rcu(dev, prop);
out:
	rcu_read_unlock();

	return rc;
}

+12 −15
Original line number Diff line number Diff line
@@ -1883,35 +1883,32 @@ static int smc_vlan_by_tcpsk_walk(struct net_device *lower_dev,
/* Determine vlan of internal TCP socket. */
int smc_vlan_by_tcpsk(struct socket *clcsock, struct smc_init_info *ini)
{
	struct dst_entry *dst = sk_dst_get(clcsock->sk);
	struct netdev_nested_priv priv;
	struct net_device *ndev;
	struct dst_entry *dst;
	int rc = 0;

	ini->vlan_id = 0;
	if (!dst) {
		rc = -ENOTCONN;
		goto out;
	}
	if (!dst->dev) {

	rcu_read_lock();

	dst = __sk_dst_get(clcsock->sk);
	ndev = dst ? dst_dev_rcu(dst) : NULL;
	if (!ndev) {
		rc = -ENODEV;
		goto out_rel;
		goto out;
	}

	ndev = dst->dev;
	if (is_vlan_dev(ndev)) {
		ini->vlan_id = vlan_dev_vlan_id(ndev);
		goto out_rel;
		goto out;
	}

	priv.data = (void *)&ini->vlan_id;
	rtnl_lock();
	netdev_walk_all_lower_dev(ndev, smc_vlan_by_tcpsk_walk, &priv);
	rtnl_unlock();

out_rel:
	dst_release(dst);
	netdev_walk_all_lower_dev_rcu(ndev, smc_vlan_by_tcpsk_walk, &priv);
out:
	rcu_read_unlock();

	return rc;
}

+22 −21
Original line number Diff line number Diff line
@@ -1126,37 +1126,38 @@ static void smc_pnet_find_ism_by_pnetid(struct net_device *ndev,
 */
void smc_pnet_find_roce_resource(struct sock *sk, struct smc_init_info *ini)
{
	struct dst_entry *dst = sk_dst_get(sk);

	if (!dst)
		goto out;
	if (!dst->dev)
		goto out_rel;
	struct net_device *dev;
	struct dst_entry *dst;

	smc_pnet_find_roce_by_pnetid(dst->dev, ini);
	rcu_read_lock();
	dst = __sk_dst_get(sk);
	dev = dst ? dst_dev_rcu(dst) : NULL;
	dev_hold(dev);
	rcu_read_unlock();

out_rel:
	dst_release(dst);
out:
	return;
	if (dev) {
		smc_pnet_find_roce_by_pnetid(dev, ini);
		dev_put(dev);
	}
}

void smc_pnet_find_ism_resource(struct sock *sk, struct smc_init_info *ini)
{
	struct dst_entry *dst = sk_dst_get(sk);
	struct net_device *dev;
	struct dst_entry *dst;

	ini->ism_dev[0] = NULL;
	if (!dst)
		goto out;
	if (!dst->dev)
		goto out_rel;

	smc_pnet_find_ism_by_pnetid(dst->dev, ini);
	rcu_read_lock();
	dst = __sk_dst_get(sk);
	dev = dst ? dst_dev_rcu(dst) : NULL;
	dev_hold(dev);
	rcu_read_unlock();

out_rel:
	dst_release(dst);
out:
	return;
	if (dev) {
		smc_pnet_find_ism_by_pnetid(dev, ini);
		dev_put(dev);
	}
}

/* Lookup and apply a pnet table entry to the given ib device.
+10 −8
Original line number Diff line number Diff line
@@ -123,17 +123,19 @@ static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
/* We assume that the socket is already connected */
static struct net_device *get_netdev_for_sock(struct sock *sk)
{
	struct dst_entry *dst = sk_dst_get(sk);
	struct net_device *netdev = NULL;
	struct net_device *dev, *lowest_dev = NULL;
	struct dst_entry *dst;

	if (likely(dst)) {
		netdev = netdev_sk_get_lowest_dev(dst->dev, sk);
		dev_hold(netdev);
	rcu_read_lock();
	dst = __sk_dst_get(sk);
	dev = dst ? dst_dev_rcu(dst) : NULL;
	if (likely(dev)) {
		lowest_dev = netdev_sk_get_lowest_dev(dev, sk);
		dev_hold(lowest_dev);
	}
	rcu_read_unlock();

	dst_release(dst);

	return netdev;
	return lowest_dev;
}

static void destroy_record(struct tls_record_info *record)