diff --git a/library/kani_macros/src/derive.rs b/library/kani_macros/src/derive.rs index de173b53bbb6..36424f5c8c3e 100644 --- a/library/kani_macros/src/derive.rs +++ b/library/kani_macros/src/derive.rs @@ -157,7 +157,7 @@ fn inv_conds_inner(ident: &Ident, fields: &Fields) -> Option { /// ``` /// which allows us to refer to the struct fields without using `self`. /// Note that the actual stream is generated in the `field_refs_inner` function. -fn field_refs(ident: &Ident, data: &Data) -> TokenStream { +pub fn field_refs(ident: &Ident, data: &Data) -> TokenStream { match data { Data::Struct(struct_data) => field_refs_inner(ident, &struct_data.fields), Data::Enum(_) => unreachable!(), @@ -344,7 +344,7 @@ pub fn expand_derive_invariant(item: proc_macro::TokenStream) -> proc_macro::Tok } /// Add a bound `T: Invariant` to every type parameter T. -fn add_trait_bound_invariant(mut generics: Generics) -> Generics { +pub fn add_trait_bound_invariant(mut generics: Generics) -> Generics { generics.params.iter_mut().for_each(|param| { if let GenericParam::Type(type_param) = param { type_param.bounds.push(parse_quote!(kani::Invariant)); @@ -422,43 +422,3 @@ fn struct_invariant_conjunction(ident: &Ident, fields: &Fields) -> TokenStream { } } } - -/// Generates an `Invariant` implementation where the `is_safe` function body is -/// the `attr` expression passed to the attribute macro. -/// Only available for structs. -pub fn attr_custom_invariant( - attr: TokenStream, - item: proc_macro::TokenStream, -) -> proc_macro::TokenStream { - let derive_item = parse_macro_input!(item as DeriveInput); - let item_name = &derive_item.ident; - - if !matches!(derive_item.data, Data::Struct(..)) { - abort!(Span::call_site(), "Cannot define invariant for `{}`", item_name; - note = item_name.span() => - "`#[kani::invariant(..)]` is only available for structs" - ) - } - - // Keep a copy of the original item to re-emit it later. - // Note that this isn't a derive macro - let original_item = derive_item.clone(); - - // Add a bound `T: Invariant` to every type parameter T. - let generics = add_trait_bound_invariant(derive_item.generics); - // Generate an expression to sum up the heap size of each field. - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - - let expanded = quote! { - // Re-emit the original item - #original_item - - // The generated implementation. - impl #impl_generics kani::Invariant for #item_name #ty_generics #where_clause { - fn is_safe(&self) -> bool { - #attr - } - } - }; - proc_macro::TokenStream::from(expanded) -} diff --git a/library/kani_macros/src/lib.rs b/library/kani_macros/src/lib.rs index d37b75b85f1e..9c9023341111 100644 --- a/library/kani_macros/src/lib.rs +++ b/library/kani_macros/src/lib.rs @@ -12,8 +12,11 @@ mod derive; // proc_macro::quote is nightly-only, so we'll cobble things together instead +use derive::{add_trait_bound_invariant, field_refs}; use proc_macro::TokenStream; -use proc_macro_error::proc_macro_error; +use proc_macro2::{Span, TokenStream as TokenStream2}; +use proc_macro_error::{abort, proc_macro_error}; +use quote::quote; #[cfg(kani_sysroot)] use sysroot as attr_impl; @@ -95,7 +98,7 @@ pub fn solver(attr: TokenStream, item: TokenStream) -> TokenStream { #[proc_macro_attribute] #[proc_macro_error] pub fn invariant(attr: TokenStream, item: TokenStream) -> TokenStream { - derive::attr_custom_invariant(attr.into(), item) + attr_custom_invariant(attr.into(), item) } /// Mark an API as unstable. This should only be used inside the Kani sysroot. @@ -398,3 +401,46 @@ mod regular { no_op!(proof_for_contract); no_op!(stub_verified); } + +/// Generates an `Invariant` implementation where the `is_safe` function body is +/// the `attr` expression passed to the attribute macro. +/// Only available for structs. +fn attr_custom_invariant( + attr: TokenStream2, + item: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let derive_item = syn::parse_macro_input!(item as syn::DeriveInput); + let item_name = &derive_item.ident; + + if !matches!(derive_item.data, syn::Data::Struct(..)) { + abort!(Span::call_site(), "Cannot define invariant for `{}`", item_name; + note = item_name.span() => + "`#[kani::invariant(..)]` is only available for structs" + ) + } + + // Keep a copy of the original item to re-emit it later. + // Note that this isn't a derive macro + let original_item = derive_item.clone(); + + // Add a bound `T: Invariant` to every type parameter T. + let generics = add_trait_bound_invariant(derive_item.generics); + // Generate an expression to sum up the heap size of each field. + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let field_refs = field_refs(&item_name, &derive_item.data); + let expanded = quote! { + // Re-emit the original item + #original_item + + // The generated implementation. + impl #impl_generics kani::Invariant for #item_name #ty_generics #where_clause { + fn is_safe(&self) -> bool { + let obj = self; + #field_refs + #attr + } + } + }; + TokenStream::from(expanded) +} diff --git a/tests/expected/attr-invariant/check_invariant/check_invariant.rs b/tests/expected/attr-invariant/check_invariant/check_invariant.rs index ffbeab5af619..2e26e9f7b01c 100644 --- a/tests/expected/attr-invariant/check_invariant/check_invariant.rs +++ b/tests/expected/attr-invariant/check_invariant/check_invariant.rs @@ -8,7 +8,7 @@ extern crate kani; use kani::Invariant; #[derive(kani::Arbitrary)] -#[kani::invariant(self.x.is_safe() && self.y.is_safe())] +#[kani::invariant(x.is_safe() && y.is_safe())] struct Point { x: i32, y: i32,