Commit 2aeb71b2 authored by Jakub Kicinski's avatar Jakub Kicinski Committed by Paolo Abeni
Browse files

selftests: drv-net: add PSP responder



PSP tests need the remote system to support PSP, and some PSP capable
application to exchange data with. Create a simple PSP responder app
which we can build and deploy to the remote host. The tests themselves
can be written in Python but for ease of deploying the responder is in C
(using C YNL).

Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
Signed-off-by: default avatarDaniel Zahka <daniel.zahka@gmail.com>
Link: https://patch.msgid.link/20250927225420.1443468-4-kuba@kernel.org


Reviewed-by: default avatarWillem de Bruijn <willemb@google.com>
Signed-off-by: default avatarPaolo Abeni <pabeni@redhat.com>
parent 8a5f956a
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
# SPDX-License-Identifier: GPL-2.0-only
napi_id_helper
psp_responder
+9 −0
Original line number Diff line number Diff line
@@ -27,4 +27,13 @@ TEST_PROGS := \
	xdp.py \
# end of TEST_PROGS

# YNL files, must be before "include ..lib.mk"
YNL_GEN_FILES := psp_responder
TEST_GEN_FILES += $(YNL_GEN_FILES)

include ../../lib.mk

# YNL build
YNL_GENS := psp

include ../../net/ynl.mk
+483 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: GPL-2.0

#include <stdio.h>
#include <string.h>
#include <sys/poll.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <netinet/in.h>
#include <unistd.h>

#include <ynl.h>

#include "psp-user.h"

#define dbg(msg...)				\
do {						\
	if (opts->verbose)			\
		fprintf(stderr, "DEBUG: " msg);	\
} while (0)

static bool should_quit;

struct opts {
	int port;
	int devid;
	bool verbose;
};

enum accept_cfg {
	ACCEPT_CFG_NONE = 0,
	ACCEPT_CFG_CLEAR,
	ACCEPT_CFG_PSP,
};

static struct {
	unsigned char tx;
	unsigned char rx;
} psp_vers;

static int conn_setup_psp(struct ynl_sock *ys, struct opts *opts, int data_sock)
{
	struct psp_rx_assoc_rsp *rsp;
	struct psp_rx_assoc_req *req;
	struct psp_tx_assoc_rsp *tsp;
	struct psp_tx_assoc_req *teq;
	char info[300];
	int key_len;
	ssize_t sz;
	__u32 spi;

	dbg("create PSP connection\n");

	// Rx assoc alloc
	req = psp_rx_assoc_req_alloc();

	psp_rx_assoc_req_set_sock_fd(req, data_sock);
	psp_rx_assoc_req_set_version(req, psp_vers.rx);

	rsp = psp_rx_assoc(ys, req);
	psp_rx_assoc_req_free(req);

	if (!rsp) {
		perror("ERROR: failed to Rx assoc");
		return -1;
	}

	// SPI exchange
	key_len = rsp->rx_key._len.key;
	memcpy(info, &rsp->rx_key.spi, sizeof(spi));
	memcpy(&info[sizeof(spi)], rsp->rx_key.key, key_len);
	sz = sizeof(spi) + key_len;

	send(data_sock, info, sz, MSG_WAITALL);
	psp_rx_assoc_rsp_free(rsp);

	sz = recv(data_sock, info, sz, MSG_WAITALL);
	if (sz < 0) {
		perror("ERROR: failed to read PSP key from sock");
		return -1;
	}
	memcpy(&spi, info, sizeof(spi));

	// Setup Tx assoc
	teq = psp_tx_assoc_req_alloc();

	psp_tx_assoc_req_set_sock_fd(teq, data_sock);
	psp_tx_assoc_req_set_version(teq, psp_vers.tx);
	psp_tx_assoc_req_set_tx_key_spi(teq, spi);
	psp_tx_assoc_req_set_tx_key_key(teq, &info[sizeof(spi)], key_len);

	tsp = psp_tx_assoc(ys, teq);
	psp_tx_assoc_req_free(teq);
	if (!tsp) {
		perror("ERROR: failed to Tx assoc");
		return -1;
	}
	psp_tx_assoc_rsp_free(tsp);

	return 0;
}

static void send_ack(int sock)
{
	send(sock, "ack", 4, MSG_WAITALL);
}

static void send_err(int sock)
{
	send(sock, "err", 4, MSG_WAITALL);
}

static void send_str(int sock, int value)
{
	char buf[128];
	int ret;

	ret = snprintf(buf, sizeof(buf), "%d", value);
	send(sock, buf, ret + 1, MSG_WAITALL);
}

