Commit cf51d617 authored by Jakub Kicinski's avatar Jakub Kicinski
Browse files

Merge branch 'tls-misc-bugfixes'

Sabrina Dubroca says:

====================
tls: misc bugfixes

Jann Horn reported multiple bugs in kTLS. This series addresses them,
and adds some corresponding selftests for those that are reproducible
(and without failure injection).
====================

Link: https://patch.msgid.link/cover.1760432043.git.sd@queasysnail.net


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parents 0c3f2e62 3667e9b4
Loading
Loading
Loading
Loading
+2 −5
Original line number Diff line number Diff line
@@ -255,12 +255,9 @@ int tls_process_cmsg(struct sock *sk, struct msghdr *msg,
			if (msg->msg_flags & MSG_MORE)
				return -EINVAL;

			rc = tls_handle_open_record(sk, msg->msg_flags);
			if (rc)
				return rc;

			*record_type = *(unsigned char *)CMSG_DATA(cmsg);
			rc = 0;

			rc = tls_handle_open_record(sk, msg->msg_flags);
			break;
		default:
			return -EINVAL;
+25 −6
Original line number Diff line number Diff line
@@ -1054,7 +1054,7 @@ static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg,
			if (ret == -EINPROGRESS)
				num_async++;
			else if (ret != -EAGAIN)
				goto send_end;
				goto end;
		}
	}

@@ -1112,8 +1112,11 @@ static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg,
				goto send_end;
			tls_ctx->pending_open_record_frags = true;

			if (sk_msg_full(msg_pl))
			if (sk_msg_full(msg_pl)) {
				full_record = true;
				sk_msg_trim(sk, msg_en,
					    msg_pl->sg.size + prot->overhead_size);
			}

			if (full_record || eor)
				goto copied;
@@ -1149,6 +1152,13 @@ static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg,
				} else if (ret != -EAGAIN)
					goto send_end;
			}

			/* Transmit if any encryptions have completed */
			if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
				cancel_delayed_work(&ctx->tx_work.work);
				tls_tx_records(sk, msg->msg_flags);
			}

			continue;
rollback_iter:
			copied -= try_to_copy;
@@ -1204,6 +1214,12 @@ static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg,
					goto send_end;
				}
			}

			/* Transmit if any encryptions have completed */
			if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
				cancel_delayed_work(&ctx->tx_work.work);
				tls_tx_records(sk, msg->msg_flags);
			}
		}

		continue;
@@ -1223,8 +1239,9 @@ static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg,
			goto alloc_encrypted;
	}

send_end:
	if (!num_async) {
		goto send_end;
		goto end;
	} else if (num_zc || eor) {
		int err;

@@ -1242,7 +1259,7 @@ static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg,
		tls_tx_records(sk, msg->msg_flags);
	}

send_end:
end:
	ret = sk_stream_error(sk, msg->msg_flags, ret);
	return copied > 0 ? copied : ret;
}
@@ -1637,8 +1654,10 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,

	if (unlikely(darg->async)) {
		err = tls_strp_msg_hold(&ctx->strp, &ctx->async_hold);
		if (err)
			__skb_queue_tail(&ctx->async_hold, darg->skb);
		if (err) {
			err = tls_decrypt_async_wait(ctx);
			darg->async = false;
		}
		return err;
	}

+65 −0
Original line number Diff line number Diff line
@@ -564,6 +564,40 @@ TEST_F(tls, msg_more)
	EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
}

TEST_F(tls, cmsg_msg_more)
{
	char *test_str =  "test_read";
	char record_type = 100;
	int send_len = 10;

	/* we don't allow MSG_MORE with non-DATA records */
	EXPECT_EQ(tls_send_cmsg(self->fd, record_type, test_str, send_len,
				MSG_MORE), -1);
	EXPECT_EQ(errno, EINVAL);
}

TEST_F(tls, msg_more_then_cmsg)
{
	char *test_str = "test_read";
	char record_type = 100;
	int send_len = 10;
	char buf[10 * 2];
	int ret;

	EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
	EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);

	ret = tls_send_cmsg(self->fd, record_type, test_str, send_len, 0);
	EXPECT_EQ(ret, send_len);

	/* initial DATA record didn't get merged with the non-DATA record */
	EXPECT_EQ(recv(self->cfd, buf, send_len * 2, 0), send_len);

	EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
				buf, sizeof(buf), MSG_WAITALL),
		  send_len);
}

TEST_F(tls, msg_more_unsent)
{
	char const *test_str = "test_read";
@@ -912,6 +946,37 @@ TEST_F(tls, peek_and_splice)
	EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
}

#define MAX_FRAGS 48
TEST_F(tls, splice_short)
{
	struct iovec sendchar_iov;
	char read_buf[0x10000];
	char sendbuf[0x100];
	char sendchar = 'S';
	int pipefds[2];
	int i;

	sendchar_iov.iov_base = &sendchar;
	sendchar_iov.iov_len = 1;

	memset(sendbuf, 's', sizeof(sendbuf));

	ASSERT_GE(pipe2(pipefds, O_NONBLOCK), 0);
	ASSERT_GE(fcntl(pipefds[0], F_SETPIPE_SZ, (MAX_FRAGS + 1) * 0x1000), 0);

	for (i = 0; i < MAX_FRAGS; i++)
		ASSERT_GE(vmsplice(pipefds[1], &sendchar_iov, 1, 0), 0);

	ASSERT_EQ(write(pipefds[1], sendbuf, sizeof(sendbuf)), sizeof(sendbuf));

	EXPECT_EQ(splice(pipefds[0], NULL, self->fd, NULL, MAX_FRAGS + 0x1000, 0),
		  MAX_FRAGS + sizeof(sendbuf));
	EXPECT_EQ(recv(self->cfd, read_buf, sizeof(read_buf), 0), MAX_FRAGS + sizeof(sendbuf));
	EXPECT_EQ(recv(self->cfd, read_buf, sizeof(read_buf), MSG_DONTWAIT), -1);
	EXPECT_EQ(errno, EAGAIN);
}
#undef MAX_FRAGS

TEST_F(tls, recvmsg_single)
{
	char const *test_str = "test_recvmsg_single";