Commit c652dc44 authored by Kaibo Ma's avatar Kaibo Ma Committed by Shuah Khan
Browse files

rust: kunit: allow `cfg` on `test`s

The `kunit_test` proc macro only checks for the `test` attribute
immediately preceding a `fn`. If the function is disabled via a `cfg`,
the generated code would result in a compile error referencing a
non-existent function [1].

This collects attributes and specifically cherry-picks `cfg` attributes
to be duplicated inside KUnit wrapper functions such that a test function
disabled via `cfg` compiles and is marked as skipped in KUnit correctly.

Link: https://lore.kernel.org/r/20250916021259.115578-1-ent3rm4n@gmail.com
Link: https://lore.kernel.org/rust-for-linux/CANiq72==48=69hYiDo1321pCzgn_n1_jg=ez5UYXX91c+g5JVQ@mail.gmail.com/ [1]
Closes: https://github.com/Rust-for-Linux/linux/issues/1185


Suggested-by: default avatarMiguel Ojeda <ojeda@kernel.org>
Suggested-by: default avatarDavid Gow <davidgow@google.com>
Signed-off-by: default avatarKaibo Ma <ent3rm4n@gmail.com>
Reviewed-by: default avatarDavid Gow <davidgow@google.com>
Signed-off-by: default avatarShuah Khan <skhan@linuxfoundation.org>
parent f20e2642
Loading
Loading
Loading
Loading
+7 −0
Original line number Diff line number Diff line
@@ -361,4 +361,11 @@ fn rust_test_kunit_example_test() {
    fn rust_test_kunit_in_kunit_test() {
        assert!(in_kunit_test());
    }

    #[test]
    #[cfg(not(all()))]
    fn rust_test_kunit_always_disabled_test() {
        // This test should never run because of the `cfg`.
        assert!(false);
    }
}
+36 −12
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@
//! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com>

use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
use std::collections::HashMap;
use std::fmt::Write;

pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
@@ -41,20 +42,32 @@ pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
    // Get the functions set as tests. Search for `[test]` -> `fn`.
    let mut body_it = body.stream().into_iter();
    let mut tests = Vec::new();
    let mut attributes: HashMap<String, TokenStream> = HashMap::new();
    while let Some(token) = body_it.next() {
        match token {
            TokenTree::Group(ident) if ident.to_string() == "[test]" => match body_it.next() {
                Some(TokenTree::Ident(ident)) if ident.to_string() == "fn" => {
                    let test_name = match body_it.next() {
                        Some(TokenTree::Ident(ident)) => ident.to_string(),
                        _ => continue,
                    };
                    tests.push(test_name);
            TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() {
                Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => {
                    if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() {
                        // Collect attributes because we need to find which are tests. We also
                        // need to copy `cfg` attributes so tests can be conditionally enabled.
                        attributes
                            .entry(name.to_string())
                            .or_default()
                            .extend([token, TokenTree::Group(g)]);
                    }
                    continue;
                }
                _ => continue,
                _ => (),
            },
            TokenTree::Ident(i) if i.to_string() == "fn" && attributes.contains_key("test") => {
                if let Some(TokenTree::Ident(test_name)) = body_it.next() {
                    tests.push((test_name, attributes.remove("cfg").unwrap_or_default()))
                }
            }

            _ => (),
        }
        attributes.clear();
    }

    // Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration.
@@ -100,11 +113,22 @@ pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
    let mut test_cases = "".to_owned();
    let mut assert_macros = "".to_owned();
    let path = crate::helpers::file();
    for test in &tests {
    let num_tests = tests.len();
    for (test, cfg_attr) in tests {
        let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}");
        // An extra `use` is used here to reduce the length of the message.
        // Append any `cfg` attributes the user might have written on their tests so we don't
        // attempt to call them when they are `cfg`'d out. An extra `use` is used here to reduce
        // the length of the assert message.
        let kunit_wrapper = format!(
            "unsafe extern \"C\" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit) {{ use ::kernel::kunit::is_test_result_ok; assert!(is_test_result_ok({test}())); }}",
            r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit)
            {{
                (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
                {cfg_attr} {{
                    (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;
                    use ::kernel::kunit::is_test_result_ok;
                    assert!(is_test_result_ok({test}()));
                }}
            }}"#,
        );
        writeln!(kunit_macros, "{kunit_wrapper}").unwrap();
        writeln!(
@@ -139,7 +163,7 @@ macro_rules! assert_eq {{
    writeln!(
        kunit_macros,
        "static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases}    ::kernel::kunit::kunit_case_null(),\n];",
        tests.len() + 1
        num_tests + 1
    )
    .unwrap();