Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add #[recursive] #1522

Merged
merged 14 commits into from
Dec 19, 2024
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ path = "src/lib.rs"

[features]
default = ["std"]
std = []
std = ["recursive"]
# Enable JSON output in the `cli` example:
json_example = ["serde_json", "serde"]
visitor = ["sqlparser_derive"]

[dependencies]
bigdecimal = { version = "0.4.1", features = ["serde"], optional = true }
log = "0.4"
recursive = { version = "0.1.1", optional = true}

serde = { version = "1.0", features = ["derive"], optional = true }
# serde_json is only used in examples/cli, but we have to put it outside
# of dev-dependencies because of
Expand Down
3 changes: 3 additions & 0 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ fn derive_visit(input: proc_macro::TokenStream, visit_type: &VisitType) -> proc_

let expanded = quote! {
// The generated impl.
// Note that it uses [`recursive::recursive`] to protect from stack overflow.
// See tests in https://github.com/apache/datafusion-sqlparser-rs/pull/1522/ for more info.
impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause {
#[cfg_attr(feature = "std", recursive::recursive)]
fn visit<V: sqlparser::ast::#visitor_trait>(
&#modifier self,
visitor: &mut V
Expand Down
40 changes: 40 additions & 0 deletions sqlparser_bench/benches/sqlparser_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,46 @@ fn basic_queries(c: &mut Criterion) {
group.bench_function("sqlparser::with_select", |b| {
b.iter(|| Parser::parse_sql(&dialect, with_query));
});

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For large_statement, making separated test would make differentiating potential(future) regression easier.
It's a suggestion(fine to leave it as it is).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understood your comment, sorry. I added tests for parsing large statements in tests/sqlparser_common.rs. Do you think we should test something else?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thinks it would be better to add as separated test.

let large_statement = {
let expressions = (0..1000)
.map(|n| format!("FN_{}(COL_{})", n, n))
.collect::<Vec<_>>()
.join(", ");
let tables = (0..1000)
.map(|n| format!("TABLE_{}", n))
.collect::<Vec<_>>()
.join(" JOIN ");
let where_condition = (0..1000)
.map(|n| format!("COL_{} = {}", n, n))
.collect::<Vec<_>>()
.join(" OR ");
let order_condition = (0..1000)
.map(|n| format!("COL_{} DESC", n))
.collect::<Vec<_>>()
.join(", ");

format!(
"SELECT {} FROM {} WHERE {} ORDER BY {}",
expressions, tables, where_condition, order_condition
)
};

group.bench_function("parse_large_statement", |b| {
b.iter(|| Parser::parse_sql(&dialect, criterion::black_box(large_statement.as_str())));
});

let large_statement = Parser::parse_sql(&dialect, large_statement.as_str())
.unwrap()
.pop()
.unwrap();

group.bench_function("format_large_statement", |b| {
b.iter(|| {
let formatted_query = large_statement.to_string();
assert_eq!(formatted_query, large_statement);
});
});
}

criterion_group!(benches, basic_queries);
Expand Down
1 change: 1 addition & 0 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,7 @@ impl fmt::Display for CastFormat {
}

impl fmt::Display for Expr {
#[cfg_attr(feature = "std", recursive::recursive)]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Expr::Identifier(s) => write!(f, "{s}"),
Expand Down
25 changes: 25 additions & 0 deletions src/ast/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -884,4 +884,29 @@ mod tests {
assert_eq!(actual, expected)
}
}

struct QuickVisitor; // [`TestVisitor`] is too slow to iterate over thousands of nodes

impl Visitor for QuickVisitor {
type Break = ();
}

#[test]
fn overflow() {
let cond = (0..1000)
blaginin marked this conversation as resolved.
Show resolved Hide resolved
.map(|n| format!("X = {}", n))
.collect::<Vec<_>>()
.join(" OR ");
let sql = format!("SELECT x where {0}", cond);

let dialect = GenericDialect {};
let tokens = Tokenizer::new(&dialect, sql.as_str()).tokenize().unwrap();
let s = Parser::new(&dialect)
.with_tokens(tokens)
.parse_statement()
.unwrap();

let mut visitor = QuickVisitor {};
s.visit(&mut visitor);
}
}
6 changes: 6 additions & 0 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ mod recursion {
/// Note: Uses an [`std::rc::Rc`] and [`std::cell::Cell`] in order to satisfy the Rust
/// borrow checker so the automatic [`DepthGuard`] decrement a
/// reference to the counter.
///
/// Note: when "std" feature is enabled, this crate uses additional stack overflow protection
/// for some of its recursive methods. See [`recursive::recursive`] for more information.
pub(crate) struct RecursionCounter {
remaining_depth: Rc<Cell<usize>>,
}
Expand Down Expand Up @@ -326,6 +329,9 @@ impl<'a> Parser<'a> {
/// # Ok(())
/// # }
/// ```
///
/// Note: when "std" feature is enabled, this crate uses additional stack overflow protection
// for some of its recursive methods. See [`recursive::recursive`] for more information.
pub fn with_recursion_limit(mut self, recursion_limit: usize) -> Self {
self.recursion_counter = RecursionCounter::new(recursion_limit);
self
Expand Down
13 changes: 13 additions & 0 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12440,3 +12440,16 @@ fn test_reserved_keywords_for_identifiers() {
let sql = "SELECT MAX(interval) FROM tbl";
dialects.parse_sql_statements(sql).unwrap();
}

#[test]
fn overflow() {
let expr = std::iter::repeat("1")
.take(1000)
blaginin marked this conversation as resolved.
Show resolved Hide resolved
.collect::<Vec<_>>()
.join(" + ");
let sql = format!("SELECT {}", expr);

let mut statements = Parser::parse_sql(&GenericDialect {}, sql.as_str()).unwrap();
let statement = statements.pop().unwrap();
assert_eq!(statement.to_string(), sql);
}