Skip to content

Commit

Permalink
Feat: Closure compilation (the VM is still missing the closure support)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yag000 committed Jul 31, 2023
1 parent 78293c4 commit 17eec25
Show file tree
Hide file tree
Showing 5 changed files with 453 additions and 64 deletions.
11 changes: 10 additions & 1 deletion crates/compiler/src/code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ pub enum Opcode {
SetLocal,
GetLocal,

GetFree,

// Custom types
Array,
HashMap,
Expand Down Expand Up @@ -147,6 +149,7 @@ impl Display for Opcode {
Opcode::GetGlobal => "OpGetGlobal",
Opcode::SetLocal => "OpSetLocal",
Opcode::GetLocal => "OpGetLocal",
Opcode::GetFree => "OpGetFree",
Opcode::Array => "OpArray",
Opcode::HashMap => "OpHashMap",
Opcode::Index => "OpIndex",
Expand All @@ -171,8 +174,14 @@ impl Opcode {
| Opcode::GetGlobal
| Opcode::Array
| Opcode::HashMap => vec![2],
Opcode::Call | Opcode::SetLocal | Opcode::GetLocal | Opcode::GetBuiltin => vec![1],

Opcode::Call
| Opcode::SetLocal
| Opcode::GetLocal
| Opcode::GetBuiltin
| Opcode::GetFree => vec![1],
Opcode::Closure => vec![2, 1],

_ => vec![],
}
}
Expand Down
113 changes: 69 additions & 44 deletions crates/compiler/src/compiler.rs
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cell::RefCell;
use std::rc::Rc;

use crate::code::{Instructions, Opcode};
Expand All @@ -6,8 +7,8 @@ use lexer::token::Token;
use num_traits::FromPrimitive;
use object::builtins::BuiltinFunction;
use object::object::{CompiledFunction, Object};
use parser::ast::Program;
use parser::ast::{BlockStatement, Conditional, Expression, InfixOperator, Primitive, Statement};
use parser::ast::{FunctionLiteral, Program};

#[allow(dead_code)]
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -111,6 +112,9 @@ impl Compiler {
SymbolScope::Local => {
self.emit(Opcode::SetLocal, vec![symbol.index as i32]);
}
SymbolScope::Free => {
unimplemented!();
}
SymbolScope::Builtin => {
unreachable!("Builtin symbols should not be set, the compiler should panic before this")
}
Expand Down Expand Up @@ -172,35 +176,7 @@ impl Compiler {
self.emit(Opcode::Index, vec![]);
}
Expression::FunctionLiteral(fun) => {
self.enter_scope();

let num_parameters = fun.parameters.len();

for param in fun.parameters {
self.symbol_table.define(param.value);
}

self.compile_block_statement(fun.body)?;

if self.last_instruction_is(Opcode::Pop) {
self.replace_last_pop_with_return();
}
if !self.last_instruction_is(Opcode::ReturnValue) {
self.emit(Opcode::Return, vec![]);
}

let num_locals = self.symbol_table.num_definitions;
let instructions = self.leave_scope().data;

let compiled_function = Object::COMPILEDFUNCTION(CompiledFunction {
instructions,
num_locals,
num_parameters,
});
let operands = i32::from_usize(self.add_constant(compiled_function))
.ok_or("Invalid integer type")?;

self.emit(Opcode::Closure, vec![operands, 0]);
self.compile_function_literal(fun)?;
}
Expression::FunctionCall(call) => {
self.compile_expression(*call.function)?;
Expand Down Expand Up @@ -312,6 +288,50 @@ impl Compiler {
Ok(())
}

fn compile_function_literal(&mut self, fun: FunctionLiteral) -> Result<(), String> {
self.enter_scope();

let num_parameters = fun.parameters.len();

for param in fun.parameters {
self.symbol_table.define(param.value);
}

self.compile_block_statement(fun.body)?;

if self.last_instruction_is(Opcode::Pop) {
self.replace_last_pop_with_return();
}
if !self.last_instruction_is(Opcode::ReturnValue) {
self.emit(Opcode::Return, vec![]);
}

let free_symbols = self.symbol_table.free_symbols.clone();
let free_symbols_len = free_symbols.len();

let num_locals = self.symbol_table.num_definitions;
let instructions = self.leave_scope().data;

for symbol in free_symbols {
// Te symbols must be loaded after the scope is left, but
// we need to get them before leaving the scope.
self.load_symbol(&symbol);
}

let compiled_function = Object::COMPILEDFUNCTION(CompiledFunction {
instructions,
num_locals,
num_parameters,
});

let operands =
i32::from_usize(self.add_constant(compiled_function)).ok_or("Invalid integer type")?;

self.emit(Opcode::Closure, vec![operands, free_symbols_len as i32]);

Ok(())
}

fn last_instruction_is(&self, opcode: Opcode) -> bool {
match self.scopes[self.scope_index].last_instruction {
Some(ref last) => last.opcode == opcode,
Expand Down Expand Up @@ -384,15 +404,23 @@ impl Compiler {

fn enter_scope(&mut self) {
let scope = CompilerScope::default();
self.symbol_table = SymbolTable::new_enclosed(Rc::new(self.symbol_table.clone()));
self.symbol_table =
SymbolTable::new_enclosed(Rc::new(RefCell::new(self.symbol_table.clone())));
self.scopes.push(scope);
self.scope_index += 1;
}

fn leave_scope(&mut self) -> Instructions {
let instructions = self.current_instructions();

self.symbol_table = self.symbol_table.outer.as_ref().unwrap().as_ref().clone();
self.symbol_table = self
.symbol_table
.outer
.clone()
.unwrap()
.as_ref()
.clone()
.into_inner();

self.scopes.pop();
self.scope_index -= 1;
Expand All @@ -415,17 +443,14 @@ impl Compiler {
}

fn load_symbol(&mut self, symbol: &Symbol) {
match symbol.scope {
SymbolScope::Global => {
self.emit(Opcode::GetGlobal, vec![symbol.index as i32]);
}
SymbolScope::Local => {
self.emit(Opcode::GetLocal, vec![symbol.index as i32]);
}
SymbolScope::Builtin => {
self.emit(Opcode::GetBuiltin, vec![symbol.index as i32]);
}
}
let opcode = match symbol.scope {
SymbolScope::Global => Opcode::GetGlobal,
SymbolScope::Local => Opcode::GetLocal,
SymbolScope::Builtin => Opcode::GetBuiltin,
SymbolScope::Free => Opcode::GetFree,
};

self.emit(opcode, vec![symbol.index as i32]);
}

pub fn bytecode(&self) -> Bytecode {
Expand Down Expand Up @@ -484,7 +509,7 @@ pub mod tests {

assert_eq!(
compiler.symbol_table.outer,
Some(Rc::new(global_symbol_table.clone())),
Some(Rc::new(RefCell::new(global_symbol_table.clone()))),
"Compiler did not enclose symbol table when entering new scope"
);

Expand Down
159 changes: 159 additions & 0 deletions crates/compiler/src/compiler_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -905,4 +905,163 @@ pub mod tests {

run_compiler(tests);
}

#[test]

fn test_closures() {
let tests = vec![
CompilerTestCase {
input: r#"
fn(a){
fn(b){
a + b
}
}"#
.to_string(),
expected_constants: vec![
Object::COMPILEDFUNCTION(CompiledFunction {
instructions: flatten_u8_instructions(vec![
Opcode::GetFree.make(vec![0]),
Opcode::GetLocal.make(vec![0]),
Opcode::Add.make(vec![]),
Opcode::ReturnValue.make(vec![]),
]),
num_locals: 1,
num_parameters: 1,
}),
Object::COMPILEDFUNCTION(CompiledFunction {
instructions: flatten_u8_instructions(vec![
Opcode::GetLocal.make(vec![0]),
Opcode::Closure.make(vec![0, 1]),
Opcode::ReturnValue.make(vec![]),
]),
num_locals: 1,
num_parameters: 1,
}),
],
expected_instructions: flatten_instructions(vec![
Opcode::Closure.make(vec![1, 0]),
Opcode::Pop.make(vec![]),
]),
},
CompilerTestCase {
input: r#"
fn(a) {
fn(b) {
fn(c) {
a + b + c
}
}
};"#
.to_string(),

expected_constants: vec![
Object::COMPILEDFUNCTION(CompiledFunction {
instructions: flatten_u8_instructions(vec![
Opcode::GetFree.make(vec![0]),
Opcode::GetFree.make(vec![1]),
Opcode::Add.make(vec![]),
Opcode::GetLocal.make(vec![0]),
Opcode::Add.make(vec![]),
Opcode::ReturnValue.make(vec![]),
]),
num_locals: 1,
num_parameters: 1,
}),
Object::COMPILEDFUNCTION(CompiledFunction {
instructions: flatten_u8_instructions(vec![
Opcode::GetFree.make(vec![0]),
Opcode::GetLocal.make(vec![0]),
Opcode::Closure.make(vec![0, 2]),
Opcode::ReturnValue.make(vec![]),
]),
num_locals: 1,
num_parameters: 1,
}),
Object::COMPILEDFUNCTION(CompiledFunction {
instructions: flatten_u8_instructions(vec![
Opcode::GetLocal.make(vec![0]),
Opcode::Closure.make(vec![1, 1]),
Opcode::ReturnValue.make(vec![]),
]),
num_locals: 1,
num_parameters: 1,
}),
],
expected_instructions: flatten_instructions(vec![
Opcode::Closure.make(vec![2, 0]),
Opcode::Pop.make(vec![]),
]),
},
CompilerTestCase {
input: r#"
let global = 55;
fn() {
let a = 66;
fn() {
let b = 77;
fn() {
let c = 88;
global + a + b + c;
}
}
}
"#
.to_string(),
expected_constants: vec![
Object::INTEGER(55),
Object::INTEGER(66),
Object::INTEGER(77),
Object::INTEGER(88),
Object::COMPILEDFUNCTION(CompiledFunction {
instructions: flatten_u8_instructions(vec![
Opcode::Constant.make(vec![3]),
Opcode::SetLocal.make(vec![0]),
Opcode::GetGlobal.make(vec![0]),
Opcode::GetFree.make(vec![0]),
Opcode::Add.make(vec![]),
Opcode::GetFree.make(vec![1]),
Opcode::Add.make(vec![]),
Opcode::GetLocal.make(vec![0]),
Opcode::Add.make(vec![]),
Opcode::ReturnValue.make(vec![]),
]),
num_locals: 1,
num_parameters: 0,
}),
Object::COMPILEDFUNCTION(CompiledFunction {
instructions: flatten_u8_instructions(vec![
Opcode::Constant.make(vec![2]),
Opcode::SetLocal.make(vec![0]),
Opcode::GetFree.make(vec![0]),
Opcode::GetLocal.make(vec![0]),
Opcode::Closure.make(vec![4, 2]),
Opcode::ReturnValue.make(vec![]),
]),
num_locals: 1,
num_parameters: 0,
}),
Object::COMPILEDFUNCTION(CompiledFunction {
instructions: flatten_u8_instructions(vec![
Opcode::Constant.make(vec![1]),
Opcode::SetLocal.make(vec![0]),
Opcode::GetLocal.make(vec![0]),
Opcode::Closure.make(vec![5, 1]),
Opcode::ReturnValue.make(vec![]),
]),
num_locals: 1,
num_parameters: 0,
}),
],
expected_instructions: flatten_instructions(vec![
Opcode::Constant.make(vec![0]),
Opcode::SetGlobal.make(vec![0]),
Opcode::Closure.make(vec![6, 0]),
Opcode::Pop.make(vec![]),
]),
},
];

run_compiler(tests);
}
}
Loading

0 comments on commit 17eec25

Please sign in to comment.