Commit 4883830e authored by Benno Lossin's avatar Benno Lossin
Browse files

rust: pin-init: rewrite the initializer macros using `syn`



Rewrite the initializer macros `[pin_]init!` using `syn`. No functional
changes intended aside from improved error messages on syntactic and
semantical errors. For example if one forgets to use `<-` with an
initializer (and instead uses `:`):

    impl Bar {
        fn new() -> impl PinInit<Self> { ... }
    }

    impl Foo {
        fn new() -> impl PinInit<Self> {
            pin_init!(Self { bar: Bar::new() })
        }
    }

Then the declarative macro would report:

    error[E0308]: mismatched types
      --> tests/ui/compile-fail/init/colon_instead_of_arrow.rs:21:9
       |
    14 |     fn new() -> impl PinInit<Self> {
       |                 ------------------ the found opaque type
    ...
    21 |         pin_init!(Self { bar: Bar::new() })
       |         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       |         |
       |         expected `Bar`, found opaque type
       |         arguments to this function are incorrect
       |
       = note:   expected struct `Bar`
               found opaque type `impl pin_init::PinInit<Bar>`
    note: function defined here
      --> $RUST/core/src/ptr/mod.rs
       |
       | pub const unsafe fn write<T>(dst: *mut T, src: T) {
       |                     ^^^^^
       = note: this error originates in the macro `$crate::__init_internal` which comes from the expansion of the macro `pin_init` (in Nightly builds, run with -Z macro-backtrace for more info)

And the new error is:

    error[E0308]: mismatched types
      --> tests/ui/compile-fail/init/colon_instead_of_arrow.rs:21:31
       |
    14 |     fn new() -> impl PinInit<Self> {
       |                 ------------------ the found opaque type
    ...
    21 |         pin_init!(Self { bar: Bar::new() })
       |                          ---  ^^^^^^^^^^ expected `Bar`, found opaque type
       |                          |
       |                          arguments to this function are incorrect
       |
       = note:   expected struct `Bar`
               found opaque type `impl pin_init::PinInit<Bar>`
    note: function defined here
      --> $RUST/core/src/ptr/mod.rs
       |
       | pub const unsafe fn write<T>(dst: *mut T, src: T) {
       |                     ^^^^^

Importantly, this error gives much more accurate span locations,
pointing to the offending field, rather than the entire macro
invocation.

Tested-by: default avatarAndreas Hindborg <a.hindborg@kernel.org>
Reviewed-by: default avatarGary Guo <gary@garyguo.net>
Signed-off-by: default avatarBenno Lossin <lossin@kernel.org>
parent dae5466c
Loading
Loading
Loading
Loading
+445 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: Apache-2.0 OR MIT

use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, quote_spanned};
use syn::{
    braced,
    parse::{End, Parse},
    parse_quote,
    punctuated::Punctuated,
    spanned::Spanned,
    token, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
};

use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};

pub(crate) struct Initializer {
    this: Option<This>,
    path: Path,
    brace_token: token::Brace,
    fields: Punctuated<InitializerField, Token![,]>,
    rest: Option<(Token![..], Expr)>,
    error: Option<(Token![?], Type)>,
}

struct This {
    _and_token: Token![&],
    ident: Ident,
    _in_token: Token![in],
}

enum InitializerField {
    Value {
        ident: Ident,
        value: Option<(Token![:], Expr)>,
    },
    Init {
        ident: Ident,
        _left_arrow_token: Token![<-],
        value: Expr,
    },
    Code {
        _underscore_token: Token![_],
        _colon_token: Token![:],
        block: Block,
    },
}

impl InitializerField {
    fn ident(&self) -> Option<&Ident> {
        match self {
            Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
            Self::Code { .. } => None,
        }
    }
}

pub(crate) fn expand(
    Initializer {
        this,
        path,
        brace_token,
        fields,
        rest,
        error,
    }: Initializer,
    default_error: Option<&'static str>,
    pinned: bool,
    dcx: &mut DiagCtxt,
) -> Result<TokenStream, ErrorGuaranteed> {
    let error = error.map_or_else(
        || {
            if let Some(default_error) = default_error {
                syn::parse_str(default_error).unwrap()
            } else {
                dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
                parse_quote!(::core::convert::Infallible)
            }
        },
        |(_, err)| err,
    );
    let slot = format_ident!("slot");
    let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
        (
            format_ident!("HasPinData"),
            format_ident!("PinData"),
            format_ident!("__pin_data"),
            format_ident!("pin_init_from_closure"),
        )
    } else {
        (
            format_ident!("HasInitData"),
            format_ident!("InitData"),
            format_ident!("__init_data"),
            format_ident!("init_from_closure"),
        )
    };
    let init_kind = get_init_kind(rest, dcx);
    let zeroable_check = match init_kind {
        InitKind::Normal => quote!(),
        InitKind::Zeroing => quote! {
            // The user specified `..Zeroable::zeroed()` at the end of the list of fields.
            // Therefore we check if the struct implements `Zeroable` and then zero the memory.
            // This allows us to also remove the check that all fields are present (since we
            // already set the memory to zero and that is a valid bit pattern).
            fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
            where T: ::pin_init::Zeroable
            {}
            // Ensure that the struct is indeed `Zeroable`.
            assert_zeroable(#slot);
            // SAFETY: The type implements `Zeroable` by the check above.
            unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
        },
    };
    let this = match this {
        None => quote!(),
        Some(This { ident, .. }) => quote! {
            // Create the `this` so it can be referenced by the user inside of the
            // expressions creating the individual fields.
            let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
        },
    };
    // `mixed_site` ensures that the data is not accessible to the user-controlled code.
    let data = Ident::new("__data", Span::mixed_site());
    let init_fields = init_fields(&fields, pinned, &data, &slot);
    let field_check = make_field_check(&fields, init_kind, &path);
    Ok(quote! {{
        // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return
        // type and shadow it later when we insert the arbitrary user code. That way there will be
        // no possibility of returning without `unsafe`.
        struct __InitOk;

        // Get the data about fields from the supplied type.
        // SAFETY: TODO
        let #data = unsafe {
            use ::pin_init::__internal::#has_data_trait;
            // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit
            // generics (which need to be present with that syntax).
            #path::#get_data()
        };
        // Ensure that `#data` really is of type `#data` and help with type inference:
        let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>(
            #data,
            move |slot| {
                {
                    // Shadow the structure so it cannot be used to return early.
                    struct __InitOk;
                    #zeroable_check
                    #this
                    #init_fields
                    #field_check
                }
                Ok(__InitOk)
            }
        );
        let init = move |slot| -> ::core::result::Result<(), #error> {
            init(slot).map(|__InitOk| ())
        };
        // SAFETY: TODO
        let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) };
        init
    }})
}

