Commit b1863fd0 authored by Tianjia Zhang's avatar Tianjia Zhang Committed by Herbert Xu
Browse files

crypto: arm64/sm4 - add CE implementation for CTS-CBC mode



This patch is a CE-optimized assembly implementation for CTS-CBC mode.

Benchmark on T-Head Yitian-710 2.75 GHz, the data comes from the 218 mode of
tcrypt, and compared the performance before and after this patch (the driver
used before this patch is cts(cbc-sm4-ce)). The abscissas are blocks of
different lengths. The data is tabulated and the unit is Mb/s:

Before:

cts(cbc-sm4-ce) |      16       64      128      256     1024     1420     4096
----------------+--------------------------------------------------------------
    CTS-CBC enc |  286.09   297.17   457.97   627.75   868.58   900.80   957.69
    CTS-CBC dec |  286.67   285.63   538.35   947.08  2241.03  2577.32  3391.14

After:

cts-cbc-sm4-ce  |      16       64      128      256     1024     1420     4096
----------------+--------------------------------------------------------------
    CTS-CBC enc |  288.19   428.80   593.57   741.04   911.73   931.80   950.00
    CTS-CBC dec |  292.22   468.99   838.23  1380.76  2741.17  3036.42  3409.62

Signed-off-by: default avatarTianjia Zhang <tianjia.zhang@linux.alibaba.com>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent 45089dbe
Loading
Loading
Loading
Loading
+102 −0
Original line number Diff line number Diff line
@@ -306,6 +306,100 @@ SYM_FUNC_START(sm4_ce_cbc_dec)
	ret
SYM_FUNC_END(sm4_ce_cbc_dec)

.align 3
SYM_FUNC_START(sm4_ce_cbc_cts_enc)
	/* input:
	 *   x0: round key array, CTX
	 *   x1: dst
	 *   x2: src
	 *   x3: iv (big endian, 128 bit)
	 *   w4: nbytes
	 */
	SM4_PREPARE(x0)

	sub		w5, w4, #16
	uxtw		x5, w5

	ld1		{RIV.16b}, [x3]

	ld1		{v0.16b}, [x2]
	eor		RIV.16b, RIV.16b, v0.16b
	SM4_CRYPT_BLK(RIV)

	/* load permute table */
	adr_l		x6, .Lcts_permute_table
	add		x7, x6, #32
	add		x6, x6, x5
	sub		x7, x7, x5
	ld1		{v3.16b}, [x6]
	ld1		{v4.16b}, [x7]

	/* overlapping loads */
	add		x2, x2, x5
	ld1		{v1.16b}, [x2]

	/* create Cn from En-1 */
	tbl		v0.16b, {RIV.16b}, v3.16b
	/* padding Pn with zeros */
	tbl		v1.16b, {v1.16b}, v4.16b

	eor		v1.16b, v1.16b, RIV.16b
	SM4_CRYPT_BLK(v1)

	/* overlapping stores */
	add		x5, x1, x5
	st1		{v0.16b}, [x5]
	st1		{v1.16b}, [x1]

	ret
SYM_FUNC_END(sm4_ce_cbc_cts_enc)

.align 3
SYM_FUNC_START(sm4_ce_cbc_cts_dec)
	/* input:
	 *   x0: round key array, CTX
	 *   x1: dst
	 *   x2: src
	 *   x3: iv (big endian, 128 bit)
	 *   w4: nbytes
	 */
	SM4_PREPARE(x0)

	sub		w5, w4, #16
	uxtw		x5, w5

	ld1		{RIV.16b}, [x3]

	/* load permute table */
	adr_l		x6, .Lcts_permute_table
	add		x7, x6, #32
	add		x6, x6, x5
	sub		x7, x7, x5
	ld1		{v3.16b}, [x6]
	ld1		{v4.16b}, [x7]

	/* overlapping loads */
	ld1		{v0.16b}, [x2], x5
	ld1		{v1.16b}, [x2]

	SM4_CRYPT_BLK(v0)
	/* select the first Ln bytes of Xn to create Pn */
	tbl		v2.16b, {v0.16b}, v3.16b
	eor		v2.16b, v2.16b, v1.16b

	/* overwrite the first Ln bytes with Cn to create En-1 */
	tbx		v0.16b, {v1.16b}, v4.16b
	SM4_CRYPT_BLK(v0)
	eor		v0.16b, v0.16b, RIV.16b

	/* overlapping stores */
	add		x5, x1, x5
	st1		{v2.16b}, [x5]
	st1		{v0.16b}, [x1]

	ret
SYM_FUNC_END(sm4_ce_cbc_cts_dec)

