Commit 3f2a5ba7 authored by Lyude Paul's avatar Lyude Paul Committed by Andreas Hindborg
Browse files

rust: hrtimer: Add HrTimerCallbackContext and ::forward()



With Linux's hrtimer API, there's a number of methods that can only be
called in two situations:

* When we have exclusive access to the hrtimer and it is not currently
  active
* When we're within the context of an hrtimer callback context

This commit handles the second situation and implements hrtimer_forward()
support in the context of a timer callback. We do this by introducing a
HrTimerCallbackContext type which is provided to users during the
RawHrTimerCallback::run() callback, and then add a forward() function to
the type.

Signed-off-by: default avatarLyude Paul <lyude@redhat.com>
Reviewed-by: default avatarDaniel Almeida <daniel.almeida@collabora.com>
Reviewed-by: default avatarAndreas Hindborg <a.hindborg@kernel.org>
Link: https://lore.kernel.org/r/20250821193259.964504-5-lyude@redhat.com


Signed-off-by: default avatarAndreas Hindborg <a.hindborg@kernel.org>
parent 3efb9ce9
Loading
Loading
Loading
Loading
+60 −3
Original line number Diff line number Diff line
@@ -69,7 +69,7 @@

use super::{ClockSource, Delta, Instant};
use crate::{prelude::*, types::Opaque};
use core::marker::PhantomData;
use core::{marker::PhantomData, ptr::NonNull};
use pin_init::PinInit;

/// A type-alias to refer to the [`Instant<C>`] for a given `T` from [`HrTimer<T>`].
@@ -196,6 +196,10 @@ unsafe fn raw_forward(self_ptr: *mut Self, now: HrTimerInstant<T>, interval: Del
    /// expires after `now` and then returns the number of times the timer was forwarded by
    /// `interval`.
    ///
    /// This function is mainly useful for timer types which can provide exclusive access to the
    /// timer when the timer is not running. For forwarding the timer from within the timer callback
    /// context, see [`HrTimerCallbackContext::forward()`].
    ///
    /// Returns the number of overruns that occurred as a result of the timer expiry change.
    pub fn forward(self: Pin<&mut Self>, now: HrTimerInstant<T>, interval: Delta) -> u64
    where
@@ -345,9 +349,13 @@ pub trait HrTimerCallback {
    type Pointer<'a>: RawHrTimerCallback;

    /// Called by the timer logic when the timer fires.
    fn run(this: <Self::Pointer<'_> as RawHrTimerCallback>::CallbackTarget<'_>) -> HrTimerRestart
    fn run(
        this: <Self::Pointer<'_> as RawHrTimerCallback>::CallbackTarget<'_>,
        ctx: HrTimerCallbackContext<'_, Self>,
    ) -> HrTimerRestart
    where
        Self: Sized;
        Self: Sized,
        Self: HasHrTimer<Self>;
}

/// A handle representing a potentially running timer.
@@ -632,6 +640,55 @@ impl<C: ClockSource> HrTimerMode for RelativePinnedHardMode<C> {
    type Expires = Delta;
}

/// Privileged smart-pointer for a [`HrTimer`] callback context.
///
/// Many [`HrTimer`] methods can only be called in two situations:
///
/// * When the caller has exclusive access to the `HrTimer` and the `HrTimer` is guaranteed not to
///   be running.
/// * From within the context of an `HrTimer`'s callback method.
///
/// This type provides access to said methods from within a timer callback context.
///
/// # Invariants
///
/// * The existence of this type means the caller is currently within the callback for an
///   [`HrTimer`].
/// * `self.0` always points to a live instance of [`HrTimer<T>`].
pub struct HrTimerCallbackContext<'a, T: HasHrTimer<T>>(NonNull<HrTimer<T>>, PhantomData<&'a ()>);

impl<'a, T: HasHrTimer<T>> HrTimerCallbackContext<'a, T> {
    /// Create a new [`HrTimerCallbackContext`].
    ///
    /// # Safety
    ///
    /// This function relies on the caller being within the context of a timer callback, so it must
    /// not be used anywhere except for within implementations of [`RawHrTimerCallback::run`]. The
    /// caller promises that `timer` points to a valid initialized instance of
    /// [`bindings::hrtimer`].
    ///
    /// The returned `Self` must not outlive the function context of [`RawHrTimerCallback::run`]
    /// where this function is called.
    pub(crate) unsafe fn from_raw(timer: *mut HrTimer<T>) -> Self {
        // SAFETY: The caller guarantees `timer` is a valid pointer to an initialized
        // `bindings::hrtimer`
        // INVARIANT: Our safety contract ensures that we're within the context of a timer callback
        // and that `timer` points to a live instance of `HrTimer<T>`.
        Self(unsafe { NonNull::new_unchecked(timer) }, PhantomData)
    }