static void
run_session(struct ynl_sock *ys, struct opts *opts,
	    int server_sock, int comm_sock)
{
	enum accept_cfg accept_cfg = ACCEPT_CFG_NONE;
	struct pollfd pfds[3];
	size_t data_read = 0;
	int data_sock = -1;

	while (true) {
		bool race_close = false;
		int nfds;

		memset(pfds, 0, sizeof(pfds));

		pfds[0].fd = server_sock;
		pfds[0].events = POLLIN;

		pfds[1].fd = comm_sock;
		pfds[1].events = POLLIN;

		nfds = 2;
		if (data_sock >= 0) {
			pfds[2].fd = data_sock;
			pfds[2].events = POLLIN;
			nfds++;
		}

		dbg(" ...\n");
		if (poll(pfds, nfds, -1) < 0) {
			perror("poll");
			break;
		}

		/* data sock */
		if (pfds[2].revents & POLLIN) {
			char buf[8192];
			ssize_t n;

			n = recv(data_sock, buf, sizeof(buf), 0);
			if (n <= 0) {
				if (n < 0)
					perror("data read");
				close(data_sock);
				data_sock = -1;
				dbg("data sock closed\n");
			} else {
				data_read += n;
				dbg("data read %zd\n", data_read);
			}
		}

		/* comm sock */
		if (pfds[1].revents & POLLIN) {
			static char buf[4096];
			static ssize_t off;
			bool consumed;
			ssize_t n;

			n = recv(comm_sock, &buf[off], sizeof(buf) - off, 0);
			if (n <= 0) {
				if (n < 0)
					perror("comm read");
				return;
			}

			off += n;
			n = off;

#define __consume(sz)						\
		({						\
			if (n == (sz)) {			\
				off = 0;			\
			} else {				\
				off -= (sz);			\
				memmove(buf, &buf[(sz)], off);	\
			}					\
		})

#define cmd(_name)							\
		({							\
			ssize_t sz = sizeof(_name);			\
			bool match = n >= sz &&	!memcmp(buf, _name, sz); \
									\
			if (match) {					\
				dbg("command: " _name "\n");		\
				__consume(sz);				\
			}						\
			consumed |= match;				\
			match;						\
		})

			do {
				consumed = false;

				if (cmd("read len"))
					send_str(comm_sock, data_read);

				if (cmd("data echo")) {
					if (data_sock >= 0)
						send(data_sock, "echo", 5,
						     MSG_WAITALL);
					else
						fprintf(stderr, "WARN: echo but no data sock\n");
					send_ack(comm_sock);
				}
				if (cmd("data close")) {
					if (data_sock >= 0) {
						close(data_sock);
						data_sock = -1;
						send_ack(comm_sock);
					} else {
						race_close = true;
					}
				}
				if (cmd("conn psp")) {
					if (accept_cfg != ACCEPT_CFG_NONE)
						fprintf(stderr, "WARN: old conn config still set!\n");
					accept_cfg = ACCEPT_CFG_PSP;
					send_ack(comm_sock);
					/* next two bytes are versions */
					if (off >= 2) {
						memcpy(&psp_vers, buf, 2);
						__consume(2);
					} else {
						fprintf(stderr, "WARN: short conn psp command!\n");
					}
				}
				if (cmd("conn clr")) {
					if (accept_cfg != ACCEPT_CFG_NONE)
						fprintf(stderr, "WARN: old conn config still set!\n");
					accept_cfg = ACCEPT_CFG_CLEAR;
					send_ack(comm_sock);
				}
				if (cmd("exit"))
					should_quit = true;
#undef cmd

				if (!consumed) {
					fprintf(stderr, "WARN: unknown cmd: [%zd] %s\n",
						off, buf);
				}
			} while (consumed && off);
		}

		/* server sock */
		if (pfds[0].revents & POLLIN) {
			if (data_sock >= 0) {
				fprintf(stderr, "WARN: new data sock but old one still here\n");
				close(data_sock);
				data_sock = -1;
			}
			data_sock = accept(server_sock, NULL, NULL);
			if (data_sock < 0) {
				perror("accept");
				continue;
			}
			data_read = 0;

			if (accept_cfg == ACCEPT_CFG_CLEAR) {
				dbg("new data sock: clear\n");
				/* nothing to do */
			} else if (accept_cfg == ACCEPT_CFG_PSP) {
				dbg("new data sock: psp\n");
				conn_setup_psp(ys, opts, data_sock);
			} else {
				fprintf(stderr, "WARN: new data sock but no config\n");
			}
			accept_cfg = ACCEPT_CFG_NONE;
		}

		if (race_close) {
			if (data_sock >= 0) {
				/* indeed, ordering problem, handle the close */
				close(data_sock);
				data_sock = -1;
				send_ack(comm_sock);
			} else {
				fprintf(stderr, "WARN: close but no data sock\n");
				send_err(comm_sock);
			}
		}
	}
	dbg("session ending\n");
}

