Unverified Commit 0b24f974 authored by Andreas Hindborg's avatar Andreas Hindborg Committed by Daniel Gomez
Browse files

rust: module: update the module macro with module parameter support



Allow module parameters to be declared in the rust `module!` macro.

Reviewed-by: default avatarBenno Lossin <lossin@kernel.org>
Signed-off-by: default avatarAndreas Hindborg <a.hindborg@kernel.org>
Tested-by: default avatarDaniel Gomez <da.gomez@samsung.com>
Signed-off-by: default avatarDaniel Gomez <da.gomez@kernel.org>
parent 3809d7a8
Loading
Loading
Loading
Loading
+25 −0
Original line number Diff line number Diff line
@@ -10,6 +10,17 @@ pub(crate) fn try_ident(it: &mut token_stream::IntoIter) -> Option<String> {
    }
}

pub(crate) fn try_sign(it: &mut token_stream::IntoIter) -> Option<char> {
    let peek = it.clone().next();
    match peek {
        Some(TokenTree::Punct(punct)) if punct.as_char() == '-' => {
            let _ = it.next();
            Some(punct.as_char())
        }
        _ => None,
    }
}

pub(crate) fn try_literal(it: &mut token_stream::IntoIter) -> Option<String> {
    if let Some(TokenTree::Literal(literal)) = it.next() {
        Some(literal.to_string())
@@ -103,3 +114,17 @@ pub(crate) fn file() -> String {
        proc_macro::Span::call_site().file()
    }
}

/// Parse a token stream of the form `expected_name: "value",` and return the
/// string in the position of "value".
///
/// # Panics
///
/// - On parse error.
pub(crate) fn expect_string_field(it: &mut token_stream::IntoIter, expected_name: &str) -> String {
    assert_eq!(expect_ident(it), expected_name);
    assert_eq!(expect_punct(it), ':');
    let string = expect_string(it);
    assert_eq!(expect_punct(it), ',');
    string
}
+31 −0
Original line number Diff line number Diff line
@@ -28,6 +28,30 @@
/// The `type` argument should be a type which implements the [`Module`]
/// trait. Also accepts various forms of kernel metadata.
///
/// The `params` field describe module parameters. Each entry has the form
///
/// ```ignore
/// parameter_name: type {
///     default: default_value,
///     description: "Description",
/// }
/// ```
///
/// `type` may be one of
///
/// - [`i8`]
/// - [`u8`]
/// - [`i8`]
/// - [`u8`]
/// - [`i16`]
/// - [`u16`]
/// - [`i32`]
/// - [`u32`]
/// - [`i64`]
/// - [`u64`]
/// - [`isize`]
/// - [`usize`]
///
/// C header: [`include/linux/moduleparam.h`](srctree/include/linux/moduleparam.h)
///
/// [`Module`]: ../kernel/trait.Module.html
@@ -44,6 +68,12 @@
///     description: "My very own kernel module!",
///     license: "GPL",
///     alias: ["alternate_module_name"],
///     params: {
///         my_parameter: i64 {
///             default: 1,
///             description: "This parameter has a default of 1",
///         },
///     },
/// }
///
/// struct MyModule(i32);
@@ -52,6 +82,7 @@
///     fn init(_module: &'static ThisModule) -> Result<Self> {
///         let foo: i32 = 42;
///         pr_info!("I contain:  {}\n", foo);
///         pr_info!("i32 param is:  {}\n", module_parameters::my_parameter.read());
///         Ok(Self(foo))
///     }
/// }
+168 −10
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@ struct ModInfoBuilder<'a> {
    module: &'a str,
    counter: usize,
    buffer: String,
    param_buffer: String,
}

