From 20f12fd8932f4ec62d925e465bf1f000164e26ec Mon Sep 17 00:00:00 2001 From: Alex Kuzmin Date: Thu, 30 May 2024 14:21:51 +0800 Subject: [PATCH] Introduce usage tracking in the symbol table --- src/compiler/semantic/analyser.rs | 62 +++++++++++++- src/compiler/semantic/mod.rs | 68 +++++++++++++++- src/compiler/semantic/rules.rs | 4 +- src/parser/ast/debug_sym_factory.rs | 51 ++---------- src/parser/ast/mod.rs | 120 ++++++++++++++++++++-------- test/circuit.chiquito | 2 +- test/circuit_error.chiquito | 2 +- 7 files changed, 225 insertions(+), 84 deletions(-) diff --git a/src/compiler/semantic/analyser.rs b/src/compiler/semantic/analyser.rs index 319ae91b..f00b518a 100644 --- a/src/compiler/semantic/analyser.rs +++ b/src/compiler/semantic/analyser.rs @@ -62,7 +62,9 @@ impl Analyser { block, } => { let sym = SymTableEntry { + id: id.name(), definition_ref: dsym, + usages: vec![], category: SymbolCategory::Machine, ty: None, }; @@ -99,7 +101,9 @@ impl Analyser { params.iter().for_each(|param| match param { Statement::SignalDecl(dsym, ids) => ids.iter().for_each(|id| { let sym = SymTableEntry { + id: id.id.name(), definition_ref: dsym.clone(), + usages: vec![], category: SymbolCategory::InputSignal, ty: id.ty.clone().map(|ty| ty.name()), }; @@ -110,7 +114,9 @@ impl Analyser { }), Statement::WGVarDecl(dsym, ids) => ids.iter().for_each(|id| { let sym = SymTableEntry { + id: id.id.name(), definition_ref: dsym.clone(), + usages: vec![], category: SymbolCategory::InputWGVar, ty: id.ty.clone().map(|ty| ty.name()), }; @@ -127,7 +133,9 @@ impl Analyser { params.iter().for_each(|param| match param { Statement::SignalDecl(dsym, ids) => ids.iter().for_each(|id| { let sym = SymTableEntry { + id: id.id.name(), definition_ref: dsym.clone(), + usages: vec![], category: SymbolCategory::OutputSignal, ty: id.ty.clone().map(|ty| ty.name()), }; @@ -139,7 +147,9 @@ impl Analyser { }), Statement::WGVarDecl(dsym, ids) => ids.iter().for_each(|id| { let sym = SymTableEntry { + id: id.id.name(), definition_ref: dsym.clone(), + usages: vec![], category: SymbolCategory::OutputWGVar, ty: id.ty.clone().map(|ty| ty.name()), }; @@ -162,7 +172,9 @@ impl Analyser { stmts.iter().for_each(|stmt| { if let Statement::StateDecl(dsym, id, _) = stmt { let sym = SymTableEntry { + id: id.name(), definition_ref: dsym.clone(), + usages: vec![], category: SymbolCategory::State, ty: None, }; @@ -189,7 +201,9 @@ impl Analyser { ); let sym = SymTableEntry { + id: id.name(), definition_ref: block.get_dsym(), + usages: vec![], category: SymbolCategory::State, ty: None, }; @@ -220,7 +234,9 @@ impl Analyser { match stmt.clone() { Statement::SignalDecl(dsym, ids) => ids.into_iter().for_each(|id| { let sym = SymTableEntry { + id: id.id.name(), category: SymbolCategory::Signal, + usages: vec![], definition_ref: dsym.clone(), ty: id.ty.map(|ty| ty.name()), }; @@ -231,7 +247,9 @@ impl Analyser { }), Statement::WGVarDecl(dsym, ids) => ids.into_iter().for_each(|id| { let sym = SymTableEntry { + id: id.id.name(), category: SymbolCategory::WGVar, + usages: vec![], definition_ref: dsym.clone(), ty: id.ty.map(|ty| ty.name()), }; @@ -243,6 +261,17 @@ impl Analyser { // State decl symbols are added in // add_state_decls Statement::StateDecl(_, _, _) => {} + Statement::Transition(dsym_ref, id, _) => { + // Find the corresponding symbol and add usage + if let Some(entry) = self.symbols.find_symbol(&self.cur_scope, id.name()) { + // TODO implement find by id AND category? + if entry.symbol.category == SymbolCategory::State { + let mut entry = entry.symbol.clone(); + entry.usages.push(dsym_ref); + self.symbols.add_symbol(&self.cur_scope, id.name(), entry); + } + } + } _ => {} } } @@ -282,9 +311,40 @@ impl Analyser { } fn analyse_expression(&mut self, expr: Expression) { + self.extract_usages_recursively(&expr); RULES.apply_expression(self, &expr) } + fn extract_usages_recursively(&mut self, expr: &Expression) { + match expr.clone() { + Expression::Query(dsym_ref, id) => { + // Find the corresponding symbol and add usage + if let Some(entry) = self.symbols.find_symbol(&self.cur_scope, id.name()) { + let mut entry = entry.symbol.clone(); + entry.usages.push(dsym_ref); + self.symbols.add_symbol(&self.cur_scope, id.name(), entry); + } + } + Expression::BinOp { + dsym: _, + op: _, + lhs, + rhs, + } => { + self.extract_usages_recursively(&lhs); + self.extract_usages_recursively(&rhs); + } + Expression::UnaryOp { + dsym: _, + op: _, + sub, + } => { + self.extract_usages_recursively(&sub); + } + _ => {} + } + } + pub(super) fn error>(&mut self, msg: S, dsym: &DebugSymRef) { self.messages.push(Message::SemErr { msg: msg.into(), @@ -360,7 +420,7 @@ mod test { assert_eq!( format!("{:?}", result), - r#"AnalysisResult { symbols: "/": ScopeTable { symbols: "\"fibo\": SymTableEntry { definition_ref: DebugSymRef { start: \"2:9\", end: \"40:13\" }, category: Machine, ty: None }", scope: Global },"//fibo": ScopeTable { symbols: "\"a\": SymTableEntry { definition_ref: DebugSymRef { line: 5, cols: \"13-32\" }, category: Signal, ty: Some(\"field\") },\"b\": SymTableEntry { definition_ref: DebugSymRef { line: 2, cols: \"33-48\" }, category: OutputSignal, ty: Some(\"field\") },\"final\": SymTableEntry { definition_ref: DebugSymRef { start: \"2:50\", end: \"40:13\" }, category: State, ty: None },\"i\": SymTableEntry { definition_ref: DebugSymRef { line: 5, cols: \"13-32\" }, category: Signal, ty: None },\"initial\": SymTableEntry { definition_ref: DebugSymRef { start: \"10:13\", end: \"18:14\" }, category: State, ty: None },\"middle\": SymTableEntry { definition_ref: DebugSymRef { start: \"20:13\", end: \"34:14\" }, category: State, ty: None },\"n\": SymTableEntry { definition_ref: DebugSymRef { line: 2, cols: \"22-30\" }, category: InputSignal, ty: None }", scope: Machine },"//fibo/final": ScopeTable { symbols: "", scope: State },"//fibo/initial": ScopeTable { symbols: "\"c\": SymTableEntry { definition_ref: DebugSymRef { line: 11, cols: \"14-23\" }, category: Signal, ty: None }", scope: State },"//fibo/middle": ScopeTable { symbols: "\"c\": SymTableEntry { definition_ref: DebugSymRef { line: 21, cols: \"14-23\" }, category: Signal, ty: None }", scope: State }, messages: [] }"# + r#"AnalysisResult { symbols: "/": ScopeTable { symbols: "\\"fibo\\": SymTableEntry { id: \\"fibo\\", definition_ref: DebugSymRef { start: \\"2:9\\", end: \\"40:13\\" }, usages: [], category: Machine, ty: None }", scope: Global },"//fibo": ScopeTable { symbols: "\\"a\\": SymTableEntry { id: \\"a\\", definition_ref: DebugSymRef { line: 5, cols: \\"13-32\\" }, usages: [], category: Signal, ty: Some(\\"field\\") },\\"b\\": SymTableEntry { id: \\"b\\", definition_ref: DebugSymRef { line: 2, cols: \\"33-48\\" }, usages: [], category: OutputSignal, ty: Some(\\"field\\") },\\"final\\": SymTableEntry { id: \\"final\\", definition_ref: DebugSymRef { start: \\"2:50\\", end: \\"40:13\\" }, usages: [], category: State, ty: None },\\"i\\": SymTableEntry { id: \\"i\\", definition_ref: DebugSymRef { line: 5, cols: \\"13-32\\" }, usages: [], category: Signal, ty: None },\\"initial\\": SymTableEntry { id: \\"initial\\", definition_ref: DebugSymRef { start: \\"10:13\\", end: \\"18:14\\" }, usages: [], category: State, ty: None },\\"middle\\": SymTableEntry { id: \\"middle\\", definition_ref: DebugSymRef { start: \\"20:13\\", end: \\"34:14\\" }, usages: [], category: State, ty: None },\\"n\\": SymTableEntry { id: \\"n\\", definition_ref: DebugSymRef { line: 2, cols: \\"22-30\\" }, usages: [], category: InputSignal, ty: None }", scope: Machine },"//fibo/final": ScopeTable { symbols: "", scope: State },"//fibo/initial": ScopeTable { symbols: "\\"b\\": SymTableEntry { id: \\"b\\", definition_ref: DebugSymRef { line: 2, cols: \\"33-48\\" }, usages: [DebugSymRef { line: 16, cols: \\"30-31\\" }], category: OutputSignal, ty: Some(\\"field\\") },\\"c\\": SymTableEntry { id: \\"c\\", definition_ref: DebugSymRef { line: 11, cols: \\"14-23\\" }, usages: [DebugSymRef { line: 16, cols: \\"33-34\\" }], category: Signal, ty: None },\\"middle\\": SymTableEntry { id: \\"middle\\", definition_ref: DebugSymRef { start: \\"20:13\\", end: \\"34:14\\" }, usages: [DebugSymRef { start: \\"15:14\\", end: \\"17:15\\" }], category: State, ty: None },\\"n\\": SymTableEntry { id: \\"n\\", definition_ref: DebugSymRef { line: 2, cols: \\"22-30\\" }, usages: [DebugSymRef { line: 16, cols: \\"36-37\\" }], category: InputSignal, ty: None }", scope: State },"//fibo/initial/middle": ScopeTable { symbols: "", scope: State },"//fibo/middle": ScopeTable { symbols: "\\"a\\": SymTableEntry { id: \\"a\\", definition_ref: DebugSymRef { line: 5, cols: \\"13-32\\" }, usages: [DebugSymRef { line: 23, cols: \\"20-21\\" }], category: Signal, ty: Some(\\"field\\") },\\"b\\": SymTableEntry { id: \\"b\\", definition_ref: DebugSymRef { line: 2, cols: \\"33-48\\" }, usages: [DebugSymRef { line: 23, cols: \\"24-25\\" }, DebugSymRef { line: 31, cols: \\"42-43\\" }], category: OutputSignal, ty: Some(\\"field\\") },\\"c\\": SymTableEntry { id: \\"c\\", definition_ref: DebugSymRef { line: 21, cols: \\"14-23\\" }, usages: [DebugSymRef { line: 27, cols: \\"38-39\\" }, DebugSymRef { line: 31, cols: \\"45-46\\" }], category: Signal, ty: None },\\"final\\": SymTableEntry { id: \\"final\\", definition_ref: DebugSymRef { start: \\"2:50\\", end: \\"40:13\\" }, usages: [DebugSymRef { start: \\"26:15\\", end: \\"28:16\\" }], category: State, ty: None },\\"i\\": SymTableEntry { id: \\"i\\", definition_ref: DebugSymRef { line: 5, cols: \\"13-32\\" }, usages: [DebugSymRef { line: 25, cols: \\"17-18\\" }, DebugSymRef { line: 27, cols: \\"31-32\\" }, DebugSymRef { line: 31, cols: \\"35-36\\" }], category: Signal, ty: None },\\"middle\\": SymTableEntry { id: \\"middle\\", definition_ref: DebugSymRef { start: \\"20:13\\", end: \\"34:14\\" }, usages: [DebugSymRef { start: \\"30:15\\", end: \\"32:16\\" }], category: State, ty: None },\\"n\\": SymTableEntry { id: \\"n\\", definition_ref: DebugSymRef { line: 2, cols: \\"22-30\\" }, usages: [DebugSymRef { line: 25, cols: \\"26-27\\" }, DebugSymRef { line: 27, cols: \\"41-42\\" }, DebugSymRef { line: 31, cols: \\"48-49\\" }], category: InputSignal, ty: None }", scope: State },"//fibo/middle/final": ScopeTable { symbols: "", scope: State },"//fibo/middle/middle": ScopeTable { symbols: "", scope: State }, messages: [] }"# ) } } diff --git a/src/compiler/semantic/mod.rs b/src/compiler/semantic/mod.rs index afbecfb7..d53bf9e3 100644 --- a/src/compiler/semantic/mod.rs +++ b/src/compiler/semantic/mod.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashMap, + collections::{BTreeMap, HashMap}, fmt::{Debug, Display}, }; @@ -48,7 +48,9 @@ pub enum ScopeCategory { /// Information about a symbol #[derive(Clone, Debug)] pub struct SymTableEntry { + pub id: String, pub definition_ref: DebugSymRef, + pub usages: Vec, pub category: SymbolCategory, /// Type pub ty: Option, @@ -85,6 +87,7 @@ pub struct FoundSymbol { pub symbol: SymTableEntry, pub scope: ScopeCategory, pub level: usize, + pub usages: Vec, } /// Contains the symbols of an scope @@ -211,6 +214,7 @@ impl SymTable { symbol: symbol.clone(), scope: table.scope.clone(), level, + usages: symbol.usages.clone(), }); } @@ -286,6 +290,68 @@ impl SymTable { .join("/") } } + + pub fn find_symbol_by_offset(&self, filename: String, offset: usize) -> Option { + let mut symbols_by_proximity = BTreeMap::::new(); + + for scope in self.scopes.values() { + for (_, entry) in &scope.symbols { + // If the entry is not in the same file, check its usages + if entry.definition_ref.get_filename() != filename.clone() { + SymTable::look_in_usages( + entry, + filename.clone(), + offset, + &mut symbols_by_proximity, + ); + } else { + let proximity = entry.definition_ref.proximity_score(offset); + // If the current entry is not enclosing the offset, check the usages of that + // entry + if proximity == -1 { + SymTable::look_in_usages( + entry, + filename.clone(), + offset, + &mut symbols_by_proximity, + ); + // If the current entry is enclosing the offset, add it to the map + } else { + symbols_by_proximity.insert(proximity, entry.clone()); + } + } + } + } + + if symbols_by_proximity.is_empty() { + return None; + } else { + // Return the first symbol in the map because BTreeMap is sorted by the key (which is + // the proximity in our case) + return symbols_by_proximity + .iter() + .next() + .map(|(_, entry)| entry.clone()); + } + } + + fn look_in_usages( + entry: &SymTableEntry, + filename: String, + offset: usize, + symbols_by_proximity: &mut BTreeMap, + ) { + for usage in &entry.usages { + if usage.get_filename() != filename { + continue; + } + let usage_proximity = usage.proximity_score(offset); + if usage_proximity != -1 { + symbols_by_proximity.insert(usage_proximity, entry.clone()); + break; + } + } + } } /// Result from running the semantic analyser. diff --git a/src/compiler/semantic/rules.rs b/src/compiler/semantic/rules.rs index 5ddea94c..fc8d6889 100644 --- a/src/compiler/semantic/rules.rs +++ b/src/compiler/semantic/rules.rs @@ -630,7 +630,7 @@ mod test { assert_eq!( format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot declare state nested here", dsym: DebugSymRef { start: 0, end: 0 } }]"# + r#"[SemErr { msg: "Cannot declare state nested here", dsym: DebugSymRef { start: "13:17", end: "15:18" } }]"# ); let circuit = " @@ -687,7 +687,7 @@ mod test { assert_eq!( format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot declare state nested here", dsym: DebugSymRef { start: 0, end: 0 } }]"# + r#"[SemErr { msg: "Cannot declare state nested here", dsym: DebugSymRef { start: "18:1", end: "20:29" } }]"# ); } diff --git a/src/parser/ast/debug_sym_factory.rs b/src/parser/ast/debug_sym_factory.rs index 8c4f6115..cdaad8ea 100644 --- a/src/parser/ast/debug_sym_factory.rs +++ b/src/parser/ast/debug_sym_factory.rs @@ -1,13 +1,13 @@ -use std::rc::Rc; +use std::sync::Arc; -use codespan_reporting::files::{Files, SimpleFile}; +use codespan_reporting::files::SimpleFile; use super::DebugSymRef; /// Factory for creating debug symbol references. pub struct DebugSymRefFactory { /// Source file reference. - file: Rc>, + file: Arc>, } impl DebugSymRefFactory { @@ -22,7 +22,7 @@ impl DebugSymRefFactory { /// /// A new debug symbol reference factory. pub fn new(file_path: &str, contents: &str) -> DebugSymRefFactory { - let file = Rc::new(SimpleFile::new(file_path.to_string(), contents.to_string())); + let file = Arc::new(SimpleFile::new(file_path.to_string(), contents.to_string())); DebugSymRefFactory { file } } @@ -34,47 +34,6 @@ impl DebugSymRefFactory { /// * `start` - Start position of the debug symbol byte reference in the source string. /// * `end` - End position of the debug symbol byte reference in the source string. pub fn create(&self, start: usize, end: usize) -> DebugSymRef { - let start_line_index = self.get_line_index(start); - let start_line_number = self.get_line_number(start_line_index); - let start_col_number = self.get_column_number(start_line_index, start); - - let end_line_index = self.get_line_index(end); - let end_line_number = self.get_line_number(end_line_index); - let end_col_number = self.get_column_number(end_line_index, end); - - DebugSymRef::new( - start_line_number, - start_col_number, - end_line_number, - end_col_number, - Rc::clone(&self.file), - ) - } - - fn get_column_number(&self, line_index: usize, start: usize) -> usize { - match self.file.column_number((), line_index, start) { - Ok(number) => number, - Err(err) => { - panic!("Column number at {} not found: {}", line_index, err); - } - } - } - - fn get_line_index(&self, start: usize) -> usize { - match self.file.line_index((), start) { - Ok(index) => index, - Err(err) => { - panic!("Line index at {} not found: {}", start, err); - } - } - } - - fn get_line_number(&self, line_index: usize) -> usize { - match self.file.line_number((), line_index) { - Ok(number) => number, - Err(err) => { - panic!("Line number at {} not found: {}", line_index, err); - } - } + DebugSymRef::new(start, end, Arc::clone(&self.file)) } } diff --git a/src/parser/ast/mod.rs b/src/parser/ast/mod.rs index 8bceb2af..b7072162 100644 --- a/src/parser/ast/mod.rs +++ b/src/parser/ast/mod.rs @@ -1,7 +1,7 @@ use core::fmt::Debug; -use std::rc::Rc; +use std::sync::Arc; -use codespan_reporting::files::SimpleFile; +use codespan_reporting::files::{Files, SimpleFile}; pub mod debug_sym_factory; pub mod expression; @@ -11,32 +11,81 @@ pub mod tl; /// Debug symbol reference, points to the source file, where a AST node comes from. #[derive(Clone)] pub struct DebugSymRef { - /// Starting line number in the file - pub line_start: usize, - /// Starting column number in the file - pub col_start: usize, - /// Ending line number in the file - pub line_end: usize, - /// Ending column number in the file - pub col_end: usize, + /// Starting byte number in the file + pub start: usize, + /// Ending byte number in the file + pub end: usize, /// Source file reference - file: Rc>, + file: Arc>, } impl DebugSymRef { - pub fn new( - line_start: usize, - col_start: usize, - line_end: usize, - col_end: usize, - file: Rc>, - ) -> DebugSymRef { - DebugSymRef { - line_start, - col_start, - line_end, - col_end, - file, + pub fn new(start: usize, end: usize, file: Arc>) -> DebugSymRef { + DebugSymRef { start, end, file } + } + + fn get_column_number(&self, line_index: usize, start: usize) -> usize { + match self.file.column_number((), line_index, start) { + Ok(number) => number, + Err(err) => { + panic!("Column number at {} not found: {}", line_index, err); + } + } + } + + fn get_line_index(&self, start: usize) -> usize { + match self.file.line_index((), start) { + Ok(index) => index, + Err(err) => { + panic!("Line index at {} not found: {}", start, err); + } + } + } + + fn get_line_number(&self, line_index: usize) -> usize { + match self.file.line_number((), line_index) { + Ok(number) => number, + Err(err) => { + panic!("Line number at {} not found: {}", line_index, err); + } + } + } + + fn get_line_start(&self) -> usize { + let line_idx = self.get_line_index(self.start); + self.get_line_number(line_idx) + } + + fn get_col_start(&self) -> usize { + let line_idx = self.get_line_index(self.start); + self.get_column_number(line_idx, self.start) + } + + fn get_line_end(&self) -> usize { + let line_idx = self.get_line_index(self.end); + self.get_line_number(line_idx) + } + + fn get_col_end(&self) -> usize { + let line_idx = self.get_line_index(self.end); + self.get_column_number(line_idx, self.end) + } + + pub(crate) fn get_filename(&self) -> String { + self.file.name().to_string() + } + + /// Returns the proximity score of the given offset to the debug symbol. + /// The proximity score is the sum of the distance from the start and end of the symbol. + /// If the offset is not within the symbol, -1 is returned. + pub fn proximity_score(&self, offset: usize) -> i32 { + if self.start <= offset && offset <= self.end { + let start_diff = offset as i32 - self.start as i32; + let end_diff = self.end as i32 - offset as i32; + + start_diff + end_diff + } else { + -1 } } } @@ -49,21 +98,28 @@ impl Debug for DebugSymRef { f, "{}:{}:{}", self.file.name(), - self.line_start, - self.col_start + self.get_line_start(), + self.get_col_start() ); } let mut debug_print = f.debug_struct("DebugSymRef"); - if self.line_start == self.line_end { - debug_print - .field("line", &self.line_start) - .field("cols", &format!("{}-{}", self.col_start, self.col_end)); + if self.get_line_start() == self.get_line_end() { + debug_print.field("line", &self.get_line_start()).field( + "cols", + &format!("{}-{}", self.get_col_start(), self.get_col_end()), + ); } else { debug_print - .field("start", &format!("{}:{}", self.line_start, self.col_start)) - .field("end", &format!("{}:{}", self.line_end, self.col_end)); + .field( + "start", + &format!("{}:{}", self.get_line_start(), self.get_col_start()), + ) + .field( + "end", + &format!("{}:{}", self.get_line_end(), self.get_col_end()), + ); } debug_print.finish() diff --git a/test/circuit.chiquito b/test/circuit.chiquito index d7c7dd0c..25a90df9 100644 --- a/test/circuit.chiquito +++ b/test/circuit.chiquito @@ -23,7 +23,7 @@ machine fibo(signal n) (signal b: field) { if i + 1 == n { -> final { - i', b', n' <== i + 1, c, n; + i', b', n' <== i + 1, c, n; } } else { -> middle { diff --git a/test/circuit_error.chiquito b/test/circuit_error.chiquito index 18c8e650..20bb1f3c 100644 --- a/test/circuit_error.chiquito +++ b/test/circuit_error.chiquito @@ -18,7 +18,7 @@ machine fibo(signal n) (signal b: field) { state middle { c <== a + b; - + if i + 1 == n { -> final { i', b', n' <== i + 1, c, n;