Commit f617d246 authored by Linus Torvalds's avatar Linus Torvalds
Browse files

Merge tag 'fpsimd-on-stack-for-linus' of...

Merge tag 'fpsimd-on-stack-for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/ebiggers/linux

Pull arm64 FPSIMD on-stack buffer updates from Eric Biggers:
 "This is a core arm64 change. However, I was asked to take this because
  most uses of kernel-mode FPSIMD are in crypto or CRC code.

  In v6.8, the size of task_struct on arm64 increased by 528 bytes due
  to the new 'kernel_fpsimd_state' field. This field was added to allow
  kernel-mode FPSIMD code to be preempted.

  Unfortunately, 528 bytes is kind of a lot for task_struct. This
  regression in the task_struct size was noticed and reported.

  Recover that space by making this state be allocated on the stack at
  the beginning of each kernel-mode FPSIMD section.

  To make it easier for all the users of kernel-mode FPSIMD to do that
  correctly, introduce and use a 'scoped_ksimd' abstraction"

* tag 'fpsimd-on-stack-for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/ebiggers/linux: (23 commits)
  lib/crypto: arm64: Move remaining algorithms to scoped ksimd API
  lib/crypto: arm/blake2b: Move to scoped ksimd API
  arm64/fpsimd: Allocate kernel mode FP/SIMD buffers on the stack
  arm64/fpu: Enforce task-context only for generic kernel mode FPU
  net/mlx5: Switch to more abstract scoped ksimd guard API on arm64
  arm64/xorblocks:  Switch to 'ksimd' scoped guard API
  crypto/arm64: sm4 - Switch to 'ksimd' scoped guard API
  crypto/arm64: sm3 - Switch to 'ksimd' scoped guard API
  crypto/arm64: sha3 - Switch to 'ksimd' scoped guard API
  crypto/arm64: polyval - Switch to 'ksimd' scoped guard API
  crypto/arm64: nhpoly1305 - Switch to 'ksimd' scoped guard API
  crypto/arm64: aes-gcm - Switch to 'ksimd' scoped guard API
  crypto/arm64: aes-blk - Switch to 'ksimd' scoped guard API
  crypto/arm64: aes-ccm - Switch to 'ksimd' scoped guard API
  raid6: Move to more abstract 'ksimd' guard API
  crypto: aegis128-neon - Move to more abstract 'ksimd' guard API
  crypto/arm64: sm4-ce-gcm - Avoid pointless yield of the NEON unit
  crypto/arm64: sm4-ce-ccm - Avoid pointless yield of the NEON unit
  crypto/arm64: aes-ce-ccm - Avoid pointless yield of the NEON unit
  lib/crc: Switch ARM and arm64 to 'ksimd' scoped guard API
  ...
parents 906003e1 5dc8d277
Loading
Loading
Loading
Loading
+7 −0
Original line number Diff line number Diff line
@@ -2,14 +2,21 @@
#ifndef _ASM_SIMD_H
#define _ASM_SIMD_H

#include <linux/cleanup.h>
#include <linux/compiler_attributes.h>
#include <linux/preempt.h>
#include <linux/types.h>

#include <asm/neon.h>

static __must_check inline bool may_use_simd(void)
{
	return IS_ENABLED(CONFIG_KERNEL_MODE_NEON) && !in_hardirq()
	       && !irqs_disabled();
}

DEFINE_LOCK_GUARD_0(ksimd, kernel_neon_begin(), kernel_neon_end())

#define scoped_ksimd()	scoped_guard(ksimd)

#endif	/* _ASM_SIMD_H */
+55 −61
Original line number Diff line number Diff line
@@ -8,7 +8,6 @@
 * Author: Ard Biesheuvel <ardb@kernel.org>
 */

#include <asm/neon.h>
#include <linux/unaligned.h>
#include <crypto/aes.h>
#include <crypto/scatterwalk.h>
@@ -16,6 +15,8 @@
#include <crypto/internal/skcipher.h>
#include <linux/module.h>

#include <asm/simd.h>

#include "aes-ce-setkey.h"

