Commit 84837cf6 authored by Benno Lossin's avatar Benno Lossin Committed by Miguel Ojeda
Browse files

rust: pin-init: change examples to the user-space version

Replace the examples in the documentation by the ones from the
user-space version and introduce the standalone examples from the
user-space version such as the `CMutex<T>` type.

The `CMutex<T>` example from the pinned-init repository [1] is used in
several documentation examples in the user-space version instead of the
kernel `Mutex<T>` type (as it's not available). In order to split off
the pin-init crate, all examples need to be free of kernel-specific
types.

Link: https://github.com/rust-for-Linux/pinned-init

 [1]
Signed-off-by: default avatarBenno Lossin <benno.lossin@proton.me>
Reviewed-by: default avatarFiona Behrens <me@kloenk.dev>
Tested-by: default avatarAndreas Hindborg <a.hindborg@kernel.org>
Link: https://lore.kernel.org/r/20250308110339.2997091-6-benno.lossin@proton.me


Signed-off-by: default avatarMiguel Ojeda <ojeda@kernel.org>
parent 4b11798e
Loading
Loading
Loading
Loading
+39 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: Apache-2.0 OR MIT

use pin_init::*;

// Struct with size over 1GiB
#[derive(Debug)]
pub struct BigStruct {
    buf: [u8; 1024 * 1024 * 1024],
    a: u64,
    b: u64,
    c: u64,
    d: u64,
    managed_buf: ManagedBuf,
}

#[derive(Debug)]
pub struct ManagedBuf {
    buf: [u8; 1024 * 1024],
}

impl ManagedBuf {
    pub fn new() -> impl Init<Self> {
        init!(ManagedBuf { buf <- zeroed() })
    }
}

fn main() {
    // we want to initialize the struct in-place, otherwise we would get a stackoverflow
    let buf: Box<BigStruct> = Box::init(init!(BigStruct {
        buf <- zeroed(),
        a: 7,
        b: 186,
        c: 7789,
        d: 34,
        managed_buf <- ManagedBuf::new(),
    }))
    .unwrap();
    println!("{}", core::mem::size_of_val(&*buf));
}
+27 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: Apache-2.0 OR MIT

#![cfg_attr(feature = "alloc", feature(allocator_api))]

use core::convert::Infallible;

#[cfg(feature = "alloc")]
use std::alloc::AllocError;

#[derive(Debug)]
pub struct Error;

impl From<Infallible> for Error {
    fn from(e: Infallible) -> Self {
        match e {}
    }
}

#[cfg(feature = "alloc")]
impl From<AllocError> for Error {
    fn from(_: AllocError) -> Self {
        Self
    }
}

#[allow(dead_code)]
fn main() {}
+161 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: Apache-2.0 OR MIT

#![allow(clippy::undocumented_unsafe_blocks)]
#![cfg_attr(feature = "alloc", feature(allocator_api))]

use core::{
    cell::Cell,
    convert::Infallible,
    marker::PhantomPinned,
    pin::Pin,
    ptr::{self, NonNull},
};

use pin_init::*;

#[expect(unused_attributes)]
mod error;
use error::Error;

#[pin_data(PinnedDrop)]
#[repr(C)]
#[derive(Debug)]
pub struct ListHead {
    next: Link,
    prev: Link,
    #[pin]
    pin: PhantomPinned,
}

impl ListHead {
    #[inline]
    pub fn new() -> impl PinInit<Self, Infallible> {
        try_pin_init!(&this in Self {
            next: unsafe { Link::new_unchecked(this) },
            prev: unsafe { Link::new_unchecked(this) },
            pin: PhantomPinned,
        }? Infallible)
    }

    #[inline]
    pub fn insert_next(list: &ListHead) -> impl PinInit<Self, Infallible> + '_ {
        try_pin_init!(&this in Self {
            prev: list.next.prev().replace(unsafe { Link::new_unchecked(this)}),
            next: list.next.replace(unsafe { Link::new_unchecked(this)}),
            pin: PhantomPinned,
        }? Infallible)
    }

    #[inline]
    pub fn insert_prev(list: &ListHead) -> impl PinInit<Self, Infallible> + '_ {
        try_pin_init!(&this in Self {
            next: list.prev.next().replace(unsafe { Link::new_unchecked(this)}),
            prev: list.prev.replace(unsafe { Link::new_unchecked(this)}),
            pin: PhantomPinned,
        }? Infallible)
    }

    #[inline]
    pub fn next(&self) -> Option<NonNull<Self>> {
        if ptr::eq(self.next.as_ptr(), self) {
            None
        } else {
            Some(unsafe { NonNull::new_unchecked(self.next.as_ptr() as *mut Self) })
        }
    }

    #[allow(dead_code)]
    pub fn size(&self) -> usize {
        let mut size = 1;
        let mut cur = self.next.clone();
        while !ptr::eq(self, cur.cur()) {
            cur = cur.next().clone();
            size += 1;
        }
        size
    }
}

