Commit decde258 authored by Dmitry Safonov's avatar Dmitry Safonov Committed by David S. Miller
Browse files

net/tcp: Add TCP-AO sign to twsk



Add support for sockets in time-wait state.
ao_info as well as all keys are inherited on transition to time-wait
socket. The lifetime of ao_info is now protected by ref counter, so
that tcp_ao_destroy_sock() will destruct it only when the last user is
gone.

Co-developed-by: default avatarFrancesco Ruggeri <fruggeri@arista.com>
Signed-off-by: default avatarFrancesco Ruggeri <fruggeri@arista.com>
Co-developed-by: default avatarSalam Noureddine <noureddine@arista.com>
Signed-off-by: default avatarSalam Noureddine <noureddine@arista.com>
Signed-off-by: default avatarDmitry Safonov <dima@arista.com>
Acked-by: default avatarDavid Ahern <dsahern@kernel.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent ba7783ad
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -514,6 +514,9 @@ struct tcp_timewait_sock {
#ifdef CONFIG_TCP_MD5SIG
	struct tcp_md5sig_key	  *tw_md5_key;
#endif
#ifdef CONFIG_TCP_AO
	struct tcp_ao_info	__rcu *ao_info;
#endif
};

static inline struct tcp_timewait_sock *tcp_twsk(const struct sock *sk)
+9 −2
Original line number Diff line number Diff line
@@ -85,6 +85,7 @@ struct tcp_ao_info {
				__unused	:31;
	__be32			lisn;
	__be32			risn;
	refcount_t		refcnt;		/* Protects twsk destruction */
	struct rcu_head		rcu;
};

@@ -124,7 +125,8 @@ struct tcp_ao_key *tcp_ao_established_key(struct tcp_ao_info *ao,
					  int sndid, int rcvid);
int tcp_ao_calc_traffic_key(struct tcp_ao_key *mkt, u8 *key, void *ctx,
			    unsigned int len, struct tcp_sigpool *hp);
void tcp_ao_destroy_sock(struct sock *sk);
void tcp_ao_destroy_sock(struct sock *sk, bool twsk);
void tcp_ao_time_wait(struct tcp_timewait_sock *tcptw, struct tcp_sock *tp);
struct tcp_ao_key *tcp_ao_do_lookup(const struct sock *sk,
				    const union tcp_ao_addr *addr,
				    int family, int sndid, int rcvid);
@@ -182,7 +184,7 @@ static inline struct tcp_ao_key *tcp_ao_do_lookup(const struct sock *sk,
	return NULL;
}

static inline void tcp_ao_destroy_sock(struct sock *sk)
static inline void tcp_ao_destroy_sock(struct sock *sk, bool twsk)
{
}

@@ -194,6 +196,11 @@ static inline void tcp_ao_finish_connect(struct sock *sk, struct sk_buff *skb)
{
}

static inline void tcp_ao_time_wait(struct tcp_timewait_sock *tcptw,
				    struct tcp_sock *tp)
{
}

static inline void tcp_ao_connect_init(struct sock *sk)
{
}
+41 −8
Original line number Diff line number Diff line
@@ -159,6 +159,7 @@ static struct tcp_ao_info *tcp_ao_alloc_info(gfp_t flags)
	if (!ao)
		return NULL;
	INIT_HLIST_HEAD(&ao->head);
	refcount_set(&ao->refcnt, 1);

	return ao;
}
@@ -176,20 +177,26 @@ static void tcp_ao_key_free_rcu(struct rcu_head *head)
	kfree_sensitive(key);
}

void tcp_ao_destroy_sock(struct sock *sk)
void tcp_ao_destroy_sock(struct sock *sk, bool twsk)
{
	struct tcp_ao_info *ao;
	struct tcp_ao_key *key;
	struct hlist_node *n;

	if (twsk) {
		ao = rcu_dereference_protected(tcp_twsk(sk)->ao_info, 1);
		tcp_twsk(sk)->ao_info = NULL;
	} else {
		ao = rcu_dereference_protected(tcp_sk(sk)->ao_info, 1);
		tcp_sk(sk)->ao_info = NULL;
	}

	if (!ao)
	if (!ao || !refcount_dec_and_test(&ao->refcnt))
		return;

	hlist_for_each_entry_safe(key, n, &ao->head, node) {
		hlist_del_rcu(&key->node);
		if (!twsk)
			atomic_sub(tcp_ao_sizeof_key(key), &sk->sk_omem_alloc);
		call_rcu(&key->rcu, tcp_ao_key_free_rcu);
	}
@@ -197,6 +204,27 @@ void tcp_ao_destroy_sock(struct sock *sk)
	kfree_rcu(ao, rcu);
}

void tcp_ao_time_wait(struct tcp_timewait_sock *tcptw, struct tcp_sock *tp)
{
	struct tcp_ao_info *ao_info = rcu_dereference_protected(tp->ao_info, 1);

	if (ao_info) {
		struct tcp_ao_key *key;
		struct hlist_node *n;
		int omem = 0;

		hlist_for_each_entry_safe(key, n, &ao_info->head, node) {
			omem += tcp_ao_sizeof_key(key);
		}

		refcount_inc(&ao_info->refcnt);
		atomic_sub(omem, &(((struct sock *)tp)->sk_omem_alloc));
		rcu_assign_pointer(tcptw->ao_info, ao_info);
	} else {
		tcptw->ao_info = NULL;
	}
}