MODULE_IMPORT_NS("CRYPTO_INTERNAL");
@@ -114,11 +115,8 @@ static u32 ce_aes_ccm_auth_data(u8 mac[], u8 const in[], u32 abytes,
			in += adv;
			abytes -= adv;

			if (unlikely(rem)) {
				kernel_neon_end();
				kernel_neon_begin();
			if (unlikely(rem))
				macp = 0;
			}
		} else {
			u32 l = min(AES_BLOCK_SIZE - macp, abytes);

@@ -187,8 +185,7 @@ static int ccm_encrypt(struct aead_request *req)
	if (unlikely(err))
		return err;

	kernel_neon_begin();

	scoped_ksimd() {
		if (req->assoclen)
			ccm_calculate_auth_mac(req, mac);

@@ -219,8 +216,7 @@ static int ccm_encrypt(struct aead_request *req)
				err = skcipher_walk_done(&walk, tail);
			}
		} while (walk.nbytes);

	kernel_neon_end();
	}

	if (unlikely(err))
		return err;
@@ -254,8 +250,7 @@ static int ccm_decrypt(struct aead_request *req)
	if (unlikely(err))
		return err;

	kernel_neon_begin();

	scoped_ksimd() {
		if (req->assoclen)
			ccm_calculate_auth_mac(req, mac);

@@ -286,8 +281,7 @@ static int ccm_decrypt(struct aead_request *req)
				err = skcipher_walk_done(&walk, tail);
			}
		} while (walk.nbytes);

	kernel_neon_end();
	}

	if (unlikely(err))
		return err;
+43 −44
Original line number Diff line number Diff line
@@ -52,9 +52,8 @@ static void aes_cipher_encrypt(struct crypto_tfm *tfm, u8 dst[], u8 const src[])
		return;
	}

	kernel_neon_begin();
	scoped_ksimd()
		__aes_ce_encrypt(ctx->key_enc, dst, src, num_rounds(ctx));
	kernel_neon_end();
}

static void aes_cipher_decrypt(struct crypto_tfm *tfm, u8 dst[], u8 const src[])
@@ -66,9 +65,8 @@ static void aes_cipher_decrypt(struct crypto_tfm *tfm, u8 dst[], u8 const src[])
		return;
	}

	kernel_neon_begin();
	scoped_ksimd()
		__aes_ce_decrypt(ctx->key_dec, dst, src, num_rounds(ctx));
	kernel_neon_end();
}

int ce_aes_expandkey(struct crypto_aes_ctx *ctx, const u8 *in_key,
@@ -94,12 +92,13 @@ int ce_aes_expandkey(struct crypto_aes_ctx *ctx, const u8 *in_key,
	for (i = 0; i < kwords; i++)
		ctx->key_enc[i] = get_unaligned_le32(in_key + i * sizeof(u32));

	kernel_neon_begin();
	scoped_ksimd() {
		for (i = 0; i < sizeof(rcon); i++) {
			u32 *rki = ctx->key_enc + (i * kwords);
			u32 *rko = rki + kwords;

		rko[0] = ror32(__aes_ce_sub(rki[kwords - 1]), 8) ^ rcon[i] ^ rki[0];
			rko[0] = ror32(__aes_ce_sub(rki[kwords - 1]), 8) ^
				 rcon[i] ^ rki[0];
			rko[1] = rko[0] ^ rki[1];
			rko[2] = rko[1] ^ rki[2];
			rko[3] = rko[2] ^ rki[3];
@@ -120,10 +119,10 @@ int ce_aes_expandkey(struct crypto_aes_ctx *ctx, const u8 *in_key,
		}

		/*
	 * Generate the decryption keys for the Equivalent Inverse Cipher.
	 * This involves reversing the order of the round keys, and applying
	 * the Inverse Mix Columns transformation on all but the first and
	 * the last one.
		 * Generate the decryption keys for the Equivalent Inverse
		 * Cipher.  This involves reversing the order of the round
		 * keys, and applying the Inverse Mix Columns transformation on
		 * all but the first and the last one.
		 */
		key_enc = (struct aes_block *)ctx->key_enc;
		key_dec = (struct aes_block *)ctx->key_dec;
@@ -133,8 +132,8 @@ int ce_aes_expandkey(struct crypto_aes_ctx *ctx, const u8 *in_key,
		for (i = 1, j--; j > 0; i++, j--)
			__aes_ce_invert(key_dec + i, key_enc + j);
		key_dec[i] = key_enc[0];
	}

	kernel_neon_end();
	return 0;
}
EXPORT_SYMBOL(ce_aes_expandkey);
+63 −76
Original line number Diff line number Diff line
@@ -5,8 +5,6 @@
 * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
 */

#include <asm/hwcap.h>
#include <asm/neon.h>
#include <crypto/aes.h>
#include <crypto/ctr.h>
#include <crypto/internal/hash.h>
@@ -20,6 +18,9 @@
#include <linux/module.h>
#include <linux/string.h>

#include <asm/hwcap.h>
#include <asm/simd.h>

#include "aes-ce-setkey.h"

#ifdef USE_V8_CRYPTO_EXTENSIONS
@@ -186,10 +187,9 @@ static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
	err = skcipher_walk_virt(&walk, req, false);

	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
		kernel_neon_begin();
		scoped_ksimd()
			aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
					ctx->key_enc, rounds, blocks);
		kernel_neon_end();
		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
	}
	return err;
@@ -206,10 +206,9 @@ static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
	err = skcipher_walk_virt(&walk, req, false);

	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
		kernel_neon_begin();
		scoped_ksimd()
			aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
					ctx->key_dec, rounds, blocks);
		kernel_neon_end();
		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
	}
	return err;
