Commit 7462fe22 authored by Geliang Tang's avatar Geliang Tang Committed by Jakub Kicinski
Browse files

mptcp: pm: use addr entry for get_local_id



The following code in mptcp_userspace_pm_get_local_id() that assigns "skc"
to "new_entry" is not allowed in BPF if we use the same code to implement
the get_local_id() interface of a BFP path manager:

	memset(&new_entry, 0, sizeof(struct mptcp_pm_addr_entry));
	new_entry.addr = *skc;
	new_entry.addr.id = 0;
	new_entry.flags = MPTCP_PM_ADDR_FLAG_IMPLICIT;

To solve the issue, this patch moves this assignment to "new_entry" forward
to mptcp_pm_get_local_id(), and then passing "new_entry" as a parameter to
both mptcp_pm_nl_get_local_id() and mptcp_userspace_pm_get_local_id().

No behavioural changes intended.

Signed-off-by: default avatarGeliang Tang <tanggeliang@kylinos.cn>
Reviewed-by: default avatarMatthieu Baerts (NGI0) <matttbe@kernel.org>
Signed-off-by: default avatarMatthieu Baerts (NGI0) <matttbe@kernel.org>
Link: https://patch.msgid.link/20250307-net-next-mptcp-pm-reorg-v1-1-abef20ada03b@kernel.org


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parent 991a1b09
Loading
Loading
Loading
Loading
+6 −3
Original line number Diff line number Diff line
@@ -406,7 +406,7 @@ bool mptcp_pm_rm_addr_signal(struct mptcp_sock *msk, unsigned int remaining,

int mptcp_pm_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
{
	struct mptcp_addr_info skc_local;
	struct mptcp_pm_addr_entry skc_local = { 0 };
	struct mptcp_addr_info msk_local;

	if (WARN_ON_ONCE(!msk))
@@ -416,10 +416,13 @@ int mptcp_pm_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
	 * addr
	 */
	mptcp_local_address((struct sock_common *)msk, &msk_local);
	mptcp_local_address((struct sock_common *)skc, &skc_local);
	if (mptcp_addresses_equal(&msk_local, &skc_local, false))
	mptcp_local_address((struct sock_common *)skc, &skc_local.addr);
	if (mptcp_addresses_equal(&msk_local, &skc_local.addr, false))
		return 0;

	skc_local.addr.id = 0;
	skc_local.flags = MPTCP_PM_ADDR_FLAG_IMPLICIT;

	if (mptcp_pm_is_userspace(msk))
		return mptcp_userspace_pm_get_local_id(msk, &skc_local);
	return mptcp_pm_nl_get_local_id(msk, &skc_local);
+4 −7
Original line number Diff line number Diff line
@@ -1150,7 +1150,8 @@ static int mptcp_pm_nl_create_listen_socket(struct sock *sk,
	return err;
}

int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct mptcp_addr_info *skc)
int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk,
			     struct mptcp_pm_addr_entry *skc)
{
	struct mptcp_pm_addr_entry *entry;
	struct pm_nl_pernet *pernet;
@@ -1159,7 +1160,7 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct mptcp_addr_info *skc
	pernet = pm_nl_get_pernet_from_msk(msk);

	rcu_read_lock();
	entry = __lookup_addr(pernet, skc);
	entry = __lookup_addr(pernet, &skc->addr);
	ret = entry ? entry->addr.id : -1;
	rcu_read_unlock();
	if (ret >= 0)
@@ -1170,12 +1171,8 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct mptcp_addr_info *skc
	if (!entry)
		return -ENOMEM;

	entry->addr = *skc;
	entry->addr.id = 0;
	*entry = *skc;
	entry->addr.port = 0;
	entry->ifindex = 0;
	entry->flags = MPTCP_PM_ADDR_FLAG_IMPLICIT;
	entry->lsk = NULL;
	ret = mptcp_pm_nl_append_new_local_addr(pernet, entry, true, false);
	if (ret < 0)
		kfree(entry);
+6 −11
Original line number Diff line number Diff line
@@ -130,27 +130,22 @@ mptcp_userspace_pm_lookup_addr_by_id(struct mptcp_sock *msk, unsigned int id)
}

int mptcp_userspace_pm_get_local_id(struct mptcp_sock *msk,
				    struct mptcp_addr_info *skc)
				    struct mptcp_pm_addr_entry *skc)
{
	struct mptcp_pm_addr_entry *entry = NULL, new_entry;
	__be16 msk_sport =  ((struct inet_sock *)
			     inet_sk((struct sock *)msk))->inet_sport;
	struct mptcp_pm_addr_entry *entry;

	spin_lock_bh(&msk->pm.lock);
	entry = mptcp_userspace_pm_lookup_addr(msk, skc);
	entry = mptcp_userspace_pm_lookup_addr(msk, &skc->addr);
	spin_unlock_bh(&msk->pm.lock);
	if (entry)
		return entry->addr.id;

	memset(&new_entry, 0, sizeof(struct mptcp_pm_addr_entry));
	new_entry.addr = *skc;
	new_entry.addr.id = 0;
	new_entry.flags = MPTCP_PM_ADDR_FLAG_IMPLICIT;

	if (new_entry.addr.port == msk_sport)
		new_entry.addr.port = 0;
	if (skc->addr.port == msk_sport)
		skc->addr.port = 0;

	return mptcp_userspace_pm_append_new_local_addr(msk, &new_entry, true);
	return mptcp_userspace_pm_append_new_local_addr(msk, skc, true);
}

bool mptcp_userspace_pm_is_backup(struct mptcp_sock *msk,
+4 −2
Original line number Diff line number Diff line
@@ -1121,8 +1121,10 @@ bool mptcp_pm_add_addr_signal(struct mptcp_sock *msk, const struct sk_buff *skb,
bool mptcp_pm_rm_addr_signal(struct mptcp_sock *msk, unsigned int remaining,
			     struct mptcp_rm_list *rm_list);
int mptcp_pm_get_local_id(struct mptcp_sock *msk, struct sock_common *skc);
int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct mptcp_addr_info *skc);
int mptcp_userspace_pm_get_local_id(struct mptcp_sock *msk, struct mptcp_addr_info *skc);
int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk,
			     struct mptcp_pm_addr_entry *skc);
int mptcp_userspace_pm_get_local_id(struct mptcp_sock *msk,
				    struct mptcp_pm_addr_entry *skc);
bool mptcp_pm_is_backup(struct mptcp_sock *msk, struct sock_common *skc);
bool mptcp_pm_nl_is_backup(struct mptcp_sock *msk, struct mptcp_addr_info *skc);
bool mptcp_userspace_pm_is_backup(struct mptcp_sock *msk, struct mptcp_addr_info *skc);