diff --git a/src/compiler/compiler.rs b/src/compiler/compiler.rs index 74786a42..98126a94 100644 --- a/src/compiler/compiler.rs +++ b/src/compiler/compiler.rs @@ -10,7 +10,13 @@ use crate::{ }, interpreter::InterpreterTraceGenerator, parser::{ - ast::{debug_sym_factory::DebugSymRefFactory, tl::TLDecl, Identifiable, Identifier}, + ast::{ + debug_sym_factory::DebugSymRefFactory, + expression::Expression, + statement::{Statement, TypedIdDecl}, + tl::TLDecl, + DebugSymRef, Identifiable, Identifier, + }, lang::TLDeclsParser, }, plonkish::{self, compiler::PlonkishCompilationResult}, @@ -21,7 +27,7 @@ use crate::{ use super::{ semantic::{SymTable, SymbolCategory}, - setup_inter::{interpret, Setup}, + setup_inter::{interpret, MachineSetup, Setup}, Config, Message, Messages, }; @@ -60,7 +66,10 @@ pub(super) struct Compiler { impl Compiler { /// Creates a configured compiler. - pub fn new(config: Config) -> Self { + pub fn new(mut config: Config) -> Self { + if config.max_steps == 0 { + config.max_steps = 1000; // TODO: organise this better + } Compiler { config, ..Compiler::default() @@ -76,6 +85,7 @@ impl Compiler { let ast = self .parse(source, debug_sym_ref_factory) .map_err(|_| self.messages.clone())?; + let ast = self.add_virtual(ast); let symbols = self.semantic(&ast).map_err(|_| self.messages.clone())?; let setup = Self::interpret(&ast, &symbols); let setup = Self::map_consts(setup); @@ -87,8 +97,12 @@ impl Compiler { circuit }; - let circuit = - circuit.with_trace(InterpreterTraceGenerator::new(ast, symbols, self.mapping)); + let circuit = circuit.with_trace(InterpreterTraceGenerator::new( + ast, + symbols, + self.mapping, + self.config.max_steps, + )); Ok(CompilerResult { messages: self.messages, @@ -114,6 +128,113 @@ impl Compiler { } } + fn add_virtual( + &mut self, + mut ast: Vec>, + ) -> Vec> { + for tldc in ast.iter_mut() { + match tldc { + TLDecl::MachineDecl { + dsym, + id: _, + input_params: _, + output_params, + block, + } => self.add_virtual_to_machine(dsym, output_params, block), + } + } + + ast + } + + fn add_virtual_to_machine( + &mut self, + dsym: &DebugSymRef, + output_params: &Vec>, + block: &mut Statement, + ) { + let dsym = DebugSymRef::into_virtual(dsym); + let output_params = Self::get_decls(output_params); + + if let Statement::Block(_, stmts) = block { + let mut has_final = false; + + for stmt in stmts.iter() { + if let Statement::StateDecl(_, id, _) = stmt + && id.name() == "final" + { + has_final = true + } + } + if !has_final { + stmts.push(Statement::StateDecl( + dsym.clone(), + Identifier::new("final", dsym.clone()), + Box::new(Statement::Block(dsym.clone(), vec![])), + )); + } + + let final_state = Self::find_state_mut("final", stmts).unwrap(); + + let mut padding_transitions = output_params + .iter() + .map(|output_signal| { + Statement::SignalAssignmentAssert( + dsym.clone(), + vec![output_signal.id.next()], + vec![Expression::Query::( + dsym.clone(), + output_signal.id.clone(), + )], + ) + }) + .collect::>(); + + padding_transitions.push(Statement::Transition( + dsym.clone(), + Identifier::new("__padding", dsym.clone()), + Box::new(Statement::Block(dsym.clone(), vec![])), + )); + + Self::add_virtual_to_state(final_state, padding_transitions.clone()); + + stmts.push(Statement::StateDecl( + dsym.clone(), + Identifier::new("__padding", dsym.clone()), + Box::new(Statement::Block(dsym.clone(), padding_transitions)), + )); + } // Semantic analyser must show an error in the else case + } + + fn find_state_mut>( + state_id: S, + stmts: &mut [Statement], + ) -> Option<&mut Statement> { + let state_id = state_id.into(); + let mut final_state: Option<&mut Statement> = None; + + for stmt in stmts.iter_mut() { + if let Statement::StateDecl(_, id, _) = stmt + && id.name() == state_id + { + final_state = Some(stmt) + } + } + + final_state + } + + fn add_virtual_to_state( + state: &mut Statement, + add_statements: Vec>, + ) { + if let Statement::StateDecl(_, _, final_state_stmts) = state { + if let Statement::Block(_, stmts) = final_state_stmts.as_mut() { + stmts.extend(add_statements) + } + } + } + fn semantic(&mut self, ast: &[TLDecl]) -> Result { let result = super::semantic::analyser::analyse(ast); let has_errors = result.messages.has_errors(); @@ -127,26 +248,26 @@ impl Compiler { } } - fn interpret( - ast: &[TLDecl], - symbols: &SymTable, - ) -> Setup { + fn interpret(ast: &[TLDecl], symbols: &SymTable) -> Setup { interpret(ast, symbols) } - fn map_consts(setup: Setup) -> Setup { + fn map_consts(setup: Setup) -> Setup { setup .iter() .map(|(machine_id, machine)| { - let new_machine: HashMap>> = machine - .iter() + let poly_constraints: HashMap>> = machine + .iter_states_poly_constraints() .map(|(step_id, step)| { - let new_step = step.iter().map(|pi| Self::map_pi_consts(pi)).collect(); + let new_step: Vec> = + step.iter().map(|pi| Self::map_pi_consts(pi)).collect(); (step_id.clone(), new_step) }) .collect(); + let new_machine: MachineSetup = + machine.replace_poly_constraints(poly_constraints); (machine_id.clone(), new_machine) }) .collect() @@ -166,17 +287,17 @@ impl Compiler { } } - fn build( - &mut self, - setup: &Setup, - symbols: &SymTable, - ) -> SBPIR { + fn build(&mut self, setup: &Setup, symbols: &SymTable) -> SBPIR { circuit::("circuit", |ctx| { for (machine_id, machine) in setup { self.add_forwards(ctx, symbols, machine_id); self.add_step_type_handlers(ctx, symbols, machine_id); - for state_id in machine.keys() { + ctx.pragma_num_steps(self.config.max_steps); + ctx.pragma_first_step(self.mapping.get_step_type_handler(machine_id, "initial")); + ctx.pragma_last_step(self.mapping.get_step_type_handler(machine_id, "__padding")); + + for state_id in machine.states() { ctx.step_type_def( self.mapping.get_step_type_handler(machine_id, state_id), |ctx| { @@ -239,11 +360,15 @@ impl Compiler { fn translate_queries( &mut self, symbols: &SymTable, - setup: &Setup, + setup: &Setup, machine_id: &str, state_id: &str, ) -> Vec, ()>> { - let exprs = setup.get(machine_id).unwrap().get(state_id).unwrap(); + let exprs = setup + .get(machine_id) + .unwrap() + .get_poly_constraints(state_id) + .unwrap(); exprs .iter() @@ -465,6 +590,18 @@ impl Compiler { } } } + + fn get_decls(stmts: &Vec>) -> Vec> { + let mut result: Vec> = vec![]; + + for stmt in stmts { + if let Statement::SignalDecl(_, ids) = stmt { + result.extend(ids.clone()) + } + } + + result + } } // Basic signal factory. diff --git a/src/compiler/mod.rs b/src/compiler/mod.rs index 5cd5f1f8..2475f8fc 100644 --- a/src/compiler/mod.rs +++ b/src/compiler/mod.rs @@ -19,6 +19,7 @@ mod setup_inter; #[derive(Default)] pub struct Config { pub(self) max_degree: Option, + pub(self) max_steps: usize, } impl Config { @@ -27,6 +28,12 @@ impl Config { self } + + pub fn max_steps(mut self, steps: usize) -> Self { + self.max_steps = steps; + + self + } } /// Compiler message. diff --git a/src/compiler/semantic/analyser.rs b/src/compiler/semantic/analyser.rs index 8a6abb55..3664608d 100644 --- a/src/compiler/semantic/analyser.rs +++ b/src/compiler/semantic/analyser.rs @@ -181,25 +181,6 @@ impl Analyser { } else { unreachable!("the parser should produce machine declaration with a block"); } - - if self - .symbols - .get_symbol(&self.cur_scope, "final".to_string()) - .is_none() - { - let id = Identifier("final".to_string(), 0, block.get_dsym()); - let stmt = Statement::StateDecl( - block.get_dsym(), - id.clone(), - Box::new(Statement::Block(block.get_dsym(), vec![])), - ); - - let sym = SymTableEntry::new(id.name(), block.get_dsym(), SymbolCategory::State, None); - - RULES.apply_new_symbol_statement(self, &stmt, &id, &sym); - - self.symbols.add_symbol(&self.cur_scope, id.name(), sym); - } } fn analyse_state(&mut self, id: Identifier, block: Statement) { @@ -406,7 +387,7 @@ mod test { assert_eq!( format!("{:?}", result), - r#"AnalysisResult { symbols: "/": ScopeTable { symbols: "\"fibo\": SymTableEntry { id: \"fibo\", definition_ref: DebugSymRef { line: 2, cols: \"17-21\" }, usages: [], category: Machine, ty: None }", scope: Global },"//fibo": ScopeTable { symbols: "\"a\": SymTableEntry { id: \"a\", definition_ref: DebugSymRef { line: 5, cols: \"20-21\" }, usages: [DebugSymRef { line: 13, cols: \"17-18\" }, DebugSymRef { line: 16, cols: \"15-17\" }, DebugSymRef { line: 23, cols: \"20-21\" }, DebugSymRef { line: 31, cols: \"20-22\" }], category: Signal, ty: Some(\"field\") },\"b\": SymTableEntry { id: \"b\", definition_ref: DebugSymRef { line: 2, cols: \"40-41\" }, usages: [DebugSymRef { line: 13, cols: \"20-21\" }, DebugSymRef { line: 16, cols: \"30-31\" }, DebugSymRef { line: 16, cols: \"19-21\" }, DebugSymRef { line: 23, cols: \"24-25\" }, DebugSymRef { line: 27, cols: \"20-22\" }, DebugSymRef { line: 31, cols: \"42-43\" }, DebugSymRef { line: 31, cols: \"24-26\" }], category: OutputSignal, ty: Some(\"field\") },\"final\": SymTableEntry { id: \"final\", definition_ref: DebugSymRef { start: \"2:50\", end: \"40:13\" }, usages: [DebugSymRef { line: 26, cols: \"18-23\" }], category: State, ty: None },\"i\": SymTableEntry { id: \"i\", definition_ref: DebugSymRef { line: 5, cols: \"30-31\" }, usages: [DebugSymRef { line: 13, cols: \"14-15\" }, DebugSymRef { line: 25, cols: \"17-18\" }, DebugSymRef { line: 27, cols: \"31-32\" }, DebugSymRef { line: 27, cols: \"16-18\" }, DebugSymRef { line: 31, cols: \"35-36\" }, DebugSymRef { line: 31, cols: \"16-18\" }], category: Signal, ty: None },\"initial\": SymTableEntry { id: \"initial\", definition_ref: DebugSymRef { line: 10, cols: \"19-26\" }, usages: [], category: State, ty: None },\"middle\": SymTableEntry { id: \"middle\", definition_ref: DebugSymRef { line: 20, cols: \"19-25\" }, usages: [DebugSymRef { line: 15, cols: \"17-23\" }, DebugSymRef { line: 30, cols: \"18-24\" }], category: State, ty: None },\"n\": SymTableEntry { id: \"n\", definition_ref: DebugSymRef { line: 2, cols: \"29-30\" }, usages: [DebugSymRef { line: 16, cols: \"36-37\" }, DebugSymRef { line: 16, cols: \"23-25\" }, DebugSymRef { line: 25, cols: \"26-27\" }, DebugSymRef { line: 27, cols: \"41-42\" }, DebugSymRef { line: 27, cols: \"24-26\" }, DebugSymRef { line: 31, cols: \"48-49\" }, DebugSymRef { line: 31, cols: \"28-30\" }], category: InputSignal, ty: None }", scope: Machine },"//fibo/final": ScopeTable { symbols: "", scope: State },"//fibo/initial": ScopeTable { symbols: "\"c\": SymTableEntry { id: \"c\", definition_ref: DebugSymRef { line: 11, cols: \"21-22\" }, usages: [DebugSymRef { line: 13, cols: \"23-24\" }, DebugSymRef { line: 16, cols: \"33-34\" }], category: Signal, ty: None }", scope: State },"//fibo/middle": ScopeTable { symbols: "\"c\": SymTableEntry { id: \"c\", definition_ref: DebugSymRef { line: 21, cols: \"21-22\" }, usages: [DebugSymRef { line: 23, cols: \"14-15\" }, DebugSymRef { line: 27, cols: \"38-39\" }, DebugSymRef { line: 31, cols: \"45-46\" }], category: Signal, ty: None }", scope: State }, messages: [] }"# - ) + r#"AnalysisResult { symbols: /: ScopeTable { symbols: "fibo: SymTableEntry { id: \"fibo\", definition_ref: nofile:2:17, usages: [], category: Machine, ty: None }", scope: Global },//fibo: ScopeTable { symbols: "a: SymTableEntry { id: \"a\", definition_ref: nofile:5:20, usages: [nofile:13:17, nofile:16:15, nofile:23:20, nofile:31:20], category: Signal, ty: Some(\"field\") },b: SymTableEntry { id: \"b\", definition_ref: nofile:2:40, usages: [nofile:13:20, nofile:16:30, nofile:16:19, nofile:23:24, nofile:27:20, nofile:31:42, nofile:31:24], category: OutputSignal, ty: Some(\"field\") },i: SymTableEntry { id: \"i\", definition_ref: nofile:5:30, usages: [nofile:13:14, nofile:25:17, nofile:27:31, nofile:27:16, nofile:31:35, nofile:31:16], category: Signal, ty: None },initial: SymTableEntry { id: \"initial\", definition_ref: nofile:10:19, usages: [], category: State, ty: None },middle: SymTableEntry { id: \"middle\", definition_ref: nofile:20:19, usages: [nofile:15:17, nofile:30:18], category: State, ty: None },n: SymTableEntry { id: \"n\", definition_ref: nofile:2:29, usages: [nofile:16:36, nofile:16:23, nofile:25:26, nofile:27:41, nofile:27:24, nofile:31:48, nofile:31:28], category: InputSignal, ty: None }", scope: Machine },//fibo/initial: ScopeTable { symbols: "c: SymTableEntry { id: \"c\", definition_ref: nofile:11:21, usages: [nofile:13:23, nofile:16:33], category: Signal, ty: None }", scope: State },//fibo/middle: ScopeTable { symbols: "c: SymTableEntry { id: \"c\", definition_ref: nofile:21:21, usages: [nofile:23:14, nofile:27:38, nofile:31:45], category: Signal, ty: None }", scope: State }, messages: [] }"# + ); } } diff --git a/src/compiler/semantic/mod.rs b/src/compiler/semantic/mod.rs index 0d6a6d8b..47d08aaa 100644 --- a/src/compiler/semantic/mod.rs +++ b/src/compiler/semantic/mod.rs @@ -156,7 +156,7 @@ impl Debug for ScopeTable { .symbols .keys() .sorted() - .map(|id| format!("\"{}\": {:?}", id, self.symbols[id])) + .map(|id| format!("{}: {:?}", id, self.symbols[id])) .collect::>() .join(","); @@ -232,7 +232,7 @@ impl Debug for SymTable { .scopes .keys() .sorted() - .map(|scope| format!("\"{}\": {:?}", scope, self.scopes[scope])) + .map(|scope| format!("{}: {:?}", scope, self.scopes[scope])) .collect::>() .join(","); @@ -531,8 +531,6 @@ mod test { let test_cases = [ (396, "a"), (397, "a"), - (395, "final"), - (398, "final"), (460, "a"), (584, "a"), (772, "a"), @@ -566,6 +564,7 @@ mod test { ]; for (offset, expected_id) in test_cases { + println!("{} {}", offset, expected_id); let SymTableEntry { id, .. } = result .symbols .find_symbol_by_offset("some".to_string(), offset) diff --git a/src/compiler/semantic/rules.rs b/src/compiler/semantic/rules.rs index f3bc8207..5bf846e8 100644 --- a/src/compiler/semantic/rules.rs +++ b/src/compiler/semantic/rules.rs @@ -479,7 +479,7 @@ mod test { assert_eq!( format!("{:?}", result.messages), - r#"[SemErr { msg: "use of undeclared variable a", dsym: DebugSymRef { line: 23, cols: "20-21" } }]"# + r#"[SemErr { msg: "use of undeclared variable a", dsym: nofile:23:20 }]"# ) } @@ -536,7 +536,7 @@ mod test { assert_eq!( format!("{:?}", result.messages), - r#"[SemErr { msg: "There cannot be rotation in identifier declaration of fibo", dsym: DebugSymRef { start: "2:9", end: "40:13" } }]"# + r#"[SemErr { msg: "There cannot be rotation in identifier declaration of fibo", dsym: nofile:2:9 }]"# ); let circuit = " @@ -589,7 +589,7 @@ mod test { assert_eq!( format!("{:?}", result.messages), - r#"[SemErr { msg: "There cannot be rotation in identifier declaration of initial", dsym: DebugSymRef { start: "10:12", end: "18:14" } }]"# + r#"[SemErr { msg: "There cannot be rotation in identifier declaration of initial", dsym: nofile:10:12 }]"# ); let circuit = " @@ -642,7 +642,7 @@ mod test { assert_eq!( format!("{:?}", result.messages), - r#"[SemErr { msg: "There cannot be rotation in identifier declaration of c", dsym: DebugSymRef { line: 11, cols: "13-23" } }]"# + r#"[SemErr { msg: "There cannot be rotation in identifier declaration of c", dsym: nofile:11:13 }]"# ) } @@ -703,7 +703,7 @@ mod test { assert_eq!( format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot declare state nested here", dsym: DebugSymRef { start: "13:17", end: "15:18" } }]"# + r#"[SemErr { msg: "Cannot declare state nested here", dsym: nofile:13:17 }]"# ); let circuit = " @@ -760,7 +760,7 @@ mod test { assert_eq!( format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot declare state nested here", dsym: DebugSymRef { start: "18:1", end: "20:29" } }]"# + r#"[SemErr { msg: "Cannot declare state nested here", dsym: nofile:18:1 }]"# ); } @@ -818,7 +818,7 @@ mod test { assert_eq!( format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot assign with <-- or <== to variable wrong with category WGVar, you can only assign to signals. Use = instead.", dsym: DebugSymRef { line: 14, cols: "14-50" } }]"# + r#"[SemErr { msg: "Cannot assign with <-- or <== to variable wrong with category WGVar, you can only assign to signals. Use = instead.", dsym: nofile:14:14 }]"# ); } @@ -878,7 +878,7 @@ mod test { assert_eq!( format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot use wgvar wrong in statement assert wrong == 3;", dsym: DebugSymRef { line: 24, cols: "14-32" } }, SemErr { msg: "Cannot use wgvar wrong in statement [c] <== [(a + b) + wrong];", dsym: DebugSymRef { line: 26, cols: "14-34" } }]"# + r#"[SemErr { msg: "Cannot use wgvar wrong in statement assert wrong == 3;", dsym: nofile:24:14 }, SemErr { msg: "Cannot use wgvar wrong in statement [c] <== [(a + b) + wrong];", dsym: nofile:26:14 }]"# ) } @@ -943,7 +943,7 @@ mod test { assert_eq!( format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot declare [i, a, b, c] <== [1, 1, 1, 2]; in the machine, only states, wgvars and signals are allowed", dsym: DebugSymRef { start: "2:9", end: "48:13" } }, SemErr { msg: "Cannot declare if (i + 1) == n { [a] <-- [3]; } else { [b] <== [3]; } in the machine, only states, wgvars and signals are allowed", dsym: DebugSymRef { start: "2:9", end: "48:13" } }]"# + r#"[SemErr { msg: "Cannot declare [i, a, b, c] <== [1, 1, 1, 2]; in the machine, only states, wgvars and signals are allowed", dsym: nofile:2:9 }, SemErr { msg: "Cannot declare if (i + 1) == n { [a] <-- [3]; } else { [b] <== [3]; } in the machine, only states, wgvars and signals are allowed", dsym: nofile:2:9 }]"# ); } @@ -1009,7 +1009,7 @@ mod test { assert_eq!( format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot redeclare middle in the same scope [\"/\", \"fibo\"]", dsym: DebugSymRef { start: "28:13", end: "43:14" } }, SemErr { msg: "Cannot redeclare n in the same scope [\"/\", \"fibo\"]", dsym: DebugSymRef { line: 20, cols: "13-22" } }, SemErr { msg: "Cannot redeclare c in the same scope [\"/\", \"fibo\", \"middle\"]", dsym: DebugSymRef { line: 30, cols: "14-23" } }]"# + r#"[SemErr { msg: "Cannot redeclare middle in the same scope [\"/\", \"fibo\"]", dsym: nofile:28:13 }, SemErr { msg: "Cannot redeclare n in the same scope [\"/\", \"fibo\"]", dsym: nofile:20:13 }, SemErr { msg: "Cannot redeclare c in the same scope [\"/\", \"fibo\", \"middle\"]", dsym: nofile:30:14 }]"# ); } @@ -1066,7 +1066,7 @@ mod test { assert_eq!( format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot declare n with type uint, only field and bool are allowed.", dsym: DebugSymRef { line: 2, cols: "22-36" } }, SemErr { msg: "Cannot declare c with type int, only field and bool are allowed.", dsym: DebugSymRef { line: 21, cols: "14-28" } }]"# + r#"[SemErr { msg: "Cannot declare n with type uint, only field and bool are allowed.", dsym: nofile:2:22 }, SemErr { msg: "Cannot declare c with type int, only field and bool are allowed.", dsym: nofile:21:14 }]"# ); } @@ -1132,7 +1132,7 @@ mod test { assert_eq!( format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot use true in expression 2 + true", dsym: DebugSymRef { line: 15, cols: "42-46" } }, SemErr { msg: "Cannot use true in expression 1 * true", dsym: DebugSymRef { line: 32, cols: "24-28" } }, SemErr { msg: "Cannot use false in expression false - 123", dsym: DebugSymRef { line: 32, cols: "31-36" } }, SemErr { msg: "Cannot use false in expression false * false", dsym: DebugSymRef { line: 32, cols: "50-55" } }, SemErr { msg: "Cannot use false in expression false * false", dsym: DebugSymRef { line: 32, cols: "58-63" } }]"# + r#"[SemErr { msg: "Cannot use true in expression 2 + true", dsym: nofile:15:42 }, SemErr { msg: "Cannot use true in expression 1 * true", dsym: nofile:32:24 }, SemErr { msg: "Cannot use false in expression false - 123", dsym: nofile:32:31 }, SemErr { msg: "Cannot use false in expression false * false", dsym: nofile:32:50 }, SemErr { msg: "Cannot use false in expression false * false", dsym: nofile:32:58 }]"# ); } @@ -1209,7 +1209,7 @@ mod test { assert_eq!( format!("{:?}", result.messages), - r#"[SemErr { msg: "Condition i + 1 in if statement must be a logic expression", dsym: DebugSymRef { start: "36:14", end: "53:15" } }, SemErr { msg: "Signal c in if statement condition must be bool", dsym: DebugSymRef { start: "37:17", end: "39:18" } }, SemErr { msg: "Condition 4 in if statement must be a logic expression", dsym: DebugSymRef { start: "43:17", end: "45:18" } }]"# + r#"[SemErr { msg: "Condition i + 1 in if statement must be a logic expression", dsym: nofile:36:14 }, SemErr { msg: "Signal c in if statement condition must be bool", dsym: nofile:37:17 }, SemErr { msg: "Condition 4 in if statement must be a logic expression", dsym: nofile:43:17 }]"# ); } @@ -1270,7 +1270,7 @@ mod test { assert_eq!( format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot assign with = to Signal i, you can only assign to WGVars. Use <-- or <== instead.", dsym: DebugSymRef { line: 15, cols: "14-20" } }]"# + r#"[SemErr { msg: "Cannot assign with = to Signal i, you can only assign to WGVars. Use <-- or <== instead.", dsym: nofile:15:14 }]"# ); } } diff --git a/src/compiler/setup_inter.rs b/src/compiler/setup_inter.rs index 4c3120ff..8647b8c5 100644 --- a/src/compiler/setup_inter.rs +++ b/src/compiler/setup_inter.rs @@ -1,18 +1,20 @@ use std::collections::HashMap; +use itertools::Itertools; use num_bigint::BigInt; use crate::{ - parser::ast::{statement::Statement, tl::TLDecl, DebugSymRef, Identifiable, Identifier}, + parser::ast::{ + statement::{Statement, TypedIdDecl}, + tl::TLDecl, + DebugSymRef, Identifiable, Identifier, + }, poly::Expr, }; use super::{abepi::CompilationUnit, semantic::SymTable}; -pub(super) fn interpret( - ast: &[TLDecl], - _symbols: &SymTable, -) -> Setup { +pub(super) fn interpret(ast: &[TLDecl], _symbols: &SymTable) -> Setup { let mut interpreter = SetupInterpreter { abepi: CompilationUnit::default(), setup: Setup::default(), @@ -25,12 +27,99 @@ pub(super) fn interpret( interpreter.setup } -pub(super) type Setup = HashMap>>>; +pub(super) type Setup = HashMap>; -pub(super) struct SetupInterpreter { +pub(super) struct MachineSetup { + poly_constraints: HashMap>>, + + input_signals: Vec>, + output_signals: Vec>, +} + +impl Default for MachineSetup { + fn default() -> Self { + Self { + poly_constraints: HashMap::new(), + input_signals: vec![], + output_signals: vec![], + } + } +} + +impl MachineSetup { + fn new( + inputs: Vec>, + outputs: Vec>, + ) -> Self { + let mut created = Self::default(); + + for input in inputs { + if let Statement::SignalDecl(_, ids) = input { + created.input_signals.extend(ids) + } + } + + for output in outputs { + if let Statement::SignalDecl(_, ids) = output { + created.output_signals.extend(ids) + } + } + + created + } + + fn new_state>(&mut self, id: S) { + self.poly_constraints.insert(id.into(), vec![]); + } + + fn _has_state>(&self, id: S) -> bool { + self.poly_constraints.contains_key(&id.into()) + } + + fn add_poly_constraints>( + &mut self, + state: S, + poly_constraints: Vec>, + ) { + self.poly_constraints + .get_mut(&state.into()) + .unwrap() + .extend(poly_constraints); + } + + pub(super) fn iter_states_poly_constraints( + &self, + ) -> std::collections::hash_map::Iter>> { + self.poly_constraints.iter() + } + + pub(super) fn replace_poly_constraints( + &self, + poly_constraints: HashMap>>, + ) -> MachineSetup { + MachineSetup { + poly_constraints, + input_signals: self.input_signals.clone(), + output_signals: self.output_signals.clone(), + } + } + + pub(super) fn states(&self) -> Vec<&String> { + self.poly_constraints.keys().collect_vec() + } + + pub(super) fn get_poly_constraints>( + &self, + state: S, + ) -> Option<&Vec>> { + self.poly_constraints.get(&state.into()) + } +} + +struct SetupInterpreter { abepi: CompilationUnit, - setup: Setup, + setup: Setup, current_machine: String, current_state: String, @@ -52,24 +141,18 @@ impl SetupInterpreter { fn interpret_machine( &mut self, - dsym: &DebugSymRef, + _dsym: &DebugSymRef, id: &Identifier, - _input_params: &[Statement], - _output_params: &[Statement], + input_params: &[Statement], + output_params: &[Statement], block: &Statement, ) { self.current_machine = id.name(); - self.setup.insert(id.name(), HashMap::default()); + self.setup.insert( + id.name(), + MachineSetup::new(input_params.to_owned(), output_params.to_owned()), + ); self.interpret_machine_statement(block); - - // There is a final state that is empty by default - if !self.setup.get(&id.name()).unwrap().contains_key("final") { - self.interpret_state_decl( - dsym, - &Identifier::new("final", dsym.clone()), - &Statement::Block(dsym.clone(), vec![]), - ) - } } fn interpret_machine_statement(&mut self, stmt: &Statement) { @@ -97,7 +180,7 @@ impl SetupInterpreter { self.setup .get_mut(&self.current_machine) .unwrap() - .insert(id.name(), Vec::default()); + .new_state(id.name()); self.interpret_state_statement(stmt); } @@ -122,15 +205,13 @@ impl SetupInterpreter { SignalDecl(_, _) | WGVarDecl(_, _) => vec![], }; - self.add_pis(result.into_iter().map(|cr| cr.anti_booly).collect()); + self.add_poly_constraints(result.into_iter().map(|cr| cr.anti_booly).collect()); } - fn add_pis(&mut self, pis: Vec>) { + fn add_poly_constraints(&mut self, pis: Vec>) { self.setup .get_mut(&self.current_machine) .unwrap() - .get_mut(&self.current_state) - .unwrap() - .extend(pis); + .add_poly_constraints(&self.current_state, pis); } } diff --git a/src/interpreter/frame.rs b/src/interpreter/frame.rs index 35d11a37..7050755b 100644 --- a/src/interpreter/frame.rs +++ b/src/interpreter/frame.rs @@ -127,7 +127,7 @@ impl<'a, F: Field + Hash> StackFrame<'a, F> { self.enter_state("initial"); } - fn enter_state>(&mut self, next_state: S) { + pub(super) fn enter_state>(&mut self, next_state: S) { self.cur_state = next_state.into(); self.lex_scope.push(self.cur_state.clone()); self.scopes diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs index 94bb3394..71c0d043 100644 --- a/src/interpreter/mod.rs +++ b/src/interpreter/mod.rs @@ -22,15 +22,17 @@ mod value; struct Interpreter<'a, F: Field + Hash> { mapping: &'a SymbolSignalMapping, cur_frame: StackFrame<'a, F>, + num_steps: usize, witness: Vec>, } impl<'a, F: Field + Hash> Interpreter<'a, F> { - fn new(symbols: &'a SymTable, mapping: &'a SymbolSignalMapping) -> Self { + fn new(symbols: &'a SymTable, mapping: &'a SymbolSignalMapping, num_steps: usize) -> Self { Self { mapping, cur_frame: StackFrame::new(symbols), + num_steps, witness: Vec::default(), } } @@ -88,17 +90,20 @@ impl<'a, F: Field + Hash> Interpreter<'a, F> { ); } - while next_state.is_some() { - next_state = self.exec_step(dsym, &machine_block)?; + while next_state.is_some() && self.witness.len() < self.num_steps { + next_state = self.exec_step(&machine_block)?; self.transition(&next_state); - if next_state.is_none() && self.cur_frame.get_state().as_str() != "final" { - panic!( - "last state is not final state but {}", - self.cur_frame.get_state() - ); - } + println!("{}", self.witness.len()) + } + + self.cur_frame.enter_state("__padding"); + + while self.witness.len() <= self.num_steps { + self.exec_step(&machine_block)?; + self.transition(&Some("__padding".to_string())); + println!("{}", self.witness.len()) } Ok(()) @@ -106,10 +111,9 @@ impl<'a, F: Field + Hash> Interpreter<'a, F> { fn exec_step( &mut self, - machine_dsym: &DebugSymRef, machine_block: &[Statement], ) -> Result, Message> { - let state_decl = self.find_state_decl(machine_dsym, machine_block).unwrap(); + let state_decl = self.find_state_decl(machine_block).unwrap(); if let Statement::StateDecl(_, _, block) = state_decl { if let Statement::Block(_, stmts) = *block { @@ -220,7 +224,6 @@ impl<'a, F: Field + Hash> Interpreter<'a, F> { fn find_state_decl( &mut self, - machine_dsym: &DebugSymRef, machine_block: &[Statement], ) -> Option> { for stmt in machine_block { @@ -231,16 +234,7 @@ impl<'a, F: Field + Hash> Interpreter<'a, F> { } } - // final state can be omited - if self.cur_frame.get_state() == "final" { - Some(Statement::StateDecl( - machine_dsym.clone(), - Identifier::new(self.cur_frame.get_state(), machine_dsym.clone()), - Box::new(Statement::Block(machine_dsym.clone(), vec![])), - )) - } else { - None - } + None } } @@ -252,9 +246,10 @@ pub fn run( program: &[TLDecl], symbols: &SymTable, mapping: &SymbolSignalMapping, + num_steps: usize, input: HashMap, ) -> Result, Message> { - let mut inter = Interpreter::::new(symbols, mapping); + let mut inter = Interpreter::::new(symbols, mapping, num_steps); inter.run(program, input) } @@ -265,6 +260,7 @@ pub struct InterpreterTraceGenerator { program: Vec>, symbols: SymTable, mapping: SymbolSignalMapping, + num_steps: usize, } impl InterpreterTraceGenerator { @@ -272,11 +268,13 @@ impl InterpreterTraceGenerator { program: Vec>, symbols: SymTable, mapping: SymbolSignalMapping, + num_steps: usize, ) -> Self { Self { program, symbols, mapping, + num_steps, } } @@ -299,7 +297,14 @@ impl TraceGenerator for InterpreterTraceGenerator { type TraceArgs = HashMap; fn generate(&self, args: Self::TraceArgs) -> TraceWitness { - run(&self.program, &self.symbols, &self.mapping, args).unwrap_or_else(|msgs| { + run( + &self.program, + &self.symbols, + &self.mapping, + self.num_steps, + args, + ) + .unwrap_or_else(|msgs| { panic!("errors when running wg interpreter: {:?}", msgs); }) } @@ -383,8 +388,12 @@ mod test { } "; - let compiled = - compile::(code, Config::default(), &DebugSymRefFactory::new("", code)).unwrap(); + let compiled = compile::( + code, + Config::default().max_steps(20), + &DebugSymRefFactory::new("", code), + ) + .unwrap(); let result = compiled .circuit @@ -439,10 +448,12 @@ mod test { } "; - let mut chiquito = - compile::(code, Config::default(), &DebugSymRefFactory::new("", code)).unwrap(); - - chiquito.circuit.num_steps = 12; + let chiquito = compile::( + code, + Config::default().max_steps(20), + &DebugSymRefFactory::new("", code), + ) + .unwrap(); let mut plonkish = chiquito.plonkish(config( SingleRowCellManager {}, @@ -467,6 +478,7 @@ mod test { halo2_prover.get_vk(), instance, ); + assert!(result.is_ok()); } @@ -515,11 +527,9 @@ mod test { } "; - let mut chiquito = + let chiquito = compile::(code, Config::default(), &DebugSymRefFactory::new("", code)).unwrap(); - chiquito.circuit.num_steps = 12; - // TODO: re-stablish evil witness // chiquito // .wit_gen diff --git a/src/parser/ast/mod.rs b/src/parser/ast/mod.rs index 1656e2de..9a855436 100644 --- a/src/parser/ast/mod.rs +++ b/src/parser/ast/mod.rs @@ -12,16 +12,36 @@ pub mod tl; #[derive(Clone)] pub struct DebugSymRef { /// Starting byte number in the file - pub start: usize, + start: usize, /// Ending byte number in the file - pub end: usize, + end: usize, /// Source file reference file: Arc>, + /// Virtual: created by the compiler, not present in source code + virt: bool, } impl DebugSymRef { pub fn new(start: usize, end: usize, file: Arc>) -> DebugSymRef { - DebugSymRef { start, end, file } + DebugSymRef { + start, + end, + file, + virt: false, + } + } + + /// Convert to virtual: created by the compiler, not present in source code. + pub fn into_virtual(other: &DebugSymRef) -> DebugSymRef { + let mut other = other.clone(); + other.virt = true; + + other + } + + /// Returns whether it is virtual: created by the compiler, not present in source code. + pub fn is_virtual(&self) -> bool { + self.virt } fn get_column_number(&self, line_index: usize, start: usize) -> usize { @@ -61,12 +81,12 @@ impl DebugSymRef { self.get_column_number(line_idx, self.start) } - fn get_line_end(&self) -> usize { + pub fn get_line_end(&self) -> usize { let line_idx = self.get_line_index(self.end); self.get_line_number(line_idx) } - fn get_col_end(&self) -> usize { + pub fn get_col_end(&self) -> usize { let line_idx = self.get_line_index(self.end); self.get_column_number(line_idx, self.end) } @@ -79,7 +99,11 @@ impl DebugSymRef { /// `filename`. The proximity score is calculated as the size of the symbol. /// If the offset is not within the symbol, returns `None`. pub fn proximity_score(&self, filename: &String, offset: usize) -> Option { - if self.get_filename() == *filename && self.start <= offset && offset <= self.end { + if !self.is_virtual() + && self.get_filename() == *filename + && self.start <= offset + && offset <= self.end + { Some(self.end - self.start) } else { None @@ -91,35 +115,21 @@ impl Debug for DebugSymRef { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if !self.file.name().is_empty() { // Produces clickable output in the terminal - return write!( + write!( f, "{}:{}:{}", self.file.name(), self.get_line_start(), self.get_col_start() - ); - } - - let mut debug_print = f.debug_struct("DebugSymRef"); - - if self.get_line_start() == self.get_line_end() { - debug_print.field("line", &self.get_line_start()).field( - "cols", - &format!("{}-{}", self.get_col_start(), self.get_col_end()), - ); + ) } else { - debug_print - .field( - "start", - &format!("{}:{}", self.get_line_start(), self.get_col_start()), - ) - .field( - "end", - &format!("{}:{}", self.get_line_end(), self.get_col_end()), - ); + write!( + f, + "nofile:{}:{}", + self.get_line_start(), + self.get_col_start() + ) } - - debug_print.finish() } } @@ -142,6 +152,10 @@ impl Identifier { pub(crate) fn debug_sym_ref(&self) -> DebugSymRef { self.2.clone() } + + pub(crate) fn next(&self) -> Self { + Self(self.0.clone(), self.1 + 1, self.2.clone()) + } } impl Debug for Identifier { @@ -213,6 +227,7 @@ mod test { start: 0, end: 1, file: Arc::new(SimpleFile::new("file_path".to_string(), "".to_string())), + virt: false, }; let result = Identifier::new("abc", debug_sym_ref.clone()); @@ -238,6 +253,7 @@ mod test { start: 10, end: 12, file: Arc::new(SimpleFile::new(file_path.clone(), "".to_string())), + virt: false, }; assert_eq!(debug_sym_ref.proximity_score(&file_path, 9), None); diff --git a/src/parser/ast/tl.rs b/src/parser/ast/tl.rs index 91f4b043..4bfa4e68 100644 --- a/src/parser/ast/tl.rs +++ b/src/parser/ast/tl.rs @@ -2,6 +2,20 @@ use std::fmt::Debug; use super::{statement::Statement, DebugSymRef, Identifiable, Identifier}; +pub struct AST(Vec>); + +impl AST { + pub fn machines_iter_mut(&mut self) -> std::slice::IterMut> { + self.0.iter_mut() + } + + pub fn find_machine(&self, id_machine: V) -> Option<&TLDecl> { + self.0.iter().find(|tldecl| match tldecl { + TLDecl::MachineDecl { id, .. } => *id == id_machine, + }) + } +} + #[derive(Clone)] pub enum TLDecl { MachineDecl { diff --git a/src/plonkish/compiler/mod.rs b/src/plonkish/compiler/mod.rs index cacb59ec..50a14d05 100644 --- a/src/plonkish/compiler/mod.rs +++ b/src/plonkish/compiler/mod.rs @@ -84,6 +84,7 @@ pub fn compile_phase1< panic!("Cannot calculate the number of rows"); } unit.num_rows = unit.num_steps * (unit.placement.first_step_height() as usize); + unit.additional_rows = unit.placement.first_step_height() as usize; compile_fixed(ast, &mut unit); @@ -100,7 +101,7 @@ pub fn compile_phase1< unit.selector.clone(), (*v).clone(), AutoTraceGenerator::from(ast), - unit.num_rows, + unit.num_rows + unit.additional_rows, unit.uuid, ) }); diff --git a/src/plonkish/compiler/unit.rs b/src/plonkish/compiler/unit.rs index 0591f774..cbd27888 100644 --- a/src/plonkish/compiler/unit.rs +++ b/src/plonkish/compiler/unit.rs @@ -38,6 +38,8 @@ pub struct CompilationUnit { pub last_step: Option<(Option, Column)>, pub num_rows: usize, + /// Additional rows for the last padding step instance that doesn't have the q_enable = 1 + pub additional_rows: usize, pub polys: Vec>, pub lookups: Vec>, @@ -74,6 +76,7 @@ impl Default for CompilationUnit { last_step: Default::default(), num_rows: Default::default(), + additional_rows: Default::default(), polys: Default::default(), lookups: Default::default(),