Commit 03bc4768 authored by Ard Biesheuvel's avatar Ard Biesheuvel
Browse files

crypto/arm64: sm4 - Switch to 'ksimd' scoped guard API



Switch to the more abstract 'scoped_ksimd()' API, which will be modified
in a future patch to transparently allocate a kernel mode FP/SIMD state
buffer on the stack, so that kernel mode FP/SIMD code remains
preemptible in principle, but without the memory overhead that adds 528
bytes to the size of struct task_struct.

Reviewed-by: default avatarEric Biggers <ebiggers@kernel.org>
Reviewed-by: default avatarJonathan Cameron <jonathan.cameron@huawei.com>
Acked-by: default avatarCatalin Marinas <catalin.marinas@arm.com>
Signed-off-by: default avatarArd Biesheuvel <ardb@kernel.org>
parent ab9615b5
Loading
Loading
Loading
Loading
+17 −21
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@
#include <linux/crypto.h>
#include <linux/kernel.h>
#include <linux/cpufeature.h>
#include <asm/neon.h>
#include <asm/simd.h>
#include <crypto/scatterwalk.h>
#include <crypto/internal/aead.h>
#include <crypto/internal/skcipher.h>
@@ -35,10 +35,9 @@ static int ccm_setkey(struct crypto_aead *tfm, const u8 *key,
	if (key_len != SM4_KEY_SIZE)
		return -EINVAL;

	kernel_neon_begin();
	scoped_ksimd()
		sm4_ce_expand_key(key, ctx->rkey_enc, ctx->rkey_dec,
				  crypto_sm4_fk, crypto_sm4_ck);
	kernel_neon_end();

	return 0;
}
@@ -167,8 +166,7 @@ static int ccm_crypt(struct aead_request *req, struct skcipher_walk *walk,
	memcpy(ctr0, walk->iv, SM4_BLOCK_SIZE);
	crypto_inc(walk->iv, SM4_BLOCK_SIZE);

	kernel_neon_begin();

	scoped_ksimd() {
		if (req->assoclen)
			ccm_calculate_auth_mac(req, mac);

@@ -184,10 +182,8 @@ static int ccm_crypt(struct aead_request *req, struct skcipher_walk *walk,

			err = skcipher_walk_done(walk, tail);
		}

		sm4_ce_ccm_final(rkey_enc, ctr0, mac);

	kernel_neon_end();
	}

	return err;
}
+4 −6
Original line number Diff line number Diff line
@@ -32,9 +32,8 @@ static void sm4_ce_encrypt(struct crypto_tfm *tfm, u8 *out, const u8 *in)
	if (!crypto_simd_usable()) {
		sm4_crypt_block(ctx->rkey_enc, out, in);
	} else {
		kernel_neon_begin();
		scoped_ksimd()
			sm4_ce_do_crypt(ctx->rkey_enc, out, in);
		kernel_neon_end();
	}
}

@@ -45,9 +44,8 @@ static void sm4_ce_decrypt(struct crypto_tfm *tfm, u8 *out, const u8 *in)
	if (!crypto_simd_usable()) {
		sm4_crypt_block(ctx->rkey_dec, out, in);
	} else {
		kernel_neon_begin();
		scoped_ksimd()
			sm4_ce_do_crypt(ctx->rkey_dec, out, in);
		kernel_neon_end();
	}
}

+24 −29
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@
#include <linux/crypto.h>
#include <linux/kernel.h>
#include <linux/cpufeature.h>
#include <asm/neon.h>
#include <asm/simd.h>
#include <crypto/b128ops.h>
#include <crypto/scatterwalk.h>
#include <crypto/internal/aead.h>
@@ -48,13 +48,11 @@ static int gcm_setkey(struct crypto_aead *tfm, const u8 *key,
	if (key_len != SM4_KEY_SIZE)
		return -EINVAL;

	kernel_neon_begin();

	scoped_ksimd() {
		sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
				crypto_sm4_fk, crypto_sm4_ck);
		sm4_ce_pmull_ghash_setup(ctx->key.rkey_enc, ctx->ghash_table);

	kernel_neon_end();
	}
	return 0;
}

@@ -149,8 +147,7 @@ static int gcm_crypt(struct aead_request *req, struct skcipher_walk *walk,
	memcpy(iv, req->iv, GCM_IV_SIZE);
	put_unaligned_be32(2, iv + GCM_IV_SIZE);

	kernel_neon_begin();

	scoped_ksimd() {
		if (req->assoclen)
			gcm_calculate_auth_mac(req, ghash);

@@ -171,9 +168,7 @@ static int gcm_crypt(struct aead_request *req, struct skcipher_walk *walk,

			err = skcipher_walk_done(walk, tail);
		} while (walk->nbytes);

	kernel_neon_end();

	}
	return err;
}

+96 −118
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@
 * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
 */

