Unverified Commit a4348546 authored by Alexandre Ghiti's avatar Alexandre Ghiti Committed by Palmer Dabbelt
Browse files

riscv: make unsafe user copy routines use existing assembly routines



The current implementation is underperforming and in addition, it
triggers misaligned access traps on platforms which do not handle
misaligned accesses in hardware.

Use the existing assembly routines to solve both problems at once.

Signed-off-by: default avatarAlexandre Ghiti <alexghiti@rivosinc.com>
Link: https://lore.kernel.org/r/20250602193918.868962-2-cleger@rivosinc.com


Signed-off-by: default avatarPalmer Dabbelt <palmer@dabbelt.com>
parent 259aaf03
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ long long __ashlti3(long long a, int b);
#ifdef CONFIG_RISCV_ISA_V

#ifdef CONFIG_MMU
asmlinkage int enter_vector_usercopy(void *dst, void *src, size_t n);
asmlinkage int enter_vector_usercopy(void *dst, void *src, size_t n, bool enable_sum);
#endif /* CONFIG_MMU  */

void xor_regs_2_(unsigned long bytes, unsigned long *__restrict p1,
+8 −25
Original line number Diff line number Diff line
@@ -450,35 +450,18 @@ static inline void user_access_restore(unsigned long enabled) { }
	(x) = (__force __typeof__(*(ptr)))__gu_val;			\
} while (0)

#define unsafe_copy_loop(dst, src, len, type, op, label)		\
	while (len >= sizeof(type)) {					\
		op(*(type *)(src), (type __user *)(dst), label);	\
		dst += sizeof(type);					\
		src += sizeof(type);					\
		len -= sizeof(type);					\
	}
unsigned long __must_check __asm_copy_to_user_sum_enabled(void __user *to,
	const void *from, unsigned long n);
unsigned long __must_check __asm_copy_from_user_sum_enabled(void *to,
	const void __user *from, unsigned long n);

#define unsafe_copy_to_user(_dst, _src, _len, label)			\
do {									\
	char __user *__ucu_dst = (_dst);				\
	const char *__ucu_src = (_src);					\
	size_t __ucu_len = (_len);					\
	unsafe_copy_loop(__ucu_dst, __ucu_src, __ucu_len, u64, unsafe_put_user, label);	\
	unsafe_copy_loop(__ucu_dst, __ucu_src, __ucu_len, u32, unsafe_put_user, label);	\
	unsafe_copy_loop(__ucu_dst, __ucu_src, __ucu_len, u16, unsafe_put_user, label);	\
	unsafe_copy_loop(__ucu_dst, __ucu_src, __ucu_len, u8, unsafe_put_user, label);	\
} while (0)
	if (__asm_copy_to_user_sum_enabled(_dst, _src, _len))		\
		goto label;

#define unsafe_copy_from_user(_dst, _src, _len, label)			\
do {									\
	char *__ucu_dst = (_dst);					\
	const char __user *__ucu_src = (_src);				\
	size_t __ucu_len = (_len);					\
	unsafe_copy_loop(__ucu_src, __ucu_dst, __ucu_len, u64, unsafe_get_user, label);	\
	unsafe_copy_loop(__ucu_src, __ucu_dst, __ucu_len, u32, unsafe_get_user, label);	\
	unsafe_copy_loop(__ucu_src, __ucu_dst, __ucu_len, u16, unsafe_get_user, label);	\
	unsafe_copy_loop(__ucu_src, __ucu_dst, __ucu_len, u8, unsafe_get_user, label);	\
} while (0)
	if (__asm_copy_from_user_sum_enabled(_dst, _src, _len))		\
		goto label;

