Commit 45e36a8e authored by Jakub Kicinski's avatar Jakub Kicinski
Browse files

Merge branch 'selftests-net-mixed-select-polling-mode-for-tcp-ao-tests'

Dmitry Safonov via says:

====================
selftests/net: Mixed select()+polling mode for TCP-AO tests

Should fix flaky tcp-ao/connect-deny-ipv6 test.

v1: https://lore.kernel.org/20250312-tcp-ao-selftests-polling-v1-0-72a642b855d5@gmail.com
====================

Link: https://patch.msgid.link/20250319-tcp-ao-selftests-polling-v2-0-da48040153d1@gmail.com


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parents 3e25c1a7 edbac739
Loading
Loading
Loading
Loading
+34 −24
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@
#include "aolib.h"

#define fault(type)	(inj == FAULT_ ## type)
static volatile int sk_pair;

static inline int test_add_key_maclen(int sk, const char *key, uint8_t maclen,
				      union tcp_addr in_addr, uint8_t prefix,
@@ -34,10 +35,10 @@ static void try_accept(const char *tst_name, unsigned int port, const char *pwd,
		       const char *cnt_name, test_cnt cnt_expected,
		       fault_t inj)
{
	struct tcp_ao_counters ao_cnt1, ao_cnt2;
	struct tcp_counters cnt1, cnt2;
	uint64_t before_cnt = 0, after_cnt = 0; /* silence GCC */
	test_cnt poll_cnt = (cnt_expected == TEST_CNT_GOOD) ? 0 : cnt_expected;
	int lsk, err, sk = 0;
	time_t timeout;

	lsk = test_listen_socket(this_ip_addr, port, 1);

@@ -46,21 +47,24 @@ static void try_accept(const char *tst_name, unsigned int port, const char *pwd,

	if (cnt_name)
		before_cnt = netstat_get_one(cnt_name, NULL);
	if (pwd && test_get_tcp_ao_counters(lsk, &ao_cnt1))
		test_error("test_get_tcp_ao_counters()");
	if (pwd && test_get_tcp_counters(lsk, &cnt1))
		test_error("test_get_tcp_counters()");

	synchronize_threads(); /* preparations done */

	timeout = fault(TIMEOUT) ? TEST_RETRANSMIT_SEC : TEST_TIMEOUT_SEC;
	err = test_wait_fd(lsk, timeout, 0);
	err = test_skpair_wait_poll(lsk, 0, poll_cnt, &sk_pair);
	if (err == -ETIMEDOUT) {
		sk_pair = err;
		if (!fault(TIMEOUT))
			test_fail("timed out for accept()");
			test_fail("%s: timed out for accept()", tst_name);
	} else if (err == -EKEYREJECTED) {
		if (!fault(KEYREJECT))
			test_fail("%s: key was rejected", tst_name);
	} else if (err < 0) {
		test_error("test_wait_fd()");
		test_error("test_skpair_wait_poll()");
	} else {
		if (fault(TIMEOUT))
			test_fail("ready to accept");
			test_fail("%s: ready to accept", tst_name);

		sk = accept(lsk, NULL, NULL);
		if (sk < 0) {
@@ -72,13 +76,13 @@ static void try_accept(const char *tst_name, unsigned int port, const char *pwd,
	}

	synchronize_threads(); /* before counter checks */
	if (pwd && test_get_tcp_ao_counters(lsk, &ao_cnt2))
		test_error("test_get_tcp_ao_counters()");
	if (pwd && test_get_tcp_counters(lsk, &cnt2))
		test_error("test_get_tcp_counters()");

	close(lsk);

	if (pwd)
		test_tcp_ao_counters_cmp(tst_name, &ao_cnt1, &ao_cnt2, cnt_expected);
		test_assert_counters(tst_name, &cnt1, &cnt2, cnt_expected);

	if (!cnt_name)
		goto out;
@@ -109,7 +113,7 @@ static void *server_fn(void *arg)

	try_accept("Non-AO server + AO client", port++, NULL,
		   this_ip_dest, -1, 100, 100, 0,
		   "TCPAOKeyNotFound", 0, FAULT_TIMEOUT);
		   "TCPAOKeyNotFound", TEST_CNT_NS_KEY_NOT_FOUND, FAULT_TIMEOUT);

	try_accept("AO server + Non-AO client", port++, DEFAULT_TEST_PASSWORD,
		   this_ip_dest, -1, 100, 100, 0,
@@ -135,8 +139,9 @@ static void *server_fn(void *arg)
		   wrong_addr, -1, 100, 100, 0,
		   "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, FAULT_TIMEOUT);

	/* Key rejected by the other side, failing short through skpair */
	try_accept("Client: Wrong addr", port++, NULL,
		   this_ip_dest, -1, 100, 100, 0, NULL, 0, FAULT_TIMEOUT);
		   this_ip_dest, -1, 100, 100, 0, NULL, 0, FAULT_KEYREJECT);

	try_accept("rcv id != snd id", port++, DEFAULT_TEST_PASSWORD,
		   this_ip_dest, -1, 200, 100, 0,
@@ -163,8 +168,7 @@ static void try_connect(const char *tst_name, unsigned int port,
			uint8_t sndid, uint8_t rcvid,
			test_cnt cnt_expected, fault_t inj)
{
	struct tcp_ao_counters ao_cnt1, ao_cnt2;
	time_t timeout;
	struct tcp_counters cnt1, cnt2;
	int sk, ret;

	sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
@@ -174,16 +178,15 @@ static void try_connect(const char *tst_name, unsigned int port,
	if (pwd && test_add_key(sk, pwd, addr, prefix, sndid, rcvid))
		test_error("setsockopt(TCP_AO_ADD_KEY)");

	if (pwd && test_get_tcp_ao_counters(sk, &ao_cnt1))
		test_error("test_get_tcp_ao_counters()");
	if (pwd && test_get_tcp_counters(sk, &cnt1))
		test_error("test_get_tcp_counters()");

	synchronize_threads(); /* preparations done */

	timeout = fault(TIMEOUT) ? TEST_RETRANSMIT_SEC : TEST_TIMEOUT_SEC;
	ret = _test_connect_socket(sk, this_ip_dest, port, timeout);

	ret = test_skpair_connect_poll(sk, this_ip_dest, port, cnt_expected, &sk_pair);
	synchronize_threads(); /* before counter checks */
	if (ret < 0) {
		sk_pair = ret;
		if (fault(KEYREJECT) && ret == -EKEYREJECTED) {
			test_ok("%s: connect() was prevented", tst_name);
		} else if (ret == -ETIMEDOUT && fault(TIMEOUT)) {
@@ -202,9 +205,11 @@ static void try_connect(const char *tst_name, unsigned int port,
	else
		test_ok("%s: connected", tst_name);
	if (pwd && ret > 0) {
		if (test_get_tcp_ao_counters(sk, &ao_cnt2))
			test_error("test_get_tcp_ao_counters()");
		test_tcp_ao_counters_cmp(tst_name, &ao_cnt1, &ao_cnt2, cnt_expected);
		if (test_get_tcp_counters(sk, &cnt2))
			test_error("test_get_tcp_counters()");
		test_assert_counters(tst_name, &cnt1, &cnt2, cnt_expected);
	} else if (pwd) {
		test_tcp_counters_free(&cnt1);
	}
out:
	synchronize_threads(); /* close() */
@@ -241,6 +246,11 @@ static void *client_fn(void *arg)
	try_connect("Wrong rcv id", port++, DEFAULT_TEST_PASSWORD,
			this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);

	/*
	 * XXX: The test doesn't increase any counters, see tcp_make_synack().
	 * Potentially, it can be speed up by setting sk_pair = -ETIMEDOUT
	 * but the price would be increased complexity of the tracer thread.
	 */
	trace_ao_event_sk_expect(TCP_AO_SYNACK_NO_KEY, this_ip_dest, addr_any,
				 port, 0, 100, 100);
	try_connect("Wrong snd id", port++, DEFAULT_TEST_PASSWORD,
+11 −11
Original line number Diff line number Diff line
@@ -35,7 +35,7 @@ static void *client_fn(void *arg)
	uint64_t before_aogood, after_aogood;
	const size_t nr_packets = 20;
	struct netstat *ns_before, *ns_after;
	struct tcp_ao_counters ao1, ao2;
	struct tcp_counters ao1, ao2;

	if (sk < 0)
		test_error("socket()");
@@ -50,18 +50,18 @@ static void *client_fn(void *arg)

	ns_before = netstat_read();
	before_aogood = netstat_get(ns_before, "TCPAOGood", NULL);
	if (test_get_tcp_ao_counters(sk, &ao1))
		test_error("test_get_tcp_ao_counters()");
	if (test_get_tcp_counters(sk, &ao1))
		test_error("test_get_tcp_counters()");

	if (test_client_verify(sk, 100, nr_packets, TEST_TIMEOUT_SEC)) {
	if (test_client_verify(sk, 100, nr_packets)) {
		test_fail("verify failed");
		return NULL;
	}

	ns_after = netstat_read();
	after_aogood = netstat_get(ns_after, "TCPAOGood", NULL);
	if (test_get_tcp_ao_counters(sk, &ao2))
		test_error("test_get_tcp_ao_counters()");
	if (test_get_tcp_counters(sk, &ao2))
		test_error("test_get_tcp_counters()");
	netstat_print_diff(ns_before, ns_after);
	netstat_free(ns_before);
	netstat_free(ns_after);
@@ -71,14 +71,14 @@ static void *client_fn(void *arg)
				nr_packets, after_aogood, before_aogood);
		return NULL;
	}
	if (test_tcp_ao_counters_cmp("connect", &ao1, &ao2, TEST_CNT_GOOD))
	if (test_assert_counters("connect", &ao1, &ao2, TEST_CNT_GOOD))
		return NULL;

	test_ok("connect TCPAOGood %" PRIu64 "/%" PRIu64 "/%" PRIu64 " => %" PRIu64 "/%" PRIu64 "/%" PRIu64 ", sent %zu",
			before_aogood, ao1.ao_info_pkt_good,
			ao1.key_cnts[0].pkt_good,
			after_aogood, ao2.ao_info_pkt_good,
			ao2.key_cnts[0].pkt_good,
			before_aogood, ao1.ao.ao_info_pkt_good,
			ao1.ao.key_cnts[0].pkt_good,
			after_aogood, ao2.ao.ao_info_pkt_good,
			ao2.ao.key_cnts[0].pkt_good,
			nr_packets);
	return NULL;
}
+8 −9
Original line number Diff line number Diff line
@@ -53,7 +53,7 @@ static void serve_interfered(int sk)
	ssize_t test_quota = packet_size * packets_nr * 10;
	uint64_t dest_unreach_a, dest_unreach_b;
	uint64_t icmp_ignored_a, icmp_ignored_b;
	struct tcp_ao_counters ao_cnt1, ao_cnt2;
	struct tcp_counters cnt1, cnt2;
	bool counter_not_found;
	struct netstat *ns_after, *ns_before;
	ssize_t bytes;
@@ -61,16 +61,16 @@ static void serve_interfered(int sk)
	ns_before = netstat_read();
	dest_unreach_a = netstat_get(ns_before, dst_unreach, NULL);
	icmp_ignored_a = netstat_get(ns_before, tcpao_icmps, NULL);
	if (test_get_tcp_ao_counters(sk, &ao_cnt1))
		test_error("test_get_tcp_ao_counters()");
	if (test_get_tcp_counters(sk, &cnt1))
		test_error("test_get_tcp_counters()");
	bytes = test_server_run(sk, test_quota, 0);
	ns_after = netstat_read();
	netstat_print_diff(ns_before, ns_after);
	dest_unreach_b = netstat_get(ns_after, dst_unreach, NULL);
	icmp_ignored_b = netstat_get(ns_after, tcpao_icmps,
					&counter_not_found);
	if (test_get_tcp_ao_counters(sk, &ao_cnt2))
		test_error("test_get_tcp_ao_counters()");
	if (test_get_tcp_counters(sk, &cnt2))
		test_error("test_get_tcp_counters()");

	netstat_free(ns_before);
	netstat_free(ns_after);
@@ -91,9 +91,9 @@ static void serve_interfered(int sk)
		return;
	}
#ifdef TEST_ICMPS_ACCEPT
	test_tcp_ao_counters_cmp(NULL, &ao_cnt1, &ao_cnt2, TEST_CNT_GOOD);
	test_assert_counters(NULL, &cnt1, &cnt2, TEST_CNT_GOOD);
#else
	test_tcp_ao_counters_cmp(NULL, &ao_cnt1, &ao_cnt2, TEST_CNT_GOOD | TEST_CNT_AO_DROPPED_ICMP);
	test_assert_counters(NULL, &cnt1, &cnt2, TEST_CNT_GOOD | TEST_CNT_AO_DROPPED_ICMP);
#endif
	if (icmp_ignored_a >= icmp_ignored_b) {
		test_icmps_fail("%s counter didn't change: %" PRIu64 " >= %" PRIu64,
@@ -395,7 +395,6 @@ static void icmp_interfere(const size_t nr, uint32_t rcv_nxt, void *src, void *d

static void send_interfered(int sk)
{
	const unsigned int timeout = TEST_TIMEOUT_SEC;
	struct sockaddr_in6 src, dst;
	socklen_t addr_sz;

@@ -409,7 +408,7 @@ static void send_interfered(int sk)
	while (1) {
		uint32_t rcv_nxt;

		if (test_client_verify(sk, packet_size, packets_nr, timeout)) {
		if (test_client_verify(sk, packet_size, packets_nr)) {
			test_fail("client: connection is broken");
			return;
		}
+38 −38
Original line number Diff line number Diff line
@@ -629,11 +629,11 @@ static int key_collection_socket(bool server, unsigned int port)
}

static void verify_counters(const char *tst_name, bool is_listen_sk, bool server,
			    struct tcp_ao_counters *a, struct tcp_ao_counters *b)
			    struct tcp_counters *a, struct tcp_counters *b)
{
	unsigned int i;

	__test_tcp_ao_counters_cmp(tst_name, a, b, TEST_CNT_GOOD);
	test_assert_counters_sk(tst_name, a, b, TEST_CNT_GOOD);

	for (i = 0; i < collection.nr_keys; i++) {
		struct test_key *key = &collection.keys[i];
@@ -652,12 +652,12 @@ static void verify_counters(const char *tst_name, bool is_listen_sk, bool server
			rx_cnt_expected = key->used_on_server_tx;
		}

		test_tcp_ao_key_counters_cmp(tst_name, a, b,
		test_assert_counters_key(tst_name, &a->ao, &b->ao,
					 rx_cnt_expected ? TEST_CNT_KEY_GOOD : 0,
					 sndid, rcvid);
	}
	test_tcp_ao_counters_free(a);
	test_tcp_ao_counters_free(b);
	test_tcp_counters_free(a);
	test_tcp_counters_free(b);
	test_ok("%s: passed counters checks", tst_name);
}

@@ -791,17 +791,17 @@ static void verify_keys(const char *tst_name, int sk,
}

static int start_server(const char *tst_name, unsigned int port, size_t quota,
			struct tcp_ao_counters *begin,
			struct tcp_counters *begin,
			unsigned int current_index, unsigned int rnext_index)
{
	struct tcp_ao_counters lsk_c1, lsk_c2;
	struct tcp_counters lsk_c1, lsk_c2;
	ssize_t bytes;
	int sk, lsk;

	synchronize_threads(); /* 1: key collection initialized */
	lsk = key_collection_socket(true, port);
	if (test_get_tcp_ao_counters(lsk, &lsk_c1))
		test_error("test_get_tcp_ao_counters()");
	if (test_get_tcp_counters(lsk, &lsk_c1))
		test_error("test_get_tcp_counters()");
	synchronize_threads(); /* 2: MKTs added => connect() */
	if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
		test_error("test_wait_fd()");
@@ -809,12 +809,12 @@ static int start_server(const char *tst_name, unsigned int port, size_t quota,
	sk = accept(lsk, NULL, NULL);
	if (sk < 0)
		test_error("accept()");
	if (test_get_tcp_ao_counters(sk, begin))
		test_error("test_get_tcp_ao_counters()");
	if (test_get_tcp_counters(sk, begin))
		test_error("test_get_tcp_counters()");

	synchronize_threads(); /* 3: accepted => send data */
	if (test_get_tcp_ao_counters(lsk, &lsk_c2))
		test_error("test_get_tcp_ao_counters()");
	if (test_get_tcp_counters(lsk, &lsk_c2))
		test_error("test_get_tcp_counters()");
	verify_keys(tst_name, lsk, true, true);
	close(lsk);

@@ -830,12 +830,12 @@ static int start_server(const char *tst_name, unsigned int port, size_t quota,
}

static void end_server(const char *tst_name, int sk,
		       struct tcp_ao_counters *begin)
		       struct tcp_counters *begin)
{
	struct tcp_ao_counters end;
	struct tcp_counters end;

	if (test_get_tcp_ao_counters(sk, &end))
		test_error("test_get_tcp_ao_counters()");
	if (test_get_tcp_counters(sk, &end))
		test_error("test_get_tcp_counters()");
	verify_keys(tst_name, sk, false, true);

	synchronize_threads(); /* 4: verified => closed */
@@ -848,7 +848,7 @@ static void end_server(const char *tst_name, int sk,
static void try_server_run(const char *tst_name, unsigned int port, size_t quota,
			   unsigned int current_index, unsigned int rnext_index)
{
	struct tcp_ao_counters tmp;
	struct tcp_counters tmp;
	int sk;

	sk = start_server(tst_name, port, quota, &tmp,
@@ -860,7 +860,7 @@ static void server_rotations(const char *tst_name, unsigned int port,
			     size_t quota, unsigned int rotations,
			     unsigned int current_index, unsigned int rnext_index)
{
	struct tcp_ao_counters tmp;
	struct tcp_counters tmp;
	unsigned int i;
	int sk;

@@ -886,7 +886,7 @@ static void server_rotations(const char *tst_name, unsigned int port,

static int run_client(const char *tst_name, unsigned int port,
		      unsigned int nr_keys, int current_index, int rnext_index,
		      struct tcp_ao_counters *before,
		      struct tcp_counters *before,
		      const size_t msg_sz, const size_t msg_nr)
{
	int sk;
@@ -904,8 +904,8 @@ static int run_client(const char *tst_name, unsigned int port,
		if (test_set_key(sk, sndid, rcvid))
			test_error("failed to set current/rnext keys");
	}
	if (before && test_get_tcp_ao_counters(sk, before))
		test_error("test_get_tcp_ao_counters()");
	if (before && test_get_tcp_counters(sk, before))
		test_error("test_get_tcp_counters()");

	synchronize_threads(); /* 2: MKTs added => connect() */
	if (test_connect_socket(sk, this_ip_dest, port++) <= 0)
@@ -918,11 +918,11 @@ static int run_client(const char *tst_name, unsigned int port,
	collection.keys[rnext_index].used_on_server_tx = 1;

	synchronize_threads(); /* 3: accepted => send data */
	if (test_client_verify(sk, msg_sz, msg_nr, TEST_TIMEOUT_SEC)) {
	if (test_client_verify(sk, msg_sz, msg_nr)) {
		test_fail("verify failed");
		close(sk);
		if (before)
			test_tcp_ao_counters_free(before);
			test_tcp_counters_free(before);
		return -1;
	}

@@ -931,7 +931,7 @@ static int run_client(const char *tst_name, unsigned int port,

static int start_client(const char *tst_name, unsigned int port,
			unsigned int nr_keys, int current_index, int rnext_index,
			struct tcp_ao_counters *before,
			struct tcp_counters *before,
			const size_t msg_sz, const size_t msg_nr)
{
	if (init_default_key_collection(nr_keys, true))
@@ -943,9 +943,9 @@ static int start_client(const char *tst_name, unsigned int port,

static void end_client(const char *tst_name, int sk, unsigned int nr_keys,
		       int current_index, int rnext_index,
		       struct tcp_ao_counters *start)
		       struct tcp_counters *start)
{
	struct tcp_ao_counters end;
	struct tcp_counters end;

	/* Some application may become dependent on this kernel choice */
	if (current_index < 0)
@@ -955,8 +955,8 @@ static void end_client(const char *tst_name, int sk, unsigned int nr_keys,
	verify_current_rnext(tst_name, sk,
			     collection.keys[current_index].client_keyid,
			     collection.keys[rnext_index].server_keyid);
	if (start && test_get_tcp_ao_counters(sk, &end))
		test_error("test_get_tcp_ao_counters()");
	if (start && test_get_tcp_counters(sk, &end))
		test_error("test_get_tcp_counters()");
	verify_keys(tst_name, sk, false, false);
	synchronize_threads(); /* 4: verify => closed */
	close(sk);
@@ -1016,7 +1016,7 @@ static void try_unmatched_keys(int sk, int *rnext_index, unsigned int port)
	trace_ao_event_expect(TCP_AO_RNEXT_REQUEST, this_ip_addr, this_ip_dest,
			      -1, port, 0, -1, -1, -1, -1, -1,
			      -1, key->server_keyid, -1);
	if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC))
	if (test_client_verify(sk, msg_len, nr_packets))
		test_fail("verify failed");
	*rnext_index = i;
}
@@ -1048,7 +1048,7 @@ static void check_current_back(const char *tst_name, unsigned int port,
			       unsigned int current_index, unsigned int rnext_index,
			       unsigned int rotate_to_index)
{
	struct tcp_ao_counters tmp;
	struct tcp_counters tmp;
	int sk;

	sk = start_client(tst_name, port, nr_keys, current_index, rnext_index,
@@ -1061,7 +1061,7 @@ static void check_current_back(const char *tst_name, unsigned int port,
			      port, -1, 0, -1, -1, -1, -1, -1,
			      collection.keys[rotate_to_index].client_keyid,
			      collection.keys[current_index].client_keyid, -1);
	if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC))
	if (test_client_verify(sk, msg_len, nr_packets))
		test_fail("verify failed");
	/* There is a race here: between setting the current_key with
	 * setsockopt(TCP_AO_INFO) and starting to send some data - there
@@ -1081,7 +1081,7 @@ static void roll_over_keys(const char *tst_name, unsigned int port,
			   unsigned int nr_keys, unsigned int rotations,
			   unsigned int current_index, unsigned int rnext_index)
{
	struct tcp_ao_counters tmp;
	struct tcp_counters tmp;
	unsigned int i;
	int sk;

@@ -1099,10 +1099,10 @@ static void roll_over_keys(const char *tst_name, unsigned int port,
				collection.keys[i].server_keyid, -1);
		if (test_set_key(sk, -1, collection.keys[i].server_keyid))
			test_error("Can't change the Rnext key");
		if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC)) {
		if (test_client_verify(sk, msg_len, nr_packets)) {
			test_fail("verify failed");
			close(sk);
			test_tcp_ao_counters_free(&tmp);
			test_tcp_counters_free(&tmp);
			return;
		}
		verify_current_rnext(tst_name, sk, -1,
@@ -1116,7 +1116,7 @@ static void roll_over_keys(const char *tst_name, unsigned int port,
static void try_client_run(const char *tst_name, unsigned int port,
			   unsigned int nr_keys, int current_index, int rnext_index)
{
	struct tcp_ao_counters tmp;
	struct tcp_counters tmp;
	int sk;

	sk = start_client(tst_name, port, nr_keys, current_index, rnext_index,
+90 −24
Original line number Diff line number Diff line
@@ -289,7 +289,7 @@ extern int link_set_up(const char *intf);
extern const unsigned int test_server_port;
extern int test_wait_fd(int sk, time_t sec, bool write);
extern int __test_connect_socket(int sk, const char *device,
				 void *addr, size_t addr_sz, time_t timeout);
				 void *addr, size_t addr_sz, bool async);
extern int __test_listen_socket(int backlog, void *addr, size_t addr_sz);

static inline int test_listen_socket(const union tcp_addr taddr,
@@ -331,25 +331,26 @@ static inline int test_listen_socket(const union tcp_addr taddr,
 * If set to 0 - kernel will try to retransmit SYN number of times, set in
 * /proc/sys/net/ipv4/tcp_syn_retries
 * By default set to 1 to make tests pass faster on non-busy machine.
 * [in process of removal, don't use in new tests]
 */
#ifndef TEST_RETRANSMIT_SEC
#define TEST_RETRANSMIT_SEC	1
#endif

static inline int _test_connect_socket(int sk, const union tcp_addr taddr,
				       unsigned int port, time_t timeout)
				       unsigned int port, bool async)
{
	sockaddr_af addr;

	tcp_addr_to_sockaddr_in(&addr, &taddr, htons(port));
	return __test_connect_socket(sk, veth_name,
				     (void *)&addr, sizeof(addr), timeout);
				     (void *)&addr, sizeof(addr), async);
}

static inline int test_connect_socket(int sk, const union tcp_addr taddr,
				      unsigned int port)
{
	return _test_connect_socket(sk, taddr, port, TEST_TIMEOUT_SEC);
	return _test_connect_socket(sk, taddr, port, false);
}

extern int __test_set_md5(int sk, void *addr, size_t addr_sz,
@@ -483,10 +484,7 @@ static inline int test_set_ao_flags(int sk, bool ao_required, bool accept_icmps)
}

extern ssize_t test_server_run(int sk, ssize_t quota, time_t timeout_sec);
extern ssize_t test_client_loop(int sk, char *buf, size_t buf_sz,
				const size_t msg_len, time_t timeout_sec);
extern int test_client_verify(int sk, const size_t msg_len, const size_t nr,
			      time_t timeout_sec);
extern int test_client_verify(int sk, const size_t msg_len, const size_t nr);

struct tcp_ao_key_counters {
	uint8_t sndid;
@@ -512,7 +510,15 @@ struct tcp_ao_counters {
	size_t nr_keys;
	struct tcp_ao_key_counters *key_cnts;
};
extern int test_get_tcp_ao_counters(int sk, struct tcp_ao_counters *out);

struct tcp_counters {
	struct tcp_ao_counters ao;
	uint64_t netns_md5_notfound;
	uint64_t netns_md5_unexpected;
	uint64_t netns_md5_failure;
};

extern int test_get_tcp_counters(int sk, struct tcp_counters *out);

#define TEST_CNT_KEY_GOOD		BIT(0)
#define TEST_CNT_KEY_BAD		BIT(1)
@@ -526,8 +532,31 @@ extern int test_get_tcp_ao_counters(int sk, struct tcp_ao_counters *out);
#define TEST_CNT_NS_KEY_NOT_FOUND	BIT(9)
#define TEST_CNT_NS_AO_REQUIRED		BIT(10)
#define TEST_CNT_NS_DROPPED_ICMP	BIT(11)
#define TEST_CNT_NS_MD5_NOT_FOUND	BIT(12)
#define TEST_CNT_NS_MD5_UNEXPECTED	BIT(13)
#define TEST_CNT_NS_MD5_FAILURE		BIT(14)
typedef uint16_t test_cnt;

#define _for_each_counter(f)						\
do {									\
	/* per-netns */							\
	f(ao.netns_ao_good,		TEST_CNT_NS_GOOD);		\
	f(ao.netns_ao_bad,		TEST_CNT_NS_BAD);		\
	f(ao.netns_ao_key_not_found,	TEST_CNT_NS_KEY_NOT_FOUND);	\
	f(ao.netns_ao_required,		TEST_CNT_NS_AO_REQUIRED);	\
	f(ao.netns_ao_dropped_icmp,	TEST_CNT_NS_DROPPED_ICMP);	\
	/* per-socket */						\
	f(ao.ao_info_pkt_good,		TEST_CNT_SOCK_GOOD);		\
	f(ao.ao_info_pkt_bad,		TEST_CNT_SOCK_BAD);		\
	f(ao.ao_info_pkt_key_not_found,	TEST_CNT_SOCK_KEY_NOT_FOUND);	\
	f(ao.ao_info_pkt_ao_required,	TEST_CNT_SOCK_AO_REQUIRED);	\
	f(ao.ao_info_pkt_dropped_icmp,	TEST_CNT_SOCK_DROPPED_ICMP);	\
	/* non-AO */							\
	f(netns_md5_notfound,		TEST_CNT_NS_MD5_NOT_FOUND);	\
	f(netns_md5_unexpected,		TEST_CNT_NS_MD5_UNEXPECTED);	\
	f(netns_md5_failure,		TEST_CNT_NS_MD5_FAILURE);	\
} while (0)

#define TEST_CNT_AO_GOOD		(TEST_CNT_SOCK_GOOD | TEST_CNT_NS_GOOD)
#define TEST_CNT_AO_BAD			(TEST_CNT_SOCK_BAD | TEST_CNT_NS_BAD)
#define TEST_CNT_AO_KEY_NOT_FOUND	(TEST_CNT_SOCK_KEY_NOT_FOUND | \
@@ -539,34 +568,71 @@ typedef uint16_t test_cnt;
#define TEST_CNT_GOOD			(TEST_CNT_KEY_GOOD | TEST_CNT_AO_GOOD)
#define TEST_CNT_BAD			(TEST_CNT_KEY_BAD | TEST_CNT_AO_BAD)

extern int __test_tcp_ao_counters_cmp(const char *tst_name,
		struct tcp_ao_counters *before, struct tcp_ao_counters *after,
extern test_cnt test_cmp_counters(struct tcp_counters *before,
				  struct tcp_counters *after);
extern int test_assert_counters_sk(const char *tst_name,
		struct tcp_counters *before, struct tcp_counters *after,
		test_cnt expected);
extern int test_tcp_ao_key_counters_cmp(const char *tst_name,
extern int test_assert_counters_key(const char *tst_name,
		struct tcp_ao_counters *before, struct tcp_ao_counters *after,
		test_cnt expected, int sndid, int rcvid);
extern void test_tcp_ao_counters_free(struct tcp_ao_counters *cnts);
extern void test_tcp_counters_free(struct tcp_counters *cnts);

/*
 * Polling for netns and socket counters during select()/connect() and also
 * client/server messaging. Instead of constant timeout on underlying select(),
 * check the counters and return early. This allows to pass the tests where
 * timeout is expected without waiting for that fixing timeout (tests speed-up).
 * Previously shorter timeouts were used for tests expecting to time out,
 * but that leaded to sporadic false positives on counter checks failures,
 * as one second timeouts aren't enough for TCP retransmit.
 *
 * Two sides of the socketpair (client/server) should synchronize failures
 * using a shared variable *err, so that they can detect the other side's
 * failure.
 */
extern int test_skpair_wait_poll(int sk, bool write, test_cnt cond,
				 volatile int *err);
extern int _test_skpair_connect_poll(int sk, const char *device,
				     void *addr, size_t addr_sz,
				     test_cnt cond, volatile int *err);
static inline int test_skpair_connect_poll(int sk, const union tcp_addr taddr,
					   unsigned int port,
					   test_cnt cond, volatile int *err)
{
	sockaddr_af addr;

	tcp_addr_to_sockaddr_in(&addr, &taddr, htons(port));
	return _test_skpair_connect_poll(sk, veth_name,
					 (void *)&addr, sizeof(addr), cond, err);
}

extern int test_skpair_client(int sk, const size_t msg_len, const size_t nr,
			      test_cnt cond, volatile int *err);
extern int test_skpair_server(int sk, ssize_t quota,
			      test_cnt cond, volatile int *err);

/*
 * Frees buffers allocated in test_get_tcp_ao_counters().
 * Frees buffers allocated in test_get_tcp_counters().
 * The function doesn't expect new keys or keys removed between calls
 * to test_get_tcp_ao_counters(). Check key counters manually if they
 * to test_get_tcp_counters(). Check key counters manually if they
 * may change.
 */
static inline int test_tcp_ao_counters_cmp(const char *tst_name,
					   struct tcp_ao_counters *before,
					   struct tcp_ao_counters *after,
static inline int test_assert_counters(const char *tst_name,
				       struct tcp_counters *before,
				       struct tcp_counters *after,
				       test_cnt expected)
{
	int ret;

	ret = __test_tcp_ao_counters_cmp(tst_name, before, after, expected);
	ret = test_assert_counters_sk(tst_name, before, after, expected);
	if (ret)
		goto out;
	ret = test_tcp_ao_key_counters_cmp(tst_name, before, after,
	ret = test_assert_counters_key(tst_name, &before->ao, &after->ao,
				       expected, -1, -1);
out:
	test_tcp_ao_counters_free(before);
	test_tcp_ao_counters_free(after);
	test_tcp_counters_free(before);
	test_tcp_counters_free(after);
	return ret;
}

Loading