#include <asm/neon.h>
#include <asm/simd.h>
#include <crypto/b128ops.h>
#include <crypto/internal/hash.h>
#include <crypto/internal/skcipher.h>
@@ -74,10 +74,9 @@ static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key,
	if (key_len != SM4_KEY_SIZE)
		return -EINVAL;

	kernel_neon_begin();
	scoped_ksimd()
		sm4_ce_expand_key(key, ctx->rkey_enc, ctx->rkey_dec,
				  crypto_sm4_fk, crypto_sm4_ck);
	kernel_neon_end();
	return 0;
}

@@ -94,12 +93,12 @@ static int sm4_xts_setkey(struct crypto_skcipher *tfm, const u8 *key,
	if (ret)
		return ret;

	kernel_neon_begin();
	scoped_ksimd() {
		sm4_ce_expand_key(key, ctx->key1.rkey_enc,
				ctx->key1.rkey_dec, crypto_sm4_fk, crypto_sm4_ck);
		sm4_ce_expand_key(&key[SM4_KEY_SIZE], ctx->key2.rkey_enc,
				ctx->key2.rkey_dec, crypto_sm4_fk, crypto_sm4_ck);
	kernel_neon_end();
	}

	return 0;
}
@@ -117,15 +116,13 @@ static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
		u8 *dst = walk.dst.virt.addr;
		unsigned int nblks;

		kernel_neon_begin();

		scoped_ksimd() {
			nblks = BYTES2BLKS(nbytes);
			if (nblks) {
				sm4_ce_crypt(rkey, dst, src, nblks);
				nbytes -= nblks * SM4_BLOCK_SIZE;
			}

		kernel_neon_end();
		}

		err = skcipher_walk_done(&walk, nbytes);
	}
@@ -167,16 +164,14 @@ static int sm4_cbc_crypt(struct skcipher_request *req,

		nblocks = nbytes / SM4_BLOCK_SIZE;
		if (nblocks) {
			kernel_neon_begin();

			scoped_ksimd() {
				if (encrypt)
					sm4_ce_cbc_enc(ctx->rkey_enc, dst, src,
						       walk.iv, nblocks);
				else
					sm4_ce_cbc_dec(ctx->rkey_dec, dst, src,
						       walk.iv, nblocks);

			kernel_neon_end();
			}
		}

		err = skcipher_walk_done(&walk, nbytes % SM4_BLOCK_SIZE);
@@ -249,16 +244,14 @@ static int sm4_cbc_cts_crypt(struct skcipher_request *req, bool encrypt)
	if (err)
		return err;

	kernel_neon_begin();

	scoped_ksimd() {
		if (encrypt)
			sm4_ce_cbc_cts_enc(ctx->rkey_enc, walk.dst.virt.addr,
					   walk.src.virt.addr, walk.iv, walk.nbytes);
		else
			sm4_ce_cbc_cts_dec(ctx->rkey_dec, walk.dst.virt.addr,
					   walk.src.virt.addr, walk.iv, walk.nbytes);

	kernel_neon_end();
	}

	return skcipher_walk_done(&walk, 0);
}
@@ -288,8 +281,7 @@ static int sm4_ctr_crypt(struct skcipher_request *req)
		u8 *dst = walk.dst.virt.addr;
		unsigned int nblks;

		kernel_neon_begin();

		scoped_ksimd() {
			nblks = BYTES2BLKS(nbytes);
			if (nblks) {
				sm4_ce_ctr_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
@@ -307,8 +299,7 @@ static int sm4_ctr_crypt(struct skcipher_request *req)
				crypto_xor_cpy(dst, src, keystream, nbytes);
				nbytes = 0;
			}

		kernel_neon_end();
		}

		err = skcipher_walk_done(&walk, nbytes);
	}
@@ -359,8 +350,7 @@ static int sm4_xts_crypt(struct skcipher_request *req, bool encrypt)
		if (nbytes < walk.total)
			nbytes &= ~(SM4_BLOCK_SIZE - 1);

		kernel_neon_begin();

		scoped_ksimd() {
			if (encrypt)
				sm4_ce_xts_enc(ctx->key1.rkey_enc, walk.dst.virt.addr,
						walk.src.virt.addr, walk.iv, nbytes,
@@ -369,8 +359,7 @@ static int sm4_xts_crypt(struct skcipher_request *req, bool encrypt)
				sm4_ce_xts_dec(ctx->key1.rkey_dec, walk.dst.virt.addr,
						walk.src.virt.addr, walk.iv, nbytes,
						rkey2_enc);

		kernel_neon_end();
		}

		rkey2_enc = NULL;

@@ -395,8 +384,7 @@ static int sm4_xts_crypt(struct skcipher_request *req, bool encrypt)
	if (err)
		return err;

	kernel_neon_begin();

	scoped_ksimd() {
		if (encrypt)
			sm4_ce_xts_enc(ctx->key1.rkey_enc, walk.dst.virt.addr,
					walk.src.virt.addr, walk.iv, walk.nbytes,
@@ -405,8 +393,7 @@ static int sm4_xts_crypt(struct skcipher_request *req, bool encrypt)
			sm4_ce_xts_dec(ctx->key1.rkey_dec, walk.dst.virt.addr,
					walk.src.virt.addr, walk.iv, walk.nbytes,
					rkey2_enc);

	kernel_neon_end();
	}

	return skcipher_walk_done(&walk, 0);
}
@@ -510,11 +497,9 @@ static int sm4_cbcmac_setkey(struct crypto_shash *tfm, const u8 *key,
	if (key_len != SM4_KEY_SIZE)
		return -EINVAL;

	kernel_neon_begin();
	scoped_ksimd()
		sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
				crypto_sm4_fk, crypto_sm4_ck);
	kernel_neon_end();

	return 0;
}

