Skip to content

Commit

Permalink
Merge pull request #264 from demitryfly/some-refactorings-1
Browse files Browse the repository at this point in the history
refactor utils, add tests, move exceptions into separate module
  • Loading branch information
xnuinside authored Aug 1, 2024
2 parents 2ef4718 + d2413d5 commit 5533ede
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 80 deletions.
6 changes: 4 additions & 2 deletions docs/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ How to use
Extract additional information from HQL (& other dialects)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In some dialects like HQL there is a lot of additional information about table like, fore example, is it external table, STORED AS, location & etc. This property will be always empty in 'classic' SQL DB like PostgreSQL or MySQL and this is the reason, why by default this information are 'hidden'.
Also some fields hidden in HQL, because they are simple not exists in HIVE, for example 'deferrable_initially'
In some dialects like HQL there is a lot of additional information about table like, fore example, is it external table,
STORED AS, location & etc. This property will be always empty in 'classic' SQL DB like PostgreSQL or MySQL
and this is the reason, why by default this information is 'hidden'.
Also some fields are hidden in HQL, because they are simple not exists in HIVE, for example 'deferrable_initially'
To get this 'hql' specific details about table in output please use 'output_mode' argument in run() method.

example:
Expand Down
6 changes: 2 additions & 4 deletions simple_ddl_parser/ddl_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@
Snowflake,
SparkSQL,
)
# "DDLParserError" is an alias for backward compatibility
from simple_ddl_parser.exception import SimpleDDLParserException as DDLParserError
from simple_ddl_parser.parser import Parser


class DDLParserError(Exception):
pass


class Dialects(
SparkSQL,
Snowflake,
Expand Down
9 changes: 9 additions & 0 deletions simple_ddl_parser/exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
__all__ = [
"SimpleDDLParserException",
]


class SimpleDDLParserException(Exception):
""" Base exception in simple ddl parser library """
pass

2 changes: 1 addition & 1 deletion simple_ddl_parser/output/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def group_by_type_result(self) -> None:
else:
_type.extend(item["comments"])
break
if result_as_dict["comments"] == []:
if not result_as_dict["comments"]:
del result_as_dict["comments"]

self.final_result = result_as_dict
Expand Down
39 changes: 19 additions & 20 deletions simple_ddl_parser/output/table_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,21 @@
from simple_ddl_parser.output.base_data import BaseData
from simple_ddl_parser.output.dialects import CommonDialectsFieldsMixin, dialect_by_name

__all__ = [
"TableData",
]


def _pre_process_kwargs(kwargs: dict, aliased_fields: dict) -> None:
for alias, field_name in aliased_fields.items():
if alias in kwargs:
kwargs[field_name] = kwargs[alias]
del kwargs[alias]

# todo: need to figure out how workaround it normally
if kwargs.get("fields_terminated_by") == "_ddl_parser_comma_only_str":
kwargs["fields_terminated_by"] = "','"


class TableData:
cls_prefix = "Dialect"
Expand All @@ -13,34 +28,18 @@ def get_dialect_class(cls, kwargs: dict):

if output_mode and output_mode != "sql":
main_cls = dialect_by_name.get(output_mode)
cls = dataclass(
return dataclass(
type(
f"{main_cls.__name__}{cls.cls_prefix}",
(main_cls, CommonDialectsFieldsMixin),
{},
)
)
else:
cls = BaseData

return cls

@staticmethod
def pre_process_kwargs(kwargs: dict, aliased_fields: dict) -> dict:
for alias, field_name in aliased_fields.items():
if alias in kwargs:
kwargs[field_name] = kwargs[alias]
del kwargs[alias]

# todo: need to figure out how workaround it normally
if (
"fields_terminated_by" in kwargs
and "_ddl_parser_comma_only_str" == kwargs["fields_terminated_by"]
):
kwargs["fields_terminated_by"] = "','"
return BaseData

@classmethod
def pre_load_mods(cls, main_cls, kwargs):
def pre_load_mods(cls, main_cls, kwargs) -> dict:
if kwargs.get("output_mode") == "bigquery":
if kwargs.get("schema"):
kwargs["dataset"] = kwargs["schema"]
Expand All @@ -55,7 +54,7 @@ def pre_load_mods(cls, main_cls, kwargs):
for name, value in cls_fields.items()
if value.metadata and "alias" in value.metadata
}
cls.pre_process_kwargs(kwargs, aliased_fields)
_pre_process_kwargs(kwargs, aliased_fields)
table_main_args = {
k.lower(): v for k, v in kwargs.items() if k.lower() in cls_fields
}
Expand Down
6 changes: 2 additions & 4 deletions simple_ddl_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@

from ply import lex, yacc

from simple_ddl_parser.exception import SimpleDDLParserException
from simple_ddl_parser.output.core import Output, dump_data_to_file
from simple_ddl_parser.output.dialects import dialect_by_name
from simple_ddl_parser.utils import (
SimpleDDLParserException,
find_first_unpair_closed_par,
)
from simple_ddl_parser.utils import find_first_unpair_closed_par

# open comment
OP_COM = "/*"
Expand Down
20 changes: 10 additions & 10 deletions simple_ddl_parser/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@


tokens = tuple(
set(
[
{
*[
"ID",
"DOT",
"STRING_BASE",
Expand All @@ -161,14 +161,14 @@
"LT",
"RT",
"COMMAT",
]
+ list(definition_statements.values())
+ list(common_statements.values())
+ list(columns_definition.values())
+ list(sequence_reserved.values())
+ list(after_columns_tokens.values())
+ list(alter_tokens.values())
)
],
*definition_statements.values(),
*common_statements.values(),
*columns_definition.values(),
*sequence_reserved.values(),
*after_columns_tokens.values(),
*alter_tokens.values(),
}
)

symbol_tokens = {
Expand Down
96 changes: 59 additions & 37 deletions simple_ddl_parser/utils.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,83 @@
import re
from typing import List
from typing import List, Tuple, Optional, Union, Any

# Backward compatibility import
from simple_ddl_parser.exception import SimpleDDLParserException

def remove_par(p_list: List[str]) -> List[str]:
remove_list = ["(", ")"]
for symbol in remove_list:
while symbol in p_list:
p_list.remove(symbol)
__all__ = [
"remove_par",
"check_spec",
"find_first_unpair_closed_par",
"normalize_name",
"get_table_id",
"SimpleDDLParserException"
]

_parentheses = ('(', ')')


def remove_par(p_list: List[Union[str, Any]]) -> List[Union[str, Any]]:
"""
Remove the parentheses from the given list
Warn: p_list may contain unhashable types, such as 'dict'.
"""
j = 0
for i in range(len(p_list)):
if p_list[i] not in _parentheses:
p_list[j] = p_list[i]
j += 1
while j < len(p_list):
p_list.pop()
return p_list


spec_mapper = {
_spec_mapper = {
"'pars_m_t'": "'\t'",
"'pars_m_n'": "'\n'",
"'pars_m_dq'": '"',
"pars_m_single": "'",
}


def check_spec(value: str) -> str:
replace_value = spec_mapper.get(value)
if not replace_value:
for item in spec_mapper:
if item in value:
replace_value = value.replace(item, spec_mapper[item])
break
else:
replace_value = value
return replace_value


def find_first_unpair_closed_par(str_: str) -> int:
stack = []
n = -1
for i in str_:
n += 1
if i == ")":
if not stack:
return n
else:
stack.pop(-1)
elif i == "(":
stack.append(i)
def check_spec(string: str) -> str:
"""
Replace escape tokens to their representation
"""
if string in _spec_mapper:
return _spec_mapper[string]
for replace_from, replace_to in _spec_mapper.items():
if replace_from in string:
return string.replace(replace_from, replace_to)
return string


def find_first_unpair_closed_par(str_: str) -> Optional[int]:
"""
Returns index of first unpair close parentheses.
Or returns None, if there is no one.
"""
count_open = 0
for i, char in enumerate(str_):
if char == '(':
count_open += 1
if char == ')':
count_open -= 1
if count_open < 0:
return i
return None


def normalize_name(name: str) -> str:
# clean up [] and " symbols from names
"""
Clean up [] and " characters from the given name
"""
clean_up_re = r'[\[\]"]'
return re.sub(clean_up_re, "", name).lower()


def get_table_id(schema_name: str, table_name: str):
def get_table_id(schema_name: str, table_name: str) -> Tuple[str, str]:
table_name = normalize_name(table_name)
if schema_name:
schema_name = normalize_name(schema_name)
return (table_name, schema_name)


class SimpleDDLParserException(Exception):
pass
4 changes: 2 additions & 2 deletions tests/non_statement_tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from simple_ddl_parser import DDLParser, DDLParserError
from simple_ddl_parser import DDLParser, SimpleDDLParserException
from simple_ddl_parser.output.core import get_table_id


Expand Down Expand Up @@ -29,7 +29,7 @@ def test_silent_false_flag():
created_timestamp TIMESTAMPTZ NOT NULL DEFAULT ALTER (now() at time zone 'utc')
);
"""
with pytest.raises(DDLParserError) as e:
with pytest.raises(SimpleDDLParserException) as e:
DDLParser(ddl, silent=False).run(group_by_type=True)

assert "Unknown statement" in e.value[1]
Expand Down
69 changes: 69 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest

from simple_ddl_parser import utils


@pytest.mark.parametrize(
"expression, expected_result",
[
([], []),
(["("], []),
([")"], []),
(["(", ")"], []),
([")", "("], []),
(["(", "A"], ["A"]),
(["A", ")"], ["A"]),
(["(", "A", ")"], ["A"]),
(["A", ")", ")"], ["A"]),
(["(", "(", "A"], ["A"]),
(["A", "B", "C"], ["A", "B", "C"]),
(["A", "(", "(", "B", "C", "("], ["A", "B", "C"]),
(["A", ")", "B", ")", "(", "C"], ["A", "B", "C"]),
(["(", "A", ")", "B", "C", ")"], ["A", "B", "C"]),
([dict()], [dict()]), # Edge case (unhashable types)
]
)
def test_remove_par(expression, expected_result):
assert utils.remove_par(expression) == expected_result


@pytest.mark.parametrize(
"expression, expected_result",
[
("", ""),
("simple", "simple"),
("'pars_m_t'", "'\t'"),
("'pars_m_n'", "'\n'"),
("'pars_m_dq'", '"'),
("pars_m_single", "'"),
("STRING_'pars_m_t'STRING", "STRING_'\t'STRING"),
("STRING_'pars_m_n'STRING", "STRING_'\n'STRING"),
("STRING_'pars_m_dq'STRING", "STRING_\"STRING"),
("STRING_pars_m_singleSTRING", "STRING_'STRING"),
("pars_m_single pars_m_single", "' '"),
("'pars_m_t''pars_m_n'", "'\t''pars_m_n'"), # determined by dict element order
]
)
def test_check_spec(expression, expected_result):
assert utils.check_spec(expression) == expected_result


@pytest.mark.parametrize(
"expression, expected_result",
[
(")", 0),
(")()", 0),
("())", 2),
("()())", 4),
("", None),
("text", None),
("()", None),
("(balanced) (brackets)", None),
("(not)) (balanced) (brackets", 5)
]
)
def test_find_first_unpair_closed_par(expression, expected_result):
assert utils.find_first_unpair_closed_par(expression) == expected_result

0 comments on commit 5533ede

Please sign in to comment.