Skip to content

Commit

Permalink
Add support to simd_bitmask (#2677)
Browse files Browse the repository at this point in the history
Create a sequential model of `simd_bitmask` and create a transformation pass that replace calls to the `simd_bitmask` intrinsic by the new model.

We will only replace those usages that we can safely get information about the `simd` representation. All other cases will be reported as unsupported feature. Type checking is not currently done for any `simd` operation and should be fixed (#2692).

Co-authored-by: Zyad Hassan <88045115+zhassan-aws@users.noreply.github.com>
  • Loading branch information
celinval and zhassan-aws committed Aug 17, 2023
1 parent d262d1a commit db73dad
Show file tree
Hide file tree
Showing 12 changed files with 492 additions and 6 deletions.
2 changes: 1 addition & 1 deletion kani-compiler/src/codegen_cprover_gotoc/codegen/place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,9 @@ impl<'tcx> GotocCtx<'tcx> {
proj: ProjectionElem<Local, Ty<'tcx>>,
) -> Result<ProjectedPlace<'tcx>, UnimplementedData> {
let before = before?;
trace!(?before, ?proj, "codegen_projection");
match proj {
ProjectionElem::Deref => {
trace!(?before, ?proj, "codegen_projection");
let base_type = before.mir_typ();
let inner_goto_expr = if base_type.is_box() {
self.deref_box(before.goto_expr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ impl CodegenBackend for GotocCodegenBackend {
}

fn provide_extern(&self, providers: &mut ExternProviders) {
provide::provide_extern(providers);
provide::provide_extern(providers, &self.queries.lock().unwrap());
}

fn print_version(&self) {
Expand Down
108 changes: 108 additions & 0 deletions kani-compiler/src/kani_middle/intrinsics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright Kani Contributors
// SPDX-License-Identifier: Apache-2.0 OR MIT
//! This module contains a MIR pass that replaces some intrinsics by rust intrinsics models as
//! well as validation logic that can only be added during monomorphization.
use rustc_index::IndexVec;
use rustc_middle::mir::{interpret::ConstValue, Body, ConstantKind, Operand, TerminatorKind};
use rustc_middle::mir::{Local, LocalDecl};
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_middle::ty::{Const, GenericArgsRef};
use rustc_span::symbol::{sym, Symbol};
use tracing::{debug, trace};

pub struct ModelIntrinsics<'tcx> {
tcx: TyCtxt<'tcx>,
/// Local declarations of the function being transformed.
local_decls: IndexVec<Local, LocalDecl<'tcx>>,
}

impl<'tcx> ModelIntrinsics<'tcx> {
/// Function that replace calls to some intrinsics that have a high level model in our library.
///
/// For now, we only look at intrinsic calls, which are modelled by a terminator.
///
/// However, this pass runs after lowering intrinsics, which may replace the terminator by
/// an intrinsic statement (non-diverging intrinsic).
pub fn run_pass(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
ModelIntrinsics { tcx, local_decls: body.local_decls.clone() }.transform(body)
}

pub fn transform(&self, body: &mut Body<'tcx>) {
for block in body.basic_blocks.as_mut() {
let terminator = block.terminator_mut();
if let TerminatorKind::Call { func, args, .. } = &mut terminator.kind {
let func_ty = func.ty(&self.local_decls, self.tcx);
if let Some((intrinsic_name, generics)) = resolve_rust_intrinsic(self.tcx, func_ty)
{
trace!(?func, ?intrinsic_name, "run_pass");
if intrinsic_name == sym::simd_bitmask {
self.replace_simd_bitmask(func, args, generics)
}
}
}
}
}

/// Change the function call to use the stubbed version.
/// We only replace calls if we can ensure the input has simd representation.
fn replace_simd_bitmask(
&self,
func: &mut Operand<'tcx>,
args: &[Operand<'tcx>],
gen_args: GenericArgsRef<'tcx>,
) {
assert_eq!(args.len(), 1);
let tcx = self.tcx;
let arg_ty = args[0].ty(&self.local_decls, tcx);
if arg_ty.is_simd() {
// Get the stub definition.
let stub_id = tcx.get_diagnostic_item(Symbol::intern("KaniModelSimdBitmask")).unwrap();
debug!(?func, ?stub_id, "replace_simd_bitmask");

// Get SIMD information from the type.
let (len, elem_ty) = simd_len_and_type(tcx, arg_ty);
debug!(?len, ?elem_ty, "replace_simd_bitmask Ok");

// Increment the list of generic arguments since our stub also takes element type and len.
let mut new_gen_args = Vec::from_iter(gen_args.iter());
new_gen_args.push(elem_ty.into());
new_gen_args.push(len.into());

let Operand::Constant(fn_def) = func else { unreachable!() };
fn_def.literal = ConstantKind::from_value(
ConstValue::ZeroSized,
tcx.type_of(stub_id).instantiate(tcx, &new_gen_args),
);
} else {
debug!(?arg_ty, "replace_simd_bitmask failed");
}
}
}

fn simd_len_and_type<'tcx>(tcx: TyCtxt<'tcx>, simd_ty: Ty<'tcx>) -> (Const<'tcx>, Ty<'tcx>) {
match simd_ty.kind() {
ty::Adt(def, args) => {
assert!(def.repr().simd(), "`simd_size_and_type` called on non-SIMD type");
let variant = def.non_enum_variant();
let f0_ty = variant.fields[0u32.into()].ty(tcx, args);

match f0_ty.kind() {
ty::Array(elem_ty, len) => (*len, *elem_ty),
_ => (Const::from_target_usize(tcx, variant.fields.len() as u64), f0_ty),
}
}
_ => unreachable!("unexpected layout for simd type {simd_ty}"),
}
}

fn resolve_rust_intrinsic<'tcx>(
tcx: TyCtxt<'tcx>,
func_ty: Ty<'tcx>,
) -> Option<(Symbol, GenericArgsRef<'tcx>)> {
if let ty::FnDef(def_id, args) = *func_ty.kind() {
if tcx.is_intrinsic(def_id) {
return Some((tcx.item_name(def_id), args));
}
}
None
}
1 change: 1 addition & 0 deletions kani-compiler/src/kani_middle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use self::attributes::KaniAttributes;
pub mod analysis;
pub mod attributes;
pub mod coercion;
mod intrinsics;
pub mod metadata;
pub mod provide;
pub mod reachability;
Expand Down
16 changes: 13 additions & 3 deletions kani-compiler/src/kani_middle/provide.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//! to run during code generation. For example, this can be used to hook up
//! custom MIR transformations.

use crate::kani_middle::intrinsics::ModelIntrinsics;
use crate::kani_middle::reachability::{collect_reachable_items, filter_crate_items};
use crate::kani_middle::stubbing;
use crate::kani_queries::{QueryDb, ReachabilityType};
Expand All @@ -18,7 +19,7 @@ use rustc_middle::{
/// Sets up rustc's query mechanism to apply Kani's custom queries to code from
/// the present crate.
pub fn provide(providers: &mut Providers, queries: &QueryDb) {
if queries.reachability_analysis != ReachabilityType::None && !queries.build_std {
if should_override(queries) {
// Don't override queries if we are only compiling our dependencies.
providers.optimized_mir = run_mir_passes;
if queries.stubbing_enabled {
Expand All @@ -30,8 +31,15 @@ pub fn provide(providers: &mut Providers, queries: &QueryDb) {

/// Sets up rustc's query mechanism to apply Kani's custom queries to code from
/// external crates.
pub fn provide_extern(providers: &mut ExternProviders) {
providers.optimized_mir = run_mir_passes_extern;
pub fn provide_extern(providers: &mut ExternProviders, queries: &QueryDb) {
if should_override(queries) {
// Don't override queries if we are only compiling our dependencies.
providers.optimized_mir = run_mir_passes_extern;
}
}

fn should_override(queries: &QueryDb) -> bool {
queries.reachability_analysis != ReachabilityType::None && !queries.build_std
}

/// Returns the optimized code for the external function associated with `def_id` by
Expand Down Expand Up @@ -61,6 +69,8 @@ fn run_kani_mir_passes<'tcx>(
tracing::debug!(?def_id, "Run Kani transformation passes");
let mut transformed_body = stubbing::transform(tcx, def_id, body);
stubbing::transform_foreign_functions(tcx, &mut transformed_body);
// This should be applied after stubbing so user stubs take precedence.
ModelIntrinsics::run_pass(tcx, &mut transformed_body);
tcx.arena.alloc(transformed_body)
}

Expand Down
8 changes: 8 additions & 0 deletions library/kani/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@
#![feature(register_tool)]
#![register_tool(kanitool)]
// Used for rustc_diagnostic_item.
// Note: We could use a kanitool attribute instead.
#![feature(rustc_attrs)]
// This is required for the optimized version of `any_array()`
#![feature(generic_const_exprs)]
#![allow(incomplete_features)]
// Used to model simd.
#![feature(repr_simd)]
// Features used for tests only.
#![cfg_attr(test, feature(platform_intrinsics, portable_simd))]
// Required for rustc_diagnostic_item
#![allow(internal_features)]

pub mod arbitrary;
Expand All @@ -19,6 +25,8 @@ pub mod slice;
pub mod tuple;
pub mod vec;

mod models;

pub use arbitrary::Arbitrary;
#[cfg(feature = "concrete_playback")]
pub use concrete_playback::concrete_playback_run;
Expand Down
Loading

0 comments on commit db73dad

Please sign in to comment.