diff --git a/Cargo.lock b/Cargo.lock index 0044279e..025c8231 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -850,6 +850,12 @@ dependencies = [ "allocator-api2", ] +[[package]] +name = "hashbrown" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" + [[package]] name = "hashlink" version = "0.8.4" @@ -929,12 +935,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.2" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "824b2ae422412366ba479e8111fd301f7b5faece8149317bb81925979a53f520" +checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown 0.14.3", + "hashbrown 0.15.0", ] [[package]] @@ -1187,6 +1193,39 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "ntest" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb183f0a1da7a937f672e5ee7b7edb727bf52b8a52d531374ba8ebb9345c0330" +dependencies = [ + "ntest_test_cases", + "ntest_timeout", +] + +[[package]] +name = "ntest_test_cases" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16d0d3f2a488592e5368ebbe996e7f1d44aa13156efad201f5b4d84e150eaa93" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ntest_timeout" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc7c92f190c97f79b4a332f5e81dcf68c8420af2045c936c9be0bc9de6f63b5" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "num-bigint-dig" version = "0.8.4" @@ -1319,7 +1358,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" dependencies = [ "fixedbitset", - "indexmap 2.2.2", + "indexmap 2.6.0", ] [[package]] @@ -1507,6 +1546,7 @@ dependencies = [ name = "pg_statement_splitter" version = "0.0.0" dependencies = [ + "ntest", "pg_lexer", "pg_query", "text-size", @@ -1668,6 +1708,15 @@ dependencies = [ "syn 2.0.71", ] +[[package]] +name = "proc-macro-crate" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" +dependencies = [ + "toml_edit", +] + [[package]] name = "proc-macro2" version = "1.0.86" @@ -2163,7 +2212,7 @@ dependencies = [ "futures-util", "hashlink", "hex", - "indexmap 2.2.2", + "indexmap 2.6.0", "log", "memchr", "once_cell", @@ -2452,6 +2501,23 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "toml_datetime" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" + +[[package]] +name = "toml_edit" +version = "0.22.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" +dependencies = [ + "indexmap 2.6.0", + "toml_datetime", + "winnow", +] + [[package]] name = "tracing" version = "0.1.40" @@ -2858,6 +2924,15 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" +[[package]] +name = "winnow" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +dependencies = [ + "memchr", +] + [[package]] name = "write-json" version = "0.1.4" diff --git a/crates/pg_base_db/src/change.rs b/crates/pg_base_db/src/change.rs index 26a926ff..a181eee1 100644 --- a/crates/pg_base_db/src/change.rs +++ b/crates/pg_base_db/src/change.rs @@ -42,9 +42,9 @@ impl ChangedStatement { } } -fn apply_text_change(text: &String, range: Option, change_text: &String) -> String { +fn apply_text_change(text: &str, range: Option, change_text: &str) -> String { if range.is_none() { - return change_text.clone(); + return change_text.to_string(); } let range = range.unwrap(); @@ -53,7 +53,7 @@ fn apply_text_change(text: &String, range: Option, change_text: &Stri let mut new_text = String::new(); new_text.push_str(&text[..start]); - new_text.push_str(&change_text); + new_text.push_str(change_text); new_text.push_str(&text[end..]); new_text @@ -97,7 +97,7 @@ impl Change { self.range.is_some() && self.text.len() < self.range.unwrap().len().into() } - pub fn apply_to_text(&self, text: &String) -> String { + pub fn apply_to_text(&self, text: &str) -> String { if self.range.is_none() { return self.text.clone(); } @@ -122,14 +122,10 @@ impl Change { changed_statements.extend( doc.drain_statements() .into_iter() - .map(|s| StatementChange::Deleted(s)), + .map(StatementChange::Deleted), ); // TODO also use errors returned by extract sql statement ranges - doc.statement_ranges = pg_statement_splitter::split(&self.text) - .ranges - .iter() - .map(|r| r.clone()) - .collect(); + doc.statement_ranges = pg_statement_splitter::split(&self.text).ranges.to_vec(); doc.text = self.text.clone(); doc.line_index = LineIndex::new(&doc.text); @@ -155,7 +151,7 @@ impl Change { changed_statements.push(StatementChange::Modified(ChangedStatement { statement: StatementRef { idx: pos, - text: doc.text[r.clone()].to_string(), + text: doc.text[*r].to_string(), document_url: doc.url.clone(), }, // change must be relative to statement @@ -166,15 +162,9 @@ impl Change { // if addition, expand the range // if deletion, shrink the range if self.is_addition() { - *r = TextRange::new( - r.start(), - r.end() + TextSize::from(self.diff_size()), - ); + *r = TextRange::new(r.start(), r.end() + self.diff_size()); } else if self.is_deletion() { - *r = TextRange::new( - r.start(), - r.end() - TextSize::from(self.diff_size()), - ); + *r = TextRange::new(r.start(), r.end() - self.diff_size()); } } else if self.is_addition() { *r += self.diff_size(); @@ -206,7 +196,7 @@ impl Change { { changed_statements.push(StatementChange::Deleted(StatementRef { idx, - text: doc.text[r.clone()].to_string(), + text: doc.text[*r].to_string(), document_url: doc.url.clone(), })); @@ -344,15 +334,14 @@ mod tests { assert_eq!(d.statement_ranges.len(), 2); for r in &pg_statement_splitter::split(&d.text).ranges { - assert_eq!( - d.statement_ranges.iter().position(|x| r == x).is_some(), - true, + assert!( + d.statement_ranges.iter().any(|x| r == x), "should have stmt with range {:#?}", r ); } - assert_eq!(d.statement_ranges[0], TextRange::new(0.into(), 26.into())); + assert_eq!(d.statement_ranges[0], TextRange::new(0.into(), 25.into())); assert_eq!(d.statement_ranges[1], TextRange::new(26.into(), 35.into())); } @@ -364,8 +353,8 @@ mod tests { assert_eq!(d.statement_ranges.len(), 2); - let stmt_1_range = d.statement_ranges[0].clone(); - let stmt_2_range = d.statement_ranges[1].clone(); + let stmt_1_range = d.statement_ranges[0]; + let stmt_2_range = d.statement_ranges[1]; let update_text = " contacts;"; @@ -522,8 +511,8 @@ mod tests { assert_eq!(d.statement_ranges.len(), 2); - let stmt_1_range = d.statement_ranges[0].clone(); - let stmt_2_range = d.statement_ranges[1].clone(); + let stmt_1_range = d.statement_ranges[0]; + let stmt_2_range = d.statement_ranges[1]; let update_text = ",test"; diff --git a/crates/pg_base_db/src/document.rs b/crates/pg_base_db/src/document.rs index a9838833..a8658cd2 100644 --- a/crates/pg_base_db/src/document.rs +++ b/crates/pg_base_db/src/document.rs @@ -1,4 +1,4 @@ -use std::{hash::Hash, hash::Hasher, ops::RangeBounds, usize}; +use std::{hash::Hash, hash::Hasher, ops::RangeBounds}; use line_index::LineIndex; use text_size::{TextRange, TextSize}; @@ -44,18 +44,11 @@ impl Document { pub fn new(url: PgLspPath, text: Option) -> Document { Document { version: 0, - line_index: LineIndex::new(&text.as_ref().unwrap_or(&"".to_string())), + line_index: LineIndex::new(text.as_ref().unwrap_or(&"".to_string())), // TODO: use errors returned by split - statement_ranges: text.as_ref().map_or_else( - || Vec::new(), - |f| { - pg_statement_splitter::split(&f) - .ranges - .iter() - .map(|range| range.clone()) - .collect() - }, - ), + statement_ranges: text.as_ref().map_or_else(Vec::new, |f| { + pg_statement_splitter::split(f).ranges.to_vec() + }), text: text.unwrap_or("".to_string()), url, } @@ -99,7 +92,7 @@ impl Document { .enumerate() .map(|(idx, range)| StatementRef { document_url: self.url.clone(), - text: self.text[range.clone()].to_string(), + text: self.text[range].to_string(), idx, }) .collect() @@ -112,10 +105,10 @@ impl Document { .enumerate() .map(|(idx, range)| { ( - range.clone(), + *range, StatementRef { document_url: self.url.clone(), - text: self.text[range.clone()].to_string(), + text: self.text[*range].to_string(), idx, }, ) @@ -130,7 +123,7 @@ impl Document { .enumerate() .map(|(idx, range)| StatementRef { document_url: self.url.clone(), - text: self.text[range.clone()].to_string(), + text: self.text[*range].to_string(), idx, }) .collect() @@ -142,7 +135,7 @@ impl Document { .get(pos) .map(|range| StatementRef { document_url: self.url.clone(), - text: self.text[range.clone()].to_string(), + text: self.text[*range].to_string(), idx: pos, }) .unwrap() @@ -154,10 +147,10 @@ impl Document { .get(pos) .map(|range| { ( - range.clone(), + *range, StatementRef { document_url: self.url.clone(), - text: self.text[range.clone()].to_string(), + text: self.text[*range].to_string(), idx: pos, }, ) diff --git a/crates/pg_lexer/src/lib.rs b/crates/pg_lexer/src/lib.rs index ece57fb3..df24f8d8 100644 --- a/crates/pg_lexer/src/lib.rs +++ b/crates/pg_lexer/src/lib.rs @@ -65,7 +65,7 @@ static PATTERN_LEXER: LazyLock = fn whitespace_tokens(input: &str) -> VecDeque { let mut tokens = VecDeque::new(); - for cap in PATTERN_LEXER.captures_iter(&input) { + for cap in PATTERN_LEXER.captures_iter(input) { if let Some(whitespace) = cap.name("whitespace") { tokens.push_back(Token { token_type: TokenType::Whitespace, @@ -139,8 +139,8 @@ pub fn lex(text: &str) -> Vec { kind: SyntaxKind::from(&pg_query_token), text: token_text, span: TextRange::new( - TextSize::try_from(u32::try_from(pg_query_token.start).unwrap()).unwrap(), - TextSize::try_from(u32::try_from(pg_query_token.end).unwrap()).unwrap(), + TextSize::from(u32::try_from(pg_query_token.start).unwrap()), + TextSize::from(u32::try_from(pg_query_token.end).unwrap()), ), }); pos += len; diff --git a/crates/pg_lsp/src/server/debouncer/thread.rs b/crates/pg_lsp/src/server/debouncer/thread.rs index 1aa85939..3d20aaed 100644 --- a/crates/pg_lsp/src/server/debouncer/thread.rs +++ b/crates/pg_lsp/src/server/debouncer/thread.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; use std::thread::{self, JoinHandle}; diff --git a/crates/pg_statement_splitter/Cargo.toml b/crates/pg_statement_splitter/Cargo.toml index 15a30680..af00de08 100644 --- a/crates/pg_statement_splitter/Cargo.toml +++ b/crates/pg_statement_splitter/Cargo.toml @@ -9,4 +9,5 @@ text-size = "1.1.1" [dev-dependencies] pg_query = "0.8" +ntest = "0.9.3" diff --git a/crates/pg_statement_splitter/src/is_at_stmt_start.rs b/crates/pg_statement_splitter/src/is_at_stmt_start.rs deleted file mode 100644 index ec1b83ea..00000000 --- a/crates/pg_statement_splitter/src/is_at_stmt_start.rs +++ /dev/null @@ -1,1015 +0,0 @@ -use std::collections::HashMap; -use std::sync::LazyLock; - -use super::Parser; -use pg_lexer::SyntaxKind; - -pub enum SyntaxToken { - Required(SyntaxKind), - Optional(SyntaxKind), -} - -#[derive(Debug, Clone, Hash)] -pub enum TokenStatement { - // The respective token is the last token of the statement - EoS(SyntaxKind), - Any(SyntaxKind), -} - -impl TokenStatement { - fn is_eos(&self) -> bool { - match self { - TokenStatement::EoS(_) => true, - _ => false, - } - } - - fn kind(&self) -> SyntaxKind { - match self { - TokenStatement::EoS(k) => k.to_owned(), - TokenStatement::Any(k) => k.to_owned(), - } - } -} - -impl PartialEq for TokenStatement { - fn eq(&self, other: &Self) -> bool { - let a = match self { - TokenStatement::EoS(s) => s, - TokenStatement::Any(s) => s, - }; - - let b = match other { - TokenStatement::EoS(s) => s, - TokenStatement::Any(s) => s, - }; - - return a == b; - } -} - -// vector of hashmaps, where each hashmap returns the list of possible statements for a token at -// the respective index. -// -// For example, at idx 0, the hashmap contains a superset of -// ``` -//{ -// Create: [ -// IndexStmt, -// CreateFunctionStmt, -// CreateStmt, -// ViewStmt, -// ], -// Select: [ -// SelectStmt, -// ], -// }, -// ``` -// -// the idea is to trim down the possible options for each token, until only one statement is left. -// -// The vector is lazily constructed out of another vector of tuples, where each tuple contains a -// statement, and a list of `SyntaxToken`s that are to be found at the start of the statement. -pub static STATEMENT_START_TOKEN_MAPS: LazyLock>>> = - LazyLock::new(|| { - let mut m: Vec<(SyntaxKind, &'static [SyntaxToken])> = Vec::new(); - - m.push(( - SyntaxKind::InsertStmt, - &[ - SyntaxToken::Required(SyntaxKind::Insert), - SyntaxToken::Required(SyntaxKind::Into), - ], - )); - - m.push(( - SyntaxKind::DeleteStmt, - &[ - SyntaxToken::Required(SyntaxKind::DeleteP), - SyntaxToken::Required(SyntaxKind::From), - ], - )); - - m.push(( - SyntaxKind::UpdateStmt, - &[SyntaxToken::Required(SyntaxKind::Update)], - )); - - m.push(( - SyntaxKind::MergeStmt, - &[ - SyntaxToken::Required(SyntaxKind::Merge), - SyntaxToken::Required(SyntaxKind::Into), - ], - )); - - m.push(( - SyntaxKind::SelectStmt, - &[SyntaxToken::Required(SyntaxKind::Select)], - )); - - m.push(( - SyntaxKind::AlterTableStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Table), - SyntaxToken::Optional(SyntaxKind::IfP), - SyntaxToken::Optional(SyntaxKind::Exists), - SyntaxToken::Optional(SyntaxKind::Only), - SyntaxToken::Required(SyntaxKind::Ident), - ], - )); - - // ALTER TABLE x RENAME ... is different to e.g. alter table alter column... - m.push(( - SyntaxKind::RenameStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Table), - SyntaxToken::Optional(SyntaxKind::IfP), - SyntaxToken::Optional(SyntaxKind::Exists), - SyntaxToken::Optional(SyntaxKind::Only), - SyntaxToken::Required(SyntaxKind::Ident), - SyntaxToken::Required(SyntaxKind::Rename), - ], - )); - - m.push(( - SyntaxKind::AlterDomainStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::DomainP), - ], - )); - - m.push(( - SyntaxKind::AlterDefaultPrivilegesStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Default), - SyntaxToken::Required(SyntaxKind::Privileges), - ], - )); - - m.push(( - SyntaxKind::ClusterStmt, - &[SyntaxToken::Required(SyntaxKind::Cluster)], - )); - - m.push(( - SyntaxKind::CopyStmt, - &[SyntaxToken::Required(SyntaxKind::Copy)], - )); - - // CREATE [ [ GLOBAL | LOCAL ] { TEMPORARY | TEMP } | UNLOGGED ] TABLE - // this is overly simplified, but it should be good enough for now - m.push(( - SyntaxKind::CreateStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Global), - SyntaxToken::Optional(SyntaxKind::Local), - SyntaxToken::Optional(SyntaxKind::Temporary), - SyntaxToken::Optional(SyntaxKind::Temp), - SyntaxToken::Optional(SyntaxKind::Unlogged), - SyntaxToken::Optional(SyntaxKind::IfP), - SyntaxToken::Optional(SyntaxKind::Not), - SyntaxToken::Optional(SyntaxKind::Exists), - SyntaxToken::Required(SyntaxKind::Table), - SyntaxToken::Required(SyntaxKind::Ident), - ], - )); - - // CREATE [ OR REPLACE ] AGGREGATE - m.push(( - SyntaxKind::DefineStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Or), - SyntaxToken::Optional(SyntaxKind::Replace), - SyntaxToken::Required(SyntaxKind::Aggregate), - ], - )); - - // CREATE OPERATOR - m.push(( - SyntaxKind::DefineStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Operator), - ], - )); - - // CREATE TYPE name - m.push(( - SyntaxKind::DefineStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::TypeP), - SyntaxToken::Required(SyntaxKind::Ident), - ], - )); - - // CREATE TYPE name AS - m.push(( - SyntaxKind::CompositeTypeStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::TypeP), - SyntaxToken::Required(SyntaxKind::Ident), - SyntaxToken::Required(SyntaxKind::As), - ], - )); - - // CREATE TYPE name AS ENUM - m.push(( - SyntaxKind::CreateEnumStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::TypeP), - SyntaxToken::Required(SyntaxKind::Ident), - SyntaxToken::Required(SyntaxKind::As), - SyntaxToken::Required(SyntaxKind::EnumP), - ], - )); - - // CREATE TYPE name AS RANGE - m.push(( - SyntaxKind::CreateRangeStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::TypeP), - SyntaxToken::Required(SyntaxKind::Ident), - SyntaxToken::Required(SyntaxKind::As), - SyntaxToken::Required(SyntaxKind::Range), - ], - )); - - // m.push(( - // SyntaxKind::DropStmt, - // &[ - // SyntaxToken::Required(SyntaxKind::Drop), - // ], - // )); - - m.push(( - SyntaxKind::TruncateStmt, - &[SyntaxToken::Required(SyntaxKind::Truncate)], - )); - - m.push(( - SyntaxKind::CommentStmt, - &[ - SyntaxToken::Required(SyntaxKind::Comment), - SyntaxToken::Required(SyntaxKind::On), - ], - )); - - m.push(( - SyntaxKind::FetchStmt, - &[SyntaxToken::Required(SyntaxKind::Fetch)], - )); - - // CREATE [ UNIQUE ] INDEX - m.push(( - SyntaxKind::IndexStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Unique), - SyntaxToken::Required(SyntaxKind::Index), - ], - )); - - // CREATE [ OR REPLACE ] FUNCTION - m.push(( - SyntaxKind::CreateFunctionStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Or), - SyntaxToken::Optional(SyntaxKind::Replace), - SyntaxToken::Required(SyntaxKind::Function), - ], - )); - - m.push(( - SyntaxKind::AlterFunctionStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Function), - ], - )); - - m.push((SyntaxKind::DoStmt, &[SyntaxToken::Required(SyntaxKind::Do)])); - - // CREATE [ OR REPLACE ] RULE - m.push(( - SyntaxKind::RuleStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Or), - SyntaxToken::Optional(SyntaxKind::Replace), - SyntaxToken::Required(SyntaxKind::Rule), - ], - )); - - m.push(( - SyntaxKind::NotifyStmt, - &[SyntaxToken::Required(SyntaxKind::Notify)], - )); - m.push(( - SyntaxKind::ListenStmt, - &[SyntaxToken::Required(SyntaxKind::Listen)], - )); - m.push(( - SyntaxKind::UnlistenStmt, - &[SyntaxToken::Required(SyntaxKind::Unlisten)], - )); - - // TransactionStmt can be Begin or Commit - m.push(( - SyntaxKind::TransactionStmt, - &[SyntaxToken::Required(SyntaxKind::BeginP)], - )); - m.push(( - SyntaxKind::TransactionStmt, - &[SyntaxToken::Required(SyntaxKind::Commit)], - )); - - // CREATE [ OR REPLACE ] [ TEMP | TEMPORARY ] [ RECURSIVE ] VIEW - // this is overly simplified, but it should be good enough for now - m.push(( - SyntaxKind::ViewStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Or), - SyntaxToken::Optional(SyntaxKind::Replace), - SyntaxToken::Optional(SyntaxKind::Temporary), - SyntaxToken::Optional(SyntaxKind::Temp), - SyntaxToken::Optional(SyntaxKind::Recursive), - SyntaxToken::Required(SyntaxKind::View), - ], - )); - - m.push(( - SyntaxKind::LoadStmt, - &[SyntaxToken::Required(SyntaxKind::Load)], - )); - - m.push(( - SyntaxKind::CreateDomainStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::DomainP), - ], - )); - - m.push(( - SyntaxKind::CreatedbStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Database), - ], - )); - - m.push(( - SyntaxKind::DropdbStmt, - &[ - SyntaxToken::Required(SyntaxKind::Drop), - SyntaxToken::Required(SyntaxKind::Database), - ], - )); - - m.push(( - SyntaxKind::VacuumStmt, - &[SyntaxToken::Required(SyntaxKind::Vacuum)], - )); - - m.push(( - SyntaxKind::ExplainStmt, - &[SyntaxToken::Required(SyntaxKind::Explain)], - )); - - // CREATE [ [ GLOBAL | LOCAL ] { TEMPORARY | TEMP } ] TABLE AS - // this is overly simplified, but it should be good enough for now - m.push(( - SyntaxKind::CreateTableAsStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Global), - SyntaxToken::Optional(SyntaxKind::Local), - SyntaxToken::Optional(SyntaxKind::Temporary), - SyntaxToken::Optional(SyntaxKind::Temp), - SyntaxToken::Required(SyntaxKind::Table), - SyntaxToken::Required(SyntaxKind::Ident), - SyntaxToken::Required(SyntaxKind::As), - ], - )); - - m.push(( - SyntaxKind::CreateSeqStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Temporary), - SyntaxToken::Optional(SyntaxKind::Temp), - SyntaxToken::Optional(SyntaxKind::Unlogged), - SyntaxToken::Required(SyntaxKind::Sequence), - ], - )); - - m.push(( - SyntaxKind::AlterSeqStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Sequence), - ], - )); - - m.push(( - SyntaxKind::VariableSetStmt, - &[SyntaxToken::Required(SyntaxKind::Set)], - )); - - m.push(( - SyntaxKind::VariableShowStmt, - &[SyntaxToken::Required(SyntaxKind::Show)], - )); - - m.push(( - SyntaxKind::DiscardStmt, - &[SyntaxToken::Required(SyntaxKind::Discard)], - )); - - // CREATE [ OR REPLACE ] [ CONSTRAINT ] TRIGGER - m.push(( - SyntaxKind::CreateTrigStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Or), - SyntaxToken::Optional(SyntaxKind::Replace), - SyntaxToken::Optional(SyntaxKind::Constraint), - SyntaxToken::Required(SyntaxKind::Trigger), - ], - )); - - m.push(( - SyntaxKind::CreateRoleStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Role), - ], - )); - - m.push(( - SyntaxKind::AlterRoleStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Role), - ], - )); - - m.push(( - SyntaxKind::DropRoleStmt, - &[ - SyntaxToken::Required(SyntaxKind::Drop), - SyntaxToken::Required(SyntaxKind::Role), - ], - )); - - m.push(( - SyntaxKind::LockStmt, - &[SyntaxToken::Required(SyntaxKind::LockP)], - )); - - m.push(( - SyntaxKind::ConstraintsSetStmt, - &[ - SyntaxToken::Required(SyntaxKind::Set), - SyntaxToken::Required(SyntaxKind::Constraints), - ], - )); - - m.push(( - SyntaxKind::ReindexStmt, - &[SyntaxToken::Required(SyntaxKind::Reindex)], - )); - - m.push(( - SyntaxKind::CheckPointStmt, - &[SyntaxToken::Required(SyntaxKind::Checkpoint)], - )); - - m.push(( - SyntaxKind::CreateSchemaStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Schema), - ], - )); - - m.push(( - SyntaxKind::AlterDatabaseStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Database), - SyntaxToken::Required(SyntaxKind::Ident), - ], - )); - - m.push(( - SyntaxKind::AlterDatabaseRefreshCollStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Database), - SyntaxToken::Required(SyntaxKind::Ident), - SyntaxToken::Required(SyntaxKind::Refresh), - SyntaxToken::Required(SyntaxKind::Collation), - SyntaxToken::Required(SyntaxKind::VersionP), - ], - )); - - m.push(( - SyntaxKind::AlterDatabaseSetStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Database), - SyntaxToken::Required(SyntaxKind::Ident), - SyntaxToken::Required(SyntaxKind::Set), - ], - )); - - m.push(( - SyntaxKind::AlterDatabaseSetStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Database), - SyntaxToken::Required(SyntaxKind::Ident), - SyntaxToken::Required(SyntaxKind::Reset), - ], - )); - - m.push(( - SyntaxKind::CreateConversionStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Default), - SyntaxToken::Required(SyntaxKind::ConversionP), - ], - )); - - m.push(( - SyntaxKind::CreateCastStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Cast), - ], - )); - - m.push(( - SyntaxKind::CreateOpClassStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Operator), - SyntaxToken::Required(SyntaxKind::Class), - ], - )); - - m.push(( - SyntaxKind::CreateOpFamilyStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Operator), - SyntaxToken::Required(SyntaxKind::Family), - ], - )); - - m.push(( - SyntaxKind::AlterOpFamilyStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Operator), - SyntaxToken::Required(SyntaxKind::Family), - ], - )); - - m.push(( - SyntaxKind::PrepareStmt, - &[SyntaxToken::Required(SyntaxKind::Prepare)], - )); - - // m.push(( - // SyntaxKind::ExecuteStmt, - // &[SyntaxToken::Required(SyntaxKind::Execute)], - // )); - - m.push(( - SyntaxKind::DeallocateStmt, - &[SyntaxToken::Required(SyntaxKind::Deallocate)], - )); - - m.push(( - SyntaxKind::CreateTableSpaceStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Tablespace), - ], - )); - - m.push(( - SyntaxKind::DropTableSpaceStmt, - &[ - SyntaxToken::Required(SyntaxKind::Drop), - SyntaxToken::Required(SyntaxKind::Tablespace), - ], - )); - - m.push(( - SyntaxKind::AlterOperatorStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Operator), - ], - )); - - m.push(( - SyntaxKind::AlterTypeStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::TypeP), - ], - )); - - m.push(( - SyntaxKind::DropOwnedStmt, - &[ - SyntaxToken::Required(SyntaxKind::Drop), - SyntaxToken::Required(SyntaxKind::Owned), - SyntaxToken::Required(SyntaxKind::By), - ], - )); - - m.push(( - SyntaxKind::ReassignOwnedStmt, - &[ - SyntaxToken::Required(SyntaxKind::Reassign), - SyntaxToken::Required(SyntaxKind::Owned), - SyntaxToken::Required(SyntaxKind::By), - ], - )); - - m.push(( - SyntaxKind::CreateFdwStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Foreign), - SyntaxToken::Required(SyntaxKind::DataP), - SyntaxToken::Required(SyntaxKind::Wrapper), - ], - )); - - m.push(( - SyntaxKind::AlterFdwStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Foreign), - SyntaxToken::Required(SyntaxKind::DataP), - SyntaxToken::Required(SyntaxKind::Wrapper), - ], - )); - - m.push(( - SyntaxKind::CreateForeignServerStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Server), - ], - )); - - m.push(( - SyntaxKind::AlterForeignServerStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Server), - ], - )); - - m.push(( - SyntaxKind::CreateUserMappingStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::User), - SyntaxToken::Required(SyntaxKind::Mapping), - ], - )); - - m.push(( - SyntaxKind::AlterUserMappingStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::User), - SyntaxToken::Required(SyntaxKind::Mapping), - SyntaxToken::Required(SyntaxKind::For), - ], - )); - - m.push(( - SyntaxKind::DropUserMappingStmt, - &[ - SyntaxToken::Required(SyntaxKind::Drop), - SyntaxToken::Required(SyntaxKind::User), - SyntaxToken::Required(SyntaxKind::Mapping), - ], - )); - - m.push(( - SyntaxKind::SecLabelStmt, - &[ - SyntaxToken::Required(SyntaxKind::Security), - SyntaxToken::Required(SyntaxKind::Label), - ], - )); - - m.push(( - SyntaxKind::CreateForeignTableStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Foreign), - SyntaxToken::Required(SyntaxKind::Table), - ], - )); - - m.push(( - SyntaxKind::ImportForeignSchemaStmt, - &[ - SyntaxToken::Required(SyntaxKind::ImportP), - SyntaxToken::Required(SyntaxKind::Foreign), - SyntaxToken::Required(SyntaxKind::Schema), - ], - )); - - m.push(( - SyntaxKind::CreateExtensionStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Extension), - ], - )); - - m.push(( - SyntaxKind::AlterExtensionStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Extension), - ], - )); - - m.push(( - SyntaxKind::CreateEventTrigStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Event), - SyntaxToken::Required(SyntaxKind::Trigger), - ], - )); - - m.push(( - SyntaxKind::AlterEventTrigStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Event), - SyntaxToken::Required(SyntaxKind::Trigger), - ], - )); - - m.push(( - SyntaxKind::RefreshMatViewStmt, - &[ - SyntaxToken::Required(SyntaxKind::Refresh), - SyntaxToken::Required(SyntaxKind::Materialized), - SyntaxToken::Required(SyntaxKind::View), - ], - )); - - m.push(( - SyntaxKind::AlterSystemStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::SystemP), - ], - )); - - m.push(( - SyntaxKind::CreatePolicyStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Policy), - ], - )); - - m.push(( - SyntaxKind::AlterPolicyStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Policy), - ], - )); - - m.push(( - SyntaxKind::CreateTransformStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Or), - SyntaxToken::Optional(SyntaxKind::Replace), - SyntaxToken::Required(SyntaxKind::Transform), - SyntaxToken::Required(SyntaxKind::For), - ], - )); - - m.push(( - SyntaxKind::CreateAmStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Access), - SyntaxToken::Required(SyntaxKind::Method), - ], - )); - - m.push(( - SyntaxKind::CreatePublicationStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Publication), - ], - )); - - m.push(( - SyntaxKind::AlterPublicationStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Publication), - ], - )); - - m.push(( - SyntaxKind::CreateSubscriptionStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Subscription), - ], - )); - - m.push(( - SyntaxKind::AlterSubscriptionStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Subscription), - ], - )); - - m.push(( - SyntaxKind::DropSubscriptionStmt, - &[ - SyntaxToken::Required(SyntaxKind::Drop), - SyntaxToken::Required(SyntaxKind::Subscription), - ], - )); - - m.push(( - SyntaxKind::CreateStatsStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Statistics), - ], - )); - - m.push(( - SyntaxKind::AlterCollationStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Collation), - ], - )); - - m.push(( - SyntaxKind::CallStmt, - &[SyntaxToken::Required(SyntaxKind::Call)], - )); - - m.push(( - SyntaxKind::AlterStatsStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Statistics), - ], - )); - - let mut vec: Vec>> = Vec::new(); - - m.iter().for_each(|(statement, tokens)| { - let mut left_pull: usize = 0; - tokens.iter().enumerate().for_each(|(idx, token)| { - if vec.len() <= idx { - vec.push(HashMap::new()); - } - - let is_last = idx == tokens.len() - 1; - - match token { - SyntaxToken::Required(t) => { - for i in (idx - left_pull)..(idx + 1) { - let list_entry = vec[i].entry(t.to_owned()); - list_entry - .and_modify(|list| { - list.push(if is_last { - TokenStatement::EoS(statement.to_owned()) - } else { - TokenStatement::Any(statement.to_owned()) - }); - }) - .or_insert(vec![if is_last { - TokenStatement::EoS(statement.to_owned()) - } else { - TokenStatement::Any(statement.to_owned()) - }]); - } - } - SyntaxToken::Optional(t) => { - if is_last { - panic!("Optional token cannot be last token"); - } - for i in (idx - left_pull)..(idx + 1) { - let list_entry = vec[i].entry(t.to_owned()); - list_entry - .and_modify(|list| { - list.push(TokenStatement::Any(statement.to_owned())); - }) - .or_insert(vec![TokenStatement::Any(statement.to_owned())]); - } - left_pull += 1; - } - } - }); - }); - - vec - }); - -// TODO: complete the hashmap above with all statements: -// RETURN statement (inside SQL function body) -// ReturnStmt, -// SetOperationStmt, -// -// TODO: parsing ambiguity, check docs for solution -// GrantStmt(super::GrantStmt), -// GrantRoleStmt(super::GrantRoleStmt), -// ClosePortalStmt, -// CreatePlangStmt, -// AlterRoleSetStmt, -// DeclareCursorStmt, -// AlterObjectDependsStmt, -// AlterObjectSchemaStmt, -// AlterOwnerStmt, -// AlterEnumStmt, -// AlterTsdictionaryStmt, -// AlterTsconfigurationStmt, -// AlterTableSpaceOptionsStmt, -// AlterTableMoveAllStmt, -// AlterExtensionContentsStmt, -// ReplicaIdentityStmt, -// - -/// Returns the statement at which the parser is currently at, if any -pub fn is_at_stmt_start(parser: &mut Parser) -> Option { - let mut options = Vec::new(); - for i in 0..STATEMENT_START_TOKEN_MAPS.len() { - // important, else infinite loop: only ignore whitespaces after first token - let token = parser.nth(i, i != 0).kind; - if let Some(result) = STATEMENT_START_TOKEN_MAPS[i].get(&token) { - if i == 0 { - options = result.clone(); - } else { - options = result - .iter() - .filter(|o| options.contains(o)) - .cloned() - .collect(); - } - } else if options.len() > 1 { - // no result is found, and there is currently more than one option - // filter the options for all statements that are complete at this point - options.retain(|o| o.is_eos()); - } - - if options.len() == 0 { - break; - } else if options.len() == 1 && options.get(0).unwrap().is_eos() { - break; - } - } - if options.len() == 0 { - None - } else if options.len() == 1 && options.get(0).unwrap().is_eos() { - Some(options.get(0).unwrap().kind()) - } else { - panic!("Ambiguous statement"); - } -} diff --git a/crates/pg_statement_splitter/src/lib.rs b/crates/pg_statement_splitter/src/lib.rs index adaea475..ab4bafa8 100644 --- a/crates/pg_statement_splitter/src/lib.rs +++ b/crates/pg_statement_splitter/src/lib.rs @@ -1,137 +1,203 @@ -///! Postgres Statement Splitter -///! -///! This crate provides a function to split a SQL source string into individual statements. -///! -///! TODO: -///! Instead of relying on statement start tokens, we need to include as many tokens as -///! possible. For example, a `CREATE TRIGGER` statement includes an `EXECUTE [ PROCEDURE | -///! FUNCTION ]` clause, but `EXECUTE` is also a statement start token for an `EXECUTE` statement. -/// We should expand the definition map to include an `Any*`, which must be followed by at least -/// one required token and allows the parser to search for the end tokens of the statement. This -/// will hopefully be enough to reduce collisions to zero. -mod is_at_stmt_start; +//! Postgres Statement Splitter +//! +//! This crate provides a function to split a SQL source string into individual statements. mod parser; mod syntax_error; -use is_at_stmt_start::{is_at_stmt_start, TokenStatement, STATEMENT_START_TOKEN_MAPS}; - -use parser::{Parse, Parser}; - -use pg_lexer::{lex, SyntaxKind}; +use parser::{source, Parse, Parser}; pub fn split(sql: &str) -> Parse { - let mut parser = Parser::new(lex(sql)); - - while !parser.eof() { - match is_at_stmt_start(&mut parser) { - Some(stmt) => { - parser.start_stmt(); - - // advance over all start tokens of the statement - for i in 0..STATEMENT_START_TOKEN_MAPS.len() { - parser.eat_whitespace(); - let token = parser.nth(0, false); - if let Some(result) = STATEMENT_START_TOKEN_MAPS[i].get(&token.kind) { - let is_in_results = result - .iter() - .find(|x| match x { - TokenStatement::EoS(y) | TokenStatement::Any(y) => y == &stmt, - }) - .is_some(); - if i == 0 && !is_in_results { - panic!("Expected statement start"); - } else if is_in_results { - parser.expect(token.kind); - } else { - break; - } - } - } - - // move until the end of the statement, or until the next statement start - let mut is_sub_stmt = 0; - let mut is_sub_trx = 0; - let mut ignore_next_non_whitespace = false; - while !parser.at(SyntaxKind::Ascii59) && !parser.eof() { - match parser.nth(0, false).kind { - SyntaxKind::All => { - // ALL is never a statement start, but needs to be skipped when combining queries - // (e.g. UNION ALL) - parser.advance(); - } - SyntaxKind::BeginP => { - // BEGIN, consume until END - is_sub_trx += 1; - parser.advance(); - } - SyntaxKind::EndP => { - is_sub_trx -= 1; - parser.advance(); - } - // opening brackets "(", consume until closing bracket ")" - SyntaxKind::Ascii40 => { - is_sub_stmt += 1; - parser.advance(); - } - SyntaxKind::Ascii41 => { - is_sub_stmt -= 1; - parser.advance(); - } - SyntaxKind::As - | SyntaxKind::Union - | SyntaxKind::Intersect - | SyntaxKind::Except => { - // ignore the next non-whitespace token - ignore_next_non_whitespace = true; - parser.advance(); - } - _ => { - // if another stmt FIRST is encountered, break - // ignore if parsing sub stmt - if ignore_next_non_whitespace == false - && is_sub_stmt == 0 - && is_sub_trx == 0 - && is_at_stmt_start(&mut parser).is_some() - { - break; - } else { - if ignore_next_non_whitespace == true && !parser.at_whitespace() { - ignore_next_non_whitespace = false; - } - parser.advance(); - } - } - } - } - - parser.expect(SyntaxKind::Ascii59); - - parser.close_stmt(); - } - None => { - parser.advance(); - } - } - } + let mut parser = Parser::new(sql); + + source(&mut parser); parser.finish() } #[cfg(test)] mod tests { + use ntest::timeout; + use pg_lexer::SyntaxKind; + use syntax_error::SyntaxError; + use text_size::TextRange; + use super::*; + struct Tester { + input: String, + parse: Parse, + } + + impl From<&str> for Tester { + fn from(input: &str) -> Self { + Tester { + parse: split(input), + input: input.to_string(), + } + } + } + + impl Tester { + fn expect_statements(&self, expected: Vec<&str>) -> &Self { + assert_eq!( + self.parse.ranges.len(), + expected.len(), + "Expected {} statements, got {}: {:?}", + expected.len(), + self.parse.ranges.len(), + self.parse + .ranges + .iter() + .map(|r| &self.input[*r]) + .collect::>() + ); + + for (range, expected) in self.parse.ranges.iter().zip(expected.iter()) { + assert_eq!(*expected, self.input[*range].to_string()); + } + + self + } + + fn expect_errors(&self, expected: Vec) -> &Self { + assert_eq!( + self.parse.errors.len(), + expected.len(), + "Expected {} errors, got {}: {:?}", + expected.len(), + self.parse.errors.len(), + self.parse.errors + ); + + for (err, expected) in self.parse.errors.iter().zip(expected.iter()) { + assert_eq!(expected, err); + } + + self + } + } + #[test] - fn test_splitter() { - let input = "select 1 from contact;\nselect 1;\nalter table test drop column id;"; - - let res = split(input); - assert_eq!(res.ranges.len(), 3); - assert_eq!("select 1 from contact;", input[res.ranges[0]].to_string()); - assert_eq!("select 1;", input[res.ranges[1]].to_string()); - assert_eq!( + #[timeout(1000)] + fn basic() { + Tester::from("select 1 from contact; select 1;") + .expect_statements(vec!["select 1 from contact;", "select 1;"]); + } + + #[test] + fn no_semicolons() { + Tester::from("select 1 from contact\nselect 1") + .expect_statements(vec!["select 1 from contact", "select 1"]); + } + + #[test] + fn double_newlines() { + Tester::from("select 1 from contact\n\nselect 1\n\nselect 3").expect_statements(vec![ + "select 1 from contact", + "select 1", + "select 3", + ]); + } + + #[test] + fn insert_expect_error() { + Tester::from("\ninsert select 1\n\nselect 3") + .expect_statements(vec!["insert select 1", "select 3"]) + .expect_errors(vec![SyntaxError::new( + format!("Expected {:?}", SyntaxKind::Into), + TextRange::new(8.into(), 14.into()), + )]); + } + + #[test] + fn insert_with_select() { + Tester::from("\ninsert into tbl (id) select 1\n\nselect 3") + .expect_statements(vec!["insert into tbl (id) select 1", "select 3"]); + } + + #[test] + fn case() { + Tester::from("select case when select 2 then 1 else 0 end") + .expect_statements(vec!["select case when select 2 then 1 else 0 end"]); + } + + #[test] + #[timeout(1000)] + fn simple_select() { + Tester::from( + " +select id, name, test1231234123, unknown from co; + +select 14433313331333 + +alter table test drop column id; + +select lower('test'); +", + ) + .expect_statements(vec![ + "select id, name, test1231234123, unknown from co;", + "select 14433313331333", "alter table test drop column id;", - input[res.ranges[2]].to_string() + "select lower('test');", + ]); + } + + #[test] + fn create_rule() { + Tester::from( + "create rule log_employee_insert as +on insert to employees +do also insert into employee_log (action, employee_id, log_time) +values ('insert', new.id, now());", + ) + .expect_statements(vec![ + "create rule log_employee_insert as +on insert to employees +do also insert into employee_log (action, employee_id, log_time) +values ('insert', new.id, now());", + ]); + } + + #[test] + fn insert_into() { + Tester::from("randomness\ninsert into tbl (id) values (1)\nselect 3").expect_statements( + vec!["randomness", "insert into tbl (id) values (1)\nselect 3"], + ); + } + + #[test] + fn update() { + Tester::from("more randomness\nupdate tbl set col = '1'\n\nselect 3").expect_statements( + vec!["more randomness", "update tbl set col = '1'", "select 3"], ); } + + #[test] + fn delete_from() { + Tester::from("more randomness\ndelete from test where id = 1\n\nselect 3") + .expect_statements(vec![ + "more randomness", + "delete from test where id = 1", + "select 3", + ]); + } + + #[test] + fn unknown() { + Tester::from("random stuff\n\nmore randomness\n\nselect 3").expect_statements(vec![ + "random stuff", + "more randomness", + "select 3", + ]); + } + + #[test] + fn unknown_2() { + Tester::from("random stuff\nselect 1\n\nselect 3").expect_statements(vec![ + "random stuff", + "select 1", + "select 3", + ]); + } } diff --git a/crates/pg_statement_splitter/src/parser.rs b/crates/pg_statement_splitter/src/parser.rs index 1b3d0f8b..33fcfaf7 100644 --- a/crates/pg_statement_splitter/src/parser.rs +++ b/crates/pg_statement_splitter/src/parser.rs @@ -1,11 +1,17 @@ -use std::cmp::min; +mod common; +mod data; +mod ddl; +mod dml; -use pg_lexer::{SyntaxKind, Token, TokenType, WHITESPACE_TOKENS}; +pub use common::source; + +use pg_lexer::{lex, SyntaxKind, Token, WHITESPACE_TOKENS}; use text_size::{TextRange, TextSize}; use crate::syntax_error::SyntaxError; /// Main parser that exposes the `cstree` api, and collects errors and statements +/// It is modelled after a Pratt Parser. For a gentle introduction to Pratt Parsing, see https://matklad.github.io/2020/04/13/simple-but-powerful-pratt-parsing.html pub struct Parser { /// The ranges of the statements ranges: Vec<(usize, usize)>, @@ -15,12 +21,10 @@ pub struct Parser { current_stmt_start: Option, /// The tokens to parse pub tokens: Vec, - /// The current position in the token stream - pub pos: usize, - /// index from which whitespace tokens are buffered - pub whitespace_token_buffer: Option, eof_token: Token, + + next_pos: usize, } /// Result of Building @@ -33,15 +37,35 @@ pub struct Parse { } impl Parser { - pub fn new(tokens: Vec) -> Self { + pub fn new(sql: &str) -> Self { + let tokens = lex(sql); + + let eof_token = Token::eof(usize::from( + tokens + .last() + .map(|t| t.span.start()) + .unwrap_or(TextSize::from(0)), + )); + + // next_pos should be the initialised with the first valid token already + let mut next_pos = 0; + loop { + let token = tokens.get(next_pos).unwrap_or(&eof_token); + + if is_irrelevant_token(token) { + next_pos += 1; + } else { + break; + } + } + Self { - eof_token: Token::eof(usize::from(tokens.last().unwrap().span.end())), ranges: Vec::new(), + eof_token, errors: Vec::new(), current_stmt_start: None, tokens, - pos: 0, - whitespace_token_buffer: None, + next_pos, } } @@ -52,145 +76,116 @@ impl Parser { .iter() .map(|(start, end)| { let from = self.tokens.get(*start); - let to = self.tokens.get(end - 1); - // get text range from token range - let text_start = from.unwrap().span.start(); - let text_end = to.unwrap().span.end(); - - TextRange::new( - TextSize::try_from(text_start).unwrap(), - TextSize::try_from(text_end).unwrap(), - ) + let to = self.tokens.get(*end).unwrap_or(&self.eof_token); + + TextRange::new(from.unwrap().span.start(), to.span.end()) }) .collect(), errors: self.errors, } } + /// Start statement pub fn start_stmt(&mut self) { assert!(self.current_stmt_start.is_none()); - self.current_stmt_start = Some(self.pos); + self.current_stmt_start = Some(self.next_pos); } + /// Close statement pub fn close_stmt(&mut self) { - assert!(self.current_stmt_start.is_some()); - self.ranges - .push((self.current_stmt_start.take().unwrap(), self.pos)); - } + assert!(self.next_pos > 0); - /// collects an SyntaxError with an `error` message at `pos` - pub fn error_at_pos(&mut self, error: String, pos: usize) { - self.errors.push(SyntaxError::new_at_offset( - error, - self.tokens - .get(min(self.tokens.len() - 1, pos)) - .unwrap() - .span - .start(), - )); - } + // go back the positions until we find the first relevant token + let mut end_token_pos = self.next_pos - 1; + loop { + let token = self.tokens.get(end_token_pos); - /// applies token and advances - pub fn advance(&mut self) { - assert!(!self.eof()); - if self.nth(0, false).kind == SyntaxKind::Whitespace { - if self.whitespace_token_buffer.is_none() { - self.whitespace_token_buffer = Some(self.pos); + if end_token_pos == 0 || token.is_none() { + break; } - } else { - self.flush_token_buffer(); - } - self.pos += 1; - } - /// flush token buffer and applies all tokens - pub fn flush_token_buffer(&mut self) { - if self.whitespace_token_buffer.is_none() { - return; - } - while self.whitespace_token_buffer.unwrap() < self.pos { - self.whitespace_token_buffer = Some(self.whitespace_token_buffer.unwrap() + 1); - } - self.whitespace_token_buffer = None; - } + if !is_irrelevant_token(token.unwrap()) { + break; + } - pub fn eat(&mut self, kind: SyntaxKind) -> bool { - if self.at(kind) { - self.advance(); - true - } else { - false + end_token_pos -= 1; } - } - pub fn at_whitespace(&self) -> bool { - self.nth(0, false).kind == SyntaxKind::Whitespace + self.ranges.push(( + self.current_stmt_start.expect("Expected active statement"), + end_token_pos, + )); + + self.current_stmt_start = None; } - pub fn eat_whitespace(&mut self) { - while self.nth(0, false).token_type == TokenType::Whitespace { - self.advance(); + fn advance(&mut self) -> &Token { + let mut first_relevant_token = None; + loop { + let token = self.tokens.get(self.next_pos).unwrap_or(&self.eof_token); + + // we need to continue with next_pos until the next relevant token after we already + // found the first one + if !is_irrelevant_token(token) { + if let Some(t) = first_relevant_token { + return t; + } + first_relevant_token = Some(token); + } + + self.next_pos += 1; } } - pub fn eof(&self) -> bool { - self.pos == self.tokens.len() + fn peek(&self) -> &Token { + match self.tokens.get(self.next_pos) { + Some(token) => token, + None => &self.eof_token, + } } - /// lookahead method. - /// - /// if `ignore_whitespace` is true, it will skip all whitespace tokens - pub fn nth(&self, lookahead: usize, ignore_whitespace: bool) -> &Token { - if ignore_whitespace { - let mut idx = 0; - let mut non_whitespace_token_ctr = 0; - loop { - match self.tokens.get(self.pos + idx) { - Some(token) => { - if !WHITESPACE_TOKENS.contains(&token.kind) { - if non_whitespace_token_ctr == lookahead { - return token; - } - non_whitespace_token_ctr += 1; - } - idx += 1; - } - None => { - return &self.eof_token; - } - } + fn look_back(&self) -> Option<&Token> { + // we need to look back to the last relevant token + let mut look_back_pos = self.next_pos - 1; + loop { + let token = self.tokens.get(look_back_pos); + + if look_back_pos == 0 || token.is_none() { + return None; } - } else { - match self.tokens.get(self.pos + lookahead) { - Some(token) => token, - None => &self.eof_token, + + if !is_irrelevant_token(token.unwrap()) { + return token; } + + look_back_pos -= 1; } } - /// checks if the current token is of `kind` - pub fn at(&self, kind: SyntaxKind) -> bool { - self.nth(0, false).kind == kind + /// checks if the current token is of `kind` and advances if true + /// returns true if the current token is of `kind` + pub fn eat(&mut self, kind: SyntaxKind) -> bool { + if self.peek().kind == kind { + self.advance(); + true + } else { + false + } } pub fn expect(&mut self, kind: SyntaxKind) { if self.eat(kind) { return; } - if self.whitespace_token_buffer.is_some() { - self.error_at_pos( - format!( - "Expected {:#?}, found {:#?}", - kind, - self.tokens[self.whitespace_token_buffer.unwrap()].kind - ), - self.whitespace_token_buffer.unwrap(), - ); - } else { - self.error_at_pos( - format!("Expected {:#?}, found {:#?}", kind, self.nth(0, false)), - self.pos + 1, - ); - } + + self.errors.push(SyntaxError::new( + format!("Expected {:#?}", kind), + self.peek().span, + )); } } + +fn is_irrelevant_token(t: &Token) -> bool { + return WHITESPACE_TOKENS.contains(&t.kind) + && (t.kind != SyntaxKind::Newline || t.text.chars().count() == 1); +} diff --git a/crates/pg_statement_splitter/src/parser/common.rs b/crates/pg_statement_splitter/src/parser/common.rs new file mode 100644 index 00000000..b723af28 --- /dev/null +++ b/crates/pg_statement_splitter/src/parser/common.rs @@ -0,0 +1,174 @@ +use pg_lexer::{SyntaxKind, Token, TokenType}; + +use super::{ + data::at_statement_start, + ddl::{alter, create}, + dml::{cte, delete, insert, select, update}, + Parser, +}; + +pub fn source(p: &mut Parser) { + loop { + match p.peek() { + Token { + kind: SyntaxKind::Eof, + .. + } => { + break; + } + Token { + // we might want to ignore TokenType::NoKeyword here too + // but this will lead to invalid statements to not being picked up + token_type: TokenType::Whitespace, + .. + } => { + p.advance(); + } + _ => { + statement(p); + } + } + } +} + +pub(crate) fn statement(p: &mut Parser) { + p.start_stmt(); + match p.peek().kind { + SyntaxKind::With => { + cte(p); + } + SyntaxKind::Select => { + select(p); + } + SyntaxKind::Insert => { + insert(p); + } + SyntaxKind::Update => { + update(p); + } + SyntaxKind::DeleteP => { + delete(p); + } + SyntaxKind::Create => { + create(p); + } + SyntaxKind::Alter => { + alter(p); + } + _ => { + unknown(p, &[]); + } + } + p.close_stmt(); +} + +pub(crate) fn parenthesis(p: &mut Parser) { + p.expect(SyntaxKind::Ascii40); + + loop { + match p.peek().kind { + SyntaxKind::Ascii41 | SyntaxKind::Eof => { + p.advance(); + break; + } + _ => { + p.advance(); + } + } + } +} + +pub(crate) fn case(p: &mut Parser) { + p.expect(SyntaxKind::Case); + + loop { + match p.peek().kind { + SyntaxKind::EndP => { + p.advance(); + break; + } + _ => { + p.advance(); + } + } + } +} + +pub(crate) fn unknown(p: &mut Parser, exclude: &[SyntaxKind]) { + loop { + match p.peek() { + Token { + kind: SyntaxKind::Ascii59, + .. + } => { + p.advance(); + break; + } + Token { + kind: SyntaxKind::Newline | SyntaxKind::Eof, + .. + } => { + break; + } + Token { + kind: SyntaxKind::Case, + .. + } => { + case(p); + } + Token { + kind: SyntaxKind::Ascii40, + .. + } => { + parenthesis(p); + } + t => match at_statement_start(t.kind, exclude) { + Some(SyntaxKind::Select) => { + let prev = p.look_back().map(|t| t.kind); + if [ + // for create view / table as + SyntaxKind::As, + // for create rule + SyntaxKind::On, + // for create rule + SyntaxKind::Also, + // for create rule + SyntaxKind::Instead, + ] + .iter() + .all(|x| Some(x) != prev.as_ref()) + { + break; + } + + p.advance(); + } + Some(SyntaxKind::Insert) | Some(SyntaxKind::Update) | Some(SyntaxKind::DeleteP) => { + let prev = p.look_back().map(|t| t.kind); + if [ + // for create trigger + SyntaxKind::After, + // for create rule + SyntaxKind::On, + // for create rule + SyntaxKind::Also, + // for create rule + SyntaxKind::Instead, + ] + .iter() + .all(|x| Some(x) != prev.as_ref()) + { + break; + } + p.advance(); + } + Some(_) => { + break; + } + None => { + p.advance(); + } + }, + } + } +} diff --git a/crates/pg_statement_splitter/src/parser/data.rs b/crates/pg_statement_splitter/src/parser/data.rs new file mode 100644 index 00000000..543896dd --- /dev/null +++ b/crates/pg_statement_splitter/src/parser/data.rs @@ -0,0 +1,22 @@ +use pg_lexer::SyntaxKind; + +// All tokens listed here must be explicitly handled in the `unknown` function to ensure that we do +// not break in the middle of another statement that contains a statement start token. +// +// All of these statements must have a dedicated parser function called from the `statement` function +static STATEMENT_START_TOKENS: &[SyntaxKind] = &[ + SyntaxKind::With, + SyntaxKind::Select, + SyntaxKind::Insert, + SyntaxKind::Update, + SyntaxKind::DeleteP, + SyntaxKind::Create, + SyntaxKind::Alter, +]; + +pub(crate) fn at_statement_start(kind: SyntaxKind, exclude: &[SyntaxKind]) -> Option<&SyntaxKind> { + STATEMENT_START_TOKENS + .iter() + .filter(|&x| !exclude.contains(x)) + .find(|&x| x == &kind) +} diff --git a/crates/pg_statement_splitter/src/parser/ddl.rs b/crates/pg_statement_splitter/src/parser/ddl.rs new file mode 100644 index 00000000..80119b6f --- /dev/null +++ b/crates/pg_statement_splitter/src/parser/ddl.rs @@ -0,0 +1,15 @@ +use pg_lexer::SyntaxKind; + +use super::{common::unknown, Parser}; + +pub(crate) fn create(p: &mut Parser) { + p.expect(SyntaxKind::Create); + + unknown(p, &[]); +} + +pub(crate) fn alter(p: &mut Parser) { + p.expect(SyntaxKind::Alter); + + unknown(p, &[]); +} diff --git a/crates/pg_statement_splitter/src/parser/dml.rs b/crates/pg_statement_splitter/src/parser/dml.rs new file mode 100644 index 00000000..40e59cea --- /dev/null +++ b/crates/pg_statement_splitter/src/parser/dml.rs @@ -0,0 +1,48 @@ +use pg_lexer::SyntaxKind; + +use super::{ + common::{parenthesis, statement, unknown}, + Parser, +}; + +pub(crate) fn cte(p: &mut Parser) { + p.expect(SyntaxKind::With); + + loop { + p.expect(SyntaxKind::Ident); + p.expect(SyntaxKind::As); + parenthesis(p); + + if !p.eat(SyntaxKind::Ascii44) { + break; + } + } + + statement(p); +} + +pub(crate) fn select(p: &mut Parser) { + p.expect(SyntaxKind::Select); + + unknown(p, &[]); +} + +pub(crate) fn insert(p: &mut Parser) { + p.expect(SyntaxKind::Insert); + p.expect(SyntaxKind::Into); + + unknown(p, &[SyntaxKind::Select]); +} + +pub(crate) fn update(p: &mut Parser) { + p.expect(SyntaxKind::Update); + + unknown(p, &[]); +} + +pub(crate) fn delete(p: &mut Parser) { + p.expect(SyntaxKind::DeleteP); + p.expect(SyntaxKind::From); + + unknown(p, &[]); +} diff --git a/crates/pg_statement_splitter/tests/skipped.txt b/crates/pg_statement_splitter/tests/skipped.txt deleted file mode 100644 index 480089b9..00000000 --- a/crates/pg_statement_splitter/tests/skipped.txt +++ /dev/null @@ -1,12 +0,0 @@ -brin -brin_bloom -brin_multi -collate.icu.utf8 -collate.linux.utf8 -collate -copy2 -create_table_like -drop_operator -replica_identity -unicode -xmlmap diff --git a/crates/pg_statement_splitter/tests/statement_splitter_tests.rs b/crates/pg_statement_splitter/tests/statement_splitter_tests.rs index fb639fef..b4ea1de6 100644 --- a/crates/pg_statement_splitter/tests/statement_splitter_tests.rs +++ b/crates/pg_statement_splitter/tests/statement_splitter_tests.rs @@ -1,84 +1,6 @@ use std::fs::{self}; const DATA_DIR_PATH: &str = "tests/data/"; -const POSTGRES_REGRESS_PATH: &str = "../../libpg_query/test/sql/postgres_regress/"; -const SKIPPED_REGRESS_TESTS: &str = include_str!("skipped.txt"); - -#[test] -fn test_postgres_regress() { - // all postgres regress tests are valid and complete statements, so we can use `split_with_parser` and compare with our own splitter - - let mut paths: Vec<_> = fs::read_dir(POSTGRES_REGRESS_PATH) - .unwrap() - .map(|r| r.unwrap()) - .collect(); - paths.sort_by_key(|dir| dir.path()); - - for f in paths.iter() { - let path = f.path(); - - let test_name = path.file_stem().unwrap().to_str().unwrap(); - - // these require fixes in the parser - if SKIPPED_REGRESS_TESTS - .lines() - .collect::>() - .contains(&test_name) - { - continue; - } - - // remove \commands because pg_query doesn't support them - let contents = fs::read_to_string(&path) - .unwrap() - .lines() - .filter(|l| !l.starts_with("\\") && !l.ends_with("\\gset")) - .collect::>() - .join(" "); - - let libpg_query_split = pg_query::split_with_parser(&contents).unwrap(); - - let parser_split = pg_statement_splitter::split(&contents); - - assert_eq!( - parser_split.errors.len(), - 0, - "Unexpected errors when parsing file {}:\n{:#?}", - test_name, - parser_split.errors - ); - - assert_eq!( - libpg_query_split.len(), - parser_split.ranges.len(), - "Mismatch in statement count for file {}: Expected {} statements, got {}", - test_name, - libpg_query_split.len(), - parser_split.ranges.len() - ); - - for (libpg_query_stmt, parser_range) in - libpg_query_split.iter().zip(parser_split.ranges.iter()) - { - let parser_stmt = &contents[parser_range.clone()].trim(); - - let libpg_query_stmt = if libpg_query_stmt.ends_with(';') { - libpg_query_stmt.to_string() - } else { - format!("{};", libpg_query_stmt.trim()) - }; - - let libpg_query_stmt_trimmed = libpg_query_stmt.trim(); - let parser_stmt_trimmed = parser_stmt.trim(); - - assert_eq!( - libpg_query_stmt_trimmed, parser_stmt_trimmed, - "Mismatch in statement {}:\nlibg_query: '{}'\nsplitter: '{}'", - test_name, libpg_query_stmt_trimmed, parser_stmt_trimmed - ); - } - } -} #[test] fn test_statement_splitter() {