Skip to content

Commit

Permalink
Add a #[derive(Invariant)] macro (#3250)
Browse files Browse the repository at this point in the history
This PR adds a `#[derive(Invariant)]` macro for structs which allows
users to automatically derive the `Invariant` implementations for any
struct. The derived implementation determines the invariant for the
struct as the conjunction of invariants of its fields. In other words,
the invariant is derived as `true && self.field1.is_safe() &&
self.field2.is_safe() && ..`.

For example, for the struct

```rs
#[derive(kani::Invariant)]
struct Point<X, Y> {
    x: X,
    y: Y,
}
```

we derive the `Invariant` implementation as

```rs
impl<X: kani::Invariant, Y: kani::Invariant> kani::Invariant for Point<X, Y> {
    fn is_safe(&self) -> bool {
        true && self.x.is_safe() && self.y.is_safe()
    }
}
```

Related #3095
  • Loading branch information
adpaco-aws authored Jun 13, 2024
1 parent 7dad847 commit 5c7cd63
Show file tree
Hide file tree
Showing 12 changed files with 244 additions and 2 deletions.
94 changes: 92 additions & 2 deletions library/kani_macros/src/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub fn expand_derive_arbitrary(item: proc_macro::TokenStream) -> proc_macro::Tok
let item_name = &derive_item.ident;

// Add a bound `T: Arbitrary` to every type parameter T.
let generics = add_trait_bound(derive_item.generics);
let generics = add_trait_bound_arbitrary(derive_item.generics);
// Generate an expression to sum up the heap size of each field.
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

Expand All @@ -40,7 +40,7 @@ pub fn expand_derive_arbitrary(item: proc_macro::TokenStream) -> proc_macro::Tok
}

