rust: pin-init: add #[default_error(<type>)] attribute to initializer macros

The `#[default_error(<type>)]` attribute can be used to supply a default
type as the error used for the `[pin_]init!` macros. This way one can
easily define custom `try_[pin_]init!` variants that default to your
project specific error type. Just write the following declarative macro:

    macro_rules! try_init {
        ($($args:tt)*) => {
            ::pin_init::init!(
                #[default_error(YourCustomErrorType)]
                $($args)*
            )
        }
    }

Tested-by: Andreas Hindborg <a.hindborg@kernel.org>
Reviewed-by: Gary Guo <gary@garyguo.net>
Signed-off-by: Benno Lossin <lossin@kernel.org>
This commit is contained in:
Benno Lossin
2026-01-16 11:54:25 +01:00
parent 4883830e97
commit aeabc92eb2

View File

@@ -8,12 +8,13 @@ use syn::{
parse_quote,
punctuated::Punctuated,
spanned::Spanned,
token, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
};
use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
pub(crate) struct Initializer {
attrs: Vec<InitializerAttribute>,
this: Option<This>,
path: Path,
brace_token: token::Brace,
@@ -54,8 +55,17 @@ impl InitializerField {
}
}
enum InitializerAttribute {
DefaultError(DefaultErrorAttribute),
}
struct DefaultErrorAttribute {
ty: Box<Type>,
}
pub(crate) fn expand(
Initializer {
attrs,
this,
path,
brace_token,
@@ -69,14 +79,23 @@ pub(crate) fn expand(
) -> Result<TokenStream, ErrorGuaranteed> {
let error = error.map_or_else(
|| {
if let Some(default_error) = default_error {
if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
#[expect(irrefutable_let_patterns)]
if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
Some(ty.clone())
} else {
acc
}
}) {
default_error
} else if let Some(default_error) = default_error {
syn::parse_str(default_error).unwrap()
} else {
dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
parse_quote!(::core::convert::Infallible)
}
},
|(_, err)| err,
|(_, err)| Box::new(err),
);
let slot = format_ident!("slot");
let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
@@ -358,6 +377,7 @@ fn make_field_check(
impl Parse for Initializer {
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
let attrs = input.call(Attribute::parse_outer)?;
let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
let path = input.parse()?;
let content;
@@ -389,7 +409,19 @@ impl Parse for Initializer {
.peek(Token![?])
.then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
.transpose()?;
let attrs = attrs
.into_iter()
.map(|a| {
if a.path().is_ident("default_error") {
a.parse_args::<DefaultErrorAttribute>()
.map(InitializerAttribute::DefaultError)
} else {
Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
}
})
.collect::<Result<Vec<_>, _>>()?;
Ok(Self {
attrs,
this,
path,
brace_token,
@@ -400,6 +432,12 @@ impl Parse for Initializer {
}
}
impl Parse for DefaultErrorAttribute {
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
Ok(Self { ty: input.parse()? })
}
}
impl Parse for This {
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
Ok(Self {