enum InitKind {
    Normal,
    Zeroing,
}

fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
    let Some((dotdot, expr)) = rest else {
        return InitKind::Normal;
    };
    match &expr {
        Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
            Expr::Path(ExprPath {
                attrs,
                qself: None,
                path:
                    Path {
                        leading_colon: None,
                        segments,
                    },
            }) if attrs.is_empty()
                && segments.len() == 2
                && segments[0].ident == "Zeroable"
                && segments[0].arguments.is_none()
                && segments[1].ident == "init_zeroed"
                && segments[1].arguments.is_none() =>
            {
                return InitKind::Zeroing;
            }
            _ => {}
        },
        _ => {}
    }
    dcx.error(
        dotdot.span().join(expr.span()).unwrap_or(expr.span()),
        "expected nothing or `..Zeroable::init_zeroed()`.",
    );
    InitKind::Normal
}

/// Generate the code that initializes the fields of the struct using the initializers in `field`.
fn init_fields(
    fields: &Punctuated<InitializerField, Token![,]>,
    pinned: bool,
    data: &Ident,
    slot: &Ident,
) -> TokenStream {
    let mut guards = vec![];
    let mut res = TokenStream::new();
    for field in fields {
        let init = match field {
            InitializerField::Value { ident, value } => {
                let mut value_ident = ident.clone();
                let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
                    // Setting the span of `value_ident` to `value`'s span improves error messages
                    // when the type of `value` is wrong.
                    value_ident.set_span(value.span());
                    quote!(let #value_ident = #value;)
                });
                // Again span for better diagnostics
                let write = quote_spanned!(ident.span()=> ::core::ptr::write);
                let accessor = if pinned {
                    let project_ident = format_ident!("__project_{ident}");
                    quote! {
                        // SAFETY: TODO
                        unsafe { #data.#project_ident(&mut (*#slot).#ident) }
                    }
                } else {
                    quote! {
                        // SAFETY: TODO
                        unsafe { &mut (*#slot).#ident }
                    }
                };
                quote! {
                    {
                        #value_prep
                        // SAFETY: TODO
                        unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) };
                    }
                    #[allow(unused_variables)]
                    let #ident = #accessor;
                }
            }
            InitializerField::Init { ident, value, .. } => {
                // Again span for better diagnostics
                let init = format_ident!("init", span = value.span());
                if pinned {
                    let project_ident = format_ident!("__project_{ident}");
                    quote! {
                        {
                            let #init = #value;
                            // SAFETY:
                            // - `slot` is valid, because we are inside of an initializer closure, we
                            //   return when an error/panic occurs.
                            // - We also use `#data` to require the correct trait (`Init` or `PinInit`)
                            //   for `#ident`.
                            unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? };
                        }
                        // SAFETY: TODO
                        #[allow(unused_variables)]
                        let #ident = unsafe { #data.#project_ident(&mut (*#slot).#ident) };
                    }
                } else {
                    quote! {
                        {
                            let #init = #value;
                            // SAFETY: `slot` is valid, because we are inside of an initializer
                            // closure, we return when an error/panic occurs.
                            unsafe {
                                ::pin_init::Init::__init(
                                    #init,
                                    ::core::ptr::addr_of_mut!((*#slot).#ident),
                                )?
                            };
                        }
                        // SAFETY: TODO
                        #[allow(unused_variables)]
                        let #ident = unsafe { &mut (*#slot).#ident };
                    }
                }
            }
            InitializerField::Code { block: value, .. } => quote!(#[allow(unused_braces)] #value),
        };
        res.extend(init);
        if let Some(ident) = field.ident() {
            // `mixed_site` ensures that the guard is not accessible to the user-controlled code.
            let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
            guards.push(guard.clone());
            res.extend(quote! {
                // Create the drop guard:
                //
                // We rely on macro hygiene to make it impossible for users to access this local
                // variable.
                // SAFETY: We forget the guard later when initialization has succeeded.
                let #guard = unsafe {
                    ::pin_init::__internal::DropGuard::new(
                        ::core::ptr::addr_of_mut!((*slot).#ident)
                    )
                };
            });
        }
    }
    quote! {
        #res
        // If execution reaches this point, all fields have been initialized. Therefore we can now
        // dismiss the guards by forgetting them.
        #(::core::mem::forget(#guards);)*
    }
}

/// Generate the check for ensuring that every field has been initialized.
fn make_field_check(
    fields: &Punctuated<InitializerField, Token![,]>,
    init_kind: InitKind,
    path: &Path,
) -> TokenStream {
    let fields = fields.iter().filter_map(|f| f.ident());
    match init_kind {
        InitKind::Normal => quote! {
            // We use unreachable code to ensure that all fields have been mentioned exactly once,
            // this struct initializer will still be type-checked and complain with a very natural
            // error message if a field is forgotten/mentioned more than once.
            #[allow(unreachable_code, clippy::diverging_sub_expression)]
            // SAFETY: this code is never executed.
            let _ = || unsafe {
                ::core::ptr::write(slot, #path {
                    #(
                        #fields: ::core::panic!(),
                    )*
                })
            };
        },
        InitKind::Zeroing => quote! {
            // We use unreachable code to ensure that all fields have been mentioned at most once.
            // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will
            // be zeroed. This struct initializer will still be type-checked and complain with a
            // very natural error message if a field is mentioned more than once, or doesn't exist.
            #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)]
            // SAFETY: this code is never executed.
            let _ = || unsafe {
                let mut zeroed = ::core::mem::zeroed();
                // We have to use type inference here to make zeroed have the correct type. This
                // does not get executed, so it has no effect.
                ::core::ptr::write(slot, zeroed);
                zeroed = ::core::mem::zeroed();
                ::core::ptr::write(slot, #path {
                    #(
                        #fields: ::core::panic!(),
                    )*
                    ..zeroed
                })
            };
        },
    }
}