#[pinned_drop]
impl PinnedDrop for ListHead {
    //#[inline]
    fn drop(self: Pin<&mut Self>) {
        if !ptr::eq(self.next.as_ptr(), &*self) {
            let next = unsafe { &*self.next.as_ptr() };
            let prev = unsafe { &*self.prev.as_ptr() };
            next.prev.set(&self.prev);
            prev.next.set(&self.next);
        }
    }
}

#[repr(transparent)]
#[derive(Clone, Debug)]
struct Link(Cell<NonNull<ListHead>>);

impl Link {
    /// # Safety
    ///
    /// The contents of the pointer should form a consistent circular
    /// linked list; for example, a "next" link should be pointed back
    /// by the target `ListHead`'s "prev" link and a "prev" link should be
    /// pointed back by the target `ListHead`'s "next" link.
    #[inline]
    unsafe fn new_unchecked(ptr: NonNull<ListHead>) -> Self {
        Self(Cell::new(ptr))
    }

    #[inline]
    fn next(&self) -> &Link {
        unsafe { &(*self.0.get().as_ptr()).next }
    }

    #[inline]
    fn prev(&self) -> &Link {
        unsafe { &(*self.0.get().as_ptr()).prev }
    }

    #[allow(dead_code)]
    fn cur(&self) -> &ListHead {
        unsafe { &*self.0.get().as_ptr() }
    }

    #[inline]
    fn replace(&self, other: Link) -> Link {
        unsafe { Link::new_unchecked(self.0.replace(other.0.get())) }
    }

    #[inline]
    fn as_ptr(&self) -> *const ListHead {
        self.0.get().as_ptr()
    }

    #[inline]
    fn set(&self, val: &Link) {
        self.0.set(val.0.get());
    }
}

#[allow(dead_code)]
#[cfg_attr(test, test)]
fn main() -> Result<(), Error> {
    let a = Box::pin_init(ListHead::new())?;
    stack_pin_init!(let b = ListHead::insert_next(&a));
    stack_pin_init!(let c = ListHead::insert_next(&a));
    stack_pin_init!(let d = ListHead::insert_next(&b));
    let e = Box::pin_init(ListHead::insert_next(&b))?;
    println!("a ({a:p}): {a:?}");
    println!("b ({b:p}): {b:?}");
    println!("c ({c:p}): {c:?}");
    println!("d ({d:p}): {d:?}");
    println!("e ({e:p}): {e:?}");
    let mut inspect = &*a;
    while let Some(next) = inspect.next() {
        println!("({inspect:p}): {inspect:?}");
        inspect = unsafe { &*next.as_ptr() };
        if core::ptr::eq(inspect, &*a) {
            break;
        }
    }
    Ok(())
}
+209 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: Apache-2.0 OR MIT

#![allow(clippy::undocumented_unsafe_blocks)]
#![cfg_attr(feature = "alloc", feature(allocator_api))]
#![allow(clippy::missing_safety_doc)]

use core::{
    cell::{Cell, UnsafeCell},
    marker::PhantomPinned,
    ops::{Deref, DerefMut},
    pin::Pin,
    sync::atomic::{AtomicBool, Ordering},
};
use std::{
    sync::Arc,
    thread::{self, park, sleep, Builder, Thread},
    time::Duration,
};

use pin_init::*;
#[expect(unused_attributes)]
#[path = "./linked_list.rs"]
pub mod linked_list;
use linked_list::*;

pub struct SpinLock {
    inner: AtomicBool,
}

impl SpinLock {
    #[inline]
    pub fn acquire(&self) -> SpinLockGuard<'_> {
        while self
            .inner
            .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
            .is_err()
        {
            while self.inner.load(Ordering::Relaxed) {
                thread::yield_now();
            }
        }
        SpinLockGuard(self)
    }

    #[inline]
    #[allow(clippy::new_without_default)]
    pub const fn new() -> Self {
        Self {
            inner: AtomicBool::new(false),
        }
    }
}

pub struct SpinLockGuard<'a>(&'a SpinLock);

impl Drop for SpinLockGuard<'_> {
    #[inline]
    fn drop(&mut self) {
        self.0.inner.store(false, Ordering::Release);
    }
}

