From b9cbd1a5d4fe8cffe42174bf232d81fe0df39168 Mon Sep 17 00:00:00 2001 From: wugeer <1284057728@qq.com> Date: Tue, 5 Nov 2024 21:25:28 +0800 Subject: [PATCH] Add support for PostgreSQL `LISTEN/NOTIFY` syntax --- src/ast/mod.rs | 28 +++++++++++++++ src/dialect/mod.rs | 10 ++++++ src/dialect/postgresql.rs | 10 ++++++ src/keywords.rs | 2 ++ src/parser/mod.rs | 19 ++++++++++ tests/sqlparser_common.rs | 76 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 145 insertions(+) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index b2672552e..2ef3a460b 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -3295,6 +3295,23 @@ pub enum Statement { include_final: bool, deduplicate: Option, }, + /// ```sql + /// LISTEN + /// ``` + /// listen for a notification channel + /// + /// See Postgres + LISTEN { channel: Ident }, + /// ```sql + /// NOTIFY channel [ , payload ] + /// ``` + /// send a notification event together with an optional “payload” string to channel + /// + /// See Postgres + NOTIFY { + channel: Ident, + payload: Option, + }, } impl fmt::Display for Statement { @@ -4839,6 +4856,17 @@ impl fmt::Display for Statement { } Ok(()) } + Statement::LISTEN { channel } => { + write!(f, "LISTEN {channel}")?; + Ok(()) + } + Statement::NOTIFY { channel, payload } => { + write!(f, "NOTIFY {channel}")?; + if let Some(payload) = payload { + write!(f, ", '{payload}'")?; + } + Ok(()) + } } } } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 28e7ac7d1..a90d554a4 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -590,6 +590,16 @@ pub trait Dialect: Debug + Any { fn supports_try_convert(&self) -> bool { false } + + /// Returns true if the dialect supports the `LISTEN` statement + fn supports_listen(&self) -> bool { + false + } + + /// Returns true if the dialect supports the `NOTIFY` statement + fn supports_notify(&self) -> bool { + false + } } /// This represents the operators for which precedence must be defined diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index dc458ec5d..c40c826c4 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -191,6 +191,16 @@ impl Dialect for PostgreSqlDialect { fn supports_explain_with_utility_options(&self) -> bool { true } + + /// see + fn supports_listen(&self) -> bool { + true + } + + /// see + fn supports_notify(&self) -> bool { + true + } } pub fn parse_comment(parser: &mut Parser) -> Result { diff --git a/src/keywords.rs b/src/keywords.rs index e98309681..d60227c99 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -438,6 +438,7 @@ define_keywords!( LIKE_REGEX, LIMIT, LINES, + LISTEN, LN, LOAD, LOCAL, @@ -513,6 +514,7 @@ define_keywords!( NOSUPERUSER, NOT, NOTHING, + NOTIFY, NOWAIT, NO_WRITE_TO_BINLOG, NTH_VALUE, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index c4b92ba4e..73ed932fc 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -532,6 +532,10 @@ impl<'a> Parser<'a> { Keyword::EXECUTE => self.parse_execute(), Keyword::PREPARE => self.parse_prepare(), Keyword::MERGE => self.parse_merge(), + // `LISTEN` and `NOTIFY` are Postgres-specific + // syntaxes. They are used for Postgres statement. + Keyword::LISTEN if self.dialect.supports_listen() => self.parse_listen(), + Keyword::NOTIFY if self.dialect.supports_notify() => self.parse_notify(), // `PRAGMA` is sqlite specific https://www.sqlite.org/pragma.html Keyword::PRAGMA => self.parse_pragma(), Keyword::UNLOAD => self.parse_unload(), @@ -946,6 +950,21 @@ impl<'a> Parser<'a> { Ok(Statement::ReleaseSavepoint { name }) } + pub fn parse_listen(&mut self) -> Result { + let channel = self.parse_identifier(false)?; + Ok(Statement::LISTEN { channel }) + } + + pub fn parse_notify(&mut self) -> Result { + let channel = self.parse_identifier(false)?; + let payload = if self.consume_token(&Token::Comma) { + Some(self.parse_literal_string()?) + } else { + None + }; + Ok(Statement::NOTIFY { channel, payload }) + } + /// Parse an expression prefix. pub fn parse_prefix(&mut self) -> Result { // allow the dialect to override prefix parsing diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 4016e5a69..4d97e4aff 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -11399,3 +11399,79 @@ fn test_show_dbs_schemas_tables_views() { verified_stmt("SHOW MATERIALIZED VIEWS FROM db1"); verified_stmt("SHOW MATERIALIZED VIEWS FROM db1 'abc'"); } + +#[test] +fn parse_listen_channel() { + let dialects = all_dialects_where(|d| d.supports_listen()); + + match dialects.verified_stmt("LISTEN test1") { + Statement::LISTEN { channel } => { + assert_eq!(Ident::new("test1"), channel); + } + _ => unreachable!(), + }; + + assert_eq!( + dialects.parse_sql_statements("LISTEN *").unwrap_err(), + ParserError::ParserError("Expected: identifier, found: *".to_string()) + ); + + let dialects = all_dialects_where(|d| !d.supports_listen()); + + assert_eq!( + dialects.parse_sql_statements("LISTEN test1").unwrap_err(), + ParserError::ParserError("Expected: an SQL statement, found: LISTEN".to_string()) + ); +} + +#[test] +fn parse_notify_channel() { + let dialects = all_dialects_where(|d| d.supports_notify()); + + match dialects.verified_stmt("NOTIFY test1") { + Statement::NOTIFY { channel, payload } => { + assert_eq!(Ident::new("test1"), channel); + assert_eq!(payload, None); + } + _ => unreachable!(), + }; + + match dialects.verified_stmt("NOTIFY test1, 'this is a test notification'") { + Statement::NOTIFY { + channel, + payload: Some(payload), + } => { + assert_eq!(Ident::new("test1"), channel); + assert_eq!("this is a test notification", payload); + } + _ => unreachable!(), + }; + + assert_eq!( + dialects.parse_sql_statements("NOTIFY *").unwrap_err(), + ParserError::ParserError("Expected: identifier, found: *".to_string()) + ); + assert_eq!( + dialects + .parse_sql_statements("NOTIFY test1, *") + .unwrap_err(), + ParserError::ParserError("Expected: literal string, found: *".to_string()) + ); + + let sql_statements = [ + "NOTIFY test1", + "NOTIFY test1, 'this is a test notification'", + ]; + let dialects = all_dialects_where(|d| !d.supports_notify()); + + for &sql in &sql_statements { + assert_eq!( + dialects.parse_sql_statements(sql).unwrap_err(), + ParserError::ParserError("Expected: an SQL statement, found: NOTIFY".to_string()) + ); + assert_eq!( + dialects.parse_sql_statements(sql).unwrap_err(), + ParserError::ParserError("Expected: an SQL statement, found: NOTIFY".to_string()) + ); + } +}