impl Parse for Initializer {
    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
        let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
        let path = input.parse()?;
        let content;
        let brace_token = braced!(content in input);
        let mut fields = Punctuated::new();
        loop {
            let lh = content.lookahead1();
            if lh.peek(End) || lh.peek(Token![..]) {
                break;
            } else if lh.peek(Ident) || lh.peek(Token![_]) {
                fields.push_value(content.parse()?);
                let lh = content.lookahead1();
                if lh.peek(End) {
                    break;
                } else if lh.peek(Token![,]) {
                    fields.push_punct(content.parse()?);
                } else {
                    return Err(lh.error());
                }
            } else {
                return Err(lh.error());
            }
        }
        let rest = content
            .peek(Token![..])
            .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
            .transpose()?;
        let error = input
            .peek(Token![?])
            .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
            .transpose()?;
        Ok(Self {
            this,
            path,
            brace_token,
            fields,
            rest,
            error,
        })
    }
}

impl Parse for This {
    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
        Ok(Self {
            _and_token: input.parse()?,
            ident: input.parse()?,
            _in_token: input.parse()?,
        })
    }
}

impl Parse for InitializerField {
    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
        let lh = input.lookahead1();
        if lh.peek(Token![_]) {
            Ok(Self::Code {
                _underscore_token: input.parse()?,
                _colon_token: input.parse()?,
                block: input.parse()?,
            })
        } else if lh.peek(Ident) {
            let ident = input.parse()?;
            let lh = input.lookahead1();
            if lh.peek(Token![<-]) {
                Ok(Self::Init {
                    ident,
                    _left_arrow_token: input.parse()?,
                    value: input.parse()?,
                })
            } else if lh.peek(Token![:]) {
                Ok(Self::Value {
                    ident,
                    value: Some((input.parse()?, input.parse()?)),
                })
            } else if lh.peek(Token![,]) || lh.peek(End) {
                Ok(Self::Value { ident, value: None })
            } else {
                Err(lh.error())
            }
        } else {
            Err(lh.error())
        }
    }
}
+13 −0
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@
use crate::diagnostics::DiagCtxt;