#else /* CONFIG_MMU */
#include <asm-generic/uaccess.h>
+8 −3
Original line number Diff line number Diff line
@@ -16,8 +16,11 @@
#ifdef CONFIG_MMU
size_t riscv_v_usercopy_threshold = CONFIG_RISCV_ISA_V_UCOPY_THRESHOLD;
int __asm_vector_usercopy(void *dst, void *src, size_t n);
int __asm_vector_usercopy_sum_enabled(void *dst, void *src, size_t n);
int fallback_scalar_usercopy(void *dst, void *src, size_t n);
asmlinkage int enter_vector_usercopy(void *dst, void *src, size_t n)
int fallback_scalar_usercopy_sum_enabled(void *dst, void *src, size_t n);
asmlinkage int enter_vector_usercopy(void *dst, void *src, size_t n,
				     bool enable_sum)
{
	size_t remain, copied;

@@ -26,7 +29,8 @@ asmlinkage int enter_vector_usercopy(void *dst, void *src, size_t n)
		goto fallback;

	kernel_vector_begin();
	remain = __asm_vector_usercopy(dst, src, n);
	remain = enable_sum ? __asm_vector_usercopy(dst, src, n) :
			      __asm_vector_usercopy_sum_enabled(dst, src, n);
	kernel_vector_end();

	if (remain) {
@@ -40,6 +44,7 @@ asmlinkage int enter_vector_usercopy(void *dst, void *src, size_t n)
	return remain;

fallback:
	return fallback_scalar_usercopy(dst, src, n);
	return enable_sum ? fallback_scalar_usercopy(dst, src, n) :
			    fallback_scalar_usercopy_sum_enabled(dst, src, n);
}
#endif
+34 −16
Original line number Diff line number Diff line
@@ -17,14 +17,43 @@ SYM_FUNC_START(__asm_copy_to_user)
	ALTERNATIVE("j fallback_scalar_usercopy", "nop", 0, RISCV_ISA_EXT_ZVE32X, CONFIG_RISCV_ISA_V)
	REG_L	t0, riscv_v_usercopy_threshold
	bltu	a2, t0, fallback_scalar_usercopy
	li	a3, 1
	tail 	enter_vector_usercopy
#endif
SYM_FUNC_START(fallback_scalar_usercopy)
SYM_FUNC_END(__asm_copy_to_user)
EXPORT_SYMBOL(__asm_copy_to_user)
SYM_FUNC_ALIAS(__asm_copy_from_user, __asm_copy_to_user)
EXPORT_SYMBOL(__asm_copy_from_user)

SYM_FUNC_START(fallback_scalar_usercopy)
	/* Enable access to user memory */
	li	t6, SR_SUM
	csrs 	CSR_STATUS, t6
	mv 	t6, ra

	call 	fallback_scalar_usercopy_sum_enabled

	/* Disable access to user memory */
	mv 	ra, t6
	li 	t6, SR_SUM
	csrc 	CSR_STATUS, t6
	ret
SYM_FUNC_END(fallback_scalar_usercopy)

SYM_FUNC_START(__asm_copy_to_user_sum_enabled)
#ifdef CONFIG_RISCV_ISA_V
	ALTERNATIVE("j fallback_scalar_usercopy_sum_enabled", "nop", 0, RISCV_ISA_EXT_ZVE32X, CONFIG_RISCV_ISA_V)
	REG_L	t0, riscv_v_usercopy_threshold
	bltu	a2, t0, fallback_scalar_usercopy_sum_enabled
	li	a3, 0
	tail 	enter_vector_usercopy
#endif
SYM_FUNC_END(__asm_copy_to_user_sum_enabled)
SYM_FUNC_ALIAS(__asm_copy_from_user_sum_enabled, __asm_copy_to_user_sum_enabled)
EXPORT_SYMBOL(__asm_copy_from_user_sum_enabled)
EXPORT_SYMBOL(__asm_copy_to_user_sum_enabled)

SYM_FUNC_START(fallback_scalar_usercopy_sum_enabled)
	/*
	 * Save the terminal address which will be used to compute the number
	 * of bytes copied in case of a fixup exception.
@@ -178,23 +207,12 @@ SYM_FUNC_START(fallback_scalar_usercopy)
	bltu	a0, t0, 4b	/* t0 - end of dst */

.Lout_copy_user:
	/* Disable access to user memory */
	csrc CSR_STATUS, t6
	li	a0, 0
	ret

	/* Exception fixup code */
10:
	/* Disable access to user memory */
	csrc CSR_STATUS, t6
	sub a0, t5, a0
	ret
SYM_FUNC_END(__asm_copy_to_user)
SYM_FUNC_END(fallback_scalar_usercopy)
EXPORT_SYMBOL(__asm_copy_to_user)
SYM_FUNC_ALIAS(__asm_copy_from_user, __asm_copy_to_user)
EXPORT_SYMBOL(__asm_copy_from_user)

SYM_FUNC_END(fallback_scalar_usercopy_sum_enabled)

SYM_FUNC_START(__clear_user)

+12 −3
Original line number Diff line number Diff line
@@ -24,7 +24,18 @@ SYM_FUNC_START(__asm_vector_usercopy)
	/* Enable access to user memory */
	li	t6, SR_SUM
	csrs	CSR_STATUS, t6
	mv	t6, ra

	call 	__asm_vector_usercopy_sum_enabled

	/* Disable access to user memory */
	mv 	ra, t6
	li 	t6, SR_SUM
	csrc	CSR_STATUS, t6
	ret
SYM_FUNC_END(__asm_vector_usercopy)

SYM_FUNC_START(__asm_vector_usercopy_sum_enabled)
loop:
	vsetvli iVL, iNum, e8, ELEM_LMUL_SETTING, ta, ma
	fixup vle8.v vData, (pSrc), 10f
@@ -36,8 +47,6 @@ loop:

	/* Exception fixup for vector load is shared with normal exit */
10:
	/* Disable access to user memory */
	csrc	CSR_STATUS, t6
	mv	a0, iNum
	ret

@@ -49,4 +58,4 @@ loop:
	csrr	t2, CSR_VSTART
	sub	iNum, iNum, t2
	j	10b
SYM_FUNC_END(__asm_vector_usercopy)
SYM_FUNC_END(__asm_vector_usercopy_sum_enabled)