From dba6e934e1363f9c7e82d7f1f9bc66de743f0813 Mon Sep 17 00:00:00 2001 From: SynodicMonth Date: Wed, 14 Aug 2024 23:13:58 +0800 Subject: [PATCH] feat(static_bp): add `eq_is_uncommon` heuristic --- src/bin/compiler.rs | 7 ++ src/ir/inst.rs | 8 ++ src/ir/passes/mod.rs | 1 + src/ir/passes/static_branch_prediction.rs | 94 +++++++++++++++++++++++ 4 files changed, 110 insertions(+) create mode 100644 src/ir/passes/static_branch_prediction.rs diff --git a/src/bin/compiler.rs b/src/bin/compiler.rs index b469f2f..05b7745 100644 --- a/src/bin/compiler.rs +++ b/src/bin/compiler.rs @@ -44,6 +44,7 @@ use orzcc::{ }, mem2reg::{Mem2reg, MEM2REG}, simple_dce::{SimpleDce, SIMPLE_DCE}, + static_branch_prediction::{StaticBranchPrediction, STATIC_BRANCH_PREDICTION}, tco::{Tco, TCO}, }, passman::{PassManager, Pipeline, TransformPass}, @@ -212,6 +213,11 @@ fn main() -> Result<(), Box> { passman.run_transform(ADCE, &mut ir, 1); passman.run_transform(CFG_SIMPLIFY, &mut ir, 32); } + + // reorder + passman.run_transform(STATIC_BRANCH_PREDICTION, &mut ir, 1); + passman.run_transform(BRANCH_CONDITION_SINK, &mut ir, 1); + passman.run_transform(BLOCK_REORDER, &mut ir, 1); } else { passman.run_transform(LEGALIZE, &mut ir, 1); } @@ -290,6 +296,7 @@ fn register_passes(passman: &mut PassManager) { GlobalValueNumbering::register(passman); Gcm::register(passman); BranchConditionSink::register(passman); + StaticBranchPrediction::register(passman); Legalize::register(passman); BlockReorder::register(passman); diff --git a/src/ir/inst.rs b/src/ir/inst.rs index ac86941..f0eced5 100644 --- a/src/ir/inst.rs +++ b/src/ir/inst.rs @@ -1543,6 +1543,14 @@ impl Inst { None } } + + pub fn inverse_br(self, ctx: &mut Context) { + if !self.is_br(ctx) { + panic!("instruction is not a branch instruction"); + } + + self.deref_mut(ctx).successors.swap(0, 1); + } } impl LinkedListNodePtr for Inst { diff --git a/src/ir/passes/mod.rs b/src/ir/passes/mod.rs index a8e458c..aba27f9 100644 --- a/src/ir/passes/mod.rs +++ b/src/ir/passes/mod.rs @@ -16,4 +16,5 @@ pub mod loops; pub mod mem2reg; pub mod side_effect; pub mod simple_dce; +pub mod static_branch_prediction; pub mod tco; diff --git a/src/ir/passes/static_branch_prediction.rs b/src/ir/passes/static_branch_prediction.rs new file mode 100644 index 0000000..ae13465 --- /dev/null +++ b/src/ir/passes/static_branch_prediction.rs @@ -0,0 +1,94 @@ +//! Use heuristics to predict the branch direction of conditional branches. +//! +//! - [`equal_is_uncommon`] +//! ```orzir +//! %0 = icmp.eq %a, %b: i1 +//! ... +//! br %0, %then, %else +//! ``` +//! --> +//! +//! ```orzir +//! %0 = icmp.ne %a, %b: i1 +//! br %0, %else, %then +//! ``` + +use std::vec; + +use super::control_flow::CfgCanonicalize; +use crate::{ + collections::linked_list::{LinkedListContainerPtr, LinkedListNodePtr}, + ir::{ + passman::{GlobalPassMut, LocalPassMut, PassManager, PassResult, TransformPass}, + Context, + Func, + IBinaryOp, + ICmpCond, + Inst, + InstKind, + }, + utils::def_use::User, +}; + +pub const STATIC_BRANCH_PREDICTION: &str = "static-branch-prediction"; + +pub struct StaticBranchPrediction; + +impl StaticBranchPrediction { + pub fn equal_is_uncommon(ctx: &mut Context, func: Func) { + let mut cursor = func.cursor(); + + while let Some(block) = cursor.next(ctx) { + let tail = block.tail(ctx).unwrap(); + + if tail.is_br(ctx) { + let cond = tail.operand(ctx, 0); + + if let Some(cond_def) = cond.def_inst(ctx) { + if let InstKind::IBinary(IBinaryOp::Cmp(ICmpCond::Eq)) = cond_def.kind(ctx) { + let lhs = cond_def.operand(ctx, 0); + let rhs = cond_def.operand(ctx, 1); + + let new_cond = Inst::ibinary(ctx, IBinaryOp::Cmp(ICmpCond::Eq), lhs, rhs); + cond_def.insert_after(ctx, new_cond); + + tail.inverse_br(ctx); + tail.replace(ctx, cond, new_cond.result(ctx, 0)); + } + } + } + } + } +} + +impl LocalPassMut for StaticBranchPrediction { + type Output = (); + + fn run(&mut self, ctx: &mut Context, func: Func) -> PassResult<(Self::Output, bool)> { + Self::equal_is_uncommon(ctx, func); + Ok(((), false)) + } +} + +impl GlobalPassMut for StaticBranchPrediction { + type Output = (); + + fn run(&mut self, ctx: &mut Context) -> PassResult<(Self::Output, bool)> { + let mut changed = false; + for func in ctx.funcs() { + let (_, local_changed) = LocalPassMut::run(self, ctx, func)?; + changed |= local_changed; + } + Ok(((), changed)) + } +} + +impl TransformPass for StaticBranchPrediction { + fn register(passman: &mut PassManager) { + passman.register_transform( + STATIC_BRANCH_PREDICTION, + StaticBranchPrediction, + vec![Box::new(CfgCanonicalize)], + ) + } +}