Commit c8bf850e authored by Ard Biesheuvel's avatar Ard Biesheuvel Committed by Herbert Xu
Browse files

crypto: arm/aes-neonbs-ctr - deal with non-multiples of AES block size



Instead of falling back to C code to deal with the final bit of input
that is not a round multiple of the block size, handle this in the asm
code, permitting us to use overlapping loads and stores for performance,
and implement the 16-byte wide XOR using a single NEON instruction.

Since NEON loads and stores have a natural width of 16 bytes, we need to
handle inputs of less than 16 bytes in a special way, but this rarely
occurs in practice so it does not impact performance. All other input
sizes can be consumed directly by the NEON asm code, although it should
be noted that the core AES transform can still only process 128 bytes (8
AES blocks) at a time.

Signed-off-by: default avatarArd Biesheuvel <ardb@kernel.org>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent 8daa399e
Loading
Loading
Loading
Loading
+63 −42
Original line number Diff line number Diff line
@@ -758,29 +758,24 @@ ENTRY(aesbs_cbc_decrypt)
ENDPROC(aesbs_cbc_decrypt)

	.macro		next_ctr, q
	vmov.32		\q\()h[1], r10
	vmov		\q\()h, r9, r10
	adds		r10, r10, #1
	vmov.32		\q\()h[0], r9
	adcs		r9, r9, #0
	vmov.32		\q\()l[1], r8
	vmov		\q\()l, r7, r8
	adcs		r8, r8, #0
	vmov.32		\q\()l[0], r7
	adc		r7, r7, #0
	vrev32.8	\q, \q
	.endm

	/*
	 * aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
	 *		     int rounds, int blocks, u8 ctr[], u8 final[])
	 *		     int rounds, int bytes, u8 ctr[])
	 */
ENTRY(aesbs_ctr_encrypt)
	mov		ip, sp
	push		{r4-r10, lr}

	ldm		ip, {r5-r7}		// load args 4-6
	teq		r7, #0
	addne		r5, r5, #1		// one extra block if final != 0

	ldm		ip, {r5, r6}		// load args 4-5
	vld1.8		{q0}, [r6]		// load counter
	vrev32.8	q1, q0
	vmov		r9, r10, d3
@@ -792,20 +787,19 @@ ENTRY(aesbs_ctr_encrypt)
	adc		r7, r7, #0

99:	vmov		q1, q0
	sub		lr, r5, #1
	vmov		q2, q0
	adr		ip, 0f
	vmov		q3, q0
	and		lr, lr, #112
	vmov		q4, q0
	cmp		r5, #112
	vmov		q5, q0
	sub		ip, ip, lr, lsl #1
	vmov		q6, q0
	add		ip, ip, lr, lsr #2
	vmov		q7, q0

	adr		ip, 0f
	sub		lr, r5, #1
	and		lr, lr, #7
	cmp		r5, #8
	sub		ip, ip, lr, lsl #5
	sub		ip, ip, lr, lsl #2
	movlt		pc, ip			// computed goto if blocks < 8
	movle		pc, ip			// computed goto if bytes < 112

	next_ctr	q1
	next_ctr	q2