@@ -530,15 +515,13 @@ static int sm4_cmac_setkey(struct crypto_shash *tfm, const u8 *key,

	memset(consts, 0, SM4_BLOCK_SIZE);

	kernel_neon_begin();

	scoped_ksimd() {
		sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
				crypto_sm4_fk, crypto_sm4_ck);

		/* encrypt the zero block */
		sm4_ce_crypt_block(ctx->key.rkey_enc, (u8 *)consts, (const u8 *)consts);

	kernel_neon_end();
	}

	/* gf(2^128) multiply zero-ciphertext with u and u^2 */
	a = be64_to_cpu(consts[0].a);
@@ -568,8 +551,7 @@ static int sm4_xcbc_setkey(struct crypto_shash *tfm, const u8 *key,
	if (key_len != SM4_KEY_SIZE)
		return -EINVAL;

	kernel_neon_begin();

	scoped_ksimd() {
		sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
				crypto_sm4_fk, crypto_sm4_ck);

@@ -578,8 +560,7 @@ static int sm4_xcbc_setkey(struct crypto_shash *tfm, const u8 *key,

		sm4_ce_expand_key(key2, ctx->key.rkey_enc, ctx->key.rkey_dec,
				crypto_sm4_fk, crypto_sm4_ck);

	kernel_neon_end();
	}

	return 0;
}
@@ -600,10 +581,9 @@ static int sm4_mac_update(struct shash_desc *desc, const u8 *p,
	unsigned int nblocks = len / SM4_BLOCK_SIZE;

	len %= SM4_BLOCK_SIZE;
	kernel_neon_begin();
	scoped_ksimd()
		sm4_ce_mac_update(tctx->key.rkey_enc, ctx->digest, p,
				nblocks, false, true);
	kernel_neon_end();
	return len;
}

@@ -619,10 +599,9 @@ static int sm4_cmac_finup(struct shash_desc *desc, const u8 *src,
		ctx->digest[len] ^= 0x80;
		consts += SM4_BLOCK_SIZE;
	}
	kernel_neon_begin();
	scoped_ksimd()
		sm4_ce_mac_update(tctx->key.rkey_enc, ctx->digest, consts, 1,
				  false, true);
	kernel_neon_end();
	memcpy(out, ctx->digest, SM4_BLOCK_SIZE);
	return 0;
}
@@ -635,10 +614,9 @@ static int sm4_cbcmac_finup(struct shash_desc *desc, const u8 *src,

	if (len) {
		crypto_xor(ctx->digest, src, len);
		kernel_neon_begin();
		scoped_ksimd()
			sm4_ce_crypt_block(tctx->key.rkey_enc, ctx->digest,
					   ctx->digest);
		kernel_neon_end();
	}
	memcpy(out, ctx->digest, SM4_BLOCK_SIZE);
	return 0;
+8 −17
Original line number Diff line number Diff line
@@ -48,11 +48,8 @@ static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)

		nblocks = nbytes / SM4_BLOCK_SIZE;
		if (nblocks) {
			kernel_neon_begin();

			scoped_ksimd()
				sm4_neon_crypt(rkey, dst, src, nblocks);

			kernel_neon_end();
		}

		err = skcipher_walk_done(&walk, nbytes % SM4_BLOCK_SIZE);
@@ -126,12 +123,9 @@ static int sm4_cbc_decrypt(struct skcipher_request *req)

		nblocks = nbytes / SM4_BLOCK_SIZE;
		if (nblocks) {
			kernel_neon_begin();

			scoped_ksimd()
				sm4_neon_cbc_dec(ctx->rkey_dec, dst, src,
						 walk.iv, nblocks);

			kernel_neon_end();
		}

		err = skcipher_walk_done(&walk, nbytes % SM4_BLOCK_SIZE);
@@ -157,13 +151,10 @@ static int sm4_ctr_crypt(struct skcipher_request *req)

		nblocks = nbytes / SM4_BLOCK_SIZE;
		if (nblocks) {
			kernel_neon_begin();

			scoped_ksimd()
				sm4_neon_ctr_crypt(ctx->rkey_enc, dst, src,
						   walk.iv, nblocks);

			kernel_neon_end();

			dst += nblocks * SM4_BLOCK_SIZE;
			src += nblocks * SM4_BLOCK_SIZE;
			nbytes -= nblocks * SM4_BLOCK_SIZE;