@@ -224,10 +223,9 @@ static int cbc_encrypt_walk(struct skcipher_request *req,
	unsigned int blocks;

	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
		kernel_neon_begin();
		scoped_ksimd()
			aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
					ctx->key_enc, rounds, blocks, walk->iv);
		kernel_neon_end();
		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
	}
	return err;
@@ -253,10 +251,9 @@ static int cbc_decrypt_walk(struct skcipher_request *req,
	unsigned int blocks;

	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
		kernel_neon_begin();
		scoped_ksimd()
			aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
					ctx->key_dec, rounds, blocks, walk->iv);
		kernel_neon_end();
		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
	}
	return err;
@@ -322,10 +319,9 @@ static int cts_cbc_encrypt(struct skcipher_request *req)
	if (err)
		return err;

	kernel_neon_begin();
	scoped_ksimd()
		aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
				    ctx->key_enc, rounds, walk.nbytes, walk.iv);
	kernel_neon_end();

	return skcipher_walk_done(&walk, 0);
}
@@ -379,10 +375,9 @@ static int cts_cbc_decrypt(struct skcipher_request *req)
	if (err)
		return err;

	kernel_neon_begin();
	scoped_ksimd()
		aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
				    ctx->key_dec, rounds, walk.nbytes, walk.iv);
	kernel_neon_end();

	return skcipher_walk_done(&walk, 0);
}
@@ -399,11 +394,11 @@ static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)

	blocks = walk.nbytes / AES_BLOCK_SIZE;
	if (blocks) {
		kernel_neon_begin();
		aes_essiv_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
		scoped_ksimd()
			aes_essiv_cbc_encrypt(walk.dst.virt.addr,
					      walk.src.virt.addr,
					      ctx->key1.key_enc, rounds, blocks,
					      req->iv, ctx->key2.key_enc);
		kernel_neon_end();
		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
	}
	return err ?: cbc_encrypt_walk(req, &walk);
@@ -421,11 +416,11 @@ static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)

	blocks = walk.nbytes / AES_BLOCK_SIZE;
	if (blocks) {
		kernel_neon_begin();
		aes_essiv_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
		scoped_ksimd()
			aes_essiv_cbc_decrypt(walk.dst.virt.addr,
					      walk.src.virt.addr,
					      ctx->key1.key_dec, rounds, blocks,
					      req->iv, ctx->key2.key_enc);
		kernel_neon_end();
		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
	}
	return err ?: cbc_decrypt_walk(req, &walk);