@@ -820,12 +814,14 @@ ENTRY(aesbs_ctr_encrypt)
	bl		aesbs_encrypt8

	adr		ip, 1f
	and		lr, r5, #7
	cmp		r5, #8
	movgt		r4, #0
	ldrle		r4, [sp, #40]		// load final in the last round
	sub		ip, ip, lr, lsl #2
	movlt		pc, ip			// computed goto if blocks < 8
	sub		lr, r5, #1
	cmp		r5, #128
	bic		lr, lr, #15
	ands		r4, r5, #15		// preserves C flag
	teqcs		r5, r5			// set Z flag if not last iteration
	sub		ip, ip, lr, lsr #2
	rsb		r4, r4, #16
	movcc		pc, ip			// computed goto if bytes < 128

	vld1.8		{q8}, [r1]!
	vld1.8		{q9}, [r1]!
@@ -834,46 +830,70 @@ ENTRY(aesbs_ctr_encrypt)
	vld1.8		{q12}, [r1]!
	vld1.8		{q13}, [r1]!
	vld1.8		{q14}, [r1]!
	teq		r4, #0			// skip last block if 'final'
1:	bne		2f
1:	subne		r1, r1, r4
	vld1.8		{q15}, [r1]!

2:	adr		ip, 3f
	cmp		r5, #8
	sub		ip, ip, lr, lsl #3
	movlt		pc, ip			// computed goto if blocks < 8
	add		ip, ip, #2f - 1b

	veor		q0, q0, q8
	vst1.8		{q0}, [r0]!
	veor		q1, q1, q9
	vst1.8		{q1}, [r0]!
	veor		q4, q4, q10
	vst1.8		{q4}, [r0]!
	veor		q6, q6, q11
	vst1.8		{q6}, [r0]!
	veor		q3, q3, q12
	vst1.8		{q3}, [r0]!
	veor		q7, q7, q13
	vst1.8		{q7}, [r0]!
	veor		q2, q2, q14
	bne		3f
	veor		q5, q5, q15

	movcc		pc, ip			// computed goto if bytes < 128

	vst1.8		{q0}, [r0]!
	vst1.8		{q1}, [r0]!
	vst1.8		{q4}, [r0]!
	vst1.8		{q6}, [r0]!
	vst1.8		{q3}, [r0]!
	vst1.8		{q7}, [r0]!
	vst1.8		{q2}, [r0]!
	teq		r4, #0			// skip last block if 'final'
	W(bne)		5f
3:	veor		q5, q5, q15
2:	subne		r0, r0, r4
	vst1.8		{q5}, [r0]!

4:	next_ctr	q0
	next_ctr	q0

	subs		r5, r5, #8
	subs		r5, r5, #128
	bgt		99b

	vst1.8		{q0}, [r6]
	pop		{r4-r10, pc}

5:	vst1.8		{q5}, [r4]
	b		4b
3:	adr		lr, .Lpermute_table + 16
	cmp		r5, #16			// Z flag remains cleared
	sub		lr, lr, r4
	vld1.8		{q8-q9}, [lr]
	vtbl.8		d16, {q5}, d16
	vtbl.8		d17, {q5}, d17
	veor		q5, q8, q15
	bcc		4f			// have to reload prev if R5 < 16
	vtbx.8		d10, {q2}, d18
	vtbx.8		d11, {q2}, d19
	mov		pc, ip			// branch back to VST sequence

4:	sub		r0, r0, r4
	vshr.s8		q9, q9, #7		// create mask for VBIF
	vld1.8		{q8}, [r0]		// reload
	vbif		q5, q8, q9
	vst1.8		{q5}, [r0]
	pop		{r4-r10, pc}
ENDPROC(aesbs_ctr_encrypt)

	.align		6
.Lpermute_table:
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
	.byte		0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07
	.byte		0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff

	.macro		next_tweak, out, in, const, tmp
	vshr.s64	\tmp, \in, #63
	vand		\tmp, \tmp, \const
@@ -888,6 +908,7 @@ ENDPROC(aesbs_ctr_encrypt)
	 * aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
	 *		     int blocks, u8 iv[], int reorder_last_tweak)
	 */
	.align		6
__xts_prepare8:
	vld1.8		{q14}, [r7]		// load iv
	vmov.i32	d30, #0x87		// compose tweak mask vector
+14 −21
Original line number Diff line number Diff line
@@ -37,7 +37,7 @@ asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
				  int rounds, int blocks, u8 iv[]);

asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
				  int rounds, int blocks, u8 ctr[], u8 final[]);
				  int rounds, int blocks, u8 ctr[]);

asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
				  int rounds, int blocks, u8 iv[], int);
@@ -243,32 +243,25 @@ static int ctr_encrypt(struct skcipher_request *req)
	err = skcipher_walk_virt(&walk, req, false);

	while (walk.nbytes > 0) {
		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
		u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
		const u8 *src = walk.src.virt.addr;
		u8 *dst = walk.dst.virt.addr;
		int bytes = walk.nbytes;

		if (walk.nbytes < walk.total) {
			blocks = round_down(blocks,
					    walk.stride / AES_BLOCK_SIZE);
			final = NULL;
		}
		if (unlikely(bytes < AES_BLOCK_SIZE))
			src = dst = memcpy(buf + sizeof(buf) - bytes,
					   src, bytes);
		else if (walk.nbytes < walk.total)
			bytes &= ~(8 * AES_BLOCK_SIZE - 1);

		kernel_neon_begin();
		aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
				  ctx->rk, ctx->rounds, blocks, walk.iv, final);
		aesbs_ctr_encrypt(dst, src, ctx->rk, ctx->rounds, bytes, walk.iv);
		kernel_neon_end();

		if (final) {
			u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
			u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
		if (unlikely(bytes < AES_BLOCK_SIZE))
			memcpy(walk.dst.virt.addr,
			       buf + sizeof(buf) - bytes, bytes);

			crypto_xor_cpy(dst, src, final,
				       walk.total % AES_BLOCK_SIZE);

			err = skcipher_walk_done(&walk, 0);
			break;
		}
		err = skcipher_walk_done(&walk,
					 walk.nbytes - blocks * AES_BLOCK_SIZE);
		err = skcipher_walk_done(&walk, walk.nbytes - bytes);
	}

	return err;