#[pin_data]
pub struct CMutex<T> {
    #[pin]
    wait_list: ListHead,
    spin_lock: SpinLock,
    locked: Cell<bool>,
    #[pin]
    data: UnsafeCell<T>,
}

impl<T> CMutex<T> {
    #[inline]
    pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
        pin_init!(CMutex {
            wait_list <- ListHead::new(),
            spin_lock: SpinLock::new(),
            locked: Cell::new(false),
            data <- unsafe {
                pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
                    val.__pinned_init(slot.cast::<T>())
                })
            },
        })
    }

    #[inline]
    pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
        let mut sguard = self.spin_lock.acquire();
        if self.locked.get() {
            stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
            // println!("wait list length: {}", self.wait_list.size());
            while self.locked.get() {
                drop(sguard);
                park();
                sguard = self.spin_lock.acquire();
            }
            // This does have an effect, as the ListHead inside wait_entry implements Drop!
            #[expect(clippy::drop_non_drop)]
            drop(wait_entry);
        }
        self.locked.set(true);
        unsafe {
            Pin::new_unchecked(CMutexGuard {
                mtx: self,
                _pin: PhantomPinned,
            })
        }
    }

    #[allow(dead_code)]
    pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
        // SAFETY: we have an exclusive reference and thus nobody has access to data.
        unsafe { &mut *self.data.get() }
    }
}

unsafe impl<T: Send> Send for CMutex<T> {}
unsafe impl<T: Send> Sync for CMutex<T> {}

pub struct CMutexGuard<'a, T> {
    mtx: &'a CMutex<T>,
    _pin: PhantomPinned,
}

impl<T> Drop for CMutexGuard<'_, T> {
    #[inline]
    fn drop(&mut self) {
        let sguard = self.mtx.spin_lock.acquire();
        self.mtx.locked.set(false);
        if let Some(list_field) = self.mtx.wait_list.next() {
            let wait_entry = list_field.as_ptr().cast::<WaitEntry>();
            unsafe { (*wait_entry).thread.unpark() };
        }
        drop(sguard);
    }
}

impl<T> Deref for CMutexGuard<'_, T> {
    type Target = T;

    #[inline]
    fn deref(&self) -> &Self::Target {
        unsafe { &*self.mtx.data.get() }
    }
}

impl<T> DerefMut for CMutexGuard<'_, T> {
    #[inline]
    fn deref_mut(&mut self) -> &mut Self::Target {
        unsafe { &mut *self.mtx.data.get() }
    }
}

#[pin_data]
#[repr(C)]
struct WaitEntry {
    #[pin]
    wait_list: ListHead,
    thread: Thread,
}

impl WaitEntry {
    #[inline]
    fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
        pin_init!(Self {
            thread: thread::current(),
            wait_list <- ListHead::insert_prev(list),
        })
    }
}

#[cfg(not(any(feature = "std", feature = "alloc")))]
fn main() {}

#[allow(dead_code)]
#[cfg_attr(test, test)]
#[cfg(any(feature = "std", feature = "alloc"))]
fn main() {
    let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
    let mut handles = vec![];
    let thread_count = 20;
    let workload = if cfg!(miri) { 100 } else { 1_000 };
    for i in 0..thread_count {
        let mtx = mtx.clone();
        handles.push(
            Builder::new()
                .name(format!("worker #{i}"))
                .spawn(move || {
                    for _ in 0..workload {
                        *mtx.lock() += 1;
                    }
                    println!("{i} halfway");
                    sleep(Duration::from_millis((i as u64) * 10));
                    for _ in 0..workload {
                        *mtx.lock() += 1;
                    }
                    println!("{i} finished");
                })
                .expect("should not fail"),
        );
    }
    for h in handles {
        h.join().expect("thread panicked");
    }
    println!("{:?}", &*mtx.lock());
    assert_eq!(*mtx.lock(), workload * thread_count * 2);
}
+178 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: Apache-2.0 OR MIT

// inspired by https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs
#![allow(clippy::undocumented_unsafe_blocks)]
#![cfg_attr(feature = "alloc", feature(allocator_api))]
#[cfg(not(windows))]
mod pthread_mtx {
    #[cfg(feature = "alloc")]
    use core::alloc::AllocError;
    use core::{
        cell::UnsafeCell,
        marker::PhantomPinned,
        mem::MaybeUninit,
        ops::{Deref, DerefMut},
        pin::Pin,
    };
    use pin_init::*;
    use std::convert::Infallible;

