lib/crypto: riscv/aes: Migrate optimized code into library

Move the aes_encrypt_zvkned() and aes_decrypt_zvkned() assembly
functions into lib/crypto/, wire them up to the AES library API, and
remove the "aes-riscv64-zvkned" crypto_cipher algorithm.

To make this possible, change the prototypes of these functions to
take (rndkeys, key_len) instead of a pointer to crypto_aes_ctx, and
change the RISC-V AES-XTS code to implement tweak encryption using the
AES library instead of directly calling aes_encrypt_zvkned().

The result is that both the AES library and crypto_cipher APIs use
RISC-V's AES instructions, whereas previously only crypto_cipher did
(and it wasn't enabled by default, which this commit fixes as well).

Acked-by: Ard Biesheuvel <ardb@kernel.org>
Link: https://lore.kernel.org/r/20260112192035.10427-15-ebiggers@kernel.org
Signed-off-by: Eric Biggers <ebiggers@kernel.org>
This commit is contained in:
Eric Biggers
2026-01-12 11:20:12 -08:00
parent 7cf2082e74
commit a4e573db06
8 changed files with 166 additions and 106 deletions

View File

@@ -6,11 +6,9 @@ config CRYPTO_AES_RISCV64
tristate "Ciphers: AES, modes: ECB, CBC, CTS, CTR, XTS" tristate "Ciphers: AES, modes: ECB, CBC, CTS, CTR, XTS"
depends on 64BIT && TOOLCHAIN_HAS_VECTOR_CRYPTO && \ depends on 64BIT && TOOLCHAIN_HAS_VECTOR_CRYPTO && \
RISCV_EFFICIENT_VECTOR_UNALIGNED_ACCESS RISCV_EFFICIENT_VECTOR_UNALIGNED_ACCESS
select CRYPTO_ALGAPI
select CRYPTO_LIB_AES select CRYPTO_LIB_AES
select CRYPTO_SKCIPHER select CRYPTO_SKCIPHER
help help
Block cipher: AES cipher algorithms
Length-preserving ciphers: AES with ECB, CBC, CTS, CTR, XTS Length-preserving ciphers: AES with ECB, CBC, CTS, CTR, XTS
Architecture: riscv64 using: Architecture: riscv64 using:

View File

@@ -51,8 +51,10 @@
// - If AES-256, loads round keys into v1-v15 and continues onwards. // - If AES-256, loads round keys into v1-v15 and continues onwards.
// //
// Also sets vl=4 and vtype=e32,m1,ta,ma. Clobbers t0 and t1. // Also sets vl=4 and vtype=e32,m1,ta,ma. Clobbers t0 and t1.
.macro aes_begin keyp, label128, label192 .macro aes_begin keyp, label128, label192, key_len
.ifb \key_len
lwu t0, 480(\keyp) // t0 = key length in bytes lwu t0, 480(\keyp) // t0 = key length in bytes
.endif
li t1, 24 // t1 = key length for AES-192 li t1, 24 // t1 = key length for AES-192
vsetivli zero, 4, e32, m1, ta, ma vsetivli zero, 4, e32, m1, ta, ma
vle32.v v1, (\keyp) vle32.v v1, (\keyp)
@@ -76,12 +78,20 @@
vle32.v v10, (\keyp) vle32.v v10, (\keyp)
addi \keyp, \keyp, 16 addi \keyp, \keyp, 16
vle32.v v11, (\keyp) vle32.v v11, (\keyp)
.ifb \key_len
blt t0, t1, \label128 // If AES-128, goto label128. blt t0, t1, \label128 // If AES-128, goto label128.
.else
blt \key_len, t1, \label128 // If AES-128, goto label128.
.endif
addi \keyp, \keyp, 16 addi \keyp, \keyp, 16
vle32.v v12, (\keyp) vle32.v v12, (\keyp)
addi \keyp, \keyp, 16 addi \keyp, \keyp, 16
vle32.v v13, (\keyp) vle32.v v13, (\keyp)
.ifb \key_len
beq t0, t1, \label192 // If AES-192, goto label192. beq t0, t1, \label192 // If AES-192, goto label192.
.else
beq \key_len, t1, \label192 // If AES-192, goto label192.
.endif
// Else, it's AES-256. // Else, it's AES-256.
addi \keyp, \keyp, 16 addi \keyp, \keyp, 16
vle32.v v14, (\keyp) vle32.v v14, (\keyp)

View File

@@ -1,7 +1,6 @@
// SPDX-License-Identifier: GPL-2.0-only // SPDX-License-Identifier: GPL-2.0-only
/* /*
* AES using the RISC-V vector crypto extensions. Includes the bare block * AES modes using the RISC-V vector crypto extensions
* cipher and the ECB, CBC, CBC-CTS, CTR, and XTS modes.
* *
* Copyright (C) 2023 VRULL GmbH * Copyright (C) 2023 VRULL GmbH
* Author: Heiko Stuebner <heiko.stuebner@vrull.eu> * Author: Heiko Stuebner <heiko.stuebner@vrull.eu>
@@ -15,7 +14,6 @@
#include <asm/simd.h> #include <asm/simd.h>
#include <asm/vector.h> #include <asm/vector.h>
#include <crypto/aes.h> #include <crypto/aes.h>
#include <crypto/internal/cipher.h>
#include <crypto/internal/simd.h> #include <crypto/internal/simd.h>
#include <crypto/internal/skcipher.h> #include <crypto/internal/skcipher.h>
#include <crypto/scatterwalk.h> #include <crypto/scatterwalk.h>
@@ -23,13 +21,6 @@
#include <linux/linkage.h> #include <linux/linkage.h>
#include <linux/module.h> #include <linux/module.h>
asmlinkage void aes_encrypt_zvkned(const struct crypto_aes_ctx *key,
const u8 in[AES_BLOCK_SIZE],
u8 out[AES_BLOCK_SIZE]);
asmlinkage void aes_decrypt_zvkned(const struct crypto_aes_ctx *key,
const u8 in[AES_BLOCK_SIZE],
u8 out[AES_BLOCK_SIZE]);
asmlinkage void aes_ecb_encrypt_zvkned(const struct crypto_aes_ctx *key, asmlinkage void aes_ecb_encrypt_zvkned(const struct crypto_aes_ctx *key,
const u8 *in, u8 *out, size_t len); const u8 *in, u8 *out, size_t len);
asmlinkage void aes_ecb_decrypt_zvkned(const struct crypto_aes_ctx *key, asmlinkage void aes_ecb_decrypt_zvkned(const struct crypto_aes_ctx *key,
@@ -86,14 +77,6 @@ static int riscv64_aes_setkey(struct crypto_aes_ctx *ctx,
return aes_expandkey(ctx, key, keylen); return aes_expandkey(ctx, key, keylen);
} }
static int riscv64_aes_setkey_cipher(struct crypto_tfm *tfm,
const u8 *key, unsigned int keylen)
{
struct crypto_aes_ctx *ctx = crypto_tfm_ctx(tfm);
return riscv64_aes_setkey(ctx, key, keylen);
}
static int riscv64_aes_setkey_skcipher(struct crypto_skcipher *tfm, static int riscv64_aes_setkey_skcipher(struct crypto_skcipher *tfm,
const u8 *key, unsigned int keylen) const u8 *key, unsigned int keylen)
{ {
@@ -102,34 +85,6 @@ static int riscv64_aes_setkey_skcipher(struct crypto_skcipher *tfm,
return riscv64_aes_setkey(ctx, key, keylen); return riscv64_aes_setkey(ctx, key, keylen);
} }
/* Bare AES, without a mode of operation */
static void riscv64_aes_encrypt(struct crypto_tfm *tfm, u8 *dst, const u8 *src)
{
const struct crypto_aes_ctx *ctx = crypto_tfm_ctx(tfm);
if (crypto_simd_usable()) {
kernel_vector_begin();
aes_encrypt_zvkned(ctx, src, dst);
kernel_vector_end();
} else {
aes_encrypt(ctx, dst, src);
}
}
static void riscv64_aes_decrypt(struct crypto_tfm *tfm, u8 *dst, const u8 *src)
{
const struct crypto_aes_ctx *ctx = crypto_tfm_ctx(tfm);
if (crypto_simd_usable()) {
kernel_vector_begin();
aes_decrypt_zvkned(ctx, src, dst);
kernel_vector_end();
} else {
aes_decrypt(ctx, dst, src);
}
}
/* AES-ECB */ /* AES-ECB */
static inline int riscv64_aes_ecb_crypt(struct skcipher_request *req, bool enc) static inline int riscv64_aes_ecb_crypt(struct skcipher_request *req, bool enc)
@@ -338,7 +293,7 @@ static int riscv64_aes_ctr_crypt(struct skcipher_request *req)
struct riscv64_aes_xts_ctx { struct riscv64_aes_xts_ctx {
struct crypto_aes_ctx ctx1; struct crypto_aes_ctx ctx1;
struct crypto_aes_ctx ctx2; struct aes_enckey tweak_key;
}; };
static int riscv64_aes_xts_setkey(struct crypto_skcipher *tfm, const u8 *key, static int riscv64_aes_xts_setkey(struct crypto_skcipher *tfm, const u8 *key,
@@ -348,7 +303,7 @@ static int riscv64_aes_xts_setkey(struct crypto_skcipher *tfm, const u8 *key,
return xts_verify_key(tfm, key, keylen) ?: return xts_verify_key(tfm, key, keylen) ?:
riscv64_aes_setkey(&ctx->ctx1, key, keylen / 2) ?: riscv64_aes_setkey(&ctx->ctx1, key, keylen / 2) ?:
riscv64_aes_setkey(&ctx->ctx2, key + keylen / 2, keylen / 2); aes_prepareenckey(&ctx->tweak_key, key + keylen / 2, keylen / 2);
} }
static int riscv64_aes_xts_crypt(struct skcipher_request *req, bool enc) static int riscv64_aes_xts_crypt(struct skcipher_request *req, bool enc)
@@ -366,9 +321,7 @@ static int riscv64_aes_xts_crypt(struct skcipher_request *req, bool enc)
return -EINVAL; return -EINVAL;
/* Encrypt the IV with the tweak key to get the first tweak. */ /* Encrypt the IV with the tweak key to get the first tweak. */
kernel_vector_begin(); aes_encrypt(&ctx->tweak_key, req->iv, req->iv);
aes_encrypt_zvkned(&ctx->ctx2, req->iv, req->iv);
kernel_vector_end();
err = skcipher_walk_virt(&walk, req, false); err = skcipher_walk_virt(&walk, req, false);
@@ -456,23 +409,6 @@ static int riscv64_aes_xts_decrypt(struct skcipher_request *req)
/* Algorithm definitions */ /* Algorithm definitions */
static struct crypto_alg riscv64_zvkned_aes_cipher_alg = {
.cra_flags = CRYPTO_ALG_TYPE_CIPHER,
.cra_blocksize = AES_BLOCK_SIZE,
.cra_ctxsize = sizeof(struct crypto_aes_ctx),
.cra_priority = 300,
.cra_name = "aes",
.cra_driver_name = "aes-riscv64-zvkned",
.cra_cipher = {
.cia_min_keysize = AES_MIN_KEY_SIZE,
.cia_max_keysize = AES_MAX_KEY_SIZE,
.cia_setkey = riscv64_aes_setkey_cipher,
.cia_encrypt = riscv64_aes_encrypt,
.cia_decrypt = riscv64_aes_decrypt,
},
.cra_module = THIS_MODULE,
};
static struct skcipher_alg riscv64_zvkned_aes_skcipher_algs[] = { static struct skcipher_alg riscv64_zvkned_aes_skcipher_algs[] = {
{ {
.setkey = riscv64_aes_setkey_skcipher, .setkey = riscv64_aes_setkey_skcipher,
@@ -574,15 +510,11 @@ static int __init riscv64_aes_mod_init(void)
if (riscv_isa_extension_available(NULL, ZVKNED) && if (riscv_isa_extension_available(NULL, ZVKNED) &&
riscv_vector_vlen() >= 128) { riscv_vector_vlen() >= 128) {
err = crypto_register_alg(&riscv64_zvkned_aes_cipher_alg);
if (err)
return err;
err = crypto_register_skciphers( err = crypto_register_skciphers(
riscv64_zvkned_aes_skcipher_algs, riscv64_zvkned_aes_skcipher_algs,
ARRAY_SIZE(riscv64_zvkned_aes_skcipher_algs)); ARRAY_SIZE(riscv64_zvkned_aes_skcipher_algs));
if (err) if (err)
goto unregister_zvkned_cipher_alg; return err;
if (riscv_isa_extension_available(NULL, ZVKB)) { if (riscv_isa_extension_available(NULL, ZVKB)) {
err = crypto_register_skcipher( err = crypto_register_skcipher(
@@ -607,8 +539,6 @@ unregister_zvkned_zvkb_skcipher_alg:
unregister_zvkned_skcipher_algs: unregister_zvkned_skcipher_algs:
crypto_unregister_skciphers(riscv64_zvkned_aes_skcipher_algs, crypto_unregister_skciphers(riscv64_zvkned_aes_skcipher_algs,
ARRAY_SIZE(riscv64_zvkned_aes_skcipher_algs)); ARRAY_SIZE(riscv64_zvkned_aes_skcipher_algs));
unregister_zvkned_cipher_alg:
crypto_unregister_alg(&riscv64_zvkned_aes_cipher_alg);
return err; return err;
} }
@@ -620,7 +550,6 @@ static void __exit riscv64_aes_mod_exit(void)
crypto_unregister_skcipher(&riscv64_zvkned_zvkb_aes_skcipher_alg); crypto_unregister_skcipher(&riscv64_zvkned_zvkb_aes_skcipher_alg);
crypto_unregister_skciphers(riscv64_zvkned_aes_skcipher_algs, crypto_unregister_skciphers(riscv64_zvkned_aes_skcipher_algs,
ARRAY_SIZE(riscv64_zvkned_aes_skcipher_algs)); ARRAY_SIZE(riscv64_zvkned_aes_skcipher_algs));
crypto_unregister_alg(&riscv64_zvkned_aes_cipher_alg);
} }
module_init(riscv64_aes_mod_init); module_init(riscv64_aes_mod_init);

View File

@@ -56,33 +56,6 @@
#define LEN a3 #define LEN a3
#define IVP a4 #define IVP a4
.macro __aes_crypt_zvkned enc, keylen
vle32.v v16, (INP)
aes_crypt v16, \enc, \keylen
vse32.v v16, (OUTP)
ret
.endm
.macro aes_crypt_zvkned enc
aes_begin KEYP, 128f, 192f
__aes_crypt_zvkned \enc, 256
128:
__aes_crypt_zvkned \enc, 128
192:
__aes_crypt_zvkned \enc, 192
.endm
// void aes_encrypt_zvkned(const struct crypto_aes_ctx *key,
// const u8 in[16], u8 out[16]);
SYM_FUNC_START(aes_encrypt_zvkned)
aes_crypt_zvkned 1
SYM_FUNC_END(aes_encrypt_zvkned)
// Same prototype and calling convention as the encryption function
SYM_FUNC_START(aes_decrypt_zvkned)
aes_crypt_zvkned 0
SYM_FUNC_END(aes_decrypt_zvkned)
.macro __aes_ecb_crypt enc, keylen .macro __aes_ecb_crypt enc, keylen
srli t0, LEN, 2 srli t0, LEN, 2
// t0 is the remaining length in 32-bit words. It's a multiple of 4. // t0 is the remaining length in 32-bit words. It's a multiple of 4.

View File

@@ -17,6 +17,8 @@ config CRYPTO_LIB_AES_ARCH
default y if ARM default y if ARM
default y if ARM64 default y if ARM64
default y if PPC && (SPE || (PPC64 && VSX)) default y if PPC && (SPE || (PPC64 && VSX))
default y if RISCV && 64BIT && TOOLCHAIN_HAS_VECTOR_CRYPTO && \
RISCV_EFFICIENT_VECTOR_UNALIGNED_ACCESS
config CRYPTO_LIB_AESCFB config CRYPTO_LIB_AESCFB
tristate tristate

View File

@@ -50,6 +50,7 @@ OBJECT_FILES_NON_STANDARD_powerpc/aesp8-ppc.o := y
endif # !CONFIG_SPE endif # !CONFIG_SPE
endif # CONFIG_PPC endif # CONFIG_PPC
libaes-$(CONFIG_RISCV) += riscv/aes-riscv64-zvkned.o
endif # CONFIG_CRYPTO_LIB_AES_ARCH endif # CONFIG_CRYPTO_LIB_AES_ARCH
################################################################################ ################################################################################

View File

@@ -0,0 +1,84 @@
/* SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause */
//
// This file is dual-licensed, meaning that you can use it under your
// choice of either of the following two licenses:
//
// Copyright 2023 The OpenSSL Project Authors. All Rights Reserved.
//
// Licensed under the Apache License 2.0 (the "License"). You can obtain
// a copy in the file LICENSE in the source distribution or at
// https://www.openssl.org/source/license.html
//
// or
//
// Copyright (c) 2023, Christoph Müllner <christoph.muellner@vrull.eu>
// Copyright (c) 2023, Phoebe Chen <phoebe.chen@sifive.com>
// Copyright (c) 2023, Jerry Shih <jerry.shih@sifive.com>
// Copyright 2024 Google LLC
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// The generated code of this file depends on the following RISC-V extensions:
// - RV64I
// - RISC-V Vector ('V') with VLEN >= 128
// - RISC-V Vector AES block cipher extension ('Zvkned')
#include <linux/linkage.h>
.text
.option arch, +zvkned
#include "../../arch/riscv/crypto/aes-macros.S"
#define RNDKEYS a0
#define KEY_LEN a1
#define OUTP a2
#define INP a3
.macro __aes_crypt_zvkned enc, keybits
vle32.v v16, (INP)
aes_crypt v16, \enc, \keybits
vse32.v v16, (OUTP)
ret
.endm
.macro aes_crypt_zvkned enc
aes_begin RNDKEYS, 128f, 192f, KEY_LEN
__aes_crypt_zvkned \enc, 256
128:
__aes_crypt_zvkned \enc, 128
192:
__aes_crypt_zvkned \enc, 192
.endm
// void aes_encrypt_zvkned(const u32 rndkeys[], int key_len,
// u8 out[AES_BLOCK_SIZE], const u8 in[AES_BLOCK_SIZE]);
SYM_FUNC_START(aes_encrypt_zvkned)
aes_crypt_zvkned 1
SYM_FUNC_END(aes_encrypt_zvkned)
// void aes_decrypt_zvkned(const u32 rndkeys[], int key_len,
// u8 out[AES_BLOCK_SIZE], const u8 in[AES_BLOCK_SIZE]);
SYM_FUNC_START(aes_decrypt_zvkned)
aes_crypt_zvkned 0
SYM_FUNC_END(aes_decrypt_zvkned)

63
lib/crypto/riscv/aes.h Normal file
View File

@@ -0,0 +1,63 @@
/* SPDX-License-Identifier: GPL-2.0-only */
/*
* Copyright (C) 2023 VRULL GmbH
* Copyright (C) 2023 SiFive, Inc.
* Copyright 2024 Google LLC
*/
#include <asm/simd.h>
#include <asm/vector.h>
static __ro_after_init DEFINE_STATIC_KEY_FALSE(have_zvkned);
void aes_encrypt_zvkned(const u32 rndkeys[], int key_len,
u8 out[AES_BLOCK_SIZE], const u8 in[AES_BLOCK_SIZE]);
void aes_decrypt_zvkned(const u32 rndkeys[], int key_len,
u8 out[AES_BLOCK_SIZE], const u8 in[AES_BLOCK_SIZE]);
static void aes_preparekey_arch(union aes_enckey_arch *k,
union aes_invkey_arch *inv_k,
const u8 *in_key, int key_len, int nrounds)
{
aes_expandkey_generic(k->rndkeys, inv_k ? inv_k->inv_rndkeys : NULL,
in_key, key_len);
}
static void aes_encrypt_arch(const struct aes_enckey *key,
u8 out[AES_BLOCK_SIZE],
const u8 in[AES_BLOCK_SIZE])
{
if (static_branch_likely(&have_zvkned) && likely(may_use_simd())) {
kernel_vector_begin();
aes_encrypt_zvkned(key->k.rndkeys, key->len, out, in);
kernel_vector_end();
} else {
aes_encrypt_generic(key->k.rndkeys, key->nrounds, out, in);
}
}
static void aes_decrypt_arch(const struct aes_key *key,
u8 out[AES_BLOCK_SIZE],
const u8 in[AES_BLOCK_SIZE])
{
/*
* Note that the Zvkned code uses the standard round keys, while the
* fallback uses the inverse round keys. Thus both must be present.
*/
if (static_branch_likely(&have_zvkned) && likely(may_use_simd())) {
kernel_vector_begin();
aes_decrypt_zvkned(key->k.rndkeys, key->len, out, in);
kernel_vector_end();
} else {
aes_decrypt_generic(key->inv_k.inv_rndkeys, key->nrounds,
out, in);
}
}
#define aes_mod_init_arch aes_mod_init_arch
static void aes_mod_init_arch(void)
{
if (riscv_isa_extension_available(NULL, ZVKNED) &&
riscv_vector_vlen() >= 128)
static_branch_enable(&have_zvkned);
}