[<prev] [next>] [<thread-prev] [day] [month] [year] [list]
Message-Id: <58f312f85a30d1da0130b10735ddba89244241cb.1754228164.git.y.j3ms.n@gmail.com>
Date: Sun, 3 Aug 2025 14:20:53 +0000
From: Jesung Yang <y.j3ms.n@...il.com>
To: Miguel Ojeda <ojeda@...nel.org>,
Alex Gaynor <alex.gaynor@...il.com>,
Boqun Feng <boqun.feng@...il.com>,
Gary Guo <gary@...yguo.net>,
Björn Roy Baron <bjorn3_gh@...tonmail.com>,
Benno Lossin <lossin@...nel.org>,
Andreas Hindborg <a.hindborg@...nel.org>,
Alice Ryhl <aliceryhl@...gle.com>,
Trevor Gross <tmgross@...ch.edu>,
Danilo Krummrich <dakr@...nel.org>,
Alexandre Courbot <acourbot@...dia.com>
Cc: linux-kernel@...r.kernel.org,
rust-for-linux@...r.kernel.org,
nouveau@...ts.freedesktop.org,
Jesung Yang <y.j3ms.n@...il.com>
Subject: [PATCH 3/4] rust: macro: add derive macro for `TryFrom`
Introduce a procedural macro `TryFrom` to automatically implement the
`TryFrom` trait for unit-only enums.
This reduces boilerplate in cases where numeric values need to be
interpreted as relevant enum variants. This situation often arise when
working with low-level data sources. A typical example is the `Chipset`
enum in nova-core, where the value read from a GPU register should be
mapped to a corresponding variant.
Since a pending RFC [1] proposes adding the `syn` crate [2] as a
dependency, keep the parsing logic minimal.
Link: https://lore.kernel.org/rust-for-linux/20250304225536.2033853-1-benno.lossin@proton.me [1]
Link: https://docs.rs/syn/latest/syn [2]
Signed-off-by: Jesung Yang <y.j3ms.n@...il.com>
---
rust/macros/convert.rs | 337 +++++++++++++++++++++++++++++++++++++++++
rust/macros/lib.rs | 124 +++++++++++++++
2 files changed, 461 insertions(+)
create mode 100644 rust/macros/convert.rs
diff --git a/rust/macros/convert.rs b/rust/macros/convert.rs
new file mode 100644
index 000000000000..0084bc4308c1
--- /dev/null
+++ b/rust/macros/convert.rs
@@ -0,0 +1,337 @@
+// SPDX-License-Identifier: GPL-2.0
+
+use proc_macro::{token_stream, Delimiter, Ident, Span, TokenStream, TokenTree};
+use std::iter::Peekable;
+
+#[derive(Debug)]
+struct TypeArgs {
+ helper: Vec<Ident>,
+ repr: Option<Ident>,
+}
+
+const VALID_TYPES: [&str; 12] = [
+ "u8", "u16", "u32", "u64", "u128", "usize", "i8", "i16", "i32", "i64", "i128", "isize",
+];
+
+pub(crate) fn derive_try_from(input: TokenStream) -> TokenStream {
+ derive(input)
+}
+
+fn derive(input: TokenStream) -> TokenStream {
+ let derive_target = "TryFrom";
+ let derive_helper = "try_from";
+
+ let mut tokens = input.into_iter().peekable();
+
+ let type_args = match parse_types(&mut tokens, derive_helper) {
+ Ok(type_args) => type_args,
+ Err(errs) => return errs,
+ };
+
+ // Skip until the `enum` keyword, including the `enum` itself.
+ for tt in tokens.by_ref() {
+ if matches!(tt, TokenTree::Ident(ident) if ident.to_string() == "enum") {
+ break;
+ }
+ }
+
+ let Some(TokenTree::Ident(enum_ident)) = tokens.next() else {
+ return format!(
+ "::core::compile_error!(\"`#[derive({derive_target})]` can only \
+ be applied to an enum\");"
+ )
+ .parse::<TokenStream>()
+ .unwrap();
+ };
+
+ let mut errs = TokenStream::new();
+
+ if matches!(tokens.peek(), Some(TokenTree::Punct(p)) if p.as_char() == '<') {
+ errs.extend(
+ format!(
+ "::core::compile_error!(\"`#[derive({derive_target})]` \
+ does not support enums with generic parameters\");"
+ )
+ .parse::<TokenStream>()
+ .unwrap(),
+ );
+ }
+
+ let Some(variants_group) = tokens.find_map(|tt| match tt {
+ TokenTree::Group(g) if g.delimiter() == Delimiter::Brace => Some(g),
+ _ => None,
+ }) else {
+ unreachable!("Enums have its corresponding body")
+ };
+
+ let enum_body_tokens = variants_group.stream().into_iter().peekable();
+ let variants = match parse_enum_variants(enum_body_tokens, &enum_ident, derive_target) {
+ Ok(variants) => variants,
+ Err(new_errs) => {
+ errs.extend(new_errs);
+ return errs;
+ }
+ };
+
+ if !errs.is_empty() {
+ return errs;
+ }
+
+ if type_args.helper.is_empty() {
+ // Extract the representation passed by `#[repr(...)]` if present.
+ // If nothing is specified, the default is `Rust` representation,
+ // which uses `isize` for the discriminant type.
+ // See: https://doc.rust-lang.org/reference/items/enumerations.html#r-items.enum.discriminant.repr-rust
+ let ty = type_args
+ .repr
+ .unwrap_or_else(|| Ident::new("isize", Span::mixed_site()));
+ impl_try_from(&ty, &enum_ident, &variants)
+ } else {
+ let impls = type_args
+ .helper
+ .iter()
+ .map(|ty| impl_try_from(ty, &enum_ident, &variants));
+ quote! { #(#impls)* }
+ }
+}
+
+fn parse_types(
+ attrs: &mut Peekable<token_stream::IntoIter>,
+ helper: &str,
+) -> Result<TypeArgs, TokenStream> {
+ let mut helper_args = vec![];
+ let mut repr_arg = None;
+
+ // Scan only the attributes. As soon as we see a token that is
+ // not `#`, we know we have consumed all attributes.
+ while let Some(TokenTree::Punct(p)) = attrs.peek() {
+ if p.as_char() != '#' {
+ unreachable!("Outer attributes start with `#`");
+ }
+ attrs.next();
+
+ // The next token should be a `Group` delimited by brackets.
+ // (e.g., #[try_from(u8, u16)])
+ // ^^^^^^^^^^^^^^^^^^^
+ let Some(TokenTree::Group(attr_group)) = attrs.next() else {
+ unreachable!("Outer attributes are surrounded by `[` and `]`");
+ };
+
+ let mut inner = attr_group.stream().into_iter();
+
+ // Extract the attribute identifier.
+ // (e.g., #[try_from(u8, u16)])
+ // ^^^^^^^^
+ let attr_name = match inner.next() {
+ Some(TokenTree::Ident(ident)) => ident.to_string(),
+ _ => unreachable!("Attributes have identifiers"),
+ };
+
+ if attr_name == helper {
+ match parse_helper_args(inner, helper) {
+ Ok(v) => helper_args.extend_from_slice(&v),
+ Err(errs) => return Err(errs),
+ }
+ } else if attr_name == "repr" {
+ repr_arg = parse_repr_args(inner);
+ }
+ }
+
+ Ok(TypeArgs {
+ helper: helper_args,
+ repr: repr_arg,
+ })
+}
+
+fn parse_repr_args(mut tt_group: impl Iterator<Item = TokenTree>) -> Option<Ident> {
+ // The next token should be a `Group` delimited by parentheses.
+ // (e.g., #[repr(C, u8)])
+ // ^^^^^^^
+ let Some(TokenTree::Group(args_group)) = tt_group.next() else {
+ unreachable!("`repr` attribute has at least one argument")
+ };
+
+ for arg in args_group.stream() {
+ if let TokenTree::Ident(type_ident) = arg {
+ if VALID_TYPES.contains(&type_ident.to_string().as_str()) {
+ return Some(type_ident);
+ }
+ }
+ }
+
+ None
+}
+
+fn parse_helper_args(
+ mut tt_group: impl Iterator<Item = TokenTree>,
+ helper: &str,
+) -> Result<Vec<Ident>, TokenStream> {
+ let mut errs = TokenStream::new();
+ let mut args = vec![];
+
+ // The next token should be a `Group`.
+ // (e.g., #[try_from(u8, u16)])
+ // ^^^^^^^^^
+ let Some(TokenTree::Group(args_group)) = tt_group.next() else {
+ return Err(format!(
+ "::core::compile_error!(\"`{helper}` attribute expects at \
+ least one integer type argument (e.g., `#[{helper}(u8)]`)\");"
+ )
+ .parse::<TokenStream>()
+ .unwrap());
+ };
+
+ let raw_args = args_group.stream();
+ if raw_args.is_empty() {
+ return Err(format!(
+ "::core::compile_error!(\"`{helper}` attribute expects at \
+ least one integer type argument (e.g., `#[{helper}(u8)]`)\");"
+ )
+ .parse::<TokenStream>()
+ .unwrap());
+ }
+
+ // Iterate over the attribute argument tokens to collect valid integer
+ // type identifiers.
+ let mut raw_args = raw_args.into_iter();
+ while let Some(tt) = raw_args.next() {
+ let TokenTree::Ident(type_ident) = tt else {
+ errs.extend(
+ format!(
+ "::core::compile_error!(\"`{helper}` attribute expects \
+ comma-separated integer type arguments \
+ (e.g., `#[{helper}(u8, u16)]`)\");"
+ )
+ .parse::<TokenStream>()
+ .unwrap(),
+ );
+ return Err(errs);
+ };
+
+ if VALID_TYPES.contains(&type_ident.to_string().as_str()) {
+ args.push(type_ident);
+ } else {
+ errs.extend(
+ format!(
+ "::core::compile_error!(\"`{type_ident}` in `{helper}` \
+ attribute is not an integer type\");"
+ )
+ .parse::<TokenStream>()
+ .unwrap(),
+ );
+ }
+
+ match raw_args.next() {
+ Some(TokenTree::Punct(p)) if p.as_char() == ',' => continue,
+ None => break,
+ Some(_) => {
+ errs.extend(
+ format!(
+ "::core::compile_error!(\"`{helper}` attribute expects \
+ comma-separated integer type arguments \
+ (e.g., `#[{helper}(u8, u16)]`)\");"
+ )
+ .parse::<TokenStream>()
+ .unwrap(),
+ );
+ return Err(errs);
+ }
+ }
+ }
+
+ if !errs.is_empty() {
+ return Err(errs);
+ }
+
+ Ok(args)
+}
+
+fn parse_enum_variants(
+ mut tokens: Peekable<token_stream::IntoIter>,
+ enum_ident: &Ident,
+ derive_target: &str,
+) -> Result<Vec<Ident>, TokenStream> {
+ let mut errs = TokenStream::new();
+
+ let mut variants = vec![];
+
+ if tokens.peek().is_none() {
+ errs.extend(
+ format!(
+ "::core::compile_error!(\"`#[derive({derive_target})]` \
+ does not support zero-variant enums\");"
+ )
+ .parse::<TokenStream>()
+ .unwrap(),
+ );
+ }
+
+ while let Some(tt) = tokens.next() {
+ // Skip attributes like `#[...]` if present.
+ if matches!(&tt, TokenTree::Punct(p) if p.as_char() == '#') {
+ tokens.next();
+ continue;
+ }
+
+ let TokenTree::Ident(ident) = tt else {
+ unreachable!("Enum variants have its corresponding identifier");
+ };
+
+ // Reject tuple-like or struct-like variants.
+ if let Some(TokenTree::Group(g)) = tokens.peek() {
+ let variant_kind = match g.delimiter() {
+ Delimiter::Brace => "struct-like",
+ Delimiter::Parenthesis => "tuple-like",
+ _ => unreachable!("Invalid enum variant syntax"),
+ };
+ errs.extend(
+ format!(
+ "::core::compile_error!(\"`#[derive({derive_target})]` does not \
+ support {variant_kind} variant `{enum_ident}::{ident}`; \
+ only unit variants are allowed\");"
+ )
+ .parse::<TokenStream>()
+ .unwrap(),
+ );
+ }
+
+ // Skip through the comma.
+ for tt in tokens.by_ref() {
+ if matches!(tt, TokenTree::Punct(p) if p.as_char() == ',') {
+ break;
+ }
+ }
+
+ variants.push(ident);
+ }
+
+ if !errs.is_empty() {
+ return Err(errs);
+ }
+
+ Ok(variants)
+}
+
+fn impl_try_from(ty: &Ident, enum_ident: &Ident, variants: &[Ident]) -> TokenStream {
+ let param = Ident::new("value", Span::mixed_site());
+
+ let clauses = variants.iter().map(|variant| {
+ quote! {
+ if #param == Self::#variant as #ty {
+ ::core::result::Result::Ok(Self::#variant)
+ } else
+ }
+ });
+
+ quote! {
+ #[automatically_derived]
+ impl ::core::convert::TryFrom<#ty> for #enum_ident {
+ type Error = ::kernel::prelude::Error;
+ fn try_from(#param: #ty) -> Result<Self, Self::Error> {
+ #(#clauses)* {
+ ::core::result::Result::Err(::kernel::prelude::EINVAL)
+ }
+ }
+ }
+ }
+}
diff --git a/rust/macros/lib.rs b/rust/macros/lib.rs
index fa847cf3a9b5..569198f188f7 100644
--- a/rust/macros/lib.rs
+++ b/rust/macros/lib.rs
@@ -14,6 +14,7 @@
#[macro_use]
mod quote;
mod concat_idents;
+mod convert;
mod export;
mod helpers;
mod kunit;
@@ -425,3 +426,126 @@ pub fn paste(input: TokenStream) -> TokenStream {
pub fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
kunit::kunit_tests(attr, ts)
}
+
+/// A derive macro for generating an impl of the [`TryFrom`] trait.
+///
+/// This macro automatically derives [`TryFrom`] trait for a given enum. Currently,
+/// it only supports [unit-only enum]s without generic parameters.
+///
+/// [unit-only enum]: https://doc.rust-lang.org/reference/items/enumerations.html#r-items.enum.unit-only
+///
+/// # Notes
+///
+/// The macro generates [`TryFrom`] implementations that:
+/// - Convert numeric values to enum variants by matching discriminant values.
+/// - Return `Ok(VARIANT)` for valid matches.
+/// - Return `Err(EINVAL)` for invalid matches (where `EINVAL` is from
+/// [`kernel::error::code`]).
+///
+/// The macro uses the `try_from` custom attribute or `repr` attribute to generate
+/// corresponding [`TryFrom`] implementations. `try_from` always takes precedence
+/// over `repr`.
+///
+/// [`kernel::error::code`]: ../kernel/error/code/index.html
+///
+/// # Caveats
+///
+/// Ensure that every integer type specified in `#[try_from(...)]` is large enough
+/// to cover all enum discriminants. Otherwise, the internal `as` casts may overflow.
+///
+/// # Examples
+///
+/// ## Without Attributes
+///
+/// Since [the default `Rust` representation uses `isize` for the discriminant type][repr-rs],
+/// the macro implements `TryFrom<isize>`:
+///
+/// [repr-rs]: https://doc.rust-lang.org/reference/items/enumerations.html#r-items.enum.discriminant.repr-rust
+///
+/// ```rust
+/// use kernel::macros::TryFrom;
+/// use kernel::prelude::*;
+///
+/// #[derive(Debug, Default, PartialEq, TryFrom)]
+/// enum Foo {
+/// #[default]
+/// A,
+/// B = 0x17,
+/// }
+///
+/// assert_eq!(Foo::try_from(0isize), Ok(Foo::A));
+/// assert_eq!(Foo::try_from(0x17isize), Ok(Foo::B));
+/// assert_eq!(Foo::try_from(0x19isize), Err(EINVAL));
+/// ```
+///
+/// ## With `#[repr(T)]`
+///
+/// The macro implements `TryFrom<T>`:
+///
+/// ```rust
+/// use kernel::macros::TryFrom;
+/// use kernel::prelude::*;
+///
+/// #[derive(Debug, Default, PartialEq, TryFrom)]
+/// #[repr(u8)]
+/// enum Foo {
+/// #[default]
+/// A,
+/// B = 0x17,
+/// }
+///
+/// assert_eq!(Foo::try_from(0u8), Ok(Foo::A));
+/// assert_eq!(Foo::try_from(0x17u8), Ok(Foo::B));
+/// assert_eq!(Foo::try_from(0x19u8), Err(EINVAL));
+/// ```
+///
+/// ## With `#[try_from(...)]`
+///
+/// The macro implements `TryFrom<T>` for each `T` specified in `#[try_from(...)]`,
+/// which always overrides `#[repr(...)]`:
+///
+/// ```rust
+/// use kernel::macros::TryFrom;
+/// use kernel::prelude::*;
+///
+/// #[derive(Debug, Default, PartialEq, TryFrom)]
+/// #[try_from(u8, u16)]
+/// #[repr(u8)]
+/// enum Foo {
+/// #[default]
+/// A,
+/// B = 0x17,
+/// }
+///
+/// assert_eq!(Foo::try_from(0u16), Ok(Foo::A));
+/// assert_eq!(Foo::try_from(0x17u16), Ok(Foo::B));
+/// assert_eq!(Foo::try_from(0x19u16), Err(EINVAL));
+/// ```
+///
+/// ## Unsupported Cases
+///
+/// The following examples do not compile:
+///
+/// ```compile_fail
+/// # use kernel::macros::TryFrom;
+/// // Generic parameters are not allowed.
+/// #[derive(TryFrom)]
+/// enum Foo<T> {
+/// A,
+/// }
+///
+/// // Tuple-like enums or struct-like enums are not allowed.
+/// #[derive(TryFrom)]
+/// enum Bar {
+/// A(u8),
+/// B { inner: u8 },
+/// }
+///
+/// // Structs are not allowed.
+/// #[derive(TryFrom)]
+/// struct Baz(u8);
+/// ```
+#[proc_macro_derive(TryFrom, attributes(try_from))]
+pub fn derive_try_from(input: TokenStream) -> TokenStream {
+ convert::derive_try_from(input)
+}
--
2.39.5
Powered by blists - more mailing lists