impl<'a> ModInfoBuilder<'a> {
@@ -34,10 +35,11 @@ fn new(module: &'a str) -> Self {
            module,
            counter: 0,
            buffer: String::new(),
            param_buffer: String::new(),
        }
    }

    fn emit_base(&mut self, field: &str, content: &str, builtin: bool) {
    fn emit_base(&mut self, field: &str, content: &str, builtin: bool, param: bool) {
        let string = if builtin {
            // Built-in modules prefix their modinfo strings by `module.`.
            format!(
@@ -51,8 +53,14 @@ fn emit_base(&mut self, field: &str, content: &str, builtin: bool) {
            format!("{field}={content}\0")
        };

        let buffer = if param {
            &mut self.param_buffer
        } else {
            &mut self.buffer
        };

        write!(
            &mut self.buffer,
            buffer,
            "
                {cfg}
                #[doc(hidden)]
@@ -75,19 +83,118 @@ fn emit_base(&mut self, field: &str, content: &str, builtin: bool) {
        self.counter += 1;
    }

    fn emit_only_builtin(&mut self, field: &str, content: &str) {
        self.emit_base(field, content, true)
    fn emit_only_builtin(&mut self, field: &str, content: &str, param: bool) {
        self.emit_base(field, content, true, param)
    }

    fn emit_only_loadable(&mut self, field: &str, content: &str) {
        self.emit_base(field, content, false)
    fn emit_only_loadable(&mut self, field: &str, content: &str, param: bool) {
        self.emit_base(field, content, false, param)
    }

    fn emit(&mut self, field: &str, content: &str) {
        self.emit_only_builtin(field, content);
        self.emit_only_loadable(field, content);
        self.emit_internal(field, content, false);
    }

    fn emit_internal(&mut self, field: &str, content: &str, param: bool) {
        self.emit_only_builtin(field, content, param);
        self.emit_only_loadable(field, content, param);
    }

    fn emit_param(&mut self, field: &str, param: &str, content: &str) {
        let content = format!("{param}:{content}", param = param, content = content);
        self.emit_internal(field, &content, true);
    }

    fn emit_params(&mut self, info: &ModuleInfo) {
        let Some(params) = &info.params else {
            return;
        };

        for param in params {
            let ops = param_ops_path(&param.ptype);

            // Note: The spelling of these fields is dictated by the user space
            // tool `modinfo`.
            self.emit_param("parmtype", &param.name, &param.ptype);
            self.emit_param("parm", &param.name, &param.description);

            write!(
                self.param_buffer,
                "
                pub(crate) static {param_name}:
                    ::kernel::module_param::ModuleParamAccess<{param_type}> =
                        ::kernel::module_param::ModuleParamAccess::new({param_default});

                const _: () = {{
                    #[link_section = \"__param\"]
                    #[used]
                    static __{module_name}_{param_name}_struct:
                        ::kernel::module_param::KernelParam =
                        ::kernel::module_param::KernelParam::new(
                            ::kernel::bindings::kernel_param {{
                                name: if ::core::cfg!(MODULE) {{
                                    ::kernel::c_str!(\"{param_name}\").as_bytes_with_nul()
                                }} else {{
                                    ::kernel::c_str!(\"{module_name}.{param_name}\")
                                        .as_bytes_with_nul()
                                }}.as_ptr(),
                                // SAFETY: `__this_module` is constructed by the kernel at load
                                // time and will not be freed until the module is unloaded.
                                #[cfg(MODULE)]
                                mod_: unsafe {{
                                    core::ptr::from_ref(&::kernel::bindings::__this_module)
                                        .cast_mut()
                                }},
                                #[cfg(not(MODULE))]
                                mod_: ::core::ptr::null_mut(),
                                ops: core::ptr::from_ref(&{ops}),
                                perm: 0, // Will not appear in sysfs
                                level: -1,
                                flags: 0,
                                __bindgen_anon_1: ::kernel::bindings::kernel_param__bindgen_ty_1 {{
                                    arg: {param_name}.as_void_ptr()
                                }},
                            }}
                        );
                }};
                ",
                module_name = info.name,
                param_type = param.ptype,
                param_default = param.default,
                param_name = param.name,
                ops = ops,
            )
            .unwrap();
        }
    }
}

fn param_ops_path(param_type: &str) -> &'static str {
    match param_type {
        "i8" => "::kernel::module_param::PARAM_OPS_I8",
        "u8" => "::kernel::module_param::PARAM_OPS_U8",
        "i16" => "::kernel::module_param::PARAM_OPS_I16",
        "u16" => "::kernel::module_param::PARAM_OPS_U16",
        "i32" => "::kernel::module_param::PARAM_OPS_I32",
        "u32" => "::kernel::module_param::PARAM_OPS_U32",
        "i64" => "::kernel::module_param::PARAM_OPS_I64",
        "u64" => "::kernel::module_param::PARAM_OPS_U64",
        "isize" => "::kernel::module_param::PARAM_OPS_ISIZE",
        "usize" => "::kernel::module_param::PARAM_OPS_USIZE",
        t => panic!("Unsupported parameter type {}", t),
    }
}

fn expect_param_default(param_it: &mut token_stream::IntoIter) -> String {
    assert_eq!(expect_ident(param_it), "default");
    assert_eq!(expect_punct(param_it), ':');
    let sign = try_sign(param_it);
    let default = try_literal(param_it).expect("Expected default param value");
    assert_eq!(expect_punct(param_it), ',');
    let mut value = sign.map(String::from).unwrap_or_default();
    value.push_str(&default);
    value
}

#[derive(Debug, Default)]
struct ModuleInfo {
@@ -98,6 +205,50 @@ struct ModuleInfo {
    description: Option<String>,
    alias: Option<Vec<String>>,
    firmware: Option<Vec<String>>,
    params: Option<Vec<Parameter>>,
}

#[derive(Debug)]
struct Parameter {
    name: String,
    ptype: String,
    default: String,
    description: String,
}

fn expect_params(it: &mut token_stream::IntoIter) -> Vec<Parameter> {
    let params = expect_group(it);
    assert_eq!(params.delimiter(), Delimiter::Brace);
    let mut it = params.stream().into_iter();
    let mut parsed = Vec::new();

    loop {
        let param_name = match it.next() {
            Some(TokenTree::Ident(ident)) => ident.to_string(),
            Some(_) => panic!("Expected Ident or end"),
            None => break,
        };

        assert_eq!(expect_punct(&mut it), ':');
        let param_type = expect_ident(&mut it);
        let group = expect_group(&mut it);
        assert_eq!(group.delimiter(), Delimiter::Brace);
        assert_eq!(expect_punct(&mut it), ',');

        let mut param_it = group.stream().into_iter();
        let param_default = expect_param_default(&mut param_it);
        let param_description = expect_string_field(&mut param_it, "description");
        expect_end(&mut param_it);

        parsed.push(Parameter {
            name: param_name,
            ptype: param_type,
            default: param_default,
            description: param_description,
        })
    }

    parsed
}

impl ModuleInfo {
@@ -112,6 +263,7 @@ fn parse(it: &mut token_stream::IntoIter) -> Self {
            "license",
            "alias",
            "firmware",
            "params",
        ];
        const REQUIRED_KEYS: &[&str] = &["type", "name", "license"];
        let mut seen_keys = Vec::new();
@@ -137,6 +289,7 @@ fn parse(it: &mut token_stream::IntoIter) -> Self {
                "license" => info.license = expect_string_ascii(it),
                "alias" => info.alias = Some(expect_string_array(it)),
                "firmware" => info.firmware = Some(expect_string_array(it)),
                "params" => info.params = Some(expect_params(it)),
                _ => panic!("Unknown key \"{key}\". Valid keys are: {EXPECTED_KEYS:?}."),
            }

@@ -199,7 +352,9 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream {
    // Built-in modules also export the `file` modinfo string.
    let file =
        std::env::var("RUST_MODFILE").expect("Unable to fetch RUST_MODFILE environmental variable");
    modinfo.emit_only_builtin("file", &file);
    modinfo.emit_only_builtin("file", &file, false);

    modinfo.emit_params(&info);

    format!(
        "
@@ -363,15 +518,18 @@ unsafe fn __exit() {{
                            __MOD.assume_init_drop();
                        }}
                    }}

                    {modinfo}
                }}
            }}
            mod module_parameters {{
                {params}
            }}
        ",
        type_ = info.type_,
        name = info.name,
        ident = ident,
        modinfo = modinfo.buffer,
        params = modinfo.param_buffer,
        initcall_section = ".initcall6.init"
    )
    .parse()