Skip to content

Commit

Permalink
fix: refactor unsupported DDL handling
Browse files Browse the repository at this point in the history
  • Loading branch information
tconbeer committed Nov 22, 2024
1 parent b6d5af4 commit c9acb3e
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ All notable changes to this project will be documented in this file.
- sqlfmt no longer adds a space between the function name and parens for `filter()`, `isnull()`, and `rlike('foo', 'bar')` (but it also permits `filter ()`, `isnull ()`, and `rlike ('foo')` to support dialects where those are operators, not function names) ([#641](https://github.com/tconbeer/sqlfmt/issues/641), [#478](https://github.com/tconbeer/sqlfmt/issues/478) - thank you [@williamscs](https://github.com/williamscs), [@hongtron](https://github.com/hongtron), and [@chwiese](https://github.com/chwiese)!).
- sqlfmt now supports Spark type-hinted numeric literals like `32y` and `+3.2e6bd` and will not introduce a space between the digits and their type suffix ([#640](https://github.com/tconbeer/sqlfmt/issues/640) - thank you [@ShaneMazur](https://github.com/ShaneMazur)!).
- sqlfmt now supports Databricks query hint comments like `/*+ COALESCE(3) */` ([#639](https://github.com/tconbeer/sqlfmt/issues/639) - thank you [@wr-atlas](https://github.com/wr-atlas)!).
- sqlfmt now no-ops instead of errors when encountering `create row access policy` statements with `grant` sub-statements (it also generally more robustly handles unsupported DDL) ([#633](https://github.com/tconbeer/sqlfmt/issues/633)).

## [0.23.3] - 2024-11-12

Expand Down
11 changes: 0 additions & 11 deletions src/sqlfmt/node_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,17 +275,6 @@ def disable_formatting(
):
formatting_disabled.pop()

# formatting can be disabled because of unsupported
# ddl. When we hit a semicolon we need to pop
# all of the formatting disabled tokens caused by ddl
# off the stack
if token.type is TokenType.SEMICOLON:
while (
formatting_disabled
and "fmt:" not in formatting_disabled[-1].token.lower()
):
formatting_disabled.pop()

return formatting_disabled

def append_newline(self, line: Line) -> None:
Expand Down
8 changes: 6 additions & 2 deletions src/sqlfmt/rules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sqlfmt.rules.function import FUNCTION as FUNCTION
from sqlfmt.rules.grant import GRANT as GRANT
from sqlfmt.rules.jinja import JINJA as JINJA # noqa
from sqlfmt.rules.unsupported import UNSUPPORTED as UNSUPPORTED
from sqlfmt.rules.warehouse import WAREHOUSE as WAREHOUSE
from sqlfmt.token import TokenType

Expand Down Expand Up @@ -77,7 +78,7 @@
r"interval",
r"is(\s+not)?(\s+distinct\s+from)?",
r"isnull",
r"(not\s+)?(r|i)?like(\s+(any|all))?",
r"(not\s+)?i?like(\s+(any|all))?",
r"over",
r"(un)?pivot",
r"notnull",
Expand Down Expand Up @@ -362,7 +363,10 @@
+ group(r"\W", r"$"),
action=partial(
actions.handle_nonreserved_top_level_keyword,
action=partial(actions.add_node_to_buffer, token_type=TokenType.FMT_OFF),
action=partial(
actions.lex_ruleset,
new_ruleset=UNSUPPORTED,
),
),
),
]
76 changes: 76 additions & 0 deletions src/sqlfmt/rules/unsupported.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from functools import partial

from sqlfmt import actions
from sqlfmt.rule import Rule
from sqlfmt.rules.common import NEWLINE, SQL_COMMENT, SQL_QUOTED_EXP, group
from sqlfmt.rules.jinja import JINJA
from sqlfmt.token import TokenType

UNSUPPORTED = [
# These match just the start of jinja tags, which allows
# the parser to deal with nesting in a more powerful way than
# regex allows
Rule(
name="jinja_start",
priority=120,
pattern=group(r"\{[{%#]"),
action=partial(actions.lex_ruleset, new_ruleset=JINJA),
),
# we should never match the end of a jinja tag by itself
Rule(
name="jinja_end",
priority=130,
pattern=group(r"[#}%]\}"),
action=actions.raise_sqlfmt_bracket_error,
),
Rule(
name="quoted_name",
priority=200,
pattern=SQL_QUOTED_EXP,
action=partial(actions.add_node_to_buffer, token_type=TokenType.QUOTED_NAME),
),
Rule(
name="comment",
priority=300,
pattern=SQL_COMMENT,
action=actions.add_comment_to_buffer,
),
Rule(
name="comment_start",
priority=310,
pattern=group(r"/\*"),
action=partial(
actions.handle_potentially_nested_tokens,
start_name="comment_start",
end_name="comment_end",
token_type=TokenType.COMMENT,
),
),
Rule(
name="comment_end",
priority=320,
pattern=group(r"\*/"),
action=actions.raise_sqlfmt_bracket_error,
),
Rule(
name="semicolon",
priority=400,
pattern=group(r";"),
action=actions.handle_semicolon,
),
Rule(
name="newline",
priority=999,
pattern=group(NEWLINE),
action=actions.handle_newline,
),
Rule(
name="unsupported_line",
priority=1000,
pattern=group(r"[^;\n]+?") + group(r";", NEWLINE, r"$"),
action=partial(
actions.handle_reserved_keyword,
action=partial(actions.add_node_to_buffer, token_type=TokenType.DATA),
),
),
]
4 changes: 4 additions & 0 deletions tests/data/preformatted/401_create_row_access_policy.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
create or replace row access policy foo
on foo.bar.baz
grant to ('user1', 'user2')
filter using ( foo = 'bar' )
1 change: 1 addition & 0 deletions tests/functional_tests/test_general_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"preformatted/302_jinjafmt_multiline_str.sql",
"preformatted/303_jinjafmt_more_mutliline_str.sql",
"preformatted/400_create_table.sql",
"preformatted/401_create_row_access_policy.sql",
"unformatted/100_select_case.sql",
"unformatted/101_multiline.sql",
"unformatted/102_lots_of_comments.sql",
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,8 @@ def test_handle_unsupported_ddl(default_analyzer: Analyzer) -> None:
query = default_analyzer.parse_query(source_string=source_string.lstrip())
assert len(query.lines) == 3
first_create_line = query.lines[0]
assert len(first_create_line.nodes) == 9
assert first_create_line.nodes[0].token.type is TokenType.FMT_OFF
assert len(first_create_line.nodes) == 3 # data, semicolon, newline
assert first_create_line.nodes[0].token.type is TokenType.DATA
assert first_create_line.nodes[-2].token.type is TokenType.SEMICOLON

select_line = query.lines[1]
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/test_node_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def test_disabled_formatting(default_mode: Mode) -> None:
assert create_publication_line.formatting_disabled
assert create_publication_line.nodes
create_token = create_publication_line.nodes[0].token
assert create_token.type is TokenType.FMT_OFF
assert create_token.type is TokenType.DATA
assert create_token in create_publication_line.nodes[0].formatting_disabled
assert len(create_publication_line.nodes[0].formatting_disabled) == 3
semicolon_node = create_publication_line.nodes[-2]
Expand Down

0 comments on commit c9acb3e

Please sign in to comment.