diff --git a/crates/interpreter/src/frame.rs b/crates/interpreter/src/frame.rs index e2899b5a..16bb825e 100644 --- a/crates/interpreter/src/frame.rs +++ b/crates/interpreter/src/frame.rs @@ -1,31 +1,28 @@ -use cranelift_entity::SecondaryMap; +use cranelift_entity::{packed_option::PackedOption, SecondaryMap}; use sonatina_ir::{module::ModuleCtx, DataFlowGraph, Type, Value, I256}; use crate::{types, EvalValue, ProgramCounter}; +#[derive(Default)] pub struct Frame { - pub ret_addr: ProgramCounter, + pub ret_addr: PackedOption, local_values: SecondaryMap, // 256-bit register alloca_region: Vec, // big endian } impl Frame { - pub fn new( - ret_addr: ProgramCounter, - args: impl Iterator, - arg_literals: impl Iterator, - ) -> Self { - let mut local_values = SecondaryMap::new(); - for (v, literal_value) in args.zip(arg_literals) { - local_values[v] = EvalValue::from_i256(literal_value) - } - let alloca_region = Vec::new(); + pub fn new() -> Self { + Self::default() + } + + pub fn set_ret_addr(&mut self, ret_addr: ProgramCounter) { + self.ret_addr = ret_addr.into(); + } - Self { - ret_addr, - local_values, - alloca_region, + pub fn load_args(&mut self, args: &[Value], arg_literals: impl Iterator) { + for (v, literal_value) in args.iter().zip(arg_literals) { + self.local_values[*v] = EvalValue::from_i256(literal_value) } } diff --git a/crates/interpreter/src/pc.rs b/crates/interpreter/src/pc.rs index b4d0206e..57332d14 100644 --- a/crates/interpreter/src/pc.rs +++ b/crates/interpreter/src/pc.rs @@ -1,11 +1,24 @@ +use cranelift_entity::packed_option::ReservedValue; use sonatina_ir::{module::FuncRef, Block, Insn, Layout}; -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct ProgramCounter { pub func_ref: FuncRef, pub insn: Insn, } +impl ReservedValue for ProgramCounter { + fn reserved_value() -> Self { + let func_ref = FuncRef::reserved_value(); + let insn = Insn::reserved_value(); + ProgramCounter { func_ref, insn } + } + + fn is_reserved_value(&self) -> bool { + self.func_ref == FuncRef::reserved_value() && self.insn == Insn::reserved_value() + } +} + impl ProgramCounter { pub fn new(entry_func: FuncRef, layout: &Layout) -> Self { let entry = layout.entry_block().unwrap(); diff --git a/crates/interpreter/src/state.rs b/crates/interpreter/src/state.rs index 77190188..1a86402d 100644 --- a/crates/interpreter/src/state.rs +++ b/crates/interpreter/src/state.rs @@ -1,12 +1,9 @@ -use std::{ - iter, - ops::{Add, BitAnd, BitOr, BitXor, Mul, Neg, Not, Sub}, -}; +use std::ops::{Add, BitAnd, BitOr, BitXor, Mul, Neg, Not, Sub}; use sonatina_ir::{ insn::{BinaryOp, CastOp, UnaryOp}, module::FuncRef, - Block, DataLocationKind, Immediate, InsnData, Module, + Block, DataLocationKind, Immediate, InsnData, Module, Value, }; use crate::{types, EvalResult, Frame, ProgramCounter}; @@ -19,11 +16,15 @@ pub struct State { } impl State { - pub fn new(module: Module, entry_func: FuncRef) -> Self { + pub fn new(module: Module, entry_func: FuncRef, args: &[Value]) -> Self { let func = &module.funcs[entry_func]; let pc = ProgramCounter::new(entry_func, &func.layout); - debug_assert!(func.arg_values.is_empty()); - let entry_frame = Frame::new(pc, iter::empty(), iter::empty()); + + let mut entry_frame = Frame::new(); + debug_assert!(func.arg_values.len() == args.len()); + for arg in args { + entry_frame.load(&module.ctx, *arg, &func.dfg); + } let frames = vec![entry_frame]; Self { @@ -153,9 +154,10 @@ impl State { let ret_addr = self.pc; let callee = &self.module.funcs[*func]; + let mut new_frame = Frame::new(); debug_assert!(callee.arg_values.len() == args.len()); - let new_frame = - Frame::new(ret_addr, callee.arg_values.iter().copied(), arg_literals); + new_frame.load_args(&callee.arg_values, arg_literals); + new_frame.set_ret_addr(ret_addr); self.frames.push(new_frame); self.pc.call(*func, &callee.layout); @@ -211,7 +213,7 @@ impl State { Some(caller_frame) => { // Function epilogue - self.pc.resume_frame_at(frame.ret_addr); + self.pc.resume_frame_at(frame.ret_addr.unwrap()); let caller = &self.module.funcs[self.pc.func_ref]; if let Some(arg) = *args { @@ -275,7 +277,7 @@ mod test { let module = parser.parse(input).unwrap().module; let func_ref = module.iter_functions().next().unwrap(); - State::new(module, func_ref) + State::new(module, func_ref, &[]) } #[test] @@ -394,7 +396,7 @@ mod test { let module = parser.parse(input).unwrap().module; let func_ref = module.iter_functions().nth(1).unwrap(); - let state = State::new(module, func_ref); + let state = State::new(module, func_ref, &[]); let data = state.run();