Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor utils, add tests, move exceptions into separate module #264

Merged
merged 18 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ How to use
Extract additional information from HQL (& other dialects)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In some dialects like HQL there is a lot of additional information about table like, fore example, is it external table, STORED AS, location & etc. This property will be always empty in 'classic' SQL DB like PostgreSQL or MySQL and this is the reason, why by default this information are 'hidden'.
Also some fields hidden in HQL, because they are simple not exists in HIVE, for example 'deferrable_initially'
In some dialects like HQL there is a lot of additional information about table like, fore example, is it external table,
STORED AS, location & etc. This property will be always empty in 'classic' SQL DB like PostgreSQL or MySQL
and this is the reason, why by default this information is 'hidden'.
Also some fields are hidden in HQL, because they are simple not exists in HIVE, for example 'deferrable_initially'
To get this 'hql' specific details about table in output please use 'output_mode' argument in run() method.

example:
Expand Down
6 changes: 2 additions & 4 deletions simple_ddl_parser/ddl_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@
Snowflake,
SparkSQL,
)
# "DDLParserError" is an alias for backward compatibility
from simple_ddl_parser.exception import SimpleDDLParserException as DDLParserError
from simple_ddl_parser.parser import Parser


class DDLParserError(Exception):
pass


class Dialects(
SparkSQL,
Snowflake,
Expand Down
9 changes: 9 additions & 0 deletions simple_ddl_parser/exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
__all__ = [
"SimpleDDLParserException",
]


class SimpleDDLParserException(Exception):
""" Base exception in simple ddl parser library """
pass

2 changes: 1 addition & 1 deletion simple_ddl_parser/output/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def group_by_type_result(self) -> None:
else:
_type.extend(item["comments"])
break
if result_as_dict["comments"] == []:
if not result_as_dict["comments"]:
del result_as_dict["comments"]

self.final_result = result_as_dict
Expand Down
39 changes: 19 additions & 20 deletions simple_ddl_parser/output/table_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,21 @@
from simple_ddl_parser.output.base_data import BaseData
from simple_ddl_parser.output.dialects import CommonDialectsFieldsMixin, dialect_by_name

__all__ = [
"TableData",
]


def _pre_process_kwargs(kwargs: dict, aliased_fields: dict) -> None:
for alias, field_name in aliased_fields.items():
if alias in kwargs:
kwargs[field_name] = kwargs[alias]
del kwargs[alias]

# todo: need to figure out how workaround it normally
if kwargs.get("fields_terminated_by") == "_ddl_parser_comma_only_str":
kwargs["fields_terminated_by"] = "','"


class TableData:
cls_prefix = "Dialect"
Expand All @@ -13,34 +28,18 @@ def get_dialect_class(cls, kwargs: dict):

if output_mode and output_mode != "sql":
main_cls = dialect_by_name.get(output_mode)
cls = dataclass(
return dataclass(
type(
f"{main_cls.__name__}{cls.cls_prefix}",
(main_cls, CommonDialectsFieldsMixin),
{},
)
)
else:
cls = BaseData

return cls

@staticmethod
def pre_process_kwargs(kwargs: dict, aliased_fields: dict) -> dict:
for alias, field_name in aliased_fields.items():
if alias in kwargs:
kwargs[field_name] = kwargs[alias]
del kwargs[alias]

# todo: need to figure out how workaround it normally
if (
"fields_terminated_by" in kwargs
and "_ddl_parser_comma_only_str" == kwargs["fields_terminated_by"]
):
kwargs["fields_terminated_by"] = "','"
return BaseData

@classmethod
def pre_load_mods(cls, main_cls, kwargs):
def pre_load_mods(cls, main_cls, kwargs) -> dict:
if kwargs.get("output_mode") == "bigquery":
if kwargs.get("schema"):
kwargs["dataset"] = kwargs["schema"]
Expand All @@ -55,7 +54,7 @@ def pre_load_mods(cls, main_cls, kwargs):
for name, value in cls_fields.items()
if value.metadata and "alias" in value.metadata
}
cls.pre_process_kwargs(kwargs, aliased_fields)
_pre_process_kwargs(kwargs, aliased_fields)
table_main_args = {
k.lower(): v for k, v in kwargs.items() if k.lower() in cls_fields
}
Expand Down
6 changes: 2 additions & 4 deletions simple_ddl_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@

from ply import lex, yacc

from simple_ddl_parser.exception import SimpleDDLParserException
from simple_ddl_parser.output.core import Output, dump_data_to_file
from simple_ddl_parser.output.dialects import dialect_by_name
from simple_ddl_parser.utils import (
SimpleDDLParserException,
find_first_unpair_closed_par,
)
from simple_ddl_parser.utils import find_first_unpair_closed_par

# open comment
OP_COM = "/*"
Expand Down
20 changes: 10 additions & 10 deletions simple_ddl_parser/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@


tokens = tuple(
set(
[
{
*[
"ID",
"DOT",
"STRING_BASE",
Expand All @@ -161,14 +161,14 @@
"LT",
"RT",
"COMMAT",
]
+ list(definition_statements.values())
+ list(common_statements.values())
+ list(columns_definition.values())
+ list(sequence_reserved.values())
+ list(after_columns_tokens.values())
+ list(alter_tokens.values())
)
],
*definition_statements.values(),
*common_statements.values(),
*columns_definition.values(),
*sequence_reserved.values(),
*after_columns_tokens.values(),
*alter_tokens.values(),
}
)

symbol_tokens = {
Expand Down
96 changes: 59 additions & 37 deletions simple_ddl_parser/utils.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,83 @@
import re
from typing import List
from typing import List, Tuple, Optional, Union, Any

# Backward compatibility import
from simple_ddl_parser.exception import SimpleDDLParserException

def remove_par(p_list: List[str]) -> List[str]:
remove_list = ["(", ")"]
for symbol in remove_list:
while symbol in p_list:
p_list.remove(symbol)
__all__ = [
"remove_par",
"check_spec",
"find_first_unpair_closed_par",
"normalize_name",
"get_table_id",
"SimpleDDLParserException"
]

_parentheses = ('(', ')')


def remove_par(p_list: List[Union[str, Any]]) -> List[Union[str, Any]]:
"""
Remove the parentheses from the given list

Warn: p_list may contain unhashable types, such as 'dict'.
"""
j = 0
for i in range(len(p_list)):
if p_list[i] not in _parentheses:
p_list[j] = p_list[i]
j += 1
while j < len(p_list):
p_list.pop()
return p_list


spec_mapper = {
_spec_mapper = {
"'pars_m_t'": "'\t'",
"'pars_m_n'": "'\n'",
"'pars_m_dq'": '"',
"pars_m_single": "'",
}


def check_spec(value: str) -> str:
replace_value = spec_mapper.get(value)
if not replace_value:
for item in spec_mapper:
if item in value:
replace_value = value.replace(item, spec_mapper[item])
break
else:
replace_value = value
return replace_value


def find_first_unpair_closed_par(str_: str) -> int:
stack = []
n = -1
for i in str_:
n += 1
if i == ")":
if not stack:
return n
else:
stack.pop(-1)
elif i == "(":
stack.append(i)
def check_spec(string: str) -> str:
"""
Replace escape tokens to their representation
"""
if string in _spec_mapper:
return _spec_mapper[string]
for replace_from, replace_to in _spec_mapper.items():
if replace_from in string:
return string.replace(replace_from, replace_to)
return string


def find_first_unpair_closed_par(str_: str) -> Optional[int]:
"""
Returns index of first unpair close parentheses.
Or returns None, if there is no one.
"""
count_open = 0
for i, char in enumerate(str_):
if char == '(':
count_open += 1
if char == ')':
count_open -= 1
if count_open < 0:
return i
return None


def normalize_name(name: str) -> str:
# clean up [] and " symbols from names
"""
Clean up [] and " characters from the given name
"""
clean_up_re = r'[\[\]"]'
return re.sub(clean_up_re, "", name).lower()


def get_table_id(schema_name: str, table_name: str):
def get_table_id(schema_name: str, table_name: str) -> Tuple[str, str]:
table_name = normalize_name(table_name)
if schema_name:
schema_name = normalize_name(schema_name)
return (table_name, schema_name)


class SimpleDDLParserException(Exception):
pass
4 changes: 2 additions & 2 deletions tests/non_statement_tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from simple_ddl_parser import DDLParser, DDLParserError
from simple_ddl_parser import DDLParser, SimpleDDLParserException
from simple_ddl_parser.output.core import get_table_id


Expand Down Expand Up @@ -29,7 +29,7 @@ def test_silent_false_flag():
created_timestamp TIMESTAMPTZ NOT NULL DEFAULT ALTER (now() at time zone 'utc')
);
"""
with pytest.raises(DDLParserError) as e:
with pytest.raises(SimpleDDLParserException) as e:
DDLParser(ddl, silent=False).run(group_by_type=True)

assert "Unknown statement" in e.value[1]
Expand Down
69 changes: 69 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest

from simple_ddl_parser import utils


@pytest.mark.parametrize(
"expression, expected_result",
[
([], []),
(["("], []),
([")"], []),
(["(", ")"], []),
([")", "("], []),
(["(", "A"], ["A"]),
(["A", ")"], ["A"]),
(["(", "A", ")"], ["A"]),
(["A", ")", ")"], ["A"]),
(["(", "(", "A"], ["A"]),
(["A", "B", "C"], ["A", "B", "C"]),
(["A", "(", "(", "B", "C", "("], ["A", "B", "C"]),
(["A", ")", "B", ")", "(", "C"], ["A", "B", "C"]),
(["(", "A", ")", "B", "C", ")"], ["A", "B", "C"]),
([dict()], [dict()]), # Edge case (unhashable types)
]
)
def test_remove_par(expression, expected_result):
assert utils.remove_par(expression) == expected_result


@pytest.mark.parametrize(
"expression, expected_result",
[
("", ""),
("simple", "simple"),

("'pars_m_t'", "'\t'"),
("'pars_m_n'", "'\n'"),
("'pars_m_dq'", '"'),
("pars_m_single", "'"),

("STRING_'pars_m_t'STRING", "STRING_'\t'STRING"),
("STRING_'pars_m_n'STRING", "STRING_'\n'STRING"),
("STRING_'pars_m_dq'STRING", "STRING_\"STRING"),
("STRING_pars_m_singleSTRING", "STRING_'STRING"),

("pars_m_single pars_m_single", "' '"),
("'pars_m_t''pars_m_n'", "'\t''pars_m_n'"), # determined by dict element order
]
)
def test_check_spec(expression, expected_result):
assert utils.check_spec(expression) == expected_result


@pytest.mark.parametrize(
"expression, expected_result",
[
(")", 0),
(")()", 0),
("())", 2),
("()())", 4),
("", None),
("text", None),
("()", None),
("(balanced) (brackets)", None),
("(not)) (balanced) (brackets", 5)
]
)
def test_find_first_unpair_closed_par(expression, expected_result):
assert utils.find_first_unpair_closed_par(expression) == expected_result
Loading