Skip to content

Commit

Permalink
fix(ph_layout): better branch prob predict
Browse files Browse the repository at this point in the history
  • Loading branch information
SynodicMonth committed Aug 17, 2024
1 parent b8b9093 commit de00828
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 11 deletions.
24 changes: 18 additions & 6 deletions src/backend/riscv64/peephole.rs
Original file line number Diff line number Diff line change
Expand Up @@ -994,9 +994,6 @@ pub fn remove_redundant_labels(mctx: &mut MContext<RvInst>) -> bool {
inst.unlink(mctx);
prev.push_back(mctx, inst);
}
cursor.next(mctx);
block.remove(mctx);
local_changed = true;
}
}
// rule 2: if the block is empty, remove it, then retarget all jumps to it to the
Expand All @@ -1009,9 +1006,17 @@ pub fn remove_redundant_labels(mctx: &mut MContext<RvInst>) -> bool {
match inst.kind(mctx) {
RvInstKind::J { .. } => {
inst.redirect_branch(mctx, block.next(mctx).unwrap());
label_usage
.entry(block.next(mctx).unwrap().label(mctx).clone())
.or_insert_with(FxHashSet::default)
.insert(inst);
}
RvInstKind::Br { .. } => {
inst.redirect_branch(mctx, block.next(mctx).unwrap());
label_usage
.entry(block.next(mctx).unwrap().label(mctx).clone())
.or_insert_with(FxHashSet::default)
.insert(inst);
}
RvInstKind::Li { .. }
| RvInstKind::AluRR { .. }
Expand All @@ -1028,7 +1033,14 @@ pub fn remove_redundant_labels(mctx: &mut MContext<RvInst>) -> bool {
| RvInstKind::LoadAddr { .. } => unreachable!(),
}
}
cursor.next(mctx);
}
}

// remove empty blocks
let mut curr_block = func.head(mctx);
while let Some(block) = curr_block {
curr_block = block.next(mctx);
if block.size(mctx) == 0 {
block.remove(mctx);
local_changed = true;
}
Expand Down Expand Up @@ -1090,8 +1102,8 @@ pub fn run_peephole_after_regalloc(mctx: &mut MContext<RvInst>, config: &LowerCo

// NOTE: remove redundant jump need to be run after tail duplication
changed |= remove_redundant_jump(mctx);
// changed |= remove_redundant_labels(mctx);
// changed |= remove_redundant_jump(mctx);
changed |= remove_redundant_labels(mctx);
changed |= remove_redundant_jump(mctx);

changed
}
57 changes: 52 additions & 5 deletions src/ir/passes/control_flow/ph_layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ use crate::{
Block,
Context,
Func,
IBinaryOp,
ICmpCond,
InstKind,
},
utils::cfg::CfgNode,
};
Expand Down Expand Up @@ -99,6 +102,48 @@ pub const PH_BLOCK_LAYOUT: &str = "ph-block-layout";

pub struct PHBlockLayout;

impl PHBlockLayout {
pub fn estimate_branch_prob(ctx: &Context, block: Block) -> (f64, f64) {
let succs = block.succs(ctx);
if succs.len() != 2 {
panic!("block has more than 2 successors");
}

// rule 1: P(left_br) = 0.99 if br in loop
if let Some(block_name) = block.name(ctx) {
println!("block name: {:?}", block_name);
if block_name.contains("while") {
return (0.99, 0.01);
} else if block_name.contains("if") {
let br_inst = block.tail(ctx).unwrap();
if !br_inst.is_br(ctx) {
panic!("block tail is not a br");
}

let cond = br_inst.operand(ctx, 0);

if let Some(cond_def) = cond.def_inst(ctx) {
// rule 2: P(left_br) = 0.9 if br is from a icmp.ne
if let InstKind::IBinary(IBinaryOp::Cmp(ICmpCond::Ne)) = cond_def.kind(ctx) {
return (0.9, 0.1);
}
}
}
}

// rule 3: P(jump_to_ret) = 0.01
if let Some(return_block) = block.container(ctx).unwrap().tail(ctx) {
if succs[0] == return_block {
return (0.01, 0.99);
} else if succs[1] == return_block {
return (0.99, 0.01);
}
}

// rule 4: P(left_br) = 0.5
(0.5, 0.5)
}
}
impl LocalPassMut for PHBlockLayout {
type Output = ();

Expand Down Expand Up @@ -128,8 +173,9 @@ impl LocalPassMut for PHBlockLayout {
let succ1_index = index_map[&succs[0]];
let succ2_index = index_map[&succs[1]];

mat[(index, succ1_index)] -= 0.99;
mat[(index, succ2_index)] -= 0.01;
let (p1, p2) = PHBlockLayout::estimate_branch_prob(ctx, block);
mat[(index, succ1_index)] -= p1;
mat[(index, succ2_index)] -= p2;
} else if succs.len() == 1 {
let succ_index = index_map[&block.succs(ctx)[0]];
mat[(index, succ_index)] -= 1.0;
Expand Down Expand Up @@ -157,7 +203,7 @@ impl LocalPassMut for PHBlockLayout {
let mut b = DVector::<f64>::zeros(n_blocks + 1);
b[n_blocks] = 1.0;

let stationary = decomp.solve(&b, 1e-6);
let stationary = decomp.solve(&b, 1e-5);

if stationary.is_err() {
println!("[ ph_layout ] stationary distribution not found");
Expand Down Expand Up @@ -186,12 +232,13 @@ impl LocalPassMut for PHBlockLayout {
let succ = block.succs(ctx);
let index = index_map[&block];
if succ.len() == 2 {
let (p1, p2) = PHBlockLayout::estimate_branch_prob(ctx, block);
let succ1 = succ[0];
if succ1 != block {
let edge = Edge {
from: block,
to: succ1,
weight: stationary_norm[index] * 0.99,
weight: stationary_norm[index] * p1,
};
edges.push(edge);
}
Expand All @@ -200,7 +247,7 @@ impl LocalPassMut for PHBlockLayout {
let edge = Edge {
from: block,
to: succ2,
weight: stationary_norm[index] * 0.01,
weight: stationary_norm[index] * p2,
};
edges.push(edge);
}
Expand Down

0 comments on commit de00828

Please sign in to comment.