diff --git a/cprover_bindings/src/irep/to_irep.rs b/cprover_bindings/src/irep/to_irep.rs index 4ed1ba5604d9..48cb9706a0ba 100644 --- a/cprover_bindings/src/irep/to_irep.rs +++ b/cprover_bindings/src/irep/to_irep.rs @@ -520,7 +520,10 @@ impl ToIrep for StmtBody { let stmt_goto = code_irep(IrepId::Goto, vec![]) .with_named_sub(IrepId::Destination, Irep::just_string_id(dest.to_string())); if let Some(inv) = loop_invariants { - stmt_goto.with_named_sub(IrepId::CSpecLoopInvariant, inv.to_irep(mm)) + stmt_goto.with_named_sub( + IrepId::CSpecLoopInvariant, + inv.clone().and(Expr::bool_true()).to_irep(mm), + ) } else { stmt_goto } diff --git a/kani-compiler/src/codegen_cprover_gotoc/codegen/block.rs b/kani-compiler/src/codegen_cprover_gotoc/codegen/block.rs index 7170504aec78..1b28de887002 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/codegen/block.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/codegen/block.rs @@ -20,55 +20,25 @@ impl<'tcx> GotocCtx<'tcx> { pub fn codegen_block(&mut self, bb: BasicBlockIdx, bbd: &BasicBlock) { debug!(?bb, "codegen_block"); let label = bb_label(bb); - - // record the seen bbidx if loop contracts enabled - if self.loop_contracts_ctx.loop_contracts_enabled() { - self.loop_contracts_ctx.add_new_seen_bbidx(bb); - } - // the first statement should be labelled. if there is no statements, then the // terminator should be labelled. match bbd.statements.len() { 0 => { let term = &bbd.terminator; - let tcode = if self.loop_contracts_ctx.loop_contracts_enabled() { - let codegen_result = self.codegen_terminator(term); - self.loop_contracts_ctx.push_onto_block(codegen_result) - } else { - self.codegen_terminator(term) - }; - + let tcode = self.codegen_terminator(term); self.current_fn_mut().push_onto_block(tcode.with_label(label)); } _ => { let stmt = &bbd.statements[0]; - let scode = if self.loop_contracts_ctx.loop_contracts_enabled() { - let codegen_result = self.codegen_statement(stmt); - self.loop_contracts_ctx.push_onto_block(codegen_result) - } else { - self.codegen_statement(stmt) - }; - + let scode = self.codegen_statement(stmt); self.current_fn_mut().push_onto_block(scode.with_label(label)); for s in &bbd.statements[1..] { - let stmt = if self.loop_contracts_ctx.loop_contracts_enabled() { - let codegen_result = self.codegen_statement(s); - self.loop_contracts_ctx.push_onto_block(codegen_result) - } else { - self.codegen_statement(s) - }; + let stmt = self.codegen_statement(s); self.current_fn_mut().push_onto_block(stmt); } let term = &bbd.terminator; - - let tcode = if self.loop_contracts_ctx.loop_contracts_enabled() { - let codegen_result = self.codegen_terminator(term); - self.loop_contracts_ctx.push_onto_block(codegen_result) - } else { - self.codegen_terminator(term) - }; - + let tcode = self.codegen_terminator(term); self.current_fn_mut().push_onto_block(tcode); } } diff --git a/kani-compiler/src/codegen_cprover_gotoc/codegen/function.rs b/kani-compiler/src/codegen_cprover_gotoc/codegen/function.rs index ac9f1c30b146..34f8363f4948 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/codegen/function.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/codegen/function.rs @@ -56,9 +56,6 @@ impl<'tcx> GotocCtx<'tcx> { let old_sym = self.symbol_table.lookup(&name).unwrap(); let _trace_span = debug_span!("CodegenFunction", name = instance.name()).entered(); - if self.loop_contracts_ctx.loop_contracts_enabled() { - self.loop_contracts_ctx.enter_new_function(); - } if old_sym.is_function_definition() { debug!("Double codegen of {:?}", old_sym); } else { diff --git a/kani-compiler/src/codegen_cprover_gotoc/codegen/statement.rs b/kani-compiler/src/codegen_cprover_gotoc/codegen/statement.rs index 590ea61e102a..ed0178511126 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/codegen/statement.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/codegen/statement.rs @@ -192,21 +192,12 @@ impl<'tcx> GotocCtx<'tcx> { /// /// See also [`GotocCtx::codegen_statement`] for ordinary [Statement]s. pub fn codegen_terminator(&mut self, term: &Terminator) -> Stmt { - let loc: Location = self.codegen_span_stable(term.span); + let loc = self.codegen_span_stable(term.span); let _trace_span = debug_span!("CodegenTerminator", statement = ?term.kind).entered(); debug!("handling terminator {:?}", term); //TODO: Instead of doing location::none(), and updating, just putit in when we make the stmt. match &term.kind { - TerminatorKind::Goto { target } => { - if self.loop_contracts_ctx.loop_contracts_enabled() - && self.loop_contracts_ctx.is_loop_latch(target) - { - Stmt::goto(bb_label(*target), loc) - .with_loop_contracts(self.loop_contracts_ctx.extract_block(loc)) - } else { - Stmt::goto(bb_label(*target), loc) - } - } + TerminatorKind::Goto { target } => Stmt::goto(bb_label(*target), loc), TerminatorKind::SwitchInt { discr, targets } => { self.codegen_switch_int(discr, targets, loc) } diff --git a/kani-compiler/src/codegen_cprover_gotoc/context/goto_ctx.rs b/kani-compiler/src/codegen_cprover_gotoc/context/goto_ctx.rs index 2fde1b9e21bc..3f17f4b87233 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/context/goto_ctx.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/context/goto_ctx.rs @@ -15,7 +15,6 @@ //! Any MIR specific functionality (e.g. codegen etc) should live in specialized files that use //! this structure as input. use super::current_fn::CurrentFnCtx; -use super::loop_contracts_ctx::LoopContractsCtx; use super::vtable_ctx::VtableCtx; use crate::codegen_cprover_gotoc::overrides::{fn_hooks, GotocHooks}; use crate::codegen_cprover_gotoc::utils::full_crate_name; @@ -75,8 +74,6 @@ pub struct GotocCtx<'tcx> { pub concurrent_constructs: UnsupportedConstructs, /// The body transformation agent. pub transformer: BodyTransformation, - /// The context for loop contracts code generation. - pub loop_contracts_ctx: LoopContractsCtx, } /// Constructor @@ -90,8 +87,6 @@ impl<'tcx> GotocCtx<'tcx> { let fhks = fn_hooks(); let symbol_table = SymbolTable::new(machine_model.clone()); let emit_vtable_restrictions = queries.args().emit_vtable_restrictions; - let loop_contracts_enabled = - queries.args().unstable_features.contains(&"loop-contracts".to_string()); GotocCtx { tcx, queries, @@ -108,7 +103,6 @@ impl<'tcx> GotocCtx<'tcx> { unsupported_constructs: FxHashMap::default(), concurrent_constructs: FxHashMap::default(), transformer, - loop_contracts_ctx: LoopContractsCtx::new(loop_contracts_enabled), } } } diff --git a/kani-compiler/src/codegen_cprover_gotoc/context/loop_contracts_ctx.rs b/kani-compiler/src/codegen_cprover_gotoc/context/loop_contracts_ctx.rs deleted file mode 100644 index d7d9f8a8a39c..000000000000 --- a/kani-compiler/src/codegen_cprover_gotoc/context/loop_contracts_ctx.rs +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright Kani Contributors -// SPDX-License-Identifier: Apache-2.0 OR MIT - -use crate::codegen_cprover_gotoc::codegen::bb_label; -use cbmc::goto_program::{CIntType, Expr, Location, Stmt, StmtBody, Type}; -use stable_mir::mir::BasicBlockIdx; -use std::collections::HashSet; - -pub struct LoopContractsCtx { - /// the GOTO block compiled from the corresponding loop invariants - invariants_block: Vec, - /// Which codegen state - stage: LoopContractsStage, - /// If enable loop contracts - loop_contracts_enabled: bool, - /// Seen basic block indexes. Used to decide if a jump is backward - seen_bbidx: HashSet, - /// Current unused bbidx label - current_bbidx_label: Option, - /// The lhs of evaluation of the loop invariant - loop_invariant_lhs: Option, -} - -/// We define two states: -/// 1. loop invariants block -/// In this state, we push all codegen stmts into the invariant block. -/// We enter this state when codegen for `KaniLoopInvariantBegin`. -/// We exit this state when codegen for `KaniLoopInvariantEnd`. -/// 2. loop latch block -/// In this state, we codegen a statement expression from the -/// invariant_block annotate the statement expression to the named sub -/// of the next backward jumping we codegen. -/// We enter this state when codegen for `KaniLoopInvariantEnd`. -/// We exit this state when codegen for the first backward jumping. -#[derive(Debug, PartialEq)] -enum LoopContractsStage { - /// Codegen for user code as usual - UserCode, - /// Codegen for loop invariants - InvariantBlock, - /// Codegen for loop latch node - FindingLatchNode, -} - -/// Constructor -impl LoopContractsCtx { - pub fn new(loop_contracts_enabled: bool) -> Self { - Self { - invariants_block: Vec::new(), - stage: LoopContractsStage::UserCode, - loop_contracts_enabled, - seen_bbidx: HashSet::new(), - current_bbidx_label: None, - loop_invariant_lhs: None, - } - } -} - -/// Getters -impl LoopContractsCtx { - pub fn loop_contracts_enabled(&self) -> bool { - self.loop_contracts_enabled - } - - /// decide if a GOTO with `target` is backward jump - pub fn is_loop_latch(&self, target: &BasicBlockIdx) -> bool { - self.stage == LoopContractsStage::FindingLatchNode && self.seen_bbidx.contains(target) - } -} - -/// Setters -impl LoopContractsCtx { - /// Returns the current block as a statement expression. - /// Exit loop latch block. - pub fn extract_block(&mut self, loc: Location) -> Expr { - assert!(self.loop_invariant_lhs.is_some()); - self.stage = LoopContractsStage::UserCode; - self.invariants_block.push(self.loop_invariant_lhs.as_ref().unwrap().clone()); - - // The first statement is the GOTO in the rhs of __kani_loop_invariant_begin() - // Ignore it - self.invariants_block.remove(0); - - Expr::statement_expression( - std::mem::take(&mut self.invariants_block), - Type::CInteger(CIntType::Bool), - loc, - ) - .cast_to(Type::bool()) - .and(Expr::bool_true()) - } - - /// Push the `s` onto the block if it is in the loop invariant block - /// and return `skip`. Otherwise, do nothing and return `s`. - pub fn push_onto_block(&mut self, s: Stmt) -> Stmt { - if self.stage == LoopContractsStage::InvariantBlock { - // Attach the label to the first `Stmt` in that block and reset it. - let to_push = if self.current_bbidx_label.is_none() { - s.clone() - } else { - s.clone().with_label(self.current_bbidx_label.clone().unwrap()) - }; - self.current_bbidx_label = None; - - match s.body() { - StmtBody::Assign { lhs, rhs: _ } => { - let lhs_stmt = lhs.clone().as_stmt(*s.location()); - self.loop_invariant_lhs = Some(lhs_stmt.clone()); - self.invariants_block.push(to_push); - } - _ => { - self.invariants_block.push(to_push); - } - }; - Stmt::skip(*s.location()) - } else { - s - } - } - - pub fn enter_loop_invariant_block(&mut self) { - assert!(self.invariants_block.is_empty()); - self.stage = LoopContractsStage::InvariantBlock; - } - - pub fn exit_loop_invariant_block(&mut self) { - self.stage = LoopContractsStage::FindingLatchNode; - } - - /// Enter a new function, reset the seen_bbidx set - pub fn enter_new_function(&mut self) { - self.seen_bbidx = HashSet::new() - } - - pub fn add_new_seen_bbidx(&mut self, bbidx: BasicBlockIdx) { - self.seen_bbidx.insert(bbidx); - self.current_bbidx_label = Some(bb_label(bbidx)); - } -} diff --git a/kani-compiler/src/codegen_cprover_gotoc/context/mod.rs b/kani-compiler/src/codegen_cprover_gotoc/context/mod.rs index 0978b299e309..0053b9add18b 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/context/mod.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/context/mod.rs @@ -8,7 +8,6 @@ mod current_fn; mod goto_ctx; -mod loop_contracts_ctx; mod vtable_ctx; pub use goto_ctx::GotocCtx; diff --git a/kani-compiler/src/codegen_cprover_gotoc/overrides/hooks.rs b/kani-compiler/src/codegen_cprover_gotoc/overrides/hooks.rs index 7ffd976ee8a8..2bed9109918e 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/overrides/hooks.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/overrides/hooks.rs @@ -8,7 +8,7 @@ //! It would be too nasty if we spread around these sort of undocumented hooks in place, so //! this module addresses this issue. -use super::loop_contracts_hooks::{LoopInvariantBegin, LoopInvariantEnd}; +use super::loop_contracts_hooks::LoopInvariantRegister; use crate::codegen_cprover_gotoc::codegen::{bb_label, PropertyClass}; use crate::codegen_cprover_gotoc::GotocCtx; use crate::kani_middle::attributes::matches_diagnostic as matches_function; @@ -556,8 +556,7 @@ pub fn fn_hooks() -> GotocHooks { Rc::new(MemCmp), Rc::new(UntrackedDeref), Rc::new(InitContracts), - Rc::new(LoopInvariantBegin), - Rc::new(LoopInvariantEnd), + Rc::new(LoopInvariantRegister), ], } } diff --git a/kani-compiler/src/codegen_cprover_gotoc/overrides/loop_contracts_hooks.rs b/kani-compiler/src/codegen_cprover_gotoc/overrides/loop_contracts_hooks.rs index f52769124a49..b0395e8afd32 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/overrides/loop_contracts_hooks.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/overrides/loop_contracts_hooks.rs @@ -4,61 +4,34 @@ use super::hooks::GotocHook; use crate::codegen_cprover_gotoc::codegen::bb_label; use crate::codegen_cprover_gotoc::GotocCtx; -use crate::kani_middle::attributes::matches_diagnostic as matches_function; -use cbmc::goto_program::{Expr, Stmt}; +use crate::kani_middle::attributes::KaniAttributes; +use cbmc::goto_program::{CIntType, Expr, Stmt, Type}; use rustc_middle::ty::TyCtxt; +use rustc_span::Symbol; use stable_mir::mir::mono::Instance; use stable_mir::mir::{BasicBlockIdx, Place}; use stable_mir::ty::Span; -pub struct LoopInvariantBegin; +pub struct LoopInvariantRegister; -impl GotocHook for LoopInvariantBegin { +impl GotocHook for LoopInvariantRegister { fn hook_applies(&self, tcx: TyCtxt, instance: Instance) -> bool { - matches_function(tcx, instance.def, "KaniLoopInvariantBegin") + KaniAttributes::for_instance(tcx, instance).fn_marker() + == Some(Symbol::intern("kani_register_loop_contract")) } fn handle( &self, gcx: &mut GotocCtx, - _instance: Instance, + instance: Instance, fargs: Vec, _assign_to: &Place, target: Option, span: Span, ) -> Stmt { - assert_eq!(fargs.len(), 0); let loc = gcx.codegen_span_stable(span); - - // Start to record loop invariant statement - gcx.loop_contracts_ctx.enter_loop_invariant_block(); - - Stmt::goto(bb_label(target.unwrap()), loc) - } -} - -pub struct LoopInvariantEnd; - -impl GotocHook for LoopInvariantEnd { - fn hook_applies(&self, tcx: TyCtxt, instance: Instance) -> bool { - matches_function(tcx, instance.def, "KaniLoopInvariantEnd") - } - - fn handle( - &self, - gcx: &mut GotocCtx, - _instance: Instance, - fargs: Vec, - _assign_to: &Place, - target: Option, - span: Span, - ) -> Stmt { - assert_eq!(fargs.len(), 0); - let loc = gcx.codegen_span_stable(span); - - // Stop to record loop invariant statement - gcx.loop_contracts_ctx.exit_loop_invariant_block(); - + let func_exp = gcx.codegen_func_expr(instance, loc); Stmt::goto(bb_label(target.unwrap()), loc) + .with_loop_contracts(func_exp.call(fargs).cast_to(Type::CInteger(CIntType::Bool))) } } diff --git a/kani-compiler/src/kani_middle/transform/loop_contracts.rs b/kani-compiler/src/kani_middle/transform/loop_contracts.rs new file mode 100644 index 000000000000..dd8a1c8704a7 --- /dev/null +++ b/kani-compiler/src/kani_middle/transform/loop_contracts.rs @@ -0,0 +1,249 @@ +use crate::kani_middle::codegen_units::CodegenUnit; +use crate::kani_middle::find_fn_def; +use crate::kani_middle::transform::body::{MutableBody, SourceInstruction}; +use crate::kani_middle::transform::{TransformPass, TransformationType}; +use crate::kani_middle::KaniAttributes; +use crate::kani_queries::QueryDb; +use crate::stable_mir::CrateDef; +use rustc_middle::ty::TyCtxt; +use rustc_span::Symbol; +use stable_mir::mir::mono::Instance; +use stable_mir::mir::{BasicBlockIdx, Body, Operand, Terminator, TerminatorKind}; +use stable_mir::ty::{FnDef, RigidTy}; +use stable_mir::DefId; +use std::collections::VecDeque; +use std::collections::{HashMap, HashSet}; +use std::fmt::Debug; +use tracing::trace; + +/// This pass will perform the following operations: +/// 1. Replace the body of `kani_register_loop_contract` by `kani::internal::run_contract_fn` +/// to invoke the closure. +/// +/// 2. Replace the dummy call to the register function with the actual call, i.e., transform +/// +/// ```ignore +/// let kani_loop_invariant = || -> bool {inv}; +/// kani_register_loop_contract(kani_loop_invariant) +/// while guard { +/// loop_body; +/// kani_register_loop_contract(||->bool{true}); +/// } +/// ``` +/// +/// to +/// +/// ```ignore +/// let kani_loop_invariant = || -> bool {inv}; +/// while guard { +/// loop_body; +/// kani_register_loop_contract(kani_loop_invariant); +/// } +/// +/// ``` +/// +/// 3. Move the call to the register function to the loop latch terminator. This is required +/// as in MIR, there could be some `StorageDead` statements between register calls and +/// loop latches. +#[derive(Debug, Default)] +pub struct FunctionWithLoopContractPass { + /// Cache KaniRunContract function used to implement contracts. + run_contract_fn: Option, + /// Function and Arguments of register functions. + registered_args: HashMap)>, + /// The terminator we are moving to the loop latch. + loop_terminator: Option, +} + +impl TransformPass for FunctionWithLoopContractPass { + fn transformation_type() -> TransformationType + where + Self: Sized, + { + TransformationType::Stubbing + } + + fn is_enabled(&self, _query_db: &QueryDb) -> bool + where + Self: Sized, + { + true + } + + /// Transform the function body by replacing it with the stub body. + fn transform(&mut self, tcx: TyCtxt, body: Body, instance: Instance) -> (bool, Body) { + trace!(function=?instance.name(), "FunctionWithLoopContractPass::transform"); + match instance.ty().kind().rigid().unwrap() { + RigidTy::FnDef(_func, args) => { + if KaniAttributes::for_instance(tcx, instance).fn_marker() + == Some(Symbol::intern("kani_register_loop_contract")) + { + // Replace the body of the register function with `run_contract_fn`'s. + let run = Instance::resolve(self.run_contract_fn.unwrap(), args).unwrap(); + (true, run.body().unwrap()) + } else { + // Replace the dummy register call with the actual register call. + let mut new_body = MutableBody::from(body); + let mut contain_loop_contracts: bool = false; + + // Visit basic blocks in control flow order. + let mut visited: HashSet = HashSet::new(); + let mut queue: VecDeque = VecDeque::new(); + queue.push_back(0); + + while let Some(bbidx) = queue.pop_front() { + visited.insert(bbidx); + // We only need to transform basic block with terminators as calls + // to the register functions, no matter dummy or actual calls. + let terminator = new_body.blocks()[bbidx].terminator.clone(); + if let TerminatorKind::Call { + func: terminator_func, + args: terminator_args, + destination, + target, + unwind, + } = &terminator.kind + { + // Get the function signature of the terminator call. + let fn_kind = terminator_func.ty(&[]).unwrap().kind(); + let RigidTy::FnDef(fn_def, ..) = fn_kind.rigid().unwrap() else { + unreachable!() + }; + + if KaniAttributes::for_def_id(tcx, fn_def.def_id()).fn_marker() + == Some(Symbol::intern("kani_register_loop_contract")) + { + contain_loop_contracts = true; + + if self.registered_args.contains_key(&fn_def.def_id()) { + // This call is a dummy call as it is not the first call + // to the register function. + // Replace it with `self.loop_terminator`. + self.loop_terminator = Some(Terminator { + kind: TerminatorKind::Call { + func: self.registered_args[&fn_def.def_id()].0.clone(), + args: self.registered_args[&fn_def.def_id()].1.to_vec(), + destination: destination.clone(), + target: target.clone(), + unwind: unwind.clone(), + }, + span: terminator.span, + }); + new_body.replace_terminator( + &SourceInstruction::Terminator { bb: bbidx }, + Terminator { + kind: TerminatorKind::Goto { target: target.unwrap() }, + span: terminator.span, + }, + ); + // Then move the loop terminator to the loop latch. + self.move_loop_terminator_to_loop_latch( + bbidx, + &mut new_body, + &mut visited, + ); + } else { + // This call is an actual call as it is the first call + // to the register function. + self.registered_args.insert( + fn_def.def_id(), + (terminator_func.clone(), terminator_args.clone()), + ); + new_body.replace_terminator( + &SourceInstruction::Terminator { bb: bbidx }, + Terminator { + kind: TerminatorKind::Goto { target: target.unwrap() }, + span: terminator.span, + }, + ); + } + } + } + + // Add successors of the current basic blocks to + // the visiting queue. + for to_visit in terminator.successors() { + if visited.contains(&to_visit) { + continue; + } + queue.push_back(to_visit); + } + } + (contain_loop_contracts, new_body.into()) + } + } + _ => { + /* static variables case */ + (false, body) + } + } + } +} + +impl FunctionWithLoopContractPass { + pub fn new(tcx: TyCtxt, _unit: &CodegenUnit) -> FunctionWithLoopContractPass { + let run_contract_fn = find_fn_def(tcx, "KaniRunContract"); + assert!(run_contract_fn.is_some(), "Failed to find Kani run contract function"); + FunctionWithLoopContractPass { + run_contract_fn, + registered_args: HashMap::new(), + loop_terminator: None, + } + } + + // Replace the next loop latch---a terminator that targets some basic block in `visited`--- + // with `self.loop_terminator`. + // We assume that there is no branching terminator with more than one targets between the + // current basic block `bbidx` and the next loop latch. + fn move_loop_terminator_to_loop_latch( + &mut self, + bbidx: BasicBlockIdx, + new_body: &mut MutableBody, + visited: &mut HashSet, + ) { + let mut current_bbidx = bbidx; + while self.loop_terminator.is_some() { + if new_body.blocks()[current_bbidx].terminator.successors().len() != 1 { + // Assume that there is no branching between the register function cal + // and the loop latch. + unreachable!() + } + let target = new_body.blocks()[current_bbidx].terminator.successors()[0]; + + if visited.contains(&target) { + // Current basic block is the loop latch. + let Some(Terminator { + kind: + TerminatorKind::Call { + func: ref loop_terminator_func, + args: ref loop_terminator_args, + destination: ref loop_terminator_destination, + target: _loop_terminator_target, + unwind: ref loop_terminator_unwind, + }, + span: loop_terminator_span, + }) = self.loop_terminator + else { + unreachable!() + }; + new_body.replace_terminator( + &SourceInstruction::Terminator { bb: current_bbidx }, + Terminator { + kind: TerminatorKind::Call { + func: loop_terminator_func.clone(), + args: loop_terminator_args.clone(), + destination: loop_terminator_destination.clone(), + target: Some(target), + unwind: loop_terminator_unwind.clone(), + }, + span: loop_terminator_span, + }, + ); + self.loop_terminator = None; + } else { + visited.insert(target); + current_bbidx = target; + } + } + } +} diff --git a/kani-compiler/src/kani_middle/transform/mod.rs b/kani-compiler/src/kani_middle/transform/mod.rs index 2d963cd1d6eb..549124143fac 100644 --- a/kani-compiler/src/kani_middle/transform/mod.rs +++ b/kani-compiler/src/kani_middle/transform/mod.rs @@ -23,6 +23,7 @@ use crate::kani_middle::transform::check_uninit::{DelayedUbPass, UninitPass}; use crate::kani_middle::transform::check_values::ValidValuePass; use crate::kani_middle::transform::contracts::{AnyModifiesPass, FunctionWithContractPass}; use crate::kani_middle::transform::kani_intrinsics::IntrinsicGeneratorPass; +use crate::kani_middle::transform::loop_contracts::FunctionWithLoopContractPass; use crate::kani_middle::transform::stubs::{ExternFnStubPass, FnStubPass}; use crate::kani_queries::QueryDb; use dump_mir_pass::DumpMirPass; @@ -41,6 +42,7 @@ mod contracts; mod dump_mir_pass; mod internal_mir; mod kani_intrinsics; +mod loop_contracts; mod stubs; /// Object used to retrieve a transformed instance body. @@ -74,6 +76,7 @@ impl BodyTransformation { // body that is relevant for this harness. transformer.add_pass(queries, AnyModifiesPass::new(tcx, &unit)); transformer.add_pass(queries, ValidValuePass { check_type: check_type.clone() }); + transformer.add_pass(queries, FunctionWithLoopContractPass::new(tcx, &unit)); // Putting `UninitPass` after `ValidValuePass` makes sure that the code generated by // `UninitPass` does not get unnecessarily instrumented by valid value checks. However, it // would also make sense to check that the values are initialized before checking their diff --git a/library/kani/src/lib.rs b/library/kani/src/lib.rs index 20589fe3a969..c3e9ae5497cc 100644 --- a/library/kani/src/lib.rs +++ b/library/kani/src/lib.rs @@ -87,7 +87,3 @@ macro_rules! implies { pub(crate) use kani_macros::unstable_feature as unstable; pub mod contracts; - -mod loop_contracts; - -pub use loop_contracts::{kani_loop_invariant_begin_marker, kani_loop_invariant_end_marker}; diff --git a/library/kani/src/loop_contracts.rs b/library/kani/src/loop_contracts.rs deleted file mode 100644 index 90949c71ff3c..000000000000 --- a/library/kani/src/loop_contracts.rs +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright Kani Contributors -// SPDX-License-Identifier: Apache-2.0 OR MIT - -/// This function is only used for loop contract annotation. -/// It behaves as a placeholder to telling us where the loop invariants stmts begin. -#[inline(never)] -#[rustc_diagnostic_item = "KaniLoopInvariantBegin"] -#[doc(hidden)] -#[crate::unstable( - feature = "loop-contracts", - issue = 3168, - reason = "experimental loop contracts support" -)] -pub const fn kani_loop_invariant_begin_marker() {} - -/// This function is only used for loop contract annotation. -/// It behaves as a placeholder to telling us where the loop invariants stmts end. -#[inline(never)] -#[rustc_diagnostic_item = "KaniLoopInvariantEnd"] -#[doc(hidden)] -#[crate::unstable( - feature = "loop-contracts", - issue = 3168, - reason = "experimental loop contracts support" -)] -pub const fn kani_loop_invariant_end_marker() {} diff --git a/library/kani_macros/src/lib.rs b/library/kani_macros/src/lib.rs index 0979275dcde2..f23819099cea 100644 --- a/library/kani_macros/src/lib.rs +++ b/library/kani_macros/src/lib.rs @@ -8,6 +8,7 @@ // So we have to enable this on the commandline (see kani-rustc) with: // RUSTFLAGS="-Zcrate-attr=feature(register_tool) -Zcrate-attr=register_tool(kanitool)" #![feature(proc_macro_diagnostic)] +#![feature(proc_macro_span)] mod derive; // proc_macro::quote is nightly-only, so we'll cobble things together instead diff --git a/library/kani_macros/src/sysroot/loop_contracts/mod.rs b/library/kani_macros/src/sysroot/loop_contracts/mod.rs index 742ef4f541d0..32cb52dfd3ea 100644 --- a/library/kani_macros/src/sysroot/loop_contracts/mod.rs +++ b/library/kani_macros/src/sysroot/loop_contracts/mod.rs @@ -5,11 +5,23 @@ //! use proc_macro::TokenStream; -use proc_macro2::TokenStream as TokenStream2; use proc_macro_error::abort_call_site; -use quote::quote; +use quote::{format_ident, quote}; +use syn::spanned::Spanned; use syn::{Expr, Stmt}; +fn generate_unique_id_from_span(stmt: &Stmt) -> String { + // Extract the span of the expression + let span = stmt.span().unwrap(); + + // Get the start and end line and column numbers + let start = span.start(); + let end = span.end(); + + // Create a tuple of location information (file path, start line, start column, end line, end column) + format!("_{:?}_{:?}_{:?}_{:?}", start.line(), start.column(), end.line(), end.column()) +} + /// Expand loop contracts macros. /// /// A while loop of the form @@ -20,43 +32,50 @@ use syn::{Expr, Stmt}; /// ``` /// will be annotated as /// ``` rust -/// while guard{ +/// #[inline(never)] +/// #[kanitool::fn_marker = "kani_register_loop_contract"] +/// const fn kani_register_loop_contract_id T>(f: F) -> T { +/// unreachable!() +/// } +/// let __kani_loop_invariant_id = || -> bool {inv}; +/// // The register function call with the actual invariant. +/// kani_register_loop_contract_id(__kani_loop_invariant_id); +/// while guard { /// body -/// kani::kani_loop_invariant_begin_marker(); -/// let __kani_loop_invariant: bool = inv; -/// kani::kani_loop_invariant_end_marker(); +/// // Call to the register function with a dummy argument +/// // for the sake of bypassing borrow checks. +/// kani_register_loop_contract_id(||->bool{true}); /// } /// ``` pub fn loop_invariant(attr: TokenStream, item: TokenStream) -> TokenStream { + // parse the stmt of the loop let mut loop_stmt: Stmt = syn::parse(item.clone()).unwrap(); - // Annotate a place holder function call at the end of the loop. + // name of the loop invariant as closure of the form + // __kani_loop_invariant_#startline_#startcol_#endline_#endcol + let mut inv_name: String = "__kani_loop_invariant".to_owned(); + let loop_id = generate_unique_id_from_span(&loop_stmt); + inv_name.push_str(&loop_id); + let inv_ident = format_ident!("{}", inv_name); + + // expr of the loop invariant + let inv_expr: Expr = syn::parse(attr).unwrap(); + + // ident of the register function + let mut register_name: String = "kani_register_loop_contract".to_owned(); + register_name.push_str(&loop_id); + let register_ident = format_ident!("{}", register_name); + match loop_stmt { Stmt::Expr(ref mut e, _) => match e { Expr::While(ref mut ew) => { - let mut to_parse = quote!( - let __kani_loop_invariant: bool = ); - to_parse.extend(TokenStream2::from(attr.clone())); - to_parse.extend(quote!(;)); - let inv_assign_stmt: Stmt = syn::parse(to_parse.into()).unwrap(); - - // kani::kani_loop_invariant_begin_marker(); - let inv_begin_stmt: Stmt = syn::parse( - quote!( - kani::kani_loop_invariant_begin_marker();) - .into(), - ) - .unwrap(); - - // kani::kani_loop_invariant_end_marker(); + // kani_register_loop_contract(#inv_ident); let inv_end_stmt: Stmt = syn::parse( quote!( - kani::kani_loop_invariant_end_marker();) + #register_ident(||->bool{true});) .into(), ) .unwrap(); - ew.body.stmts.push(inv_begin_stmt); - ew.body.stmts.push(inv_assign_stmt); ew.body.stmts.push(inv_end_stmt); } _ => (), @@ -66,6 +85,17 @@ pub fn loop_invariant(attr: TokenStream, item: TokenStream) -> TokenStream { ), } quote!( - {#loop_stmt}) + { + // Dummy function used to force the compiler to capture the environment. + // We cannot call closures inside constant functions. + // This function gets replaced by `kani::internal::call_closure`. + #[inline(never)] + #[kanitool::fn_marker = "kani_register_loop_contract"] + const fn #register_ident T>(f: F) -> T { + unreachable!() + } + let mut #inv_ident = || -> bool {#inv_expr}; + #register_ident(#inv_ident); + #loop_stmt}) .into() }