Commit aeabc92e authored by Benno Lossin's avatar Benno Lossin
Browse files

rust: pin-init: add `#[default_error(<type>)]` attribute to initializer macros



The `#[default_error(<type>)]` attribute can be used to supply a default
type as the error used for the `[pin_]init!` macros. This way one can
easily define custom `try_[pin_]init!` variants that default to your
project specific error type. Just write the following declarative macro:

    macro_rules! try_init {
        ($($args:tt)*) => {
            ::pin_init::init!(
                #[default_error(YourCustomErrorType)]
                $($args)*
            )
        }
    }

Tested-by: default avatarAndreas Hindborg <a.hindborg@kernel.org>
Reviewed-by: default avatarGary Guo <gary@garyguo.net>
Signed-off-by: default avatarBenno Lossin <lossin@kernel.org>
parent 4883830e
Loading
Loading
Loading
Loading
+41 −3
Original line number Diff line number Diff line
@@ -8,12 +8,13 @@
    parse_quote,
    punctuated::Punctuated,
    spanned::Spanned,
    token, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
    token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
};

use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};

pub(crate) struct Initializer {
    attrs: Vec<InitializerAttribute>,
    this: Option<This>,
    path: Path,
    brace_token: token::Brace,
@@ -54,8 +55,17 @@ fn ident(&self) -> Option<&Ident> {
    }
}

enum InitializerAttribute {
    DefaultError(DefaultErrorAttribute),
}

struct DefaultErrorAttribute {
    ty: Box<Type>,
}

pub(crate) fn expand(
    Initializer {
        attrs,
        this,
        path,
        brace_token,
@@ -69,14 +79,23 @@ pub(crate) fn expand(
) -> Result<TokenStream, ErrorGuaranteed> {
    let error = error.map_or_else(
        || {
            if let Some(default_error) = default_error {
            if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
                #[expect(irrefutable_let_patterns)]
                if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
                    Some(ty.clone())
                } else {
                    acc
                }
            }) {
                default_error
            } else if let Some(default_error) = default_error {
                syn::parse_str(default_error).unwrap()
            } else {
                dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
                parse_quote!(::core::convert::Infallible)
            }
        },
        |(_, err)| err,
        |(_, err)| Box::new(err),
    );
    let slot = format_ident!("slot");
    let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
@@ -358,6 +377,7 @@ fn make_field_check(

impl Parse for Initializer {
    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
        let attrs = input.call(Attribute::parse_outer)?;
        let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
        let path = input.parse()?;
        let content;
@@ -389,7 +409,19 @@ fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
            .peek(Token![?])
            .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
            .transpose()?;
        let attrs = attrs
            .into_iter()
            .map(|a| {
                if a.path().is_ident("default_error") {
                    a.parse_args::<DefaultErrorAttribute>()
                        .map(InitializerAttribute::DefaultError)
                } else {
                    Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
                }
            })
            .collect::<Result<Vec<_>, _>>()?;
        Ok(Self {
            attrs,
            this,
            path,
            brace_token,
@@ -400,6 +432,12 @@ fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
    }
}

impl Parse for DefaultErrorAttribute {
    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
        Ok(Self { ty: input.parse()? })
    }
}

impl Parse for This {
    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
        Ok(Self {