@@ -461,10 +456,9 @@ static int __maybe_unused xctr_encrypt(struct skcipher_request *req)
		else if (nbytes < walk.total)
			nbytes &= ~(AES_BLOCK_SIZE - 1);

		kernel_neon_begin();
		scoped_ksimd()
			aes_xctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
							 walk.iv, byte_ctr);
		kernel_neon_end();

		if (unlikely(nbytes < AES_BLOCK_SIZE))
			memcpy(walk.dst.virt.addr,
@@ -506,10 +500,9 @@ static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
		else if (nbytes < walk.total)
			nbytes &= ~(AES_BLOCK_SIZE - 1);

		kernel_neon_begin();
		scoped_ksimd()
			aes_ctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
					walk.iv);
		kernel_neon_end();

		if (unlikely(nbytes < AES_BLOCK_SIZE))
			memcpy(walk.dst.virt.addr,
@@ -562,11 +555,10 @@ static int __maybe_unused xts_encrypt(struct skcipher_request *req)
		if (walk.nbytes < walk.total)
			nbytes &= ~(AES_BLOCK_SIZE - 1);

		kernel_neon_begin();
		scoped_ksimd()
			aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
					ctx->key1.key_enc, rounds, nbytes,
					ctx->key2.key_enc, walk.iv, first);
		kernel_neon_end();
		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
	}

@@ -584,11 +576,10 @@ static int __maybe_unused xts_encrypt(struct skcipher_request *req)
	if (err)
		return err;

	kernel_neon_begin();
	scoped_ksimd()
		aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
				ctx->key1.key_enc, rounds, walk.nbytes,
				ctx->key2.key_enc, walk.iv, first);
	kernel_neon_end();

	return skcipher_walk_done(&walk, 0);
}
@@ -634,11 +625,10 @@ static int __maybe_unused xts_decrypt(struct skcipher_request *req)
		if (walk.nbytes < walk.total)
			nbytes &= ~(AES_BLOCK_SIZE - 1);

		kernel_neon_begin();
		scoped_ksimd()
			aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
					ctx->key1.key_dec, rounds, nbytes,
					ctx->key2.key_enc, walk.iv, first);
		kernel_neon_end();
		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
	}

@@ -657,11 +647,10 @@ static int __maybe_unused xts_decrypt(struct skcipher_request *req)
		return err;


	kernel_neon_begin();
	scoped_ksimd()
		aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
				ctx->key1.key_dec, rounds, walk.nbytes,
				ctx->key2.key_enc, walk.iv, first);
	kernel_neon_end();

	return skcipher_walk_done(&walk, 0);
}
@@ -808,10 +797,9 @@ static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
		return err;

	/* encrypt the zero vector */
	kernel_neon_begin();
	aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
			rounds, 1);
	kernel_neon_end();
	scoped_ksimd()
		aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){},
				ctx->key.key_enc, rounds, 1);

	cmac_gf128_mul_by_x(consts, consts);
	cmac_gf128_mul_by_x(consts + 1, consts);
@@ -837,10 +825,10 @@ static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
	if (err)
		return err;

	kernel_neon_begin();
	scoped_ksimd() {
		aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
		aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
	kernel_neon_end();
	}

	return cbcmac_setkey(tfm, key, sizeof(key));
}
@@ -860,10 +848,9 @@ static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
	int rem;

	do {
		kernel_neon_begin();
		scoped_ksimd()
			rem = aes_mac_update(in, ctx->key_enc, rounds, blocks,
					     dg, enc_before, !enc_before);
		kernel_neon_end();
		in += (blocks - rem) * AES_BLOCK_SIZE;
		blocks = rem;
	} while (blocks);
+75 −75
Original line number Diff line number Diff line
@@ -85,9 +85,8 @@ static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,

	ctx->rounds = 6 + key_len / 4;

	kernel_neon_begin();
	scoped_ksimd()
		aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
	kernel_neon_end();

	return 0;
}
@@ -110,10 +109,9 @@ static int __ecb_crypt(struct skcipher_request *req,
			blocks = round_down(blocks,
					    walk.stride / AES_BLOCK_SIZE);

		kernel_neon_begin();
		scoped_ksimd()
			fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
			   ctx->rounds, blocks);
		kernel_neon_end();
		err = skcipher_walk_done(&walk,
					 walk.nbytes - blocks * AES_BLOCK_SIZE);
	}
@@ -146,9 +144,8 @@ static int aesbs_cbc_ctr_setkey(struct crypto_skcipher *tfm, const u8 *in_key,

	memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc));

	kernel_neon_begin();
	scoped_ksimd()
		aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
	kernel_neon_end();
	memzero_explicit(&rk, sizeof(rk));

	return 0;