/// Add a bound `T: Arbitrary` to every type parameter T.
fn add_trait_bound(mut generics: Generics) -> Generics {
fn add_trait_bound_arbitrary(mut generics: Generics) -> Generics {
generics.params.iter_mut().for_each(|param| {
if let GenericParam::Type(type_param) = param {
type_param.bounds.push(parse_quote!(kani::Arbitrary));
Expand Down Expand Up @@ -165,3 +165,93 @@ fn fn_any_enum(ident: &Ident, data: &DataEnum) -> TokenStream {
}
}
}

pub fn expand_derive_invariant(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
let derive_item = parse_macro_input!(item as DeriveInput);
let item_name = &derive_item.ident;

// Add a bound `T: Invariant` to every type parameter T.
let generics = add_trait_bound_invariant(derive_item.generics);
// Generate an expression to sum up the heap size of each field.
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let body = is_safe_body(&item_name, &derive_item.data);
let expanded = quote! {
// The generated implementation.
impl #impl_generics kani::Invariant for #item_name #ty_generics #where_clause {
fn is_safe(&self) -> bool {
#body
}
}
};
proc_macro::TokenStream::from(expanded)
}

/// Add a bound `T: Invariant` to every type parameter T.
fn add_trait_bound_invariant(mut generics: Generics) -> Generics {
generics.params.iter_mut().for_each(|param| {
if let GenericParam::Type(type_param) = param {
type_param.bounds.push(parse_quote!(kani::Invariant));
}
});
generics
}

fn is_safe_body(ident: &Ident, data: &Data) -> TokenStream {
match data {
Data::Struct(struct_data) => struct_safe_conjunction(ident, &struct_data.fields),
Data::Enum(_) => {
abort!(Span::call_site(), "Cannot derive `Invariant` for `{}` enum", ident;
note = ident.span() =>
"`#[derive(Invariant)]` cannot be used for enums such as `{}`", ident
)
}
Data::Union(_) => {
abort!(Span::call_site(), "Cannot derive `Invariant` for `{}` union", ident;
note = ident.span() =>
"`#[derive(Invariant)]` cannot be used for unions such as `{}`", ident
)
}
}
}

/// Generates an expression that is the conjunction of `is_safe` calls for each field in the struct.
fn struct_safe_conjunction(_ident: &Ident, fields: &Fields) -> TokenStream {
match fields {
// Expands to the expression
// `true && self.field1.is_safe() && self.field2.is_safe() && ..`
Fields::Named(ref fields) => {
let safe_calls = fields.named.iter().map(|field| {
let name = &field.ident;
quote_spanned! {field.span()=>
self.#name.is_safe()
}
});
// An initial value is required for empty structs
safe_calls.fold(quote! { true }, |acc, call| {
quote! { #acc && #call }
})
}
Fields::Unnamed(ref fields) => {
// Expands to the expression
// `true && self.0.is_safe() && self.1.is_safe() && ..`
let safe_calls = fields.unnamed.iter().enumerate().map(|(i, field)| {
let idx = syn::Index::from(i);
quote_spanned! {field.span()=>
self.#idx.is_safe()
}
});
// An initial value is required for empty structs
safe_calls.fold(quote! { true }, |acc, call| {
quote! { #acc && #call }
})
}
// Expands to the expression
// `true`
Fields::Unit => {
quote! {
true
}
}
}
}
7 changes: 7 additions & 0 deletions library/kani_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ pub fn derive_arbitrary(item: TokenStream) -> TokenStream {
derive::expand_derive_arbitrary(item)
}

/// Allow users to auto generate Invariant implementations by using `#[derive(Invariant)]` macro.
#[proc_macro_error]
#[proc_macro_derive(Invariant)]
pub fn derive_invariant(item: TokenStream) -> TokenStream {
derive::expand_derive_invariant(item)
}

/// Add a precondition to this function.
///
/// This is part of the function contract API, for more general information see
Expand Down
37 changes: 37 additions & 0 deletions tests/expected/derive-invariant/empty_struct/empty_struct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright Kani Contributors
// SPDX-License-Identifier: Apache-2.0 OR MIT

//! Check that Kani can automatically derive `Invariant` for empty structs.

extern crate kani;
use kani::Invariant;

#[derive(kani::Arbitrary)]
#[derive(kani::Invariant)]
struct Void;

#[derive(kani::Arbitrary)]
#[derive(kani::Invariant)]
struct Void2(());

#[derive(kani::Arbitrary)]
#[derive(kani::Invariant)]
struct VoidOfVoid(Void, Void2);

#[kani::proof]
fn check_empty_struct_invariant_1() {
let void1: Void = kani::any();
assert!(void1.is_safe());
}

#[kani::proof]
fn check_empty_struct_invariant_2() {
let void2: Void2 = kani::any();
assert!(void2.is_safe());
}

#[kani::proof]
fn check_empty_struct_invariant_3() {
let void3: VoidOfVoid = kani::any();
assert!(void3.is_safe());
}
8 changes: 8 additions & 0 deletions tests/expected/derive-invariant/empty_struct/expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
- Status: SUCCESS\
- Description: "assertion failed: void1.is_safe()"

- Status: SUCCESS\
- Description: "assertion failed: void2.is_safe()"

- Status: SUCCESS\
- Description: "assertion failed: void3.is_safe()"
2 changes: 2 additions & 0 deletions tests/expected/derive-invariant/generic_struct/expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- Status: SUCCESS\
- Description: "assertion failed: point.is_safe()"
20 changes: 20 additions & 0 deletions tests/expected/derive-invariant/generic_struct/generic_struct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright Kani Contributors
// SPDX-License-Identifier: Apache-2.0 OR MIT

//! Check that Kani can automatically derive `Invariant` for structs with generics.

extern crate kani;
use kani::Invariant;

#[derive(kani::Arbitrary)]
#[derive(kani::Invariant)]
struct Point<X, Y> {
x: X,
y: Y,
}

#[kani::proof]
fn check_generic_struct_invariant() {
let point: Point<i32, i8> = kani::any();
assert!(point.is_safe());
}
4 changes: 4 additions & 0 deletions tests/expected/derive-invariant/invariant_fail/expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- Status: FAILURE\
- Description: "assertion failed: wrapper.is_safe()"

Verification failed for - check_invariant_fail
33 changes: 33 additions & 0 deletions tests/expected/derive-invariant/invariant_fail/invariant_fail.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright Kani Contributors
// SPDX-License-Identifier: Apache-2.0 OR MIT

//! Check that a verification failure is triggered when the derived `Invariant`
//! method is checked but not satisfied.

extern crate kani;
use kani::Invariant;
// Note: This represents an incorrect usage of `Arbitrary` and `Invariant`.
//
// The `Arbitrary` implementation should respect the type invariant,
// but Kani does not enforce this in any way at the moment.
// <https://github.com/model-checking/kani/issues/3265>
#[derive(kani::Arbitrary)]
struct NotNegative(i32);

impl kani::Invariant for NotNegative {
fn is_safe(&self) -> bool {
self.0 >= 0
}
}

#[derive(kani::Arbitrary)]
#[derive(kani::Invariant)]
struct NotNegativeWrapper {
x: NotNegative,
}

#[kani::proof]
fn check_invariant_fail() {
let wrapper: NotNegativeWrapper = kani::any();
assert!(wrapper.is_safe());
}
2 changes: 2 additions & 0 deletions tests/expected/derive-invariant/named_struct/expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- Status: SUCCESS\
- Description: "assertion failed: point.is_safe()"
20 changes: 20 additions & 0 deletions tests/expected/derive-invariant/named_struct/named_struct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright Kani Contributors
// SPDX-License-Identifier: Apache-2.0 OR MIT

//! Check that Kani can automatically derive `Invariant` for structs with named fields.

extern crate kani;
use kani::Invariant;

#[derive(kani::Arbitrary)]
#[derive(kani::Invariant)]
struct Point {
x: i32,
y: i32,
}

#[kani::proof]
fn check_generic_struct_invariant() {
let point: Point = kani::any();
assert!(point.is_safe());
}
2 changes: 2 additions & 0 deletions tests/expected/derive-invariant/unnamed_struct/expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- Status: SUCCESS\
- Description: "assertion failed: point.is_safe()"
17 changes: 17 additions & 0 deletions tests/expected/derive-invariant/unnamed_struct/unnamed_struct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright Kani Contributors
// SPDX-License-Identifier: Apache-2.0 OR MIT

//! Check that Kani can automatically derive `Invariant` for structs with unnamed fields.

extern crate kani;
use kani::Invariant;

#[derive(kani::Arbitrary)]
#[derive(kani::Invariant)]
struct Point(i32, i32);

#[kani::proof]
fn check_generic_struct_invariant() {
let point: Point = kani::any();
assert!(point.is_safe());
}

0 comments on commit 5c7cd63

Please sign in to comment.