    /// Conditionally forward the timer.
    ///
    /// This function is identical to [`HrTimer::forward()`] except that it may only be used from
    /// within the context of a [`HrTimer`] callback.
    pub fn forward(&mut self, now: HrTimerInstant<T>, interval: Delta) -> u64 {
        // SAFETY:
        // - We are guaranteed to be within the context of a timer callback by our type invariants
        // - By our type invariants, `self.0` always points to a valid `HrTimer<T>`
        unsafe { HrTimer::<T>::raw_forward(self.0.as_ptr(), now, interval) }
    }
}

/// Use to implement the [`HasHrTimer<T>`] trait.
///
/// See [`module`] documentation for an example.
+8 −1
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@
use super::HasHrTimer;
use super::HrTimer;
use super::HrTimerCallback;
use super::HrTimerCallbackContext;
use super::HrTimerHandle;
use super::HrTimerMode;
use super::HrTimerPointer;
@@ -99,6 +100,12 @@ impl<T> RawHrTimerCallback for Arc<T>
        //    allocation from other `Arc` clones.
        let receiver = unsafe { ArcBorrow::from_raw(data_ptr) };

        T::run(receiver).into_c()
        // SAFETY:
        // - By C API contract `timer_ptr` is the pointer that we passed when queuing the timer, so
        //   it is a valid pointer to a `HrTimer<T>` embedded in a `T`.
        // - We are within `RawHrTimerCallback::run`
        let context = unsafe { HrTimerCallbackContext::from_raw(timer_ptr) };

        T::run(receiver, context).into_c()
    }
}
+8 −1
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@
use super::HasHrTimer;
use super::HrTimer;
use super::HrTimerCallback;
use super::HrTimerCallbackContext;
use super::HrTimerHandle;
use super::HrTimerMode;
use super::RawHrTimerCallback;
@@ -103,6 +104,12 @@ impl<'a, T> RawHrTimerCallback for Pin<&'a T>
        // here.
        let receiver_pin = unsafe { Pin::new_unchecked(receiver_ref) };

        T::run(receiver_pin).into_c()
        // SAFETY:
        // - By C API contract `timer_ptr` is the pointer that we passed when queuing the timer, so
        //   it is a valid pointer to a `HrTimer<T>` embedded in a `T`.
        // - We are within `RawHrTimerCallback::run`
        let context = unsafe { HrTimerCallbackContext::from_raw(timer_ptr) };

        T::run(receiver_pin, context).into_c()
    }
}
+9 −3
Original line number Diff line number Diff line
// SPDX-License-Identifier: GPL-2.0

use super::{
    HasHrTimer, HrTimer, HrTimerCallback, HrTimerHandle, HrTimerMode, RawHrTimerCallback,
    UnsafeHrTimerPointer,
    HasHrTimer, HrTimer, HrTimerCallback, HrTimerCallbackContext, HrTimerHandle, HrTimerMode,
    RawHrTimerCallback, UnsafeHrTimerPointer,
};
use core::{marker::PhantomData, pin::Pin, ptr::NonNull};

@@ -107,6 +107,12 @@ impl<'a, T> RawHrTimerCallback for Pin<&'a mut T>
        // here.
        let receiver_pin = unsafe { Pin::new_unchecked(receiver_ref) };

        T::run(receiver_pin).into_c()
        // SAFETY:
        // - By C API contract `timer_ptr` is the pointer that we passed when queuing the timer, so
        //   it is a valid pointer to a `HrTimer<T>` embedded in a `T`.
        // - We are within `RawHrTimerCallback::run`
        let context = unsafe { HrTimerCallbackContext::from_raw(timer_ptr) };

        T::run(receiver_pin, context).into_c()
    }
}
+8 −1
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@
use super::HasHrTimer;
use super::HrTimer;
use super::HrTimerCallback;
use super::HrTimerCallbackContext;
use super::HrTimerHandle;
use super::HrTimerMode;
use super::HrTimerPointer;
@@ -119,6 +120,12 @@ impl<T, A> RawHrTimerCallback for Pin<Box<T, A>>
        //   `data_ptr` exist.
        let data_mut_ref = unsafe { Pin::new_unchecked(&mut *data_ptr) };

        T::run(data_mut_ref).into_c()
        // SAFETY:
        // - By C API contract `timer_ptr` is the pointer that we passed when queuing the timer, so
        //   it is a valid pointer to a `HrTimer<T>` embedded in a `T`.
        // - We are within `RawHrTimerCallback::run`
        let context = unsafe { HrTimerCallbackContext::from_raw(timer_ptr) };

        T::run(data_mut_ref, context).into_c()
    }
}