Skip to content

Commit

Permalink
Add override tag support in mixins
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-turbaszek committed Sep 12, 2024
1 parent fd1bb79 commit c26e554
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 6 deletions.
1 change: 0 additions & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
-->

# Unreleased version

## Backward incompatibility
* Dropped support for Python below 3.10 version.
* `snow object stage` commands are removed in favour of `snow stage`.
Expand Down
14 changes: 14 additions & 0 deletions src/snowflake/cli/api/project/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB
from snowflake.cli.api.project.schemas.project_definition import (
ProjectProperties,
YamlOverride,
)
from snowflake.cli.api.project.util import (
append_to_identifier,
Expand All @@ -37,6 +38,7 @@
)
from snowflake.cli.api.utils.dict_utils import deep_merge_dicts
from snowflake.cli.api.utils.types import Context, Definition
from yaml import MappingNode, SequenceNode

DEFAULT_USERNAME = "unknown_user"

Expand All @@ -50,6 +52,10 @@ def _get_merged_definitions(paths: List[Path]) -> Optional[Definition]:
loader.add_constructor(
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, _no_duplicates_constructor
)
loader.add_constructor(
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, _no_duplicates_constructor
)
loader.add_constructor("!override", _override_tag)

with spaths[0].open("r", read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB) as base_yml:
definition = yaml.load(base_yml.read(), Loader=loader) or {}
Expand Down Expand Up @@ -113,3 +119,11 @@ def _no_duplicates_constructor(loader, node, deep=False):
)
mapping[key] = value
return loader.construct_mapping(node, deep)


def _override_tag(loader, node, deep=False):
if isinstance(node, SequenceNode):
return YamlOverride(data=loader.construct_sequence(node, deep))
if isinstance(node, MappingNode):
return YamlOverride(data=loader.construct_mapping(node, deep))
return node.value
25 changes: 21 additions & 4 deletions src/snowflake/cli/api/project/schemas/project_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ class ProjectProperties:
project_context: Context


@dataclass
class YamlOverride:
data: dict | list


class _ProjectDefinitionBase(UpdatableModel):
def __init__(self, *args, **kwargs):
try:
Expand Down Expand Up @@ -196,7 +201,11 @@ def apply_mixins(cls, data: Dict) -> Dict:

@classmethod
def _merge_mixins_with_entity(
cls, entity_id: str, entity: dict, entity_mixins_names: list, mixin_defs: dict
cls,
entity_id: str,
entity: dict,
entity_mixins_names: list,
mixin_defs: dict,
) -> dict:
# Validate mixins
for mixin_name in entity_mixins_names:
Expand All @@ -215,8 +224,10 @@ def _merge_mixins_with_entity(
)

entity_value = entity.get(key)
if entity_value is not None and not isinstance(
entity_value, type(override_value)
if (
entity_value is not None
and not isinstance(entity_value, YamlOverride)
and not isinstance(entity_value, type(override_value))
):
raise ValueError(
f"Value from mixins for property {key} is of type '{type(override_value).__name__}' "
Expand All @@ -231,13 +242,16 @@ def _merge_mixins_with_entity(
def _merge_data(
cls,
left: dict | list | scalar | None,
right: dict | list | scalar | None,
right: dict | list | scalar | None | YamlOverride,
):
"""
Merges right data into left. Right and left is expected to be of the same type, if not right is returned.
If left is sequence then missing elements from right are appended.
If left is dictionary then we update it with data from right. The update is done recursively key by key.
"""
if isinstance(right, YamlOverride):
return right.data

if left is None:
return right

Expand Down Expand Up @@ -300,6 +314,9 @@ def get_allowed_fields_for_entity(entity: Dict[str, Any]) -> List[str]:
Get the allowed fields for the given entity.
"""
entity_type = entity.get("type")
if entity_type is None:
raise ValueError("Entity is missing type declaration.")

if entity_type not in v2_entity_model_types_map:
return []

Expand Down
4 changes: 3 additions & 1 deletion src/snowflake/cli/api/utils/definition_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,9 @@ def _add_defaults_to_definition(original_definition: Definition) -> Definition:
with context({"skip_validation_on_templates": True}):
# pass a flag to Pydantic to skip validation for templated scalars
# populate the defaults
project_definition = build_project_definition(**original_definition)
project_definition = build_project_definition(
**copy.deepcopy(original_definition)
)

definition_with_defaults = project_definition.model_dump(
exclude_none=True, warnings=False, by_alias=True
Expand Down
74 changes: 74 additions & 0 deletions tests/api/utils/test_override_tag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from textwrap import dedent

import pytest
from snowflake.cli.api.project.definition import load_project


@pytest.mark.parametrize(
"override, expected",
[("", ["A", "B", "entity_value"]), ("!override", ["entity_value"])],
)
def test_override_works_for_sequences(named_temporary_file, override, expected):
text = f"""\
definition_version: "2"
mixins:
mixin_a:
external_access_integrations:
- A
- B
entities:
my_function:
type: "function"
stage: foo_stage
returns: string
handler: foo.baz
signature: ""
artifacts: []
external_access_integrations: {override}
- entity_value
meta:
use_mixins: ["mixin_a"]
"""

with named_temporary_file(suffix=".yml") as p:
p.write_text(dedent(text))
result = load_project([p])

pd = result.project_definition
assert pd.entities["my_function"].external_access_integrations == expected


@pytest.mark.parametrize(
"override, expected",
[("", {"A": "a", "B": "b", "entity": "value"}), ("!override", {"entity": "value"})],
)
def test_override_works_for_mapping(named_temporary_file, override, expected):
text = f"""\
definition_version: "2"
mixins:
mixin_a:
secrets:
A: a
mixin_b:
secrets:
B: b
entities:
my_function:
type: "function"
stage: foo_stage
returns: string
handler: foo.baz
signature: ""
artifacts: []
secrets: {override}
entity: value
meta:
use_mixins: ["mixin_a", "mixin_b"]
"""

with named_temporary_file(suffix=".yml") as p:
p.write_text(dedent(text))
result = load_project([p])

pd = result.project_definition
assert pd.entities["my_function"].secrets == expected

0 comments on commit c26e554

Please sign in to comment.