Commit bc6d6a41 authored by Eric Biggers's avatar Eric Biggers
Browse files

lib/crypto: x86/sha256: Add support for 2-way interleaved hashing



Add an implementation of sha256_finup_2x_arch() for x86_64.  It
interleaves the computation of two SHA-256 hashes using the x86 SHA-NI
instructions.  dm-verity and fs-verity will take advantage of this for
greatly improved performance on capable CPUs.

This increases the throughput of SHA-256 hashing 4096-byte messages by
the following amounts on the following CPUs:

    Intel Ice Lake (server):        4%
    Intel Sapphire Rapids:          38%
    Intel Emerald Rapids:           38%
    AMD Zen 1 (Threadripper 1950X): 84%
    AMD Zen 4 (EPYC 9B14):          98%
    AMD Zen 5 (Ryzen 9 9950X):      64%

For now, this seems to benefit AMD more than Intel.  This seems to be
because current AMD CPUs support concurrent execution of the SHA-NI
instructions, but unfortunately current Intel CPUs don't, except for the
sha256msg2 instruction.  Hopefully future Intel CPUs will support SHA-NI
on more execution ports.  Zen 1 supports 2 concurrent sha256rnds2, and
Zen 4 supports 4 concurrent sha256rnds2, which suggests that even better
performance may be achievable on Zen 4 by interleaving more than two
hashes.  However, doing so poses a number of trade-offs, and furthermore
Zen 5 goes back to supporting "only" 2 concurrent sha256rnds2.

Reviewed-by: default avatarArd Biesheuvel <ardb@kernel.org>
Link: https://lore.kernel.org/r/20250915160819.140019-4-ebiggers@kernel.org


Signed-off-by: default avatarEric Biggers <ebiggers@kernel.org>
parent 34c3f1e3
Loading
Loading
Loading
Loading
+368 −0
Original line number Diff line number Diff line
@@ -165,6 +165,374 @@ SYM_FUNC_START(sha256_ni_transform)
	RET
SYM_FUNC_END(sha256_ni_transform)

#undef DIGEST_PTR
#undef DATA_PTR
#undef NUM_BLKS
#undef SHA256CONSTANTS
#undef MSG
#undef STATE0
#undef STATE1
#undef MSG0
#undef MSG1
#undef MSG2
#undef MSG3
#undef TMP
#undef SHUF_MASK
#undef ABEF_SAVE
#undef CDGH_SAVE

// parameters for sha256_ni_finup2x()
#define CTX		%rdi
#define DATA1		%rsi
#define DATA2		%rdx
#define LEN		%ecx
#define LEN8		%cl
#define LEN64		%rcx
#define OUT1		%r8
#define OUT2		%r9

// other scalar variables
#define SHA256CONSTANTS	%rax
#define COUNT		%r10
#define COUNT32		%r10d
#define FINAL_STEP	%r11d

// rbx is used as a temporary.

#define MSG		%xmm0	// sha256rnds2 implicit operand
#define STATE0_A	%xmm1
#define STATE1_A	%xmm2
#define STATE0_B	%xmm3
#define STATE1_B	%xmm4
#define TMP_A		%xmm5
#define TMP_B		%xmm6
#define MSG0_A		%xmm7
#define MSG1_A		%xmm8
#define MSG2_A		%xmm9
#define MSG3_A		%xmm10
#define MSG0_B		%xmm11
#define MSG1_B		%xmm12
#define MSG2_B		%xmm13
#define MSG3_B		%xmm14
#define SHUF_MASK	%xmm15

#define OFFSETOF_STATE		0  // offsetof(struct __sha256_ctx, state)
#define OFFSETOF_BYTECOUNT	32 // offsetof(struct __sha256_ctx, bytecount)
#define OFFSETOF_BUF		40 // offsetof(struct __sha256_ctx, buf)

// Do 4 rounds of SHA-256 for each of two messages (interleaved).  m0_a and m0_b
// contain the current 4 message schedule words for the first and second message
// respectively.
//
// If not all the message schedule words have been computed yet, then this also
// computes 4 more message schedule words for each message.  m1_a-m3_a contain
// the next 3 groups of 4 message schedule words for the first message, and
// likewise m1_b-m3_b for the second.  After consuming the current value of
// m0_a, this macro computes the group after m3_a and writes it to m0_a, and
// likewise for *_b.  This means that the next (m0_a, m1_a, m2_a, m3_a) is the
// current (m1_a, m2_a, m3_a, m0_a), and likewise for *_b, so the caller must
// cycle through the registers accordingly.
.macro	do_4rounds_2x	i, m0_a, m1_a, m2_a, m3_a,  m0_b, m1_b, m2_b, m3_b
	movdqa		(\i-32)*4(SHA256CONSTANTS), TMP_A
	movdqa		TMP_A, TMP_B
	paddd		\m0_a, TMP_A
	paddd		\m0_b, TMP_B
