Commit 78e563f2 authored by David S. Miller's avatar David S. Miller
Browse files

Merge branch 'tls-fixes'

Jakub Kicinski says:

====================
net: tls: fix some issues with async encryption

valis was reporting a race on socket close so I sat down to try to fix it.
I used Sabrina's async crypto debug patch to test... and in the process
run into some of the same issues, and created very similar fixes :(
I didn't realize how many of those patches weren't applied. Once I found
Sabrina's code [1] it turned out to be so similar in fact that I added
her S-o-b's and Co-develop'eds in a semi-haphazard way.

With this series in place all expected tests pass with async crypto.
Sabrina had a few more fixes, but I'll leave those to her, things are
not crashing anymore.

[1] https://lore.kernel.org/netdev/cover.1694018970.git.sd@queasysnail.net/


====================

Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 2fbdc5c6 ac437a51
Loading
Loading
Loading
Loading
+0 −5
Original line number Diff line number Diff line
@@ -97,9 +97,6 @@ struct tls_sw_context_tx {
	struct tls_rec *open_rec;
	struct list_head tx_list;
	atomic_t encrypt_pending;
	/* protect crypto_wait with encrypt_pending */
	spinlock_t encrypt_compl_lock;
	int async_notify;
	u8 async_capable:1;

#define BIT_TX_SCHEDULED	0
@@ -136,8 +133,6 @@ struct tls_sw_context_rx {
	struct tls_strparser strp;

	atomic_t decrypt_pending;
	/* protect crypto_wait with decrypt_pending*/
	spinlock_t decrypt_compl_lock;
	struct sk_buff_head async_hold;
	struct wait_queue_head wq;
};
+62 −73
Original line number Diff line number Diff line
@@ -63,6 +63,7 @@ struct tls_decrypt_ctx {
	u8 iv[TLS_MAX_IV_SIZE];
	u8 aad[TLS_MAX_AAD_SIZE];
	u8 tail;
	bool free_sgout;
	struct scatterlist sg[];
};

@@ -187,7 +188,6 @@ static void tls_decrypt_done(void *data, int err)
	struct aead_request *aead_req = data;
	struct crypto_aead *aead = crypto_aead_reqtfm(aead_req);
	struct scatterlist *sgout = aead_req->dst;
	struct scatterlist *sgin = aead_req->src;
	struct tls_sw_context_rx *ctx;
	struct tls_decrypt_ctx *dctx;
	struct tls_context *tls_ctx;
@@ -196,6 +196,17 @@ static void tls_decrypt_done(void *data, int err)
	struct sock *sk;
	int aead_size;

	/* If requests get too backlogged crypto API returns -EBUSY and calls
	 * ->complete(-EINPROGRESS) immediately followed by ->complete(0)
	 * to make waiting for backlog to flush with crypto_wait_req() easier.
	 * First wait converts -EBUSY -> -EINPROGRESS, and the second one
	 * -EINPROGRESS -> 0.
	 * We have a single struct crypto_async_request per direction, this
	 * scheme doesn't help us, so just ignore the first ->complete().
	 */
	if (err == -EINPROGRESS)
		return;

	aead_size = sizeof(*aead_req) + crypto_aead_reqsize(aead);
	aead_size = ALIGN(aead_size, __alignof__(*dctx));
	dctx = (void *)((u8 *)aead_req + aead_size);
@@ -213,7 +224,7 @@ static void tls_decrypt_done(void *data, int err)
	}

	/* Free the destination pages if skb was not decrypted inplace */
	if (sgout != sgin) {
	if (dctx->free_sgout) {
		/* Skip the first S/G entry as it points to AAD */
		for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
			if (!sg)
@@ -224,10 +235,17 @@ static void tls_decrypt_done(void *data, int err)

	kfree(aead_req);

	spin_lock_bh(&ctx->decrypt_compl_lock);
	if (!atomic_dec_return(&ctx->decrypt_pending))
	if (atomic_dec_and_test(&ctx->decrypt_pending))
		complete(&ctx->async_wait.completion);
	spin_unlock_bh(&ctx->decrypt_compl_lock);
}

static int tls_decrypt_async_wait(struct tls_sw_context_rx *ctx)
{
	if (!atomic_dec_and_test(&ctx->decrypt_pending))
		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
	atomic_inc(&ctx->decrypt_pending);

	return ctx->async_wait.err;
}

static int tls_do_decryption(struct sock *sk,
@@ -253,6 +271,7 @@ static int tls_do_decryption(struct sock *sk,
		aead_request_set_callback(aead_req,
					  CRYPTO_TFM_REQ_MAY_BACKLOG,
					  tls_decrypt_done, aead_req);
		DEBUG_NET_WARN_ON_ONCE(atomic_read(&ctx->decrypt_pending) < 1);
		atomic_inc(&ctx->decrypt_pending);
	} else {
		aead_request_set_callback(aead_req,
@@ -261,6 +280,10 @@ static int tls_do_decryption(struct sock *sk,
	}

	ret = crypto_aead_decrypt(aead_req);
	if (ret == -EBUSY) {
		ret = tls_decrypt_async_wait(ctx);
		ret = ret ?: -EINPROGRESS;
	}
	if (ret == -EINPROGRESS) {
		if (darg->async)
			return 0;
@@ -439,9 +462,10 @@ static void tls_encrypt_done(void *data, int err)
	struct tls_rec *rec = data;
	struct scatterlist *sge;
	struct sk_msg *msg_en;
	bool ready = false;
	struct sock *sk;
	int pending;

	if (err == -EINPROGRESS) /* see the comment in tls_decrypt_done() */
		return;

	msg_en = &rec->msg_encrypted;

@@ -476,23 +500,25 @@ static void tls_encrypt_done(void *data, int err)
		/* If received record is at head of tx_list, schedule tx */
		first_rec = list_first_entry(&ctx->tx_list,
					     struct tls_rec, list);
		if (rec == first_rec)
			ready = true;
		if (rec == first_rec) {
			/* Schedule the transmission */
			if (!test_and_set_bit(BIT_TX_SCHEDULED,
					      &ctx->tx_bitmask))
				schedule_delayed_work(&ctx->tx_work.work, 1);
		}
	}

	spin_lock_bh(&ctx->encrypt_compl_lock);
	pending = atomic_dec_return(&ctx->encrypt_pending);

	if (!pending && ctx->async_notify)
	if (atomic_dec_and_test(&ctx->encrypt_pending))
		complete(&ctx->async_wait.completion);
	spin_unlock_bh(&ctx->encrypt_compl_lock);
}

	if (!ready)
		return;
static int tls_encrypt_async_wait(struct tls_sw_context_tx *ctx)
{
	if (!atomic_dec_and_test(&ctx->encrypt_pending))
		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
	atomic_inc(&ctx->encrypt_pending);

	/* Schedule the transmission */
	if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
		schedule_delayed_work(&ctx->tx_work.work, 1);
	return ctx->async_wait.err;
}

static int tls_do_encryption(struct sock *sk,
@@ -541,9 +567,14 @@ static int tls_do_encryption(struct sock *sk,

	/* Add the record in tx_list */
	list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
	DEBUG_NET_WARN_ON_ONCE(atomic_read(&ctx->encrypt_pending) < 1);
	atomic_inc(&ctx->encrypt_pending);

	rc = crypto_aead_encrypt(aead_req);
	if (rc == -EBUSY) {
		rc = tls_encrypt_async_wait(ctx);
		rc = rc ?: -EINPROGRESS;
	}
	if (!rc || rc != -EINPROGRESS) {
		atomic_dec(&ctx->encrypt_pending);
		sge->offset -= prot->prepend_size;
@@ -984,7 +1015,6 @@ static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg,
	int num_zc = 0;
	int orig_size;
	int ret = 0;
	int pending;

	if (!eor && (msg->msg_flags & MSG_EOR))
		return -EINVAL;
@@ -1163,24 +1193,12 @@ static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg,
	if (!num_async) {
		goto send_end;
	} else if (num_zc) {
		/* Wait for pending encryptions to get completed */
		spin_lock_bh(&ctx->encrypt_compl_lock);
		ctx->async_notify = true;

		pending = atomic_read(&ctx->encrypt_pending);
		spin_unlock_bh(&ctx->encrypt_compl_lock);
		if (pending)
			crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
		else
			reinit_completion(&ctx->async_wait.completion);

		/* There can be no concurrent accesses, since we have no
		 * pending encrypt operations
		 */
		WRITE_ONCE(ctx->async_notify, false);
		int err;

		if (ctx->async_wait.err) {
			ret = ctx->async_wait.err;
		/* Wait for pending encryptions to get completed */
		err = tls_encrypt_async_wait(ctx);
		if (err) {
			ret = err;
			copied = 0;
		}
	}
@@ -1229,7 +1247,6 @@ void tls_sw_splice_eof(struct socket *sock)
	ssize_t copied = 0;
	bool retrying = false;
	int ret = 0;
	int pending;

	if (!ctx->open_rec)
		return;
@@ -1264,22 +1281,7 @@ void tls_sw_splice_eof(struct socket *sock)
	}

	/* Wait for pending encryptions to get completed */
	spin_lock_bh(&ctx->encrypt_compl_lock);
	ctx->async_notify = true;

	pending = atomic_read(&ctx->encrypt_pending);
	spin_unlock_bh(&ctx->encrypt_compl_lock);
	if (pending)
		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
	else
		reinit_completion(&ctx->async_wait.completion);

	/* There can be no concurrent accesses, since we have no pending
	 * encrypt operations
	 */
	WRITE_ONCE(ctx->async_notify, false);

	if (ctx->async_wait.err)
	if (tls_encrypt_async_wait(ctx))
		goto unlock;

	/* Transmit if any encryptions have completed */
@@ -1581,6 +1583,7 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
	} else if (out_sg) {
		memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
	}
	dctx->free_sgout = !!pages;

	/* Prepare and submit AEAD request */
	err = tls_do_decryption(sk, sgin, sgout, dctx->iv,
@@ -2109,16 +2112,10 @@ int tls_sw_recvmsg(struct sock *sk,

recv_end:
	if (async) {
		int ret, pending;
		int ret;

		/* Wait for all previously submitted records to be decrypted */
		spin_lock_bh(&ctx->decrypt_compl_lock);
		reinit_completion(&ctx->async_wait.completion);
		pending = atomic_read(&ctx->decrypt_pending);
		spin_unlock_bh(&ctx->decrypt_compl_lock);
		ret = 0;
		if (pending)
			ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
		ret = tls_decrypt_async_wait(ctx);
		__skb_queue_purge(&ctx->async_hold);

		if (ret) {
@@ -2135,7 +2132,6 @@ int tls_sw_recvmsg(struct sock *sk,
		else
			err = process_rx_list(ctx, msg, &control, 0,
					      async_copy_bytes, is_peek);
		decrypted += max(err, 0);
	}

	copied += decrypted;
@@ -2435,16 +2431,9 @@ void tls_sw_release_resources_tx(struct sock *sk)
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
	struct tls_rec *rec, *tmp;
	int pending;

	/* Wait for any pending async encryptions to complete */
	spin_lock_bh(&ctx->encrypt_compl_lock);
	ctx->async_notify = true;
	pending = atomic_read(&ctx->encrypt_pending);
	spin_unlock_bh(&ctx->encrypt_compl_lock);

	if (pending)
		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
	tls_encrypt_async_wait(ctx);

	tls_tx_records(sk, -1);

@@ -2607,7 +2596,7 @@ static struct tls_sw_context_tx *init_ctx_tx(struct tls_context *ctx, struct soc
	}

	crypto_init_wait(&sw_ctx_tx->async_wait);
	spin_lock_init(&sw_ctx_tx->encrypt_compl_lock);
	atomic_set(&sw_ctx_tx->encrypt_pending, 1);
	INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
	INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
	sw_ctx_tx->tx_work.sk = sk;
@@ -2628,7 +2617,7 @@ static struct tls_sw_context_rx *init_ctx_rx(struct tls_context *ctx)
	}

	crypto_init_wait(&sw_ctx_rx->async_wait);
	spin_lock_init(&sw_ctx_rx->decrypt_compl_lock);
	atomic_set(&sw_ctx_rx->decrypt_pending, 1);
	init_waitqueue_head(&sw_ctx_rx->wq);
	skb_queue_head_init(&sw_ctx_rx->rx_list);
	skb_queue_head_init(&sw_ctx_rx->async_hold);
+4 −4
Original line number Diff line number Diff line
@@ -1002,12 +1002,12 @@ TEST_F(tls, recv_partial)

	memset(recv_mem, 0, sizeof(recv_mem));
	EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
	EXPECT_NE(recv(self->cfd, recv_mem, strlen(test_str_first),
		       MSG_WAITALL), -1);
	EXPECT_EQ(recv(self->cfd, recv_mem, strlen(test_str_first),
		       MSG_WAITALL), strlen(test_str_first));
	EXPECT_EQ(memcmp(test_str_first, recv_mem, strlen(test_str_first)), 0);
	memset(recv_mem, 0, sizeof(recv_mem));
	EXPECT_NE(recv(self->cfd, recv_mem, strlen(test_str_second),
		       MSG_WAITALL), -1);
	EXPECT_EQ(recv(self->cfd, recv_mem, strlen(test_str_second),
		       MSG_WAITALL), strlen(test_str_second));
	EXPECT_EQ(memcmp(test_str_second, recv_mem, strlen(test_str_second)),
		  0);
}