Commit d26732e5 authored by Benno Lossin's avatar Benno Lossin
Browse files

rust: pin-init: internal: init: add support for attributes on initializer fields



Initializer fields ought to support the same attributes that are allowed
in struct initializers on fields. For example, `cfg` or lint levels such
as `expect`, `allow` etc. Add parsing support for these attributes using
syn to initializer fields and adjust the macro expansion accordingly.

Tested-by: default avatarAndreas Hindborg <a.hindborg@kernel.org>
Reviewed-by: default avatarGary Guo <gary@garyguo.net>
Signed-off-by: default avatarBenno Lossin <lossin@kernel.org>
parent d083a621
Loading
Loading
Loading
Loading
+55 −14
Original line number Diff line number Diff line
@@ -29,7 +29,12 @@ struct This {
    _in_token: Token![in],
}

enum InitializerField {
struct InitializerField {
    attrs: Vec<Attribute>,
    kind: InitializerKind,
}

enum InitializerKind {
    Value {
        ident: Ident,
        value: Option<(Token![:], Expr)>,
@@ -46,7 +51,7 @@ enum InitializerField {
    },
}

impl InitializerField {
impl InitializerKind {
    fn ident(&self) -> Option<&Ident> {
        match self {
            Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
@@ -227,10 +232,16 @@ fn init_fields(
    slot: &Ident,
) -> TokenStream {
    let mut guards = vec![];
    let mut guard_attrs = vec![];
    let mut res = TokenStream::new();
    for field in fields {
        let init = match field {
            InitializerField::Value { ident, value } => {
    for InitializerField { attrs, kind } in fields {
        let cfgs = {
            let mut cfgs = attrs.clone();
            cfgs.retain(|attr| attr.path().is_ident("cfg"));
            cfgs
        };
        let init = match kind {
            InitializerKind::Value { ident, value } => {
                let mut value_ident = ident.clone();
                let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
                    // Setting the span of `value_ident` to `value`'s span improves error messages
@@ -253,21 +264,24 @@ fn init_fields(
                    }
                };
                quote! {
                    #(#attrs)*
                    {
                        #value_prep
                        // SAFETY: TODO
                        unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) };
                    }
                    #(#cfgs)*
                    #[allow(unused_variables)]
                    let #ident = #accessor;
                }
            }
            InitializerField::Init { ident, value, .. } => {
            InitializerKind::Init { ident, value, .. } => {
                // Again span for better diagnostics
                let init = format_ident!("init", span = value.span());
                if pinned {
                    let project_ident = format_ident!("__project_{ident}");
                    quote! {
                        #(#attrs)*
                        {
                            let #init = #value;
                            // SAFETY:
@@ -277,12 +291,14 @@ fn init_fields(
                            //   for `#ident`.
                            unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? };
                        }
                        #(#cfgs)*
                        // SAFETY: TODO
                        #[allow(unused_variables)]
                        let #ident = unsafe { #data.#project_ident(&mut (*#slot).#ident) };
                    }
                } else {
                    quote! {
                        #(#attrs)*
                        {
                            let #init = #value;
                            // SAFETY: `slot` is valid, because we are inside of an initializer
@@ -294,20 +310,25 @@ fn init_fields(
                                )?
                            };
                        }
                        #(#cfgs)*
                        // SAFETY: TODO
                        #[allow(unused_variables)]
                        let #ident = unsafe { &mut (*#slot).#ident };
                    }
                }
            }
            InitializerField::Code { block: value, .. } => quote!(#[allow(unused_braces)] #value),
            InitializerKind::Code { block: value, .. } => quote! {
                #(#attrs)*
                #[allow(unused_braces)]
                #value
            },
        };
        res.extend(init);
        if let Some(ident) = field.ident() {
        if let Some(ident) = kind.ident() {
            // `mixed_site` ensures that the guard is not accessible to the user-controlled code.
            let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
            guards.push(guard.clone());
            res.extend(quote! {
                #(#cfgs)*
                // Create the drop guard:
                //
                // We rely on macro hygiene to make it impossible for users to access this local
@@ -319,13 +340,18 @@ fn init_fields(
                    )
                };
            });
            guards.push(guard);
            guard_attrs.push(cfgs);
        }
    }
    quote! {
        #res
        // If execution reaches this point, all fields have been initialized. Therefore we can now
        // dismiss the guards by forgetting them.
        #(::core::mem::forget(#guards);)*
        #(
            #(#guard_attrs)*
            ::core::mem::forget(#guards);
        )*
    }
}

@@ -335,7 +361,10 @@ fn make_field_check(
    init_kind: InitKind,
    path: &Path,
) -> TokenStream {
    let fields = fields.iter().filter_map(|f| f.ident());
    let field_attrs = fields
        .iter()
        .filter_map(|f| f.kind.ident().map(|_| &f.attrs));
    let field_name = fields.iter().filter_map(|f| f.kind.ident());
    match init_kind {
        InitKind::Normal => quote! {
            // We use unreachable code to ensure that all fields have been mentioned exactly once,
@@ -346,7 +375,8 @@ fn make_field_check(
            let _ = || unsafe {
                ::core::ptr::write(slot, #path {
                    #(
                        #fields: ::core::panic!(),
                        #(#field_attrs)*
                        #field_name: ::core::panic!(),
                    )*
                })
            };
@@ -366,7 +396,8 @@ fn make_field_check(
                zeroed = ::core::mem::zeroed();
                ::core::ptr::write(slot, #path {
                    #(
                        #fields: ::core::panic!(),
                        #(#field_attrs)*
                        #field_name: ::core::panic!(),
                    )*
                    ..zeroed
                })
@@ -387,7 +418,7 @@ fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
            let lh = content.lookahead1();
            if lh.peek(End) || lh.peek(Token![..]) {
                break;
            } else if lh.peek(Ident) || lh.peek(Token![_]) {
            } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) {
                fields.push_value(content.parse()?);
                let lh = content.lookahead1();
                if lh.peek(End) {
@@ -449,6 +480,16 @@ fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
}

impl Parse for InitializerField {
    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
        let attrs = input.call(Attribute::parse_outer)?;
        Ok(Self {
            attrs,
            kind: input.parse()?,
        })
    }
}

impl Parse for InitializerKind {
    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
        let lh = input.lookahead1();
        if lh.peek(Token![_]) {