.if \i < 48
	sha256msg1	\m1_a, \m0_a
	sha256msg1	\m1_b, \m0_b
.endif
	movdqa		TMP_A, MSG
	sha256rnds2	STATE0_A, STATE1_A
	movdqa		TMP_B, MSG
	sha256rnds2	STATE0_B, STATE1_B
	pshufd 		$0x0E, TMP_A, MSG
	sha256rnds2	STATE1_A, STATE0_A
	pshufd 		$0x0E, TMP_B, MSG
	sha256rnds2	STATE1_B, STATE0_B
.if \i < 48
	movdqa		\m3_a, TMP_A
	movdqa		\m3_b, TMP_B
	palignr		$4, \m2_a, TMP_A
	palignr		$4, \m2_b, TMP_B
	paddd		TMP_A, \m0_a
	paddd		TMP_B, \m0_b
	sha256msg2	\m3_a, \m0_a
	sha256msg2	\m3_b, \m0_b
.endif
.endm

//
// void sha256_ni_finup2x(const struct __sha256_ctx *ctx,
//			  const u8 *data1, const u8 *data2, int len,
//			  u8 out1[SHA256_DIGEST_SIZE],
//			  u8 out2[SHA256_DIGEST_SIZE]);
//
// This function computes the SHA-256 digests of two messages |data1| and
// |data2| that are both |len| bytes long, starting from the initial context
// |ctx|.  |len| must be at least SHA256_BLOCK_SIZE.
//
// The instructions for the two SHA-256 operations are interleaved.  On many
// CPUs, this is almost twice as fast as hashing each message individually due
// to taking better advantage of the CPU's SHA-256 and SIMD throughput.
//
SYM_FUNC_START(sha256_ni_finup2x)
	// Allocate 128 bytes of stack space, 16-byte aligned.
	push		%rbx
	push		%rbp
	mov		%rsp, %rbp
	sub		$128, %rsp
	and		$~15, %rsp

	// Load the shuffle mask for swapping the endianness of 32-bit words.
	movdqa		PSHUFFLE_BYTE_FLIP_MASK(%rip), SHUF_MASK

	// Set up pointer to the round constants.
	lea		K256+32*4(%rip), SHA256CONSTANTS

	// Initially we're not processing the final blocks.
	xor		FINAL_STEP, FINAL_STEP

	// Load the initial state from ctx->state.
	movdqu		OFFSETOF_STATE+0*16(CTX), STATE0_A	// DCBA
	movdqu		OFFSETOF_STATE+1*16(CTX), STATE1_A	// HGFE
	movdqa		STATE0_A, TMP_A
	punpcklqdq	STATE1_A, STATE0_A			// FEBA
	punpckhqdq	TMP_A, STATE1_A				// DCHG
	pshufd		$0x1B, STATE0_A, STATE0_A		// ABEF
	pshufd		$0xB1, STATE1_A, STATE1_A		// CDGH

	// Load ctx->bytecount.  Take the mod 64 of it to get the number of
	// bytes that are buffered in ctx->buf.  Also save it in a register with
	// LEN added to it.
	mov		LEN, LEN
	mov		OFFSETOF_BYTECOUNT(CTX), %rbx
	lea		(%rbx, LEN64, 1), COUNT
	and		$63, %ebx
	jz		.Lfinup2x_enter_loop	// No bytes buffered?

	// %ebx bytes (1 to 63) are currently buffered in ctx->buf.  Load them
	// followed by the first 64 - %ebx bytes of data.  Since LEN >= 64, we
	// just load 64 bytes from each of ctx->buf, DATA1, and DATA2
	// unconditionally and rearrange the data as needed.

	movdqu		OFFSETOF_BUF+0*16(CTX), MSG0_A
	movdqu		OFFSETOF_BUF+1*16(CTX), MSG1_A
	movdqu		OFFSETOF_BUF+2*16(CTX), MSG2_A
	movdqu		OFFSETOF_BUF+3*16(CTX), MSG3_A
	movdqa		MSG0_A, 0*16(%rsp)
	movdqa		MSG1_A, 1*16(%rsp)
	movdqa		MSG2_A, 2*16(%rsp)
	movdqa		MSG3_A, 3*16(%rsp)

	movdqu		0*16(DATA1), MSG0_A
	movdqu		1*16(DATA1), MSG1_A
	movdqu		2*16(DATA1), MSG2_A
	movdqu		3*16(DATA1), MSG3_A
	movdqu		MSG0_A, 0*16(%rsp,%rbx)
	movdqu		MSG1_A, 1*16(%rsp,%rbx)
	movdqu		MSG2_A, 2*16(%rsp,%rbx)
	movdqu		MSG3_A, 3*16(%rsp,%rbx)
	movdqa		0*16(%rsp), MSG0_A
	movdqa		1*16(%rsp), MSG1_A
	movdqa		2*16(%rsp), MSG2_A
	movdqa		3*16(%rsp), MSG3_A

	movdqu		0*16(DATA2), MSG0_B
	movdqu		1*16(DATA2), MSG1_B
	movdqu		2*16(DATA2), MSG2_B
	movdqu		3*16(DATA2), MSG3_B
	movdqu		MSG0_B, 0*16(%rsp,%rbx)
	movdqu		MSG1_B, 1*16(%rsp,%rbx)
	movdqu		MSG2_B, 2*16(%rsp,%rbx)
	movdqu		MSG3_B, 3*16(%rsp,%rbx)
	movdqa		0*16(%rsp), MSG0_B
	movdqa		1*16(%rsp), MSG1_B
	movdqa		2*16(%rsp), MSG2_B
	movdqa		3*16(%rsp), MSG3_B

	sub		$64, %rbx 	// rbx = buffered - 64
	sub		%rbx, DATA1	// DATA1 += 64 - buffered
	sub		%rbx, DATA2	// DATA2 += 64 - buffered
	add		%ebx, LEN	// LEN += buffered - 64
	movdqa		STATE0_A, STATE0_B
	movdqa		STATE1_A, STATE1_B
	jmp		.Lfinup2x_loop_have_data

