From 5182b5b323693fb85c8bcd0d01af41ce06aa08d9 Mon Sep 17 00:00:00 2001 From: glyh Date: Wed, 23 Oct 2024 02:35:12 +0800 Subject: [PATCH] fix closure freezing --- src/riscv/before_alloc.mbt | 49 +++++++++++--------------- src/riscv/codegen.mbt | 1 - src/riscv/collect_labels.mbt | 1 - src/riscv/interference_graph_build.mbt | 1 - src/riscv/reg_spill.mbt | 2 +- src/ssacfg/ssa_ir.mbt | 2 -- 6 files changed, 22 insertions(+), 34 deletions(-) diff --git a/src/riscv/before_alloc.mbt b/src/riscv/before_alloc.mbt index d8f263d..b32f416 100644 --- a/src/riscv/before_alloc.mbt +++ b/src/riscv/before_alloc.mbt @@ -31,51 +31,44 @@ fn reserve_fregs(cfg : @ssacfg.SsaCfg, block : @ssacfg.Block) -> Unit { block.insts = insts } -// NOTE: load constant closure into scope whenever it's needed -fn ensure_closure_load( +// NOTE: replace constant closure with labels +fn freeze_closure( all_constant_closures : @hashset.T[Var], block : @ssacfg.Block ) -> Unit { let insts_backup = block.insts let insts : Array[@ssacfg.Inst] = [] - let defined_constant_closure : @hashset.T[Var] = @hashset.new() - fn check_constant_closure_val(v : Value) { + fn fix_val(v : Value) -> Value { match v { Var(v) => - if all_constant_closures.contains(v) && - not(defined_constant_closure.contains(v)) { - insts.push(LoadAddr(v, v.to_string())) - defined_constant_closure.insert(v) + if all_constant_closures.contains(v) { + Label(v) + } else { + Var(v) } - _ => () + _ => v } } for inst in insts_backup { - match inst { - MakeTuple(bind, vals) => { - vals.each(check_constant_closure_val) - defined_constant_closure.insert(bind) - } - KthTuple(bind, tup, _) => { - check_constant_closure_val(tup) - defined_constant_closure.insert(bind) - } - Prim(bind, _, args) => { - args.each(check_constant_closure_val) - defined_constant_closure.insert(bind) - } + let inst : @ssacfg.Inst = match inst { + MakeTuple(bind, vals) => MakeTuple(bind, vals.map(fix_val)) + KthTuple(bind, tup, k) => KthTuple(bind, fix_val(tup), k) + Prim(bind, op, args) => Prim(bind, op, args.map(fix_val)) Store(_) | Load(_) => @util.die("unreachable: load/store occurs before allocation") - Copy(bind, copied) => { - check_constant_closure_val(copied) - defined_constant_closure.insert(bind) - } - LoadAddr(bind, _) => defined_constant_closure.insert(bind) + Copy(bind, copied) => Copy(bind, fix_val(copied)) } insts.push(inst) } block.insts = insts + block.last_inst.val = match block.last_inst.val { + Branch(cond, _then, _else) => Branch(fix_val(cond), _then, _else) + Call(f, args) => Call(fix_val(f), args.map(fix_val)) + MakeArray(len, elem, kont) => + MakeArray(fix_val(len), fix_val(elem), fix_val(kont)) + Exit => Exit + } } fn before_alloc(cfg : @ssacfg.SsaCfg) -> @ssacfg.SsaCfg { @@ -87,7 +80,7 @@ fn before_alloc(cfg : @ssacfg.SsaCfg) -> @ssacfg.SsaCfg { for item in cfg.blocks { let (_, block) = item reserve_fregs(cfg, block) - ensure_closure_load(all_constant_closures, block) + freeze_closure(all_constant_closures, block) } cfg } diff --git a/src/riscv/codegen.mbt b/src/riscv/codegen.mbt index a344fbf..befdda3 100644 --- a/src/riscv/codegen.mbt +++ b/src/riscv/codegen.mbt @@ -598,7 +598,6 @@ fn CodegenBlock::codegen(self : CodegenBlock) -> Unit { for inst in block.insts { self.insert_asm(Comment(inst.to_string())) match inst { - LoadAddr(var, addr) => self.assign_i(var, fn(reg) { [La(reg, addr)] }) Copy(bind, copied) => match get_reg_ty(copied) { F64 => { diff --git a/src/riscv/collect_labels.mbt b/src/riscv/collect_labels.mbt index 0d29f55..057b055 100644 --- a/src/riscv/collect_labels.mbt +++ b/src/riscv/collect_labels.mbt @@ -38,7 +38,6 @@ fn collect_externals(cfg : @ssacfg.SsaCfg) -> ExternalLabels { collect_label_val(val) } Store(var) | Load(var) => collect_label_var(var) - LoadAddr(_) => () // load addr can't be external } } match blk.last_inst.val { diff --git a/src/riscv/interference_graph_build.mbt b/src/riscv/interference_graph_build.mbt index 371239d..8138172 100644 --- a/src/riscv/interference_graph_build.mbt +++ b/src/riscv/interference_graph_build.mbt @@ -114,7 +114,6 @@ fn LiveVarAnalysis::collect_inst( args.each(fn(v) { self.collect_val(v) }) } Load(bind) => self.var_set.remove(bind) - LoadAddr(bind, _) => self.var_set.remove(bind) Store(bind) => self.var_set.insert(bind) } } diff --git a/src/riscv/reg_spill.mbt b/src/riscv/reg_spill.mbt index aa7baba..c59572e 100644 --- a/src/riscv/reg_spill.mbt +++ b/src/riscv/reg_spill.mbt @@ -72,7 +72,7 @@ fn reg_spill_block(blk : @ssacfg.Block, spilled_var : Var) -> @ssacfg.Block { insts_new.push(Store(spilled_var)) } } - LoadAddr(_) | Load(_) | Store(_) => insts_new.push(inst) + Load(_) | Store(_) => insts_new.push(inst) } } let load_before_exit = match blk.last_inst.val { diff --git a/src/ssacfg/ssa_ir.mbt b/src/ssacfg/ssa_ir.mbt index 218339a..862c8c2 100644 --- a/src/ssacfg/ssa_ir.mbt +++ b/src/ssacfg/ssa_ir.mbt @@ -20,8 +20,6 @@ pub enum Inst { Load(Var) // So we can deal with the case where 1 tmp reg is not enough Copy(Var, Value) - // Support for constant closure - LoadAddr(Var, String) } derive(Show) pub enum PCInst {