diff --git a/notes.md b/notes.md index 4e43d7e..0cd7603 100644 --- a/notes.md +++ b/notes.md @@ -1,3 +1,14 @@ ## Call Convention 1. We use `ra` to store the continuation address, as it's otherwise unused in our language. We do need to push it onto stack to preserve it's value when doing a native call, though. 2. We store `closure` pointer after any arguments, so we should be able to work with native functions just fine. This differs from what is being done in the book "Compiling with Continuations". + +## TODO +1. fix call convention for external functions. For example: +``` +:print_int([?26, kont_main.4.22]) +``` +Inside print_int, we should do something like this: +``` +fn_ptr.99 = kont_main.4.22.0 +fn_ptr.99([result, kont_main.4.22]) +``` diff --git a/src/bin/externals.mbt b/src/bin/externals.mbt index dc4ae65..b3c8f0f 100644 --- a/src/bin/externals.mbt +++ b/src/bin/externals.mbt @@ -1,6 +1,6 @@ -fn add_interpreter_fns(interpreter : @knf_eval.KnfInterpreter) -> Unit { +fn add_clops_interp_fns(interpreter : @closureps_eval.CloPSInterpreter) -> Unit { interpreter.add_extern_fn( - "minimbt_print_int", + "print_int", fn(args) { match args[0] { Int(i) => @io.print(i) @@ -8,178 +8,93 @@ fn add_interpreter_fns(interpreter : @knf_eval.KnfInterpreter) -> Unit { } Unit }, + 1, ) interpreter.add_extern_fn( - "minimbt_print_endline", + "print_endline", fn(_args) { @io.print("\n") Unit }, + 0, ) - let create_array = fn(args : Array[@knf_eval.Value]) { + let create_array = fn(args : Array[@closureps_eval.Value]) { match args[0] { - Int(n) => @knf_eval.Value::Array(Array::make(n, args[1])) + Int(n) => @closureps_eval.Value::Array(Array::make(n, args[1])) _ => @util.die("create_array expects Int") } } - interpreter.add_extern_fn("minimbt_create_array", create_array) - interpreter.add_extern_fn("minimbt_create_float_array", create_array) - interpreter.add_extern_fn("minimbt_create_ptr_array", create_array) + interpreter.add_extern_fn("create_array", create_array, 2) + interpreter.add_extern_fn("create_float_array", create_array, 2) + interpreter.add_extern_fn("create_ptr_array", create_array, 2) interpreter.add_extern_fn( - "minimbt_truncate", + "truncate", fn(args) { match args[0] { Double(d) => Int(d.to_int()) - _ => @util.die("truncate expects Double") + _ => @util.die("expects Double") } }, + 1, ) interpreter.add_extern_fn( - "minimbt_sin", + "sin", fn(args) { match args[0] { Double(d) => Double(@math.sin(d)) _ => @util.die("sin expects Double") } }, + 1, ) interpreter.add_extern_fn( - "minimbt_cos", + "cos", fn(args) { match args[0] { Double(d) => Double(@math.cos(d)) _ => @util.die("cos expects Double") } }, + 1, ) interpreter.add_extern_fn( - "minimbt_sqrt", + "sqrt", fn(args) { match args[0] { Double(d) => Double(d.sqrt()) _ => @util.die("sqrt expects Double") } }, + 1, ) interpreter.add_extern_fn( - "minimbt_abs_float", + "abs_float", fn(args) { match args[0] { Double(d) => Double(@double.abs(d)) _ => @util.die("abs_float expects Double") } }, + 1, ) interpreter.add_extern_fn( - "minimbt_int_of_float", + "int_of_float", fn(args) { match args[0] { Double(d) => Int(d.to_int()) _ => @util.die("int_of_float expects Double") } }, + 1, ) interpreter.add_extern_fn( - "minimbt_float_of_int", - fn(args) { - match args[0] { - Int(i) => Double(i.to_double()) - _ => @util.die("float_of_int expects Int") - } - }, - ) -} - -fn add_closure_interpreter_fns( - interpreter : @closure_eval.ClosureInterpreter -) -> Unit { - interpreter.add_extern_fn( - "minimbt_print_int", - fn(args) { - match args[0] { - Int(i) => @io.print(i) - _ => @util.die("print_int expects Int") - } - Unit - }, - ) - interpreter.add_extern_fn( - "minimbt_print_endline", - fn(_args) { - @io.print("\n") - Unit - }, - ) - let create_array = fn(args : Array[@closure_eval.Value]) { - match args[0] { - Int(n) => @closure_eval.Value::Array(Array::make(n, args[1])) - _ => @util.die("create_array expects Int") - } - } - interpreter.add_extern_fn("minimbt_create_array", create_array) - interpreter.add_extern_fn("minimbt_create_float_array", create_array) - interpreter.add_extern_fn("minimbt_create_ptr_array", create_array) - interpreter.add_extern_fn( - "minimbt_truncate", - fn(args) { - match args[0] { - Double(d) => Int(d.to_int()) - _ => @util.die("truncate expects Double") - } - }, - ) - interpreter.add_extern_fn( - "minimbt_sin", - fn(args) { - match args[0] { - Double(d) => Double(@math.sin(d)) - _ => @util.die("sin expects Double") - } - }, - ) - interpreter.add_extern_fn( - "minimbt_cos", - fn(args) { - match args[0] { - Double(d) => Double(@math.cos(d)) - _ => @util.die("cos expects Double") - } - }, - ) - interpreter.add_extern_fn( - "minimbt_sqrt", - fn(args) { - match args[0] { - Double(d) => Double(d.sqrt()) - _ => @util.die("sqrt expects Double") - } - }, - ) - interpreter.add_extern_fn( - "minimbt_abs_float", - fn(args) { - match args[0] { - Double(d) => Double(@double.abs(d)) - _ => @util.die("abs_float expects Double") - } - }, - ) - interpreter.add_extern_fn( - "minimbt_int_of_float", - fn(args) { - match args[0] { - Double(d) => Int(d.to_int()) - _ => @util.die("int_of_float expects Double") - } - }, - ) - interpreter.add_extern_fn( - "minimbt_float_of_int", + "float_of_int", fn(args) { match args[0] { Int(i) => Double(i.to_double()) _ => @util.die("float_of_int expects Int") } }, + 1, ) } diff --git a/src/bin/main.mbt b/src/bin/main.mbt index 5d83a7e..dd38a24 100644 --- a/src/bin/main.mbt +++ b/src/bin/main.mbt @@ -102,7 +102,7 @@ fn CompileStatus::step(self : CompileStatus) -> Bool { } Cps => { let cpsenv = @cps.CpsEnv::new(self.counter) - let mut cps = cpsenv.precps2cps(self.precps.unwrap(), @cps.Cps::Just) + let mut cps = cpsenv.precps2cps(self.precps.unwrap(), fn { _ => Exit }) cps = @cps.alias_analysis(cps) self.cps = Some(cps) self.counter = cpsenv.counter.val @@ -152,15 +152,12 @@ fn CompileStatus::output(self : CompileStatus, json : Bool) -> String { fn main { let argv = @env.get_args() let mut file = None - let knf_opt_iters = Ref::new(10) - let knf_opt_inline_threshold = Ref::new(10) // Testing directives let json = Ref::new(false) let start_stage = Ref::new(Stages::Parse) let end_stage = Ref::new(Stages::Finished) - let knf_interpreter = Ref::new(false) - let closure_interpreter = Ref::new(false) + let closureps_interpreter = Ref::new(false) let out_file = Ref::new("-") let print = Ref::new([]) @@ -194,16 +191,10 @@ fn main { "End stage", ), ( - "--knf-interpreter", + "--clops-interp", "", - @ArgParser.Set(knf_interpreter), - "Run with KNF interpreter", - ), - ( - "--closure-interpreter", - "", - @ArgParser.Set(closure_interpreter), - "Run with closure interpreter", + @ArgParser.Set(closureps_interpreter), + "Run with closure passing style interpreter", ), ( "--out-file", @@ -220,34 +211,6 @@ fn main { @ArgParser.String(fn(s) { print.val = s.split(",").collect() }), "", ), - ( - "--knf-opt-iters", - "N", - @ArgParser.String( - fn(s) { - let i = @strconv.parse_int?(s) - match i { - Ok(i) => knf_opt_iters.val = i - Err(_) => @util.die("Invalid number") - } - }, - ), - "Number of optimization iterations", - ), - ( - "--knf-opt-inline-threshold", - "N", - @ArgParser.String( - fn(s) { - let i = @strconv.parse_int?(s) - match i { - Ok(i) => knf_opt_inline_threshold.val = i - Err(_) => @util.die("Invalid number") - } - }, - ), - "Inline threshold for KNF optimization", - ), ], fn(s) { if file.is_empty().not() { @@ -260,12 +223,6 @@ fn main { ) // Configure pipeline - //if knf_interpreter.val { - // end_stage.val = Stages::Knf - //} - //if closure_interpreter.val { - // end_stage.val = Stages::Closure - //} let stages_to_print = print.val.map( fn(s) { match Stages::from_string(s) { @@ -308,32 +265,23 @@ fn main { } // Output - //if knf_interpreter.val { - // let knfi = @knf_eval.KnfInterpreter::new() - // add_interpreter_fns(knfi) - // match knfi.eval_full?(status.knf.unwrap()) { - // Ok(_) => () - // Err(Failure(e)) => { - // println(e) - // @util.die("KNF interpreter error") - // } - // } - //} else if closure_interpreter.val { - // let clsi = @closure_eval.ClosureInterpreter::new() - // add_closure_interpreter_fns(clsi) - // match clsi.eval_full?(status.closure_ir.unwrap()) { - // Ok(_) => () - // Err(Failure(e)) => { - // println(e) - // @util.die("Closure interpreter error") - // } - // } - //} else { - let out_string = status.output(json.val) - if out_file.val == "-" { - println(out_string) + if closureps_interpreter.val { + let clops = status.clops.unwrap() + let interpreter = @closureps_eval.CloPSInterpreter::new(clops) + add_clops_interp_fns(interpreter) + try { + interpreter.eval!(clops.root) + } catch { + VariableNotFind(v) => println("Undefined variable: \{v}") + } else { + v => ignore(v) + } } else { - @fs.write_to_string(out_file.val, out_string) + let out_string = status.output(json.val) + if out_file.val == "-" { + println(out_string) + } else { + @fs.write_to_string(out_file.val, out_string) + } } - //} } diff --git a/src/bin/moon.pkg.json b/src/bin/moon.pkg.json index 936d6d8..8335cdb 100644 --- a/src/bin/moon.pkg.json +++ b/src/bin/moon.pkg.json @@ -16,6 +16,7 @@ "moonbitlang/minimbt/knf", "moonbitlang/minimbt/typing", "moonbitlang/minimbt/knf_eval", + "moonbitlang/minimbt/closureps_eval", "lijunchen/unstable_io/io", "moonbitlang/minimbt/closure", "moonbitlang/minimbt/riscv", diff --git a/src/closureps/cps2closureps.mbt b/src/closureps/cps2closureps.mbt index b617b06..ffa600a 100644 --- a/src/closureps/cps2closureps.mbt +++ b/src/closureps/cps2closureps.mbt @@ -9,11 +9,14 @@ fn CloEnv::rebind_var(self : CloEnv, v : Var) -> Var { fn CloEnv::rebind_value(self : CloEnv, v : Value) -> Value { match v { Var(v) => Var(self.rebind_var(v)) + Label(l) => Label(self.rebind_var(l)) v => v } } // collect all closures to top level and fix call convention +// NOTE: +// whenever we store a reference to an external call, we need to wrap it as a closure fn CloEnv::collect_closure(self : CloEnv, s : S) -> S { fn rec(c : S) { self.collect_closure(c) @@ -38,87 +41,105 @@ fn CloEnv::collect_closure(self : CloEnv, s : S) -> S { Prim(op, args, bind, rest) => Prim(op, args, bind, self.add_rebind(bind, bind).collect_closure(rest)) Fix(f, args, body, rest) => { - // Step 1. Calculate free variables of body - let fvs = body.free_variables() - fvs.remove(f) - args.each(fn { a => fvs.remove(a) }) - let free_vars = fvs.iter().collect() + // Step 1. recurse on body to collect free vars + // HACK: this is the closure passed into the function + // we set it's type to unit as we can't know it's type before collecting + // free vars, but to collect free var we have to have it bound + let closure_ref = self.new_named("ref_\{f.to_string()}", T::Unit) + let body_env = args + .fold(init=self, fn(acc, ele) { acc.add_rebind(ele, ele) }) + .add_rebind(f, closure_ref) + let mut body = body_env.collect_closure(body) - // Step 2. Calculate the free variable tuple we need to pass inside the - // closure - let fv_data_ty = match free_vars { - [] => T::Unit - _ => T::Tuple(free_vars.map(fn { v => v.ty })) - } + // Step 2. Calculate free variables + let free_vars = body + .free_variables() + .iter() + .filter(fn(v) { v != closure_ref && not(args.contains(v)) }) + .collect() + let has_free_vars = not(free_vars.is_empty()) - // this is the closure passed into the function - let closure_ref = self.new_named( - "closure_ref_\{f.to_string()}", - T::Tuple([f.ty, fv_data_ty]), - ) + // Step 3. Rewrap the body on demand + if has_free_vars { + let fn_ptr = self.new_named("fn_ptr", f.ty) + let fv_data_ty = T::Tuple(free_vars.map(fn { v => v.ty })) + let freevars = self.new_named("fvs", fv_data_ty) + body = free_vars.foldi( + init=body, + fn(idx, acc, ele) { KthTuple(idx, Var(freevars), ele, acc) }, + ) + body = KthTuple( + 0, + Var(closure_ref), + fn_ptr, + KthTuple(1, Var(closure_ref), freevars, body), + ) + } - // fix the type of f to accept an additional closure arg at the end - guard let T::Fun(args_ty, _) = f.ty else { + // Step 4. Patching f's args and types + guard let T::Fun(_args_ty, ret_ty) = f.ty else { _ => @util.die("calling non function") // NOTE: the following alters f's type } - args_ty.push(closure_ref.ty) // WARN: after this operation our ds is now self-recursive - let body_fixed = match free_vars { - [] => rec(body) - _ => { - let fn_ptr = self.new_named("fn_ptr", f.ty) - let freevars = self.new_named("freevars", fv_data_ty) - let body_to_wrap = self - .add_rebind(f, closure_ref) - .collect_closure(body) - let body_with_freevars_bound = free_vars.foldi( - init=body_to_wrap, - fn(idx, acc, ele) { KthTuple(idx, Var(freevars), ele, acc) }, - ) - KthTuple( - 0, - Var(closure_ref), - fn_ptr, - KthTuple(1, Var(closure_ref), freevars, body_with_freevars_bound), - ) - } + let args_ty = _args_ty.copy() + args_ty.push(closure_ref.ty) // WARN: types are self-recursive now + let f = f.retype(Fun(args_ty, ret_ty)) + let args = args.copy() + args.push(closure_ref) + self.fnblocks[f] = { args, free_vars, body } + + // Step 4. recurse on rest + + if has_free_vars { + let fv_data_ty = T::Tuple(free_vars.map(fn { v => v.ty })) + let freevars_captured = self.new_named("fvs_cap", fv_data_ty) + let closure_gen = self.new_named( + "clo_\{f.to_string()}", + T::Tuple([f.ty, fv_data_ty]), + ) + let rest_fixed = self.add_rebind(f, closure_gen).collect_closure(rest) + Tuple( + free_vars.map(Value::Var), + freevars_captured, + Tuple([Label(f), Var(freevars_captured)], closure_gen, rest_fixed), + ) + } else { + let closure_gen = self.new_named( + "clo_\{f.to_string()}", + T::Tuple([f.ty, Unit]), + ) + let rest_fixed = self.add_rebind(f, closure_gen).collect_closure(rest) + Tuple([Label(f), Unit], closure_gen, rest_fixed) } - self.fnblocks[f] = { args, free_vars, body: body_fixed } - let freevars_captured = self.new_named("freevars_captured", fv_data_ty) - let closure_gen = self.new_named( - "closure_\{f.to_string()}", - T::Tuple([f.ty, fv_data_ty]), - ) - let rest_fixed = self.add_rebind(f, closure_gen).collect_closure(rest) - Tuple( - free_vars.map(Value::Var), - freevars_captured, - Tuple([Label(f), Var(freevars_captured)], closure_gen, rest_fixed), - ) } App(f, args) => match f { Var(f_var) => + match self.bindings[f_var] { + Some(closure) => { + args.push(Var(closure)) + // Calling a user level function + let tmp = self.new_named("fn_ptr", f_var.ty) + KthTuple(0, Var(closure), tmp, App(Var(tmp), args.map(recrbva))) + } + None => @util.die("undefined function \{f} called") + } + Label(f_var) => // NOTE: always generate a call as if we're calling a closure. // Since there's no way for us to decide whether we're calling a // closure or not. match self.bindings[f_var] { Some(maybe_closure) => { args.push(Var(maybe_closure)) - // we know the called function statically, so we're allowed to mark - // it as a label - App(Label(f_var), args) - } - None => { - args.push(f) - App(f, args) + // Calling a user level function + App(Label(f_var), args.map(recrbva)) } + None => + // Calling an external function + App(Label(f_var), args.map(recrbva)) } - // NOTE: must be a native call - // there's no guarantee we always use this case for all native calls - Label(_) => App(f, args.map(recrbva)) _ => @util.die("Can't invoke call on \{f}") } - Just(v) => Just(recrbva(v)) + Exit => Exit } } diff --git a/src/closureps/show.mbt b/src/closureps/show.mbt index 5ab7ee8..14e3a48 100644 --- a/src/closureps/show.mbt +++ b/src/closureps/show.mbt @@ -6,7 +6,7 @@ pub fn ClosurePS::to_string(self : ClosurePS) -> String { let mut output = "" for item in self.fnblocks.iter() { let (name, def) = item - output += "[\{name}], args: \{def.args}, freevars: \{def.free_vars}\n" + output += "\{name}, args: \{def.args}, freevars: \{def.free_vars}\n" output += "\{def.body}\n\n" } output += "[root]\n\{self.root}\n" diff --git a/src/closureps_eval/interpreter.mbt b/src/closureps_eval/interpreter.mbt new file mode 100644 index 0000000..42c3237 --- /dev/null +++ b/src/closureps_eval/interpreter.mbt @@ -0,0 +1,247 @@ +pub enum Value { + Unit + Int(Int) + Double(Double) + Tuple(Array[Value]) + Label(Var) + Array(Array[Value]) + ExternFn(String) +} derive(Show) + +struct CloPSInterpreter { + clops : @closureps.ClosurePS + extern_fns : Map[String, (Int, (Array[Value]) -> Value)] + mut cur_env : @hashmap.T[Var, Value] +} + +pub typealias Var = @cps.Var + +pub typealias S = @cps.Cps + +pub fn Value::op_equal(self : Value, other : Value) -> Bool { + match (self, other) { + (Unit, Unit) => true + (Int(x), Int(y)) => x == y + (Double(x), Double(y)) => x == y + (Tuple(xs), Tuple(ys)) => xs == ys + (Array(xs), Array(ys)) => xs == ys + (ExternFn(x), ExternFn(y)) => x == y + _ => false + } +} + +pub type! EvalError { + VariableNotFind(Var) +} derive(Show) + +pub fn CloPSInterpreter::new(clops : @closureps.ClosurePS) -> CloPSInterpreter { + { extern_fns: Map::new(), clops, cur_env: @hashmap.new() } +} + +fn CloPSInterpreter::replace_env( + self : CloPSInterpreter, + new_env : @hashmap.T[Var, Value] +) -> CloPSInterpreter { + { ..self, cur_env: new_env } +} + +pub fn CloPSInterpreter::add_extern_fn( + self : CloPSInterpreter, + name : String, + f : (Array[Value]) -> Value, + argc : Int +) -> Unit { + self.extern_fns.set(name, (argc, f)) +} + +fn CloPSInterpreter::find(self : CloPSInterpreter, v : Var) -> Value!EvalError { + match self.cur_env[v] { + Some(val) => val + None => + match v.name.val { + None => raise VariableNotFind(v) + Some(name) => + if self.extern_fns.contains(name) && v.id == -1 { + ExternFn(name) + } else { + raise VariableNotFind(v) + } + } + } +} + +fn CloPSInterpreter::eval_v( + self : CloPSInterpreter, + value : @cps.Value +) -> Value!EvalError { + match value { + Var(v) => self.find!(v) + Label(v) => Label(v) + Unit => Unit + Int(i) => Int(i) + Double(f) => Double(f) + } +} + +pub fn CloPSInterpreter::eval( + self : CloPSInterpreter, + expr : S +) -> Value!EvalError { + loop expr { + Exit => break Unit + Tuple(vs, bind, rest) => { + let to_binds = [] + for v in vs { + to_binds.push(self.eval_v!(v)) + } + self.cur_env[bind] = Tuple(to_binds) + continue rest + } + KthTuple(idx, val, bind, rest) => { + match self.eval_v!(val) { + Tuple(tup) => self.cur_env[bind] = tup[idx] + _ => @util.die("extrating members from non tuple") + } + continue rest + } + Fix(f, _, _, _) => + @util.die("Unexpected non-top level function definiton \{f}") + Switch(val, branches) => + match self.eval_v!(val) { + Int(idx) => continue branches[idx] + _ => @util.die("branching on non ints") + } + Prim(Not, [v], bind, rest) => { + match self.eval_v!(v) { + Int(1) => self.cur_env[bind] = Int(0) + Int(0) => self.cur_env[bind] = Int(1) + v => @util.die("unexpected input \{v} for `not`") + } + continue rest + } + Prim(MakeArray, [len, elem], bind, rest) => { + match (self.eval_v!(len), self.eval_v!(elem)) { + (Int(len), elem) => self.cur_env[bind] = Array(Array::make(len, elem)) + (l, elem) => @util.die("unexpected input \{l}, \{elem} for `makearray`") + } + continue rest + } + Prim(Neg(Double), [f], bind, rest) => { + match self.eval_v!(f) { + Double(f) => self.cur_env[bind] = Double(-f) + v => @util.die("unexpected input \{v} for `neg_double`") + } + continue rest + } + Prim(Neg(Int), [i], bind, rest) => { + match self.eval_v!(i) { + Int(i) => self.cur_env[bind] = Int(-i) + v => @util.die("unexpected input \{v} for `neg_int`") + } + continue rest + } + Prim(Get, [arr, idx], bind, rest) => { + match (self.eval_v!(arr), self.eval_v!(idx)) { + (Array(arr), Int(idx)) => self.cur_env[bind] = arr[idx] + (arr, idx) => @util.die("unexpected input \{arr}, \{idx} for `get`") + } + continue rest + } + Prim(Put, [arr, idx, rhs], bind, rest) => { + match (self.eval_v!(arr), self.eval_v!(idx), self.eval_v!(rhs)) { + (Array(arr), Int(idx), rhs) => { + arr[idx] = rhs + self.cur_env[bind] = Unit + } + (arr, idx, rhs) => + @util.die("unexpected input \{arr}, \{idx} and \{rhs} for `put`") + } + continue rest + } + Prim(Math(op, Int), [lhs, rhs], bind, rest) => { + match (self.eval_v!(lhs), self.eval_v!(rhs)) { + (Int(a), Int(b)) => { + let result = match op { + Add => Int(a + b) + Sub => Int(a - b) + Mul => Int(a * b) + Div => Int(a / b) + } + self.cur_env[bind] = result + } + (lhs, rhs) => + @util.die("unexpected input \{lhs}, \{rhs} for `\{op}_int`") + } + continue rest + } + Prim(Math(op, Double), [lhs, rhs], bind, rest) => { + match (self.eval_v!(lhs), self.eval_v!(rhs)) { + (Double(a), Double(b)) => { + let result = match op { + Add => Double(a + b) + Sub => Double(a - b) + Mul => Double(a * b) + Div => Double(a / b) + } + self.cur_env[bind] = result + } + (lhs, rhs) => + @util.die("unexpected input \{lhs}, \{rhs} for `\{op}_double`") + } + continue rest + } + Prim(Eq, [lhs, rhs], bind, rest) => { + let lhs = self.eval_v!(lhs) + let rhs = self.eval_v!(rhs) + self.cur_env[bind] = Int(if lhs == rhs { 1 } else { 0 }) + continue rest + } + Prim(Le, [lhs, rhs], bind, rest) => { + match (self.eval_v!(lhs), self.eval_v!(rhs)) { + (Double(a), Double(b)) => + self.cur_env[bind] = Int(if a <= b { 1 } else { 0 }) + (Int(a), Int(b)) => self.cur_env[bind] = Int(if a <= b { 1 } else { 0 }) + (lhs, rhs) => @util.die("unexpected input \{lhs}, \{rhs} for `le`") + } + continue rest + } + Prim(_) => @util.die("malformed prim call \{expr}") + App(f, args) => { + let mut f_val = self.eval_v!(f) + let mut args_evaled = [] + for arg in args { + args_evaled.push(self.eval_v!(arg)) + } + while true { + guard let Label(address) = f_val else { + v => @util.die("jumping to non function \{v}") + } + match self.clops.fnblocks[address] { + None => { + let name = address.name.val.unwrap() + let (args_needed, extern_fn) = self.extern_fns[name].unwrap() + // NOTE: There's 2 cases: + // 1. args_needed + 1 = args_passed, then last arg is continuation + // 2. args_needed + 2 = args_passed, then 2nd-to-last arg is continuation + let return_val = extern_fn(args_evaled) + let cont = args_evaled[args_needed] // rememeber index starts from 0 + args_evaled = [return_val, cont] + guard let Tuple([fn_ptr, _]) = cont else { + v => @util.die("\{v} is not a closure") + } + f_val = fn_ptr + } + _ => break + } + } + guard let Label(address) = f_val else { _ => @util.die("unreachable") } + let f = self.clops.fnblocks[address].unwrap() + let new_env : @hashmap.T[Var, Value] = @hashmap.new() + //let args_to_pass = f.args.copy() + //args_to_pass.push(f.closure_ref) + zip2(f.args, args_evaled).each(fn { (k, v) => new_env[k] = v }) + self.cur_env = new_env + continue f.body + } + } +} diff --git a/src/closureps_eval/moon.pkg.json b/src/closureps_eval/moon.pkg.json new file mode 100644 index 0000000..2464606 --- /dev/null +++ b/src/closureps_eval/moon.pkg.json @@ -0,0 +1,17 @@ +{ + "import": [ + "moonbitlang/minimbt/closureps", + "moonbitlang/minimbt/cps", + "moonbitlang/minimbt/precps", + { + "path": "moonbitlang/minimbt", + "alias": "types" + }, + "moonbitlang/minimbt/util" + ], + "test-import": [ + "moonbitlang/minimbt/parser", + "moonbitlang/minimbt/lex", + "moonbitlang/minimbt/typing" + ] +} diff --git a/src/closureps_eval/utils.mbt b/src/closureps_eval/utils.mbt new file mode 100644 index 0000000..49ca0ce --- /dev/null +++ b/src/closureps_eval/utils.mbt @@ -0,0 +1,11 @@ +fn zip2[A : Show, B : Show](arr1 : Array[A], arr2 : Array[B]) -> Array[(A, B)] { + let out : Array[(A, B)] = [] + loop (arr1[:], arr2[:]) { + ([], []) => break out + ([a, .. as arr1], [b, .. as arr2]) => { + out.push((a, b)) + continue (arr1, arr2) + } + _ => @util.die("zipping arrays of different size") + } +} diff --git a/src/cps/cps_ir.mbt b/src/cps/cps_ir.mbt index e184b5a..684f10e 100644 --- a/src/cps/cps_ir.mbt +++ b/src/cps/cps_ir.mbt @@ -16,6 +16,10 @@ fn Var::from_precps(v : @precps.Var, t : T) -> Var { { id: v.id, name: { val: v.name }, ty: t } } +pub fn Var::retype(self : Var, ty : T) -> Var { + { ..self, ty, } +} + pub enum Value { Var(Var) Label(Var) @@ -40,7 +44,7 @@ pub enum Cps { Prim(PrimOp, Array[Value], Var, Cps) // T marks the return type App(Value, Array[Value]) - Just(Value) + Exit } fn Cps::replace_var_bind(self : Cps, from : Var, to : Value) -> Cps { @@ -76,7 +80,7 @@ fn Cps::replace_var_bind(self : Cps, from : Var, to : Value) -> Cps { Prim(op, args.map(recv), bind, rest_new) } App(f, args) => App(recv(f), args.map(recv)) - Just(v) => Just(recv(v)) + Exit => Exit } } @@ -122,6 +126,6 @@ pub fn Cps::free_variables(self : Cps) -> @hashset.T[Var] { init=f.free_variables(), fn(acc, ele) { acc.union(ele.free_variables()) }, ) - Just(v) => v.free_variables() + Exit => @hashset.new() } } diff --git a/src/cps/cps_ir_string.mbt b/src/cps/cps_ir_string.mbt index 7123c17..1244e5b 100644 --- a/src/cps/cps_ir_string.mbt +++ b/src/cps/cps_ir_string.mbt @@ -23,7 +23,7 @@ pub fn Value::to_string(self : Value) -> String { } } -impl Show for Cps with output(self, logger) { +pub impl Show for Cps with output(self, logger) { logger.write_string(self.to_string()) } @@ -51,13 +51,24 @@ fn to_str(cps : Cps, ~ident : String = "") -> String { Switch(v, branches) => ident + "switch(\{v}){\n" + - branches.map(fn { c => to_str(c, ident=ident + " ") }).join(";\n") + + branches + .mapi( + fn(i, c) { + ident + + " \{i} => {\n" + + to_str(c, ident=ident + " ") + + "\n" + + ident + + " }" + }, + ) + .join("\n") + "\n" + ident + "}" Prim(op, args, bind, rest) => ident + "prim \{bind} = \{op}(\{args})\n" + rec(rest) App(f, args) => ident + "\{f}(\{args})" - Just(v) => ident + "return \{v}" + Exit => ident + "exit" } } diff --git a/src/precps/ast2precps.mbt b/src/precps/ast2precps.mbt index b740c12..16306ed 100644 --- a/src/precps/ast2precps.mbt +++ b/src/precps/ast2precps.mbt @@ -62,7 +62,8 @@ pub fn TyEnv::ast2precps(self : TyEnv, s : S) -> PreCps { Let(ty, bind, rec(rhs), env_new.ast2precps(rest)) } LetRec(f, rest) => { - let (fvar, env_rest) = self.add(f.name.0, f.name.1) + // NOTE: we don't actually need any extra variables for simply refer to the function pointer once we've done closure conversion + let (fvar, env_rest) = self.add_label(f.name.0, f.name.1) let mut env_body = env_rest let args = [] f.args.each( diff --git a/src/precps/precps_ir.mbt b/src/precps/precps_ir.mbt index cb61d06..fcbd4e6 100644 --- a/src/precps/precps_ir.mbt +++ b/src/precps/precps_ir.mbt @@ -58,12 +58,12 @@ pub fn PreCps::get_type(self : PreCps) -> T { } } -enum Numeric { +pub enum Numeric { Double Int } derive(Show) -enum PrimOp { +pub enum PrimOp { Not MakeArray Neg(Numeric) diff --git a/src/precps/tyenv.mbt b/src/precps/tyenv.mbt index e37b888..3d13105 100644 --- a/src/precps/tyenv.mbt +++ b/src/precps/tyenv.mbt @@ -34,6 +34,13 @@ pub fn TyEnv::add(self : TyEnv, name : String, ty : T) -> (Var, TyEnv) { (to_bind, { ..self, bindings, }) } +pub fn TyEnv::add_label(self : TyEnv, name : String, ty : T) -> (Var, TyEnv) { + self.counter.val = self.counter.val + 1 + let to_bind = { id: self.counter.val, name: Some(name) } + let bindings = self.bindings.add(find_bind_key(to_bind), Label(ty, to_bind)) + (to_bind, { ..self, bindings, }) +} + pub fn TyEnv::add_many(self : TyEnv, args : Iter[(String, T)]) -> TyEnv { args.fold( init=self, diff --git a/test/test_src/fib_small.mbt b/test/test_src/fib_small.mbt new file mode 100644 index 0000000..5e947c6 --- /dev/null +++ b/test/test_src/fib_small.mbt @@ -0,0 +1,11 @@ +fn fib(n : Int) -> Int { + if n <= 1 { + n + } else { + fib(n - 1) + fib(n - 2) + } +}; + +fn main { + print_int(fib(5)) +}; diff --git a/test/test_src/loop.mbt b/test/test_src/loop.mbt new file mode 100644 index 0000000..eba9b4e --- /dev/null +++ b/test/test_src/loop.mbt @@ -0,0 +1,11 @@ +fn loop(x: Int) -> Int { + if x <= 1 { + x + } else { + loop(x-1) + loop(x-3) + } +}; + +fn main { + print_int(loop(10)) +};