.Lfinup2x_enter_loop:
	sub		$64, LEN
	movdqa		STATE0_A, STATE0_B
	movdqa		STATE1_A, STATE1_B
.Lfinup2x_loop:
	// Load the next two data blocks.
	movdqu		0*16(DATA1), MSG0_A
	movdqu		0*16(DATA2), MSG0_B
	movdqu		1*16(DATA1), MSG1_A
	movdqu		1*16(DATA2), MSG1_B
	movdqu		2*16(DATA1), MSG2_A
	movdqu		2*16(DATA2), MSG2_B
	movdqu		3*16(DATA1), MSG3_A
	movdqu		3*16(DATA2), MSG3_B
	add		$64, DATA1
	add		$64, DATA2
.Lfinup2x_loop_have_data:
	// Convert the words of the data blocks from big endian.
	pshufb		SHUF_MASK, MSG0_A
	pshufb		SHUF_MASK, MSG0_B
	pshufb		SHUF_MASK, MSG1_A
	pshufb		SHUF_MASK, MSG1_B
	pshufb		SHUF_MASK, MSG2_A
	pshufb		SHUF_MASK, MSG2_B
	pshufb		SHUF_MASK, MSG3_A
	pshufb		SHUF_MASK, MSG3_B
.Lfinup2x_loop_have_bswapped_data:

	// Save the original state for each block.
	movdqa		STATE0_A, 0*16(%rsp)
	movdqa		STATE0_B, 1*16(%rsp)
	movdqa		STATE1_A, 2*16(%rsp)
	movdqa		STATE1_B, 3*16(%rsp)

	// Do the SHA-256 rounds on each block.
.irp i, 0, 16, 32, 48
	do_4rounds_2x	(\i + 0),  MSG0_A, MSG1_A, MSG2_A, MSG3_A, \
				   MSG0_B, MSG1_B, MSG2_B, MSG3_B
	do_4rounds_2x	(\i + 4),  MSG1_A, MSG2_A, MSG3_A, MSG0_A, \
				   MSG1_B, MSG2_B, MSG3_B, MSG0_B
	do_4rounds_2x	(\i + 8),  MSG2_A, MSG3_A, MSG0_A, MSG1_A, \
				   MSG2_B, MSG3_B, MSG0_B, MSG1_B
	do_4rounds_2x	(\i + 12), MSG3_A, MSG0_A, MSG1_A, MSG2_A, \
				   MSG3_B, MSG0_B, MSG1_B, MSG2_B