    #[pin_data(PinnedDrop)]
    pub struct PThreadMutex<T> {
        #[pin]
        raw: UnsafeCell<libc::pthread_mutex_t>,
        data: UnsafeCell<T>,
        #[pin]
        pin: PhantomPinned,
    }

    unsafe impl<T: Send> Send for PThreadMutex<T> {}
    unsafe impl<T: Send> Sync for PThreadMutex<T> {}

    #[pinned_drop]
    impl<T> PinnedDrop for PThreadMutex<T> {
        fn drop(self: Pin<&mut Self>) {
            unsafe {
                libc::pthread_mutex_destroy(self.raw.get());
            }
        }
    }

    #[derive(Debug)]
    pub enum Error {
        #[expect(dead_code)]
        IO(std::io::Error),
        Alloc,
    }

    impl From<Infallible> for Error {
        fn from(e: Infallible) -> Self {
            match e {}
        }
    }

    #[cfg(feature = "alloc")]
    impl From<AllocError> for Error {
        fn from(_: AllocError) -> Self {
            Self::Alloc
        }
    }

    impl<T> PThreadMutex<T> {
        pub fn new(data: T) -> impl PinInit<Self, Error> {
            fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> {
                let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| {
                    // we can cast, because `UnsafeCell` has the same layout as T.
                    let slot: *mut libc::pthread_mutex_t = slot.cast();
                    let mut attr = MaybeUninit::uninit();
                    let attr = attr.as_mut_ptr();
                    // SAFETY: ptr is valid
                    let ret = unsafe { libc::pthread_mutexattr_init(attr) };
                    if ret != 0 {
                        return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
                    }
                    // SAFETY: attr is initialized
                    let ret = unsafe {
                        libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL)
                    };
                    if ret != 0 {
                        // SAFETY: attr is initialized
                        unsafe { libc::pthread_mutexattr_destroy(attr) };
                        return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
                    }
                    // SAFETY: slot is valid
                    unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) };
                    // SAFETY: attr and slot are valid ptrs and attr is initialized
                    let ret = unsafe { libc::pthread_mutex_init(slot, attr) };
                    // SAFETY: attr was initialized
                    unsafe { libc::pthread_mutexattr_destroy(attr) };
                    if ret != 0 {
                        return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
                    }
                    Ok(())
                };
                // SAFETY: mutex has been initialized
                unsafe { pin_init_from_closure(init) }
            }
            try_pin_init!(Self {
            data: UnsafeCell::new(data),
            raw <- init_raw(),
            pin: PhantomPinned,
        }? Error)
        }

        pub fn lock(&self) -> PThreadMutexGuard<'_, T> {
            // SAFETY: raw is always initialized
            unsafe { libc::pthread_mutex_lock(self.raw.get()) };
            PThreadMutexGuard { mtx: self }
        }
    }

    pub struct PThreadMutexGuard<'a, T> {
        mtx: &'a PThreadMutex<T>,
    }

    impl<T> Drop for PThreadMutexGuard<'_, T> {
        fn drop(&mut self) {
            // SAFETY: raw is always initialized
            unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) };
        }
    }

    impl<T> Deref for PThreadMutexGuard<'_, T> {
        type Target = T;

        fn deref(&self) -> &Self::Target {
            unsafe { &*self.mtx.data.get() }
        }
    }

    impl<T> DerefMut for PThreadMutexGuard<'_, T> {
        fn deref_mut(&mut self) -> &mut Self::Target {
            unsafe { &mut *self.mtx.data.get() }
        }
    }
}

#[cfg_attr(test, test)]
fn main() {
    #[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))]
    {
        use core::pin::Pin;
        use pin_init::*;
        use pthread_mtx::*;
        use std::{
            sync::Arc,
            thread::{sleep, Builder},
            time::Duration,
        };
        let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap();
        let mut handles = vec![];
        let thread_count = 20;
        let workload = 1_000_000;
        for i in 0..thread_count {
            let mtx = mtx.clone();
            handles.push(
                Builder::new()
                    .name(format!("worker #{i}"))
                    .spawn(move || {
                        for _ in 0..workload {
                            *mtx.lock() += 1;
                        }
                        println!("{i} halfway");
                        sleep(Duration::from_millis((i as u64) * 10));
                        for _ in 0..workload {
                            *mtx.lock() += 1;
                        }
                        println!("{i} finished");
                    })
                    .expect("should not fail"),
            );
        }
        for h in handles {
            h.join().expect("thread panicked");
        }
        println!("{:?}", &*mtx.lock());
        assert_eq!(*mtx.lock(), workload * thread_count * 2);
    }
}
Loading