@@ -167,11 +164,11 @@ static int cbc_encrypt(struct skcipher_request *req)
		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;

		/* fall back to the non-bitsliced NEON implementation */
		kernel_neon_begin();
		neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
		scoped_ksimd()
			neon_aes_cbc_encrypt(walk.dst.virt.addr,
					     walk.src.virt.addr,
					     ctx->enc, ctx->key.rounds, blocks,
					     walk.iv);
		kernel_neon_end();
		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
	}
	return err;
@@ -193,11 +190,10 @@ static int cbc_decrypt(struct skcipher_request *req)
			blocks = round_down(blocks,
					    walk.stride / AES_BLOCK_SIZE);

		kernel_neon_begin();
		scoped_ksimd()
			aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
					  ctx->key.rk, ctx->key.rounds, blocks,
					  walk.iv);
		kernel_neon_end();
		err = skcipher_walk_done(&walk,
					 walk.nbytes - blocks * AES_BLOCK_SIZE);
	}
@@ -220,10 +216,11 @@ static int ctr_encrypt(struct skcipher_request *req)
		const u8 *src = walk.src.virt.addr;
		u8 *dst = walk.dst.virt.addr;

		kernel_neon_begin();
		scoped_ksimd() {
			if (blocks >= 8) {
			aesbs_ctr_encrypt(dst, src, ctx->key.rk, ctx->key.rounds,
					  blocks, walk.iv);
				aesbs_ctr_encrypt(dst, src, ctx->key.rk,
						  ctx->key.rounds, blocks,
						  walk.iv);
				dst += blocks * AES_BLOCK_SIZE;
				src += blocks * AES_BLOCK_SIZE;
			}
@@ -232,18 +229,19 @@ static int ctr_encrypt(struct skcipher_request *req)
				u8 *d = dst;

				if (unlikely(nbytes < AES_BLOCK_SIZE))
				src = dst = memcpy(buf + sizeof(buf) - nbytes,
						   src, nbytes);
					src = dst = memcpy(buf + sizeof(buf) -
							   nbytes, src, nbytes);

			neon_aes_ctr_encrypt(dst, src, ctx->enc, ctx->key.rounds,
					     nbytes, walk.iv);
				neon_aes_ctr_encrypt(dst, src, ctx->enc,
						     ctx->key.rounds, nbytes,
						     walk.iv);

				if (unlikely(nbytes < AES_BLOCK_SIZE))
					memcpy(d, dst, nbytes);

				nbytes = 0;
			}
		kernel_neon_end();
		}
		err = skcipher_walk_done(&walk, nbytes);
	}
	return err;
@@ -320,7 +318,7 @@ static int __xts_crypt(struct skcipher_request *req, bool encrypt,
		in = walk.src.virt.addr;
		nbytes = walk.nbytes;

		kernel_neon_begin();
		scoped_ksimd() {
			if (blocks >= 8) {
				if (first == 1)
					neon_aes_ecb_encrypt(walk.iv, walk.iv,
@@ -346,7 +344,7 @@ static int __xts_crypt(struct skcipher_request *req, bool encrypt,
							     ctx->twkey, walk.iv, first);
				nbytes = first = 0;
			}
		kernel_neon_end();
		}
		err = skcipher_walk_done(&walk, nbytes);
	}

@@ -369,14 +367,16 @@ static int __xts_crypt(struct skcipher_request *req, bool encrypt,
	in = walk.src.virt.addr;
	nbytes = walk.nbytes;

	kernel_neon_begin();
	scoped_ksimd() {
		if (encrypt)
		neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, ctx->key.rounds,
				     nbytes, ctx->twkey, walk.iv, first);
			neon_aes_xts_encrypt(out, in, ctx->cts.key_enc,
					     ctx->key.rounds, nbytes, ctx->twkey,
					     walk.iv, first);
		else
		neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, ctx->key.rounds,
				     nbytes, ctx->twkey, walk.iv, first);
	kernel_neon_end();
			neon_aes_xts_decrypt(out, in, ctx->cts.key_dec,
					     ctx->key.rounds, nbytes, ctx->twkey,
					     walk.iv, first);
	}

	return skcipher_walk_done(&walk, 0);
}
Loading