From e659fb3373993bec4af299acf226b5ab0fc2ab9e Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Tue, 17 Sep 2024 10:23:53 +0200 Subject: [PATCH] Add override tag support in mixins (#1568) --- src/snowflake/cli/api/project/definition.py | 11 +++ .../api/project/schemas/project_definition.py | 25 ++++++- .../cli/api/utils/definition_rendering.py | 4 +- tests/api/utils/test_override_tag.py | 74 +++++++++++++++++++ 4 files changed, 109 insertions(+), 5 deletions(-) create mode 100644 tests/api/utils/test_override_tag.py diff --git a/src/snowflake/cli/api/project/definition.py b/src/snowflake/cli/api/project/definition.py index 60dec254e0..0504bd3227 100644 --- a/src/snowflake/cli/api/project/definition.py +++ b/src/snowflake/cli/api/project/definition.py @@ -23,6 +23,7 @@ from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB from snowflake.cli.api.project.schemas.project_definition import ( ProjectProperties, + YamlOverride, ) from snowflake.cli.api.project.util import ( append_to_identifier, @@ -37,6 +38,7 @@ ) from snowflake.cli.api.utils.dict_utils import deep_merge_dicts from snowflake.cli.api.utils.types import Context, Definition +from yaml import MappingNode, SequenceNode DEFAULT_USERNAME = "unknown_user" @@ -50,6 +52,7 @@ def _get_merged_definitions(paths: List[Path]) -> Optional[Definition]: loader.add_constructor( yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, _no_duplicates_constructor ) + loader.add_constructor("!override", _override_tag) with spaths[0].open("r", read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB) as base_yml: definition = yaml.load(base_yml.read(), Loader=loader) or {} @@ -113,3 +116,11 @@ def _no_duplicates_constructor(loader, node, deep=False): ) mapping[key] = value return loader.construct_mapping(node, deep) + + +def _override_tag(loader, node, deep=False): + if isinstance(node, SequenceNode): + return YamlOverride(data=loader.construct_sequence(node, deep)) + if isinstance(node, MappingNode): + return YamlOverride(data=loader.construct_mapping(node, deep)) + return node.value diff --git a/src/snowflake/cli/api/project/schemas/project_definition.py b/src/snowflake/cli/api/project/schemas/project_definition.py index 072c9c2acf..d14e59f13e 100644 --- a/src/snowflake/cli/api/project/schemas/project_definition.py +++ b/src/snowflake/cli/api/project/schemas/project_definition.py @@ -62,6 +62,11 @@ class ProjectProperties: project_context: Context +@dataclass +class YamlOverride: + data: dict | list + + class _ProjectDefinitionBase(UpdatableModel): def __init__(self, *args, **kwargs): try: @@ -196,7 +201,11 @@ def apply_mixins(cls, data: Dict) -> Dict: @classmethod def _merge_mixins_with_entity( - cls, entity_id: str, entity: dict, entity_mixins_names: list, mixin_defs: dict + cls, + entity_id: str, + entity: dict, + entity_mixins_names: list, + mixin_defs: dict, ) -> dict: # Validate mixins for mixin_name in entity_mixins_names: @@ -215,8 +224,10 @@ def _merge_mixins_with_entity( ) entity_value = entity.get(key) - if entity_value is not None and not isinstance( - entity_value, type(override_value) + if ( + entity_value is not None + and not isinstance(entity_value, YamlOverride) + and not isinstance(entity_value, type(override_value)) ): raise ValueError( f"Value from mixins for property {key} is of type '{type(override_value).__name__}' " @@ -231,13 +242,16 @@ def _merge_mixins_with_entity( def _merge_data( cls, left: dict | list | scalar | None, - right: dict | list | scalar | None, + right: dict | list | scalar | None | YamlOverride, ): """ Merges right data into left. Right and left is expected to be of the same type, if not right is returned. If left is sequence then missing elements from right are appended. If left is dictionary then we update it with data from right. The update is done recursively key by key. """ + if isinstance(right, YamlOverride): + return right.data + if left is None: return right @@ -300,6 +314,9 @@ def get_allowed_fields_for_entity(entity: Dict[str, Any]) -> List[str]: Get the allowed fields for the given entity. """ entity_type = entity.get("type") + if entity_type is None: + raise ValueError("Entity is missing type declaration.") + if entity_type not in v2_entity_model_types_map: return [] diff --git a/src/snowflake/cli/api/utils/definition_rendering.py b/src/snowflake/cli/api/utils/definition_rendering.py index 9716e4186c..1755b31609 100644 --- a/src/snowflake/cli/api/utils/definition_rendering.py +++ b/src/snowflake/cli/api/utils/definition_rendering.py @@ -277,7 +277,9 @@ def _add_defaults_to_definition(original_definition: Definition) -> Definition: with context({"skip_validation_on_templates": True}): # pass a flag to Pydantic to skip validation for templated scalars # populate the defaults - project_definition = build_project_definition(**original_definition) + project_definition = build_project_definition( + **copy.deepcopy(original_definition) + ) definition_with_defaults = project_definition.model_dump( exclude_none=True, warnings=False, by_alias=True diff --git a/tests/api/utils/test_override_tag.py b/tests/api/utils/test_override_tag.py new file mode 100644 index 0000000000..3d680838e9 --- /dev/null +++ b/tests/api/utils/test_override_tag.py @@ -0,0 +1,74 @@ +from textwrap import dedent + +import pytest +from snowflake.cli.api.project.definition import load_project + + +@pytest.mark.parametrize( + "override, expected", + [("", ["A", "B", "entity_value"]), ("!override", ["entity_value"])], +) +def test_override_works_for_sequences(named_temporary_file, override, expected): + text = f"""\ + definition_version: "2" + mixins: + mixin_a: + external_access_integrations: + - A + - B + entities: + my_function: + type: "function" + stage: foo_stage + returns: string + handler: foo.baz + signature: "" + artifacts: [] + external_access_integrations: {override} + - entity_value + meta: + use_mixins: ["mixin_a"] + """ + + with named_temporary_file(suffix=".yml") as p: + p.write_text(dedent(text)) + result = load_project([p]) + + pd = result.project_definition + assert pd.entities["my_function"].external_access_integrations == expected + + +@pytest.mark.parametrize( + "override, expected", + [("", {"A": "a", "B": "b", "entity": "value"}), ("!override", {"entity": "value"})], +) +def test_override_works_for_mapping(named_temporary_file, override, expected): + text = f"""\ + definition_version: "2" + mixins: + mixin_a: + secrets: + A: a + mixin_b: + secrets: + B: b + entities: + my_function: + type: "function" + stage: foo_stage + returns: string + handler: foo.baz + signature: "" + artifacts: [] + secrets: {override} + entity: value + meta: + use_mixins: ["mixin_a", "mixin_b"] + """ + + with named_temporary_file(suffix=".yml") as p: + p.write_text(dedent(text)) + result = load_project([p]) + + pd = result.project_definition + assert pd.entities["my_function"].secrets == expected