mod diagnostics;
mod init;
mod pin_data;
mod pinned_drop;
mod zeroable;
@@ -45,3 +46,15 @@ pub fn maybe_derive_zeroable(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input);
    DiagCtxt::with(|dcx| zeroable::maybe_derive(input, dcx)).into()
}
#[proc_macro]
pub fn init(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input);
    DiagCtxt::with(|dcx| init::expand(input, Some("::core::convert::Infallible"), false, dcx))
        .into()
}

#[proc_macro]
pub fn pin_init(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input);
    DiagCtxt::with(|dcx| init::expand(input, Some("::core::convert::Infallible"), true, dcx)).into()
}
+2 −54
Original line number Diff line number Diff line
@@ -297,8 +297,6 @@

#[doc(hidden)]
pub mod __internal;
#[doc(hidden)]
pub mod macros;

#[cfg(any(feature = "std", feature = "alloc"))]
mod alloc;
@@ -781,32 +779,7 @@ macro_rules! stack_try_pin_init {
/// ```
///
/// [`NonNull<Self>`]: core::ptr::NonNull
// For a detailed example of how this macro works, see the module documentation of the hidden
// module `macros` inside of `macros.rs`.
#[macro_export]
macro_rules! pin_init {
    ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? {
        $($fields:tt)*
    }) => {
        $crate::pin_init!($(&$this in)? $t $(::<$($generics),*>)? {
            $($fields)*
        }? ::core::convert::Infallible)
    };
    ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? {
        $($fields:tt)*
    }? $err:ty) => {
        $crate::__init_internal!(
            @this($($this)?),
            @typ($t $(::<$($generics),*>)? ),
            @fields($($fields)*),
            @error($err),
            @data(PinData, use_data),
            @has_data(HasPinData, __pin_data),
            @construct_closure(pin_init_from_closure),
            @munch_fields($($fields)*),
        )
    }
}
pub use pin_init_internal::pin_init;

/// Construct an in-place, fallible initializer for `struct`s.
///
@@ -844,32 +817,7 @@ macro_rules! pin_init {
/// }
/// # let _ = Box::init(BigBuf::new());
/// ```
// For a detailed example of how this macro works, see the module documentation of the hidden
// module `macros` inside of `macros.rs`.
#[macro_export]
macro_rules! init {
    ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? {
        $($fields:tt)*
    }) => {
        $crate::init!($(&$this in)? $t $(::<$($generics),*>)? {
            $($fields)*
        }? ::core::convert::Infallible)
    };
    ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? {
        $($fields:tt)*
    }? $err:ty) => {
        $crate::__init_internal!(
            @this($($this)?),
            @typ($t $(::<$($generics),*>)?),
            @fields($($fields)*),
            @error($err),
            @data(InitData, /*no use_data*/),
            @has_data(HasInitData, __init_data),
            @construct_closure(init_from_closure),
            @munch_fields($($fields)*),
        )
    }
}
pub use pin_init_internal::init;

/// Asserts that a field on a struct using `#[pin_data]` is marked with `#[pin]` ie. that it is
/// structurally pinned.

rust/pin-init/src/macros.rs

deleted100644 → 0
+0 −951

File deleted.

Preview size limit exceeded, changes collapsed.