static int spawn_server(struct opts *opts)
{
	struct sockaddr_in6 addr;
	int fd;

	fd = socket(AF_INET6, SOCK_STREAM, 0);
	if (fd < 0) {
		perror("can't open socket");
		return -1;
	}

	memset(&addr, 0, sizeof(addr));

	addr.sin6_family = AF_INET6;
	addr.sin6_addr = in6addr_any;
	addr.sin6_port = htons(opts->port);

	if (bind(fd, (struct sockaddr *)&addr, sizeof(addr))) {
		perror("can't bind socket");
		return -1;
	}

	if (listen(fd, 5)) {
		perror("can't listen");
		return -1;
	}

	return fd;
}

static int run_responder(struct ynl_sock *ys, struct opts *opts)
{
	int server_sock, comm;

	server_sock = spawn_server(opts);
	if (server_sock < 0)
		return 4;

	while (!should_quit) {
		comm = accept(server_sock, NULL, NULL);
		if (comm < 0) {
			perror("accept failed");
		} else {
			run_session(ys, opts, server_sock, comm);
			close(comm);
		}
	}

	return 0;
}

static void usage(const char *name, const char *miss)
{
	if (miss)
		fprintf(stderr, "Missing argument: %s\n", miss);

	fprintf(stderr, "Usage: %s -p port [-v] [-d psp-dev-id]\n", name);
	exit(EXIT_FAILURE);
}

static void parse_cmd_opts(int argc, char **argv, struct opts *opts)
{
	int opt;

	while ((opt = getopt(argc, argv, "vp:d:")) != -1) {
		switch (opt) {
		case 'v':
			opts->verbose = 1;
			break;
		case 'p':
			opts->port = atoi(optarg);
			break;
		case 'd':
			opts->devid = atoi(optarg);
			break;
		default:
			usage(argv[0], NULL);
		}
	}
}

static int psp_dev_set_ena(struct ynl_sock *ys, __u32 dev_id, __u32 versions)
{
	struct psp_dev_set_req *sreq;
	struct psp_dev_set_rsp *srsp;

	fprintf(stderr, "Set PSP enable on device %d to 0x%x\n",
		dev_id, versions);

	sreq = psp_dev_set_req_alloc();

	psp_dev_set_req_set_id(sreq, dev_id);
	psp_dev_set_req_set_psp_versions_ena(sreq, versions);

	srsp = psp_dev_set(ys, sreq);
	psp_dev_set_req_free(sreq);
	if (!srsp)
		return 10;

	psp_dev_set_rsp_free(srsp);
	return 0;
}

int main(int argc, char **argv)
{
	struct psp_dev_get_list *dev_list;
	bool devid_found = false;
	__u32 ver_ena, ver_cap;
	struct opts opts = {};
	struct ynl_error yerr;
	struct ynl_sock *ys;
	int first_id = 0;
	int ret;

	parse_cmd_opts(argc, argv, &opts);
	if (!opts.port)
		usage(argv[0], "port"); // exits

	ys = ynl_sock_create(&ynl_psp_family, &yerr);
	if (!ys) {
		fprintf(stderr, "YNL: %s\n", yerr.msg);
		return 1;
	}

	dev_list = psp_dev_get_dump(ys);
	if (ynl_dump_empty(dev_list)) {
		if (ys->err.code)
			goto err_close;
		fprintf(stderr, "No PSP devices\n");
		goto err_close_silent;
	}

	ynl_dump_foreach(dev_list, d) {
		if (opts.devid) {
			devid_found = true;
			ver_ena = d->psp_versions_ena;
			ver_cap = d->psp_versions_cap;
		} else if (!first_id) {
			first_id = d->id;
			ver_ena = d->psp_versions_ena;
			ver_cap = d->psp_versions_cap;
		} else {
			fprintf(stderr, "Multiple PSP devices found\n");
			goto err_close_silent;
		}
	}
	psp_dev_get_list_free(dev_list);

	if (opts.devid && !devid_found) {
		fprintf(stderr, "PSP device %d requested on cmdline, not found\n",
			opts.devid);
		goto err_close_silent;
	} else if (!opts.devid) {
		opts.devid = first_id;
	}

	if (ver_ena != ver_cap) {
		ret = psp_dev_set_ena(ys, opts.devid, ver_cap);
		if (ret)
			goto err_close;
	}

	ret = run_responder(ys, &opts);

	if (ver_ena != ver_cap && psp_dev_set_ena(ys, opts.devid, ver_ena))
		fprintf(stderr, "WARN: failed to set the PSP versions back\n");

	ynl_sock_destroy(ys);

	return ret;

err_close:
	fprintf(stderr, "YNL: %s\n", ys->err.msg);
err_close_silent:
	ynl_sock_destroy(ys);
	return 2;
}