From 2822a72789f216c1937c6da62778f2ed1ca2b1f8 Mon Sep 17 00:00:00 2001 From: Niko Strijbol Date: Mon, 14 Aug 2023 17:09:12 +0200 Subject: [PATCH] Add support for custom types via YAML tags See #402 --- tested/dsl/translate_parser.py | 121 +++++++++++++++++++++++++-------- tests/test_dsl_yaml.py | 44 ++++++++++++ 2 files changed, 138 insertions(+), 27 deletions(-) diff --git a/tested/dsl/translate_parser.py b/tested/dsl/translate_parser.py index ed6e87ce..3e325231 100644 --- a/tested/dsl/translate_parser.py +++ b/tested/dsl/translate_parser.py @@ -1,20 +1,32 @@ import json +from decimal import Decimal from logging import getLogger from pathlib import Path -from typing import Callable, Dict, List, Optional, TextIO, TypeVar, Union, cast +from typing import Any, Callable, Dict, List, Optional, TextIO, TypeVar, Union, cast import yaml +from attrs import define from jsonschema import Draft7Validator from tested.datatypes import ( + AdvancedNumericTypes, + AllTypes, BasicBooleanTypes, + BasicNothingTypes, BasicNumericTypes, BasicObjectTypes, BasicSequenceTypes, BasicStringTypes, + BooleanTypes, + NothingTypes, + NumericTypes, + ObjectTypes, + SequenceTypes, + StringTypes, + resolve_to_basic, ) from tested.dsl.ast_translator import parse_string -from tested.parsing import suite_to_json +from tested.parsing import get_converter, suite_to_json from tested.serialisation import ( BooleanType, NothingType, @@ -44,7 +56,7 @@ TextOutputChannel, ValueOutputChannel, ) -from tested.utils import recursive_dict_merge +from tested.utils import get_args, recursive_dict_merge logger = getLogger(__name__) @@ -53,11 +65,33 @@ YamlObject = Union[YamlDict, list, bool, float, int, str, None] +@define +class TestedType: + value: Any + type: str | AllTypes + + +def custom_type_constructors(loader: yaml.Loader, node: yaml.Node): + tested_tag = node.tag[1:] + if isinstance(node, yaml.MappingNode): + base_result = loader.construct_mapping(node) + elif isinstance(node, yaml.SequenceNode): + base_result = loader.construct_sequence(node) + else: + assert isinstance(node, yaml.ScalarNode) + base_result = loader.construct_scalar(node) + return TestedType(type=tested_tag, value=base_result) + + def _parse_yaml(yaml_stream: Union[str, TextIO]) -> YamlObject: """ Parse a string or stream to YAML. """ - return yaml.load(yaml_stream, Loader=yaml.CSafeLoader) + loader: type[yaml.Loader] = cast(type[yaml.Loader], yaml.CSafeLoader) + for types in get_args(AllTypes): + for actual_type in types: + yaml.add_constructor("!" + actual_type, custom_type_constructors, loader) + return yaml.load(yaml_stream, loader) def _load_schema_validator(): @@ -113,37 +147,70 @@ def _deepen_config_level( return recursive_dict_merge(current_level, new_level_object["config"]) -def _convert_value(value: YamlObject) -> Value: - if value is None: - return NothingType() - elif isinstance(value, str): - return StringType(type=BasicStringTypes.TEXT, data=value) - elif isinstance(value, bool): - return BooleanType(type=BasicBooleanTypes.BOOLEAN, data=value) - elif isinstance(value, int): - return NumberType(type=BasicNumericTypes.INTEGER, data=value) - elif isinstance(value, float): - return NumberType(type=BasicNumericTypes.REAL, data=value) - elif isinstance(value, list): - return SequenceType( - type=BasicSequenceTypes.SEQUENCE, - data=[_convert_value(part_value) for part_value in value], - ) - elif isinstance(value, set): +def _tested_type_to_value(tested_type: TestedType) -> Value: + type_enum = get_converter().structure(tested_type.type, AllTypes) + if isinstance(type_enum, NumericTypes): + # Some special cases for advanced numeric types. + if type_enum == AdvancedNumericTypes.FIXED_PRECISION: + value = Decimal(tested_type.value) + else: + basic_type = resolve_to_basic(type_enum) + if basic_type == BasicNumericTypes.INTEGER: + value = int(tested_type.value) + elif basic_type == BasicNumericTypes.REAL: + value = float(tested_type.value) + else: + raise ValueError(f"Unknown basic numeric type {type_enum}") + return NumberType(type=type_enum, data=value) + elif isinstance(type_enum, StringTypes): + return StringType(type=type_enum, data=tested_type.value) + elif isinstance(type_enum, BooleanTypes): + return BooleanType(type=type_enum, data=bool(tested_type.value)) + elif isinstance(type_enum, NothingTypes): + return NothingType(type=type_enum, data=None) + elif isinstance(type_enum, SequenceTypes): return SequenceType( - type=BasicSequenceTypes.SET, - data=[_convert_value(part_value) for part_value in value], + type=type_enum, + data=[_convert_value(part_value) for part_value in tested_type.value], ) - else: + elif isinstance(type_enum, ObjectTypes): data = [] - for key, val in value.items(): + for key, val in tested_type.value.items(): data.append( ObjectKeyValuePair( - key=StringType(type=BasicStringTypes.TEXT, data=key), + key=_convert_value(key), value=_convert_value(val), ) ) - return ObjectType(type=BasicObjectTypes.MAP, data=data) + return ObjectType(type=type_enum, data=data) + raise ValueError(f"Unknown type {tested_type.type} with value {tested_type.value}") + + +def _convert_value(value: YamlObject) -> Value: + if isinstance(value, TestedType): + tested_type = value + else: + # Convert the value into a "TESTed" type. + if value is None: + tested_type = TestedType(value=None, type=BasicNothingTypes.NOTHING) + elif isinstance(value, str): + tested_type = TestedType(value=value, type=BasicStringTypes.TEXT) + elif isinstance(value, bool): + tested_type = TestedType(type=BasicBooleanTypes.BOOLEAN, value=value) + elif isinstance(value, int): + tested_type = TestedType(type=BasicNumericTypes.INTEGER, value=value) + elif isinstance(value, float): + tested_type = TestedType(type=BasicNumericTypes.REAL, value=value) + elif isinstance(value, list): + tested_type = TestedType(type=BasicSequenceTypes.SEQUENCE, value=value) + elif isinstance(value, set): + tested_type = TestedType(type=BasicSequenceTypes.SET, value=value) + elif isinstance(value, dict): + tested_type = TestedType(type=BasicObjectTypes.MAP, value=value) + else: + raise ValueError(f"Unknown type for value {value}.") + + return _tested_type_to_value(tested_type) def _convert_file(link_file: YamlDict) -> FileUrl: diff --git a/tests/test_dsl_yaml.py b/tests/test_dsl_yaml.py index 18be9369..08147604 100644 --- a/tests/test_dsl_yaml.py +++ b/tests/test_dsl_yaml.py @@ -1,3 +1,4 @@ +import json from pathlib import Path import pytest @@ -9,6 +10,12 @@ BasicObjectTypes, BasicSequenceTypes, BasicStringTypes, + BooleanTypes, + NothingTypes, + NumericTypes, + ObjectTypes, + SequenceTypes, + StringTypes, ) from tested.dsl import translate_to_test_suite from tested.serialisation import ( @@ -27,6 +34,7 @@ ValueOutputChannel, parse_test_suite, ) +from tested.utils import get_args def test_parse_one_tab_ctx(): @@ -722,3 +730,39 @@ def test_yaml_set_tag_is_supported(): NumberType(type=BasicNumericTypes.INTEGER, data=6), ], ) + + +@pytest.mark.parametrize( + "all_types,value", + [ + (NumericTypes, 5), + (StringTypes, "hallo"), + (BooleanTypes, True), + (NothingTypes, None), + (SequenceTypes, [5, 6]), + (ObjectTypes, {"test": 6}), + ], +) +def test_yaml_custom_tags_are_supported(all_types, value): + json_type = json.dumps(value) + for types in get_args(all_types): + for the_type in types: + yaml_str = f""" + - tab: 'Test' + contexts: + - testcases: + - statement: 'test()' + return: !{the_type} {json_type} + """ + json_str = translate_to_test_suite(yaml_str) + suite = parse_test_suite(json_str) + assert len(suite.tabs) == 1 + tab = suite.tabs[0] + assert len(tab.contexts) == 1 + testcases = tab.contexts[0].testcases + assert len(testcases) == 1 + test = testcases[0] + assert isinstance(test.input, FunctionCall) + assert isinstance(test.output.result, ValueOutputChannel) + value = test.output.result.value + assert value.type == the_type