Unverified Commit 259aaf03 authored by Palmer Dabbelt's avatar Palmer Dabbelt
Browse files

Merge patch series "riscv: uaccess: optimisations"

Cyril Bur <cyrilbur@tenstorrent.com> says:

This series tries to optimize riscv uaccess by allowing the use of
user_access_begin() and user_access_end() which permits grouping user accesses
and avoiding the CSR write penalty for each access.

The error path can also be optimised using asm goto which patches 3 and 4
achieve. This will speed up jumping to labels by avoiding the need of an
intermediary error type variable within the uaccess macros

I did read the discussion this series generated. It isn't clear to me
which direction to take the patches, if any.

* b4-shazam-merge:
  riscv: uaccess: use 'asm_goto_output' for get_user()
  riscv: uaccess: use 'asm goto' for put_user()
  riscv: uaccess: use input constraints for ptr of __put_user()
  riscv: implement user_access_begin() and families
  riscv: save the SR_SUM status over switches

Link: https://lore.kernel.org/r/20250410070526.3160847-1-cyrilbur@tenstorrent.com


Signed-off-by: default avatarPalmer Dabbelt <palmer@rivosinc.com>
parents 85f79dec f6bff782
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -103,6 +103,7 @@ struct thread_struct {
	struct __riscv_d_ext_state fstate;
	unsigned long bad_cause;
	unsigned long envcfg;
	unsigned long status;
	u32 riscv_v_flags;
	u32 vstate_ctrl;
	struct __riscv_v_ext_state vstate;
+165 −53
Original line number Diff line number Diff line
@@ -61,6 +61,19 @@ static inline unsigned long __untagged_addr_remote(struct mm_struct *mm, unsigne
#define __disable_user_access()							\
	__asm__ __volatile__ ("csrc sstatus, %0" : : "r" (SR_SUM) : "memory")

/*
 * This is the smallest unsigned integer type that can fit a value
 * (up to 'long long')
 */
#define __inttype(x) __typeof__(		\
	__typefits(x, char,			\
	  __typefits(x, short,			\
	    __typefits(x, int,			\
	      __typefits(x, long, 0ULL)))))

#define __typefits(x, type, not) \
	__builtin_choose_expr(sizeof(x) <= sizeof(type), (unsigned type)0, not)

/*
 * The exception table consists of pairs of addresses: the first is the
 * address of an instruction that is allowed to fault, and the second is
@@ -83,27 +96,58 @@ static inline unsigned long __untagged_addr_remote(struct mm_struct *mm, unsigne
 * call.
 */

#define __get_user_asm(insn, x, ptr, err)			\
#ifdef CONFIG_CC_HAS_ASM_GOTO_OUTPUT
#define __get_user_asm(insn, x, ptr, label)			\
	asm_goto_output(					\
		"1:\n"						\
		"	" insn " %0, %1\n"			\
		_ASM_EXTABLE_UACCESS_ERR(1b, %l2, %0)		\
		: "=&r" (x)					\
		: "m" (*(ptr)) : : label)
#else /* !CONFIG_CC_HAS_ASM_GOTO_OUTPUT */
#define __get_user_asm(insn, x, ptr, label)			\
do {								\
	__typeof__(x) __x;					\
	long __gua_err = 0;					\
	__asm__ __volatile__ (					\
		"1:\n"						\
		"	" insn " %1, %2\n"			\
		"2:\n"						\
		_ASM_EXTABLE_UACCESS_ERR_ZERO(1b, 2b, %0, %1)	\
		: "+r" (err), "=&r" (__x)			\
		: "+r" (__gua_err), "=&r" (x)			\
		: "m" (*(ptr)));				\
	(x) = __x;						\
	if (__gua_err)						\
		goto label;					\
} while (0)
#endif /* CONFIG_CC_HAS_ASM_GOTO_OUTPUT */

#ifdef CONFIG_64BIT
#define __get_user_8(x, ptr, err) \
	__get_user_asm("ld", x, ptr, err)
#define __get_user_8(x, ptr, label) \
	__get_user_asm("ld", x, ptr, label)
#else /* !CONFIG_64BIT */
#define __get_user_8(x, ptr, err)				\

#ifdef CONFIG_CC_HAS_ASM_GOTO_OUTPUT
#define __get_user_8(x, ptr, label)				\
	u32 __user *__ptr = (u32 __user *)(ptr);		\
	u32 __lo, __hi;						\
	asm_goto_output(					\
		"1:\n"						\
		"	lw %0, %2\n"				\
		"2:\n"						\
		"	lw %1, %3\n"				\
		_ASM_EXTABLE_UACCESS_ERR(1b, %l4, %0)		\
		_ASM_EXTABLE_UACCESS_ERR(2b, %l4, %0)		\
		: "=&r" (__lo), "=r" (__hi)			\
		: "m" (__ptr[__LSW]), "m" (__ptr[__MSW])	\
		: : label);                                     \
	(x) = (__typeof__(x))((__typeof__((x) - (x)))(		\
		(((u64)__hi << 32) | __lo)));			\

#else /* !CONFIG_CC_HAS_ASM_GOTO_OUTPUT */
#define __get_user_8(x, ptr, label)				\
do {								\
	u32 __user *__ptr = (u32 __user *)(ptr);		\
	u32 __lo, __hi;						\
	long __gu8_err = 0;					\
	__asm__ __volatile__ (					\
		"1:\n"						\
		"	lw %1, %3\n"				\
@@ -112,35 +156,51 @@ do { \
		"3:\n"						\
		_ASM_EXTABLE_UACCESS_ERR_ZERO(1b, 3b, %0, %1)	\
		_ASM_EXTABLE_UACCESS_ERR_ZERO(2b, 3b, %0, %1)	\
		: "+r" (err), "=&r" (__lo), "=r" (__hi)		\
		: "+r" (__gu8_err), "=&r" (__lo), "=r" (__hi)	\
		: "m" (__ptr[__LSW]), "m" (__ptr[__MSW]));	\
	if (err)						\
	if (__gu8_err) {					\
		__hi = 0;					\
		goto label;					\
	}							\
	(x) = (__typeof__(x))((__typeof__((x) - (x)))(		\
		(((u64)__hi << 32) | __lo)));			\
} while (0)
#endif /* CONFIG_CC_HAS_ASM_GOTO_OUTPUT */

#endif /* CONFIG_64BIT */

#define __get_user_nocheck(x, __gu_ptr, __gu_err)		\
#define __get_user_nocheck(x, __gu_ptr, label)			\
do {								\
	switch (sizeof(*__gu_ptr)) {				\
	case 1:							\
		__get_user_asm("lb", (x), __gu_ptr, __gu_err);	\
		__get_user_asm("lb", (x), __gu_ptr, label);	\
		break;						\
	case 2:							\
		__get_user_asm("lh", (x), __gu_ptr, __gu_err);	\
		__get_user_asm("lh", (x), __gu_ptr, label);	\
		break;						\
	case 4:							\
		__get_user_asm("lw", (x), __gu_ptr, __gu_err);	\
		__get_user_asm("lw", (x), __gu_ptr, label);	\
		break;						\
	case 8:							\
		__get_user_8((x), __gu_ptr, __gu_err);	\
		__get_user_8((x), __gu_ptr, label);		\
		break;						\
	default:						\
		BUILD_BUG();					\
	}							\
} while (0)

#define __get_user_error(x, ptr, err)					\
do {									\
	__label__ __gu_failed;						\
									\
	__get_user_nocheck(x, ptr, __gu_failed);			\
		err = 0;						\
		break;							\
__gu_failed:								\
		x = 0;							\
		err = -EFAULT;						\
} while (0)

/**
 * __get_user: - Get a simple variable from user space, with less checking.
 * @x:   Variable to store result.
@@ -165,13 +225,16 @@ do { \
({								\
	const __typeof__(*(ptr)) __user *__gu_ptr = untagged_addr(ptr); \
	long __gu_err = 0;					\
	__typeof__(x) __gu_val;					\
								\
	__chk_user_ptr(__gu_ptr);				\
								\
	__enable_user_access();					\
	__get_user_nocheck(x, __gu_ptr, __gu_err);		\
	__get_user_error(__gu_val, __gu_ptr, __gu_err);		\
	__disable_user_access();				\
								\
	(x) = __gu_val;						\
								\
	__gu_err;						\
})

@@ -201,61 +264,66 @@ do { \
		((x) = (__force __typeof__(x))0, -EFAULT);	\
})

#define __put_user_asm(insn, x, ptr, err)			\
#define __put_user_asm(insn, x, ptr, label)			\
do {								\
	__typeof__(*(ptr)) __x = x;				\
	__asm__ __volatile__ (					\
	asm goto(						\
		"1:\n"						\
		"	" insn " %z2, %1\n"			\
		"2:\n"						\
		_ASM_EXTABLE_UACCESS_ERR(1b, 2b, %0)		\
		: "+r" (err), "=m" (*(ptr))			\
		: "rJ" (__x));					\
		"	" insn " %z0, %1\n"			\
		_ASM_EXTABLE(1b, %l2)				\
		: : "rJ" (__x), "m"(*(ptr)) : : label);		\
} while (0)

#ifdef CONFIG_64BIT
#define __put_user_8(x, ptr, err) \
	__put_user_asm("sd", x, ptr, err)
#define __put_user_8(x, ptr, label) \
	__put_user_asm("sd", x, ptr, label)
#else /* !CONFIG_64BIT */
#define __put_user_8(x, ptr, err)				\
#define __put_user_8(x, ptr, label)				\
do {								\
	u32 __user *__ptr = (u32 __user *)(ptr);		\
	u64 __x = (__typeof__((x)-(x)))(x);			\
	__asm__ __volatile__ (					\
	asm goto(						\
		"1:\n"						\
		"	sw %z3, %1\n"				\
		"	sw %z0, %2\n"				\
		"2:\n"						\
		"	sw %z4, %2\n"				\
		"3:\n"						\
		_ASM_EXTABLE_UACCESS_ERR(1b, 3b, %0)		\
		_ASM_EXTABLE_UACCESS_ERR(2b, 3b, %0)		\
		: "+r" (err),					\
			"=m" (__ptr[__LSW]),			\
			"=m" (__ptr[__MSW])			\
		: "rJ" (__x), "rJ" (__x >> 32));		\
		"	sw %z1, %3\n"				\
		_ASM_EXTABLE(1b, %l4)				\
		_ASM_EXTABLE(2b, %l4)				\
		: : "rJ" (__x), "rJ" (__x >> 32),		\
			"m" (__ptr[__LSW]),			\
			"m" (__ptr[__MSW]) : : label);		\
} while (0)
#endif /* CONFIG_64BIT */

#define __put_user_nocheck(x, __gu_ptr, __pu_err)					\
#define __put_user_nocheck(x, __gu_ptr, label)			\
do {								\
	switch (sizeof(*__gu_ptr)) {				\
	case 1:							\
		__put_user_asm("sb", (x), __gu_ptr, __pu_err);	\
		__put_user_asm("sb", (x), __gu_ptr, label);	\
		break;						\
	case 2:							\
		__put_user_asm("sh", (x), __gu_ptr, __pu_err);	\
		__put_user_asm("sh", (x), __gu_ptr, label);	\
		break;						\
	case 4:							\
		__put_user_asm("sw", (x), __gu_ptr, __pu_err);	\
		__put_user_asm("sw", (x), __gu_ptr, label);	\
		break;						\
	case 8:							\
		__put_user_8((x), __gu_ptr, __pu_err);	\
		__put_user_8((x), __gu_ptr, label);		\
		break;						\
	default:						\
		BUILD_BUG();					\
	}							\
} while (0)

#define __put_user_error(x, ptr, err)				\
do {								\
	__label__ err_label;					\
	__put_user_nocheck(x, ptr, err_label);			\
	break;							\
err_label:							\
	(err) = -EFAULT;					\
} while (0)

/**
 * __put_user: - Write a simple value into user space, with less checking.
 * @x:   Value to copy to user space.
@@ -286,7 +354,7 @@ do { \
	__chk_user_ptr(__gu_ptr);				\
								\
	__enable_user_access();					\
	__put_user_nocheck(__val, __gu_ptr, __pu_err);		\
	__put_user_error(__val, __gu_ptr, __pu_err);		\
	__disable_user_access();				\
								\
	__pu_err;						\
@@ -351,21 +419,65 @@ unsigned long __must_check clear_user(void __user *to, unsigned long n)
}

#define __get_kernel_nofault(dst, src, type, err_label)			\
	__get_user_nocheck(*((type *)(dst)), (type *)(src), err_label)

#define __put_kernel_nofault(dst, src, type, err_label)			\
	__put_user_nocheck(*((type *)(src)), (type *)(dst), err_label)

static __must_check __always_inline bool user_access_begin(const void __user *ptr, size_t len)
{
	if (unlikely(!access_ok(ptr, len)))
		return 0;
	__enable_user_access();
	return 1;
}
#define user_access_begin user_access_begin
#define user_access_end __disable_user_access

static inline unsigned long user_access_save(void) { return 0UL; }
static inline void user_access_restore(unsigned long enabled) { }

/*
 * We want the unsafe accessors to always be inlined and use
 * the error labels - thus the macro games.
 */
#define unsafe_put_user(x, ptr, label)					\
	__put_user_nocheck(x, (ptr), label)

#define unsafe_get_user(x, ptr, label)	do {				\
	__inttype(*(ptr)) __gu_val;					\
	__get_user_nocheck(__gu_val, (ptr), label);			\
	(x) = (__force __typeof__(*(ptr)))__gu_val;			\
} while (0)

#define unsafe_copy_loop(dst, src, len, type, op, label)		\
	while (len >= sizeof(type)) {					\
		op(*(type *)(src), (type __user *)(dst), label);	\
		dst += sizeof(type);					\
		src += sizeof(type);					\
		len -= sizeof(type);					\
	}

#define unsafe_copy_to_user(_dst, _src, _len, label)			\
do {									\
	long __kr_err = 0;						\
									\
	__get_user_nocheck(*((type *)(dst)), (type *)(src), __kr_err);	\
	if (unlikely(__kr_err))						\
		goto err_label;						\
	char __user *__ucu_dst = (_dst);				\
	const char *__ucu_src = (_src);					\
	size_t __ucu_len = (_len);					\
	unsafe_copy_loop(__ucu_dst, __ucu_src, __ucu_len, u64, unsafe_put_user, label);	\
	unsafe_copy_loop(__ucu_dst, __ucu_src, __ucu_len, u32, unsafe_put_user, label);	\
	unsafe_copy_loop(__ucu_dst, __ucu_src, __ucu_len, u16, unsafe_put_user, label);	\
	unsafe_copy_loop(__ucu_dst, __ucu_src, __ucu_len, u8, unsafe_put_user, label);	\
} while (0)

#define __put_kernel_nofault(dst, src, type, err_label)			\
#define unsafe_copy_from_user(_dst, _src, _len, label)			\
do {									\
	long __kr_err = 0;						\
									\
	__put_user_nocheck(*((type *)(src)), (type *)(dst), __kr_err);	\
	if (unlikely(__kr_err))						\
		goto err_label;						\
	char *__ucu_dst = (_dst);					\
	const char __user *__ucu_src = (_src);				\
	size_t __ucu_len = (_len);					\
	unsafe_copy_loop(__ucu_src, __ucu_dst, __ucu_len, u64, unsafe_get_user, label);	\
	unsafe_copy_loop(__ucu_src, __ucu_dst, __ucu_len, u32, unsafe_get_user, label);	\
	unsafe_copy_loop(__ucu_src, __ucu_dst, __ucu_len, u16, unsafe_get_user, label);	\
	unsafe_copy_loop(__ucu_src, __ucu_dst, __ucu_len, u8, unsafe_get_user, label);	\
} while (0)

#else /* CONFIG_MMU */
+5 −0
Original line number Diff line number Diff line
@@ -34,6 +34,7 @@ void asm_offsets(void)
	OFFSET(TASK_THREAD_S9, task_struct, thread.s[9]);
	OFFSET(TASK_THREAD_S10, task_struct, thread.s[10]);
	OFFSET(TASK_THREAD_S11, task_struct, thread.s[11]);
	OFFSET(TASK_THREAD_STATUS, task_struct, thread.status);

	OFFSET(TASK_TI_CPU, task_struct, thread_info.cpu);
	OFFSET(TASK_TI_PREEMPT_COUNT, task_struct, thread_info.preempt_count);
@@ -346,6 +347,10 @@ void asm_offsets(void)
		  offsetof(struct task_struct, thread.s[11])
		- offsetof(struct task_struct, thread.ra)
	);
	DEFINE(TASK_THREAD_STATUS_RA,
		  offsetof(struct task_struct, thread.status)
		- offsetof(struct task_struct, thread.ra)
	);

	DEFINE(TASK_THREAD_F0_F0,
		  offsetof(struct task_struct, thread.fstate.f[0])
+8 −0
Original line number Diff line number Diff line
@@ -397,9 +397,17 @@ SYM_FUNC_START(__switch_to)
	REG_S s9,  TASK_THREAD_S9_RA(a3)
	REG_S s10, TASK_THREAD_S10_RA(a3)
	REG_S s11, TASK_THREAD_S11_RA(a3)

	/* save the user space access flag */
	li    s0, SR_SUM
	csrr  s1, CSR_STATUS
	REG_S s1, TASK_THREAD_STATUS_RA(a3)

	/* Save the kernel shadow call stack pointer */
	scs_save_current
	/* Restore context from next->thread */
	REG_L s0,  TASK_THREAD_STATUS_RA(a4)
	csrs  CSR_STATUS, s0
	REG_L ra,  TASK_THREAD_RA_RA(a4)
	REG_L sp,  TASK_THREAD_SP_RA(a4)
	REG_L s0,  TASK_THREAD_S0_RA(a4)