Commit 0eddb802 authored by Raed Salem's avatar Raed Salem Committed by Paolo Abeni
Browse files

psp: provide decapsulation and receive helper for drivers



Create psp_dev_rcv(), which drivers can call to psp decapsulate and attach
a psp_skb_ext to an skb.

psp_dev_rcv() only supports what the PSP architecture specification
refers to as "transport mode" packets, where the L3 header is either
IPv6 or IPv4.

Reviewed-by: default avatarWillem de Bruijn <willemb@google.com>
Signed-off-by: default avatarRaed Salem <raeds@nvidia.com>
Signed-off-by: default avatarRahul Rameshbabu <rrameshbabu@nvidia.com>
Signed-off-by: default avatarCosmin Ratiu <cratiu@nvidia.com>
Co-developed-by: default avatarDaniel Zahka <daniel.zahka@gmail.com>
Signed-off-by: default avatarDaniel Zahka <daniel.zahka@gmail.com>
Reviewed-by: default avatarEric Dumazet <edumazet@google.com>
Link: https://patch.msgid.link/20250917000954.859376-18-daniel.zahka@gmail.com


Signed-off-by: default avatarPaolo Abeni <pabeni@redhat.com>
parent 2b6e450b
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@ psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops,
void psp_dev_unregister(struct psp_dev *psd);
bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
			 u8 ver, __be16 sport);
int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv);

/* Kernel-facing API */
void psp_assoc_put(struct psp_assoc *pas);
+88 −0
Original line number Diff line number Diff line
@@ -223,6 +223,94 @@ bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
}
EXPORT_SYMBOL(psp_dev_encapsulate);

/* Receive handler for PSP packets.
 *
 * Presently it accepts only already-authenticated packets and does not
 * support optional fields, such as virtualization cookies. The caller should
 * ensure that skb->data is pointing to the mac header, and that skb->mac_len
 * is set.
 */
int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv)
{
	int l2_hlen = 0, l3_hlen, encap;
	struct psp_skb_ext *pse;
	struct psphdr *psph;
	struct ethhdr *eth;
	struct udphdr *uh;
	__be16 proto;
	bool is_udp;

	eth = (struct ethhdr *)skb->data;
	proto = __vlan_get_protocol(skb, eth->h_proto, &l2_hlen);
	if (proto == htons(ETH_P_IP))
		l3_hlen = sizeof(struct iphdr);
	else if (proto == htons(ETH_P_IPV6))
		l3_hlen = sizeof(struct ipv6hdr);
	else
		return -EINVAL;

	if (unlikely(!pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN)))
		return -EINVAL;

	if (proto == htons(ETH_P_IP)) {
		struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);

		is_udp = iph->protocol == IPPROTO_UDP;
		l3_hlen = iph->ihl * 4;
		if (l3_hlen != sizeof(struct iphdr) &&
		    !pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN))
			return -EINVAL;
	} else {
		struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);

		is_udp = ipv6h->nexthdr == IPPROTO_UDP;
	}

	if (unlikely(!is_udp))
		return -EINVAL;

	uh = (struct udphdr *)(skb->data + l2_hlen + l3_hlen);
	if (unlikely(uh->dest != htons(PSP_DEFAULT_UDP_PORT)))
		return -EINVAL;

	pse = skb_ext_add(skb, SKB_EXT_PSP);
	if (!pse)
		return -EINVAL;

	psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen +
				 sizeof(struct udphdr));
	pse->spi = psph->spi;
	pse->dev_id = dev_id;
	pse->generation = generation;
	pse->version = FIELD_GET(PSPHDR_VERFL_VERSION, psph->verfl);

	encap = PSP_ENCAP_HLEN;
	encap += strip_icv ? PSP_TRL_SIZE : 0;

	if (proto == htons(ETH_P_IP)) {
		struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);

		iph->protocol = psph->nexthdr;
		iph->tot_len = htons(ntohs(iph->tot_len) - encap);
		iph->check = 0;
		iph->check = ip_fast_csum((u8 *)iph, iph->ihl);
	} else {
		struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);

		ipv6h->nexthdr = psph->nexthdr;
		ipv6h->payload_len = htons(ntohs(ipv6h->payload_len) - encap);
	}

	memmove(skb->data + PSP_ENCAP_HLEN, skb->data, l2_hlen + l3_hlen);
	skb_pull(skb, PSP_ENCAP_HLEN);

	if (strip_icv)
		pskb_trim(skb, skb->len - PSP_TRL_SIZE);

	return 0;
}
EXPORT_SYMBOL(psp_dev_rcv);

static int __init psp_init(void)
{
	mutex_init(&psp_devs_lock);