.endr

	// Add the original state for each block.
	paddd		0*16(%rsp), STATE0_A
	paddd		1*16(%rsp), STATE0_B
	paddd		2*16(%rsp), STATE1_A
	paddd		3*16(%rsp), STATE1_B

	// Update LEN and loop back if more blocks remain.
	sub		$64, LEN
	jge		.Lfinup2x_loop

	// Check if any final blocks need to be handled.
	// FINAL_STEP = 2: all done
	// FINAL_STEP = 1: need to do count-only padding block
	// FINAL_STEP = 0: need to do the block with 0x80 padding byte
	cmp		$1, FINAL_STEP
	jg		.Lfinup2x_done
	je		.Lfinup2x_finalize_countonly
	add		$64, LEN
	jz		.Lfinup2x_finalize_blockaligned

	// Not block-aligned; 1 <= LEN <= 63 data bytes remain.  Pad the block.
	// To do this, write the padding starting with the 0x80 byte to
	// &sp[64].  Then for each message, copy the last 64 data bytes to sp
	// and load from &sp[64 - LEN] to get the needed padding block.  This
	// code relies on the data buffers being >= 64 bytes in length.
	mov		$64, %ebx
	sub		LEN, %ebx		// ebx = 64 - LEN
	sub		%rbx, DATA1		// DATA1 -= 64 - LEN
	sub		%rbx, DATA2		// DATA2 -= 64 - LEN
	mov		$0x80, FINAL_STEP   // using FINAL_STEP as a temporary
	movd		FINAL_STEP, MSG0_A
	pxor		MSG1_A, MSG1_A
	movdqa		MSG0_A, 4*16(%rsp)
	movdqa		MSG1_A, 5*16(%rsp)
	movdqa		MSG1_A, 6*16(%rsp)
	movdqa		MSG1_A, 7*16(%rsp)
	cmp		$56, LEN
	jge		1f	// will COUNT spill into its own block?
	shl		$3, COUNT
	bswap		COUNT
	mov		COUNT, 56(%rsp,%rbx)
	mov		$2, FINAL_STEP	// won't need count-only block
	jmp		2f
1:
	mov		$1, FINAL_STEP	// will need count-only block
2:
	movdqu		0*16(DATA1), MSG0_A
	movdqu		1*16(DATA1), MSG1_A
	movdqu		2*16(DATA1), MSG2_A
	movdqu		3*16(DATA1), MSG3_A
	movdqa		MSG0_A, 0*16(%rsp)
	movdqa		MSG1_A, 1*16(%rsp)
	movdqa		MSG2_A, 2*16(%rsp)
	movdqa		MSG3_A, 3*16(%rsp)
	movdqu		0*16(%rsp,%rbx), MSG0_A
	movdqu		1*16(%rsp,%rbx), MSG1_A
	movdqu		2*16(%rsp,%rbx), MSG2_A
	movdqu		3*16(%rsp,%rbx), MSG3_A

	movdqu		0*16(DATA2), MSG0_B
	movdqu		1*16(DATA2), MSG1_B
	movdqu		2*16(DATA2), MSG2_B
	movdqu		3*16(DATA2), MSG3_B
	movdqa		MSG0_B, 0*16(%rsp)
	movdqa		MSG1_B, 1*16(%rsp)
	movdqa		MSG2_B, 2*16(%rsp)
	movdqa		MSG3_B, 3*16(%rsp)
	movdqu		0*16(%rsp,%rbx), MSG0_B
	movdqu		1*16(%rsp,%rbx), MSG1_B
	movdqu		2*16(%rsp,%rbx), MSG2_B
	movdqu		3*16(%rsp,%rbx), MSG3_B
	jmp		.Lfinup2x_loop_have_data

	// Prepare a padding block, either:
	//
	//	{0x80, 0, 0, 0, ..., count (as __be64)}
	//	This is for a block aligned message.
	//
	//	{   0, 0, 0, 0, ..., count (as __be64)}
	//	This is for a message whose length mod 64 is >= 56.
	//
	// Pre-swap the endianness of the words.
.Lfinup2x_finalize_countonly:
	pxor		MSG0_A, MSG0_A
	jmp		1f

.Lfinup2x_finalize_blockaligned:
	mov		$0x80000000, %ebx
	movd		%ebx, MSG0_A