/* 4 tuple and ISNs are expected in NBO */
static int tcp_v4_ao_calc_key(struct tcp_ao_key *mkt, u8 *key,
			      __be32 saddr, __be32 daddr,
@@ -514,10 +542,12 @@ int tcp_ao_prepare_reset(const struct sock *sk, struct sk_buff *skb,
	if (!sk)
		return -ENOTCONN;

	if ((1 << sk->sk_state) &
	    (TCPF_LISTEN | TCPF_NEW_SYN_RECV | TCPF_TIME_WAIT))
	if ((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_NEW_SYN_RECV)) {
		return -1;

	if (sk->sk_state == TCP_TIME_WAIT)
		ao_info = rcu_dereference(tcp_twsk(sk)->ao_info);
	else
		ao_info = rcu_dereference(tcp_sk(sk)->ao_info);
	if (!ao_info)
		return -ENOENT;
@@ -910,6 +940,9 @@ static struct tcp_ao_info *setsockopt_ao_info(struct sock *sk)
	if (sk_fullsock(sk)) {
		return rcu_dereference_protected(tcp_sk(sk)->ao_info,
						 lockdep_sock_is_held(sk));
	} else if (sk->sk_state == TCP_TIME_WAIT) {
		return rcu_dereference_protected(tcp_twsk(sk)->ao_info,
						 lockdep_sock_is_held(sk));
	}
	return ERR_PTR(-ESOCKTNOSUPPORT);
}
+73 −19
Original line number Diff line number Diff line
@@ -911,17 +911,13 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb)
static void tcp_v4_send_ack(const struct sock *sk,
			    struct sk_buff *skb, u32 seq, u32 ack,
			    u32 win, u32 tsval, u32 tsecr, int oif,
			    struct tcp_md5sig_key *key,
			    struct tcp_key *key,
			    int reply_flags, u8 tos, u32 txhash)
{
	const struct tcphdr *th = tcp_hdr(skb);
	struct {
		struct tcphdr th;
		__be32 opt[(TCPOLEN_TSTAMP_ALIGNED >> 2)
#ifdef CONFIG_TCP_MD5SIG
			   + (TCPOLEN_MD5SIG_ALIGNED >> 2)
#endif
			];
		__be32 opt[(MAX_TCP_OPTION_SPACE  >> 2)];
	} rep;
	struct net *net = sock_net(sk);
	struct ip_reply_arg arg;
@@ -952,7 +948,7 @@ static void tcp_v4_send_ack(const struct sock *sk,
	rep.th.window  = htons(win);

#ifdef CONFIG_TCP_MD5SIG
	if (key) {
	if (tcp_key_is_md5(key)) {
		int offset = (tsecr) ? 3 : 0;

		rep.opt[offset++] = htonl((TCPOPT_NOP << 24) |
@@ -963,9 +959,27 @@ static void tcp_v4_send_ack(const struct sock *sk,
		rep.th.doff = arg.iov[0].iov_len/4;

		tcp_v4_md5_hash_hdr((__u8 *) &rep.opt[offset],
				    key, ip_hdr(skb)->saddr,
				    key->md5_key, ip_hdr(skb)->saddr,
				    ip_hdr(skb)->daddr, &rep.th);
	}
#endif
#ifdef CONFIG_TCP_AO
	if (tcp_key_is_ao(key)) {
		int offset = (tsecr) ? 3 : 0;

		rep.opt[offset++] = htonl((TCPOPT_AO << 24) |
					  (tcp_ao_len(key->ao_key) << 16) |
					  (key->ao_key->sndid << 8) |
					  key->rcv_next);
		arg.iov[0].iov_len += round_up(tcp_ao_len(key->ao_key), 4);
		rep.th.doff = arg.iov[0].iov_len / 4;

		tcp_ao_hash_hdr(AF_INET, (char *)&rep.opt[offset],
				key->ao_key, key->traffic_key,
				(union tcp_ao_addr *)&ip_hdr(skb)->saddr,
				(union tcp_ao_addr *)&ip_hdr(skb)->daddr,
				&rep.th, key->sne);
	}
#endif
	arg.flags = reply_flags;
	arg.csum = csum_tcpudp_nofold(ip_hdr(skb)->daddr,
@@ -999,18 +1013,50 @@ static void tcp_v4_timewait_ack(struct sock *sk, struct sk_buff *skb)
{
	struct inet_timewait_sock *tw = inet_twsk(sk);
	struct tcp_timewait_sock *tcptw = tcp_twsk(sk);
	struct tcp_key key = {};
#ifdef CONFIG_TCP_AO
	struct tcp_ao_info *ao_info;

	/* FIXME: the segment to-be-acked is not verified yet */
	ao_info = rcu_dereference(tcptw->ao_info);
	if (ao_info) {
		const struct tcp_ao_hdr *aoh;

		if (tcp_parse_auth_options(tcp_hdr(skb), NULL, &aoh)) {
			inet_twsk_put(tw);
			return;
		}

		if (aoh)
			key.ao_key = tcp_ao_established_key(ao_info, aoh->rnext_keyid, -1);
	}
	if (key.ao_key) {
		struct tcp_ao_key *rnext_key;

		key.traffic_key = snd_other_key(key.ao_key);
		rnext_key = READ_ONCE(ao_info->rnext_key);
		key.rcv_next = rnext_key->rcvid;
		key.type = TCP_KEY_AO;
#else
	if (0) {
#endif
#ifdef CONFIG_TCP_MD5SIG
	} else if (static_branch_unlikely(&tcp_md5_needed.key)) {
		key.md5_key = tcp_twsk_md5_key(tcptw);
		if (key.md5_key)
			key.type = TCP_KEY_MD5;
#endif
	}

	tcp_v4_send_ack(sk, skb,
			tcptw->tw_snd_nxt, tcptw->tw_rcv_nxt,
			tcptw->tw_rcv_wnd >> tw->tw_rcv_wscale,
			tcp_tw_tsval(tcptw),
			tcptw->tw_ts_recent,
			tw->tw_bound_dev_if,
			tcp_twsk_md5_key(tcptw),
			tw->tw_bound_dev_if, &key,
			tw->tw_transparent ? IP_REPLY_ARG_NOSRCCHECK : 0,
			tw->tw_tos,
			tw->tw_txhash
			);
			tw->tw_txhash);

	inet_twsk_put(tw);
}
@@ -1018,8 +1064,7 @@ static void tcp_v4_timewait_ack(struct sock *sk, struct sk_buff *skb)
static void tcp_v4_reqsk_send_ack(const struct sock *sk, struct sk_buff *skb,
				  struct request_sock *req)
{
	const union tcp_md5_addr *addr;
	int l3index;
	struct tcp_key key = {};

	/* sk->sk_state == TCP_LISTEN -> for regular TCP_SYN_RECV
	 * sk->sk_state == TCP_SYN_RECV -> for Fast Open.
@@ -1032,15 +1077,24 @@ static void tcp_v4_reqsk_send_ack(const struct sock *sk, struct sk_buff *skb,
	 * exception of <SYN> segments, MUST be right-shifted by
	 * Rcv.Wind.Shift bits:
	 */
#ifdef CONFIG_TCP_MD5SIG
	if (static_branch_unlikely(&tcp_md5_needed.key)) {
		const union tcp_md5_addr *addr;
		int l3index;

		addr = (union tcp_md5_addr *)&ip_hdr(skb)->saddr;
		l3index = tcp_v4_sdif(skb) ? inet_iif(skb) : 0;
		key.md5_key = tcp_md5_do_lookup(sk, l3index, addr, AF_INET);
		if (key.md5_key)
			key.type = TCP_KEY_MD5;
	}
#endif
	tcp_v4_send_ack(sk, skb, seq,
			tcp_rsk(req)->rcv_nxt,
			req->rsk_rcv_wnd >> inet_rsk(req)->rcv_wscale,
			tcp_rsk_tsval(tcp_rsk(req)),
			READ_ONCE(req->ts_recent),
			0,
			tcp_md5_do_lookup(sk, l3index, addr, AF_INET),
			0, &key,
			inet_rsk(req)->no_srccheck ? IP_REPLY_ARG_NOSRCCHECK : 0,
			ip_hdr(skb)->tos,
			READ_ONCE(tcp_rsk(req)->txhash));
@@ -2404,7 +2458,7 @@ void tcp_v4_destroy_sock(struct sock *sk)
		rcu_assign_pointer(tp->md5sig_info, NULL);
	}
#endif
	tcp_ao_destroy_sock(sk);
	tcp_ao_destroy_sock(sk, false);

	/* Clean up a referenced TCP bind bucket. */
	if (inet_csk(sk)->icsk_bind_hash)
+3 −1
Original line number Diff line number Diff line
@@ -279,7 +279,7 @@ static void tcp_time_wait_init(struct sock *sk, struct tcp_timewait_sock *tcptw)
void tcp_time_wait(struct sock *sk, int state, int timeo)
{
	const struct inet_connection_sock *icsk = inet_csk(sk);
	const struct tcp_sock *tp = tcp_sk(sk);
	struct tcp_sock *tp = tcp_sk(sk);
	struct net *net = sock_net(sk);
	struct inet_timewait_sock *tw;

@@ -316,6 +316,7 @@ void tcp_time_wait(struct sock *sk, int state, int timeo)
#endif

		tcp_time_wait_init(sk, tcptw);
		tcp_ao_time_wait(tcptw, tp);

		/* Get the TIME_WAIT timeout firing. */
		if (timeo < rto)
@@ -370,6 +371,7 @@ void tcp_twsk_destructor(struct sock *sk)
			call_rcu(&twsk->tw_md5_key->rcu, tcp_md5_twsk_free_rcu);
	}
#endif
	tcp_ao_destroy_sock(sk, true);
}
EXPORT_SYMBOL_GPL(tcp_twsk_destructor);

Loading