.align 3
SYM_FUNC_START(sm4_ce_cfb_enc)
	/* input:
@@ -576,3 +670,11 @@ SYM_FUNC_END(sm4_ce_ctr_enc)
.Lbswap128_mask:
	.byte		0x0c, 0x0d, 0x0e, 0x0f, 0x08, 0x09, 0x0a, 0x0b
	.byte		0x04, 0x05, 0x06, 0x07, 0x00, 0x01, 0x02, 0x03

.Lcts_permute_table:
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
	.byte		 0x0,  0x1,  0x2,  0x3,  0x4,  0x5,  0x6,  0x7
	.byte		 0x8,  0x9,  0xa,  0xb,  0xc,  0xd,  0xe,  0xf
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+94 −0
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@
#include <asm/simd.h>
#include <crypto/internal/simd.h>
#include <crypto/internal/skcipher.h>
#include <crypto/scatterwalk.h>
#include <crypto/sm4.h>

#define BYTES2BLKS(nbytes)	((nbytes) >> 4)
@@ -29,6 +30,10 @@ asmlinkage void sm4_ce_cbc_enc(const u32 *rkey, u8 *dst, const u8 *src,
			       u8 *iv, unsigned int nblocks);
asmlinkage void sm4_ce_cbc_dec(const u32 *rkey, u8 *dst, const u8 *src,
			       u8 *iv, unsigned int nblocks);
asmlinkage void sm4_ce_cbc_cts_enc(const u32 *rkey, u8 *dst, const u8 *src,
				   u8 *iv, unsigned int nbytes);
asmlinkage void sm4_ce_cbc_cts_dec(const u32 *rkey, u8 *dst, const u8 *src,
				   u8 *iv, unsigned int nbytes);
asmlinkage void sm4_ce_cfb_enc(const u32 *rkey, u8 *dst, const u8 *src,
			       u8 *iv, unsigned int nblks);
asmlinkage void sm4_ce_cfb_dec(const u32 *rkey, u8 *dst, const u8 *src,
@@ -153,6 +158,78 @@ static int sm4_cbc_decrypt(struct skcipher_request *req)
	return sm4_cbc_crypt(req, ctx, false);
}

static int sm4_cbc_cts_crypt(struct skcipher_request *req, bool encrypt)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
	struct scatterlist *src = req->src;
	struct scatterlist *dst = req->dst;
	struct scatterlist sg_src[2], sg_dst[2];
	struct skcipher_request subreq;
	struct skcipher_walk walk;
	int cbc_blocks;
	int err;

	if (req->cryptlen < SM4_BLOCK_SIZE)
		return -EINVAL;

	if (req->cryptlen == SM4_BLOCK_SIZE)
		return sm4_cbc_crypt(req, ctx, encrypt);

	skcipher_request_set_tfm(&subreq, tfm);
	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
				      NULL, NULL);

	/* handle the CBC cryption part */
	cbc_blocks = DIV_ROUND_UP(req->cryptlen, SM4_BLOCK_SIZE) - 2;
	if (cbc_blocks) {
		skcipher_request_set_crypt(&subreq, src, dst,
					   cbc_blocks * SM4_BLOCK_SIZE,
					   req->iv);

		err = sm4_cbc_crypt(&subreq, ctx, encrypt);
		if (err)
			return err;

		dst = src = scatterwalk_ffwd(sg_src, src, subreq.cryptlen);
		if (req->dst != req->src)
			dst = scatterwalk_ffwd(sg_dst, req->dst,
					       subreq.cryptlen);
	}

	/* handle ciphertext stealing */
	skcipher_request_set_crypt(&subreq, src, dst,
				   req->cryptlen - cbc_blocks * SM4_BLOCK_SIZE,
				   req->iv);

	err = skcipher_walk_virt(&walk, &subreq, false);
	if (err)
		return err;

	kernel_neon_begin();

	if (encrypt)
		sm4_ce_cbc_cts_enc(ctx->rkey_enc, walk.dst.virt.addr,
				   walk.src.virt.addr, walk.iv, walk.nbytes);
	else
		sm4_ce_cbc_cts_dec(ctx->rkey_dec, walk.dst.virt.addr,
				   walk.src.virt.addr, walk.iv, walk.nbytes);

	kernel_neon_end();

	return skcipher_walk_done(&walk, 0);
}

static int sm4_cbc_cts_encrypt(struct skcipher_request *req)
{
	return sm4_cbc_cts_crypt(req, true);
}

static int sm4_cbc_cts_decrypt(struct skcipher_request *req)
{
	return sm4_cbc_cts_crypt(req, false);
}

static int sm4_cfb_encrypt(struct skcipher_request *req)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
@@ -342,6 +419,22 @@ static struct skcipher_alg sm4_algs[] = {
		.setkey		= sm4_setkey,
		.encrypt	= sm4_ctr_crypt,
		.decrypt	= sm4_ctr_crypt,
	}, {
		.base = {
			.cra_name		= "cts(cbc(sm4))",
			.cra_driver_name	= "cts-cbc-sm4-ce",
			.cra_priority		= 400,
			.cra_blocksize		= SM4_BLOCK_SIZE,
			.cra_ctxsize		= sizeof(struct sm4_ctx),
			.cra_module		= THIS_MODULE,
		},
		.min_keysize	= SM4_KEY_SIZE,
		.max_keysize	= SM4_KEY_SIZE,
		.ivsize		= SM4_BLOCK_SIZE,
		.walksize	= SM4_BLOCK_SIZE * 2,
		.setkey		= sm4_setkey,
		.encrypt	= sm4_cbc_cts_encrypt,
		.decrypt	= sm4_cbc_cts_decrypt,
	}
};

@@ -365,5 +458,6 @@ MODULE_ALIAS_CRYPTO("ecb(sm4)");
MODULE_ALIAS_CRYPTO("cbc(sm4)");
MODULE_ALIAS_CRYPTO("cfb(sm4)");
MODULE_ALIAS_CRYPTO("ctr(sm4)");
MODULE_ALIAS_CRYPTO("cts(cbc(sm4))");
MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
MODULE_LICENSE("GPL v2");