1:
	pxor		MSG1_A, MSG1_A
	pxor		MSG2_A, MSG2_A
	ror		$29, COUNT
	movq		COUNT, MSG3_A
	pslldq		$8, MSG3_A
	movdqa		MSG0_A, MSG0_B
	pxor		MSG1_B, MSG1_B
	pxor		MSG2_B, MSG2_B
	movdqa		MSG3_A, MSG3_B
	mov		$2, FINAL_STEP
	jmp		.Lfinup2x_loop_have_bswapped_data

.Lfinup2x_done:
	// Write the two digests with all bytes in the correct order.
	movdqa		STATE0_A, TMP_A
	movdqa		STATE0_B, TMP_B
	punpcklqdq	STATE1_A, STATE0_A		// GHEF
	punpcklqdq	STATE1_B, STATE0_B
	punpckhqdq	TMP_A, STATE1_A			// ABCD
	punpckhqdq	TMP_B, STATE1_B
	pshufd		$0xB1, STATE0_A, STATE0_A	// HGFE
	pshufd		$0xB1, STATE0_B, STATE0_B
	pshufd		$0x1B, STATE1_A, STATE1_A	// DCBA
	pshufd		$0x1B, STATE1_B, STATE1_B
	pshufb		SHUF_MASK, STATE0_A
	pshufb		SHUF_MASK, STATE0_B
	pshufb		SHUF_MASK, STATE1_A
	pshufb		SHUF_MASK, STATE1_B
	movdqu		STATE0_A, 1*16(OUT1)
	movdqu		STATE0_B, 1*16(OUT2)
	movdqu		STATE1_A, 0*16(OUT1)
	movdqu		STATE1_B, 0*16(OUT2)

	mov		%rbp, %rsp
	pop		%rbp
	pop		%rbx
	RET
SYM_FUNC_END(sha256_ni_finup2x)

.section	.rodata.cst256.K256, "aM", @progbits, 256
.align 64
K256:
+39 −0
Original line number Diff line number Diff line
@@ -8,6 +8,8 @@
#include <crypto/internal/simd.h>
#include <linux/static_call.h>

static __ro_after_init DEFINE_STATIC_KEY_FALSE(have_sha_ni);

DEFINE_STATIC_CALL(sha256_blocks_x86, sha256_blocks_generic);

#define DEFINE_X86_SHA256_FN(c_fn, asm_fn)                                 \
@@ -36,11 +38,48 @@ static void sha256_blocks(struct sha256_block_state *state,
	static_call(sha256_blocks_x86)(state, data, nblocks);
}

static_assert(offsetof(struct __sha256_ctx, state) == 0);
static_assert(offsetof(struct __sha256_ctx, bytecount) == 32);
static_assert(offsetof(struct __sha256_ctx, buf) == 40);
asmlinkage void sha256_ni_finup2x(const struct __sha256_ctx *ctx,
				  const u8 *data1, const u8 *data2, int len,
				  u8 out1[SHA256_DIGEST_SIZE],
				  u8 out2[SHA256_DIGEST_SIZE]);

#define sha256_finup_2x_arch sha256_finup_2x_arch
static bool sha256_finup_2x_arch(const struct __sha256_ctx *ctx,
				 const u8 *data1, const u8 *data2, size_t len,
				 u8 out1[SHA256_DIGEST_SIZE],
				 u8 out2[SHA256_DIGEST_SIZE])
{
	/*
	 * The assembly requires len >= SHA256_BLOCK_SIZE && len <= INT_MAX.
	 * Further limit len to 65536 to avoid spending too long with preemption
	 * disabled.  (Of course, in practice len is nearly always 4096 anyway.)
	 */
	if (static_branch_likely(&have_sha_ni) && len >= SHA256_BLOCK_SIZE &&
	    len <= 65536 && likely(irq_fpu_usable())) {
		kernel_fpu_begin();
		sha256_ni_finup2x(ctx, data1, data2, len, out1, out2);
		kernel_fpu_end();
		kmsan_unpoison_memory(out1, SHA256_DIGEST_SIZE);
		kmsan_unpoison_memory(out2, SHA256_DIGEST_SIZE);
		return true;
	}
	return false;
}

static bool sha256_finup_2x_is_optimized_arch(void)
{
	return static_key_enabled(&have_sha_ni);
}

#define sha256_mod_init_arch sha256_mod_init_arch
static inline void sha256_mod_init_arch(void)
{
	if (boot_cpu_has(X86_FEATURE_SHA_NI)) {
		static_call_update(sha256_blocks_x86, sha256_blocks_ni);
		static_branch_enable(&have_sha_ni);
	} else if (cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM,
				     NULL) &&
		   boot_cpu_has(X86_FEATURE_AVX)) {