Skip to content

Commit

Permalink
Support empty constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
niknetniko committed Jun 16, 2023
1 parent f20ac55 commit 4415a5c
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 13 deletions.
61 changes: 48 additions & 13 deletions tested/dsl/ast_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import ast
import dataclasses
from typing import Optional

from pydantic import ValidationError

Expand Down Expand Up @@ -62,6 +63,24 @@ class InvalidDslError(Exception):
pass


def _is_and_get_allowed_empty(node: ast.Call) -> Optional[Value]:
"""
Check if we allow this cast without params to represent an "empty" value.
Returns the empty value if allowed, otherwise None.
"""
assert isinstance(node.func, ast.Name)
if node.func.id in AdvancedSequenceTypes.__members__.values():
# noinspection PyTypeChecker
return SequenceType(type=node.func.id, data=[])
elif node.func.id in BasicSequenceTypes.__members__.values():
# noinspection PyTypeChecker
return SequenceType(type=node.func.id, data=[])
elif node.func.id in BasicObjectTypes.__members__.values():
return ObjectType(type=BasicObjectTypes.MAP, data=[])
else:
return None


def _is_type_cast(node: ast.expr) -> bool:
"""
Check if this is a cast to a specific type or not.
Expand Down Expand Up @@ -165,20 +184,36 @@ def _convert_constant(node: ast.Constant) -> Value:
def _convert_expression(node: ast.expr, is_return: bool) -> Expression:
if _is_type_cast(node):
assert isinstance(node, ast.Call)
if n := len(node.args) != 1:
raise InvalidDslError(
f"A cast function must have exactly one argument, found {n}"
)
# We have a cast, so extract the value, but modify the type later on.
subexpression = node.args[0]
value = _convert_expression(subexpression, is_return)

if not isinstance(value, get_args(Value)):
raise InvalidDslError(
"The argument of a cast function must resolve to a value."
)

assert isinstance(node.func, ast.Name)
# "Casts" of sequence types can also be used a constructor for an empty sequence.
# For example, "set()", "map()", ...
nr_of_args = len(node.args)
if (empty_value := _is_and_get_allowed_empty(node)) and nr_of_args == 0:
value = empty_value
else:
assert isinstance(node.func, ast.Name)
if nr_of_args != 1:
if _is_and_get_allowed_empty(node) is not None:
error = f"""
The cast function '{node.func.id}' must have either zero or one arguments:
- Zero if you want to use it to represent an empty value, e.g. '{node.func.id}()'
- One if you want to cast another value to the type '{node.func.id}'.
"""
raise InvalidDslError(error)
else:
error = f"""
The cast function '{node.func.id}' must have exact one argument, but found {nr_of_args}.
For example, '{node.func.id}(...)', where '...' is the value.
"""
raise InvalidDslError(error)
# We have a cast, so extract the value, but modify the type later on.
subexpression = node.args[0]
value = _convert_expression(subexpression, is_return)

if not isinstance(value, get_args(Value)):
raise InvalidDslError(
"The argument of a cast function must resolve to a value."
)
return dataclasses.replace(value, type=node.func.id)
elif isinstance(node, ast.Call):
if is_return:
Expand Down
63 changes: 63 additions & 0 deletions tests/test_dsl_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
AdvancedNumericTypes,
AdvancedSequenceTypes,
BasicNumericTypes,
BasicObjectTypes,
BasicSequenceTypes,
)
from tested.dsl import translate_to_test_suite
from tested.serialisation import Assignment, FunctionCall, ObjectType, SequenceType
Expand Down Expand Up @@ -430,3 +432,64 @@ def test_statement_raw_return():
assert isinstance(test.output.result.value, SequenceType)
for element in test.output.result.value.data:
assert element.type == AdvancedSequenceTypes.TUPLE


@pytest.mark.parametrize(
"function_name,result",
[
("set", BasicSequenceTypes.SET),
("sequence", BasicSequenceTypes.SEQUENCE),
("array", AdvancedSequenceTypes.ARRAY),
("tuple", AdvancedSequenceTypes.TUPLE),
("map", BasicObjectTypes.MAP),
],
)
def test_empty_constructor(function_name, result):
yaml_str = f"""
- tab: 'Test'
contexts:
- testcases:
- statement: 'test()'
return_raw: '{function_name}()'
"""
json_str = translate_to_test_suite(yaml_str)
suite = parse_test_suite(json_str)
assert len(suite.tabs) == 1
tab = suite.tabs[0]
assert len(tab.contexts) == 1
testcases = tab.contexts[0].testcases
assert len(testcases) == 1
test = testcases[0]
assert isinstance(test.input, FunctionCall)
assert test.output.result.value.type == result
assert len(test.output.result.value.data) == 0


@pytest.mark.parametrize(
"function_name,result",
[
("set", BasicSequenceTypes.SET),
("sequence", BasicSequenceTypes.SEQUENCE),
("array", AdvancedSequenceTypes.ARRAY),
("tuple", AdvancedSequenceTypes.TUPLE),
],
)
def test_empty_constructor_with_param(function_name, result):
yaml_str = f"""
- tab: 'Test'
contexts:
- testcases:
- statement: 'test()'
return_raw: '{function_name}([])'
"""
json_str = translate_to_test_suite(yaml_str)
suite = parse_test_suite(json_str)
assert len(suite.tabs) == 1
tab = suite.tabs[0]
assert len(tab.contexts) == 1
testcases = tab.contexts[0].testcases
assert len(testcases) == 1
test = testcases[0]
assert isinstance(test.input, FunctionCall)
assert test.output.result.value.type == result
assert len(test.output.result.value.data) == 0

0 comments on commit 4415a5c

Please sign in to comment.