From 5631c6a2e5e7c032a69ccb613a9fc459b400484e Mon Sep 17 00:00:00 2001 From: Mitchell Douglass Date: Wed, 12 Jun 2024 03:27:44 -0400 Subject: [PATCH] Support inherited fields for RunnerInput, other niceties (#1093) ## Pull Request Checklist - [ ] Fixes # - [x] Tests added - [ ] Documentation/examples added - [x] [Good commit messages](https://cbea.ms/git-commit/) and/or PR title ## Summary Currently, users cannot leverage inheritance when defining `RunnerInput` models that utilize `Annotated` fields. For example: ``` class ResourceIndex(RunnerInput): resource_name: Annotated[str, Parameter(name="resource-name")] resource_namespace: Annotated[str, Parameter(name="resource-namespace")] class GetStatusInput(ResourceIndex): wait: Annotated[bool, Parameter(name="wait")] timeout_s: Annotated[int, Parameter(name="timeout")] ``` This would fail at runtime with a KeyError at [this line](https://github.com/argoproj-labs/hera/blob/main/src/hera/workflows/_runner/script_annotations_util.py#L152), because apparently `cls.__annotations__` are not inherited by child classes. This PR enables the above inheritance structure by utilizing the existing [get_field_annotations](https://github.com/argoproj-labs/hera/blob/main/src/hera/shared/_pydantic.py#L69) instead of `my_runner_input.__annotations__`. #### Parameter Default Names This PR also changes logic so that `Input` field `Parameters` adopt the name of their field if no name is provided. For example, in ``` class MyInput(Input): my_int: Annotated[int, Parameter(description="this is my int")] ``` The parameter is converted to `Parameter(name="my_int", description="this is my int")`. This avoids an error from [here](https://github.com/argoproj-labs/hera/blob/main/src/hera/workflows/parameter.py#L30) which is thrown when trying to generate a workflow yaml. --------- Signed-off-by: Mitchell Douglass Co-authored-by: Sambhav Kothari --- .../_runner/script_annotations_util.py | 37 +++++---- src/hera/workflows/io/_io_mixins.py | 10 ++- tests/test_unit/test_io_mixins.py | 21 +++++ .../test_unit/test_script_annotations_util.py | 76 +++++++++++++++++++ 4 files changed, 128 insertions(+), 16 deletions(-) create mode 100644 tests/test_unit/test_io_mixins.py diff --git a/src/hera/workflows/_runner/script_annotations_util.py b/src/hera/workflows/_runner/script_annotations_util.py index 14c6eaa55..34ae67e62 100644 --- a/src/hera/workflows/_runner/script_annotations_util.py +++ b/src/hera/workflows/_runner/script_annotations_util.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast -from hera.shared._pydantic import BaseModel, get_fields +from hera.shared._pydantic import BaseModel, get_field_annotations, get_fields from hera.shared.serialization import serialize from hera.workflows import Artifact, Parameter from hera.workflows.artifact import ArtifactLoader @@ -145,24 +145,33 @@ def load_parameter_value(value: str, value_type: type) -> Any: except json.JSONDecodeError: return value + runner_input_annotations = get_field_annotations(runner_input_class) + def map_field( field: str, kwargs: Dict[str, str], ) -> Any: - annotation = runner_input_class.__annotations__[field] + annotation = runner_input_annotations.get(field) + assert annotation is not None, "RunnerInput fields must be type-annotated" if get_origin(annotation) is Annotated: - meta_annotation = get_args(annotation)[1] - if isinstance(meta_annotation, Parameter): - assert not meta_annotation.output - return load_parameter_value( - _get_annotated_input_param_value(field, meta_annotation, kwargs), - get_args(annotation)[0], - ) - - if isinstance(meta_annotation, Artifact): - return get_annotated_artifact_value(meta_annotation) - - return load_parameter_value(kwargs[field], annotation) + # my_field: Annotated[int, Parameter(...)] + ann_type = get_args(annotation)[0] + param_or_artifact = get_args(annotation)[1] + else: + # my_field: int + ann_type = annotation + param_or_artifact = None + + if isinstance(param_or_artifact, Parameter): + assert not param_or_artifact.output + return load_parameter_value( + _get_annotated_input_param_value(field, param_or_artifact, kwargs), + ann_type, + ) + elif isinstance(param_or_artifact, Artifact): + return get_annotated_artifact_value(param_or_artifact) + else: + return load_parameter_value(kwargs[field], ann_type) for field in get_fields(runner_input_class): input_model_obj[field] = map_field(field, kwargs) diff --git a/src/hera/workflows/io/_io_mixins.py b/src/hera/workflows/io/_io_mixins.py index f54e5948e..bfbc3e339 100644 --- a/src/hera/workflows/io/_io_mixins.py +++ b/src/hera/workflows/io/_io_mixins.py @@ -64,8 +64,11 @@ def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Paramet for field, field_info in get_fields(cls).items(): if get_origin(annotations[field]) is Annotated: - param = get_args(annotations[field])[1] + # Copy so as to not modify the Input fields themselves + param = get_args(annotations[field])[1].copy() if isinstance(param, Parameter): + if param.name is None: + param.name = field if object_override: param.default = serialize(getattr(object_override, field)) elif field_info.default is not None and field_info.default != PydanticUndefined: # type: ignore @@ -91,8 +94,11 @@ def _get_artifacts(cls) -> List[Artifact]: for field in get_fields(cls): if get_origin(annotations[field]) is Annotated: - artifact = get_args(annotations[field])[1] + # Copy so as to not modify the Input fields themselves + artifact = get_args(annotations[field])[1].copy() if isinstance(artifact, Artifact): + if artifact.name is None: + artifact.name = field if artifact.path is None: artifact.path = artifact._get_default_inputs_path() artifacts.append(artifact) diff --git a/tests/test_unit/test_io_mixins.py b/tests/test_unit/test_io_mixins.py new file mode 100644 index 000000000..832d37ea3 --- /dev/null +++ b/tests/test_unit/test_io_mixins.py @@ -0,0 +1,21 @@ +from hera.workflows.io import Input +from hera.workflows.parameter import Parameter + +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + + +def test_input_mixin_get_parameters(): + class Foo(Input): + foo: Annotated[int, Parameter(name="foo")] + + assert Foo._get_parameters() == [Parameter(name="foo")] + + +def test_input_mixin_get_parameters_default_name(): + class Foo(Input): + foo: Annotated[int, Parameter(description="a foo")] + + assert Foo._get_parameters() == [Parameter(name="foo", description="a foo")] diff --git a/tests/test_unit/test_script_annotations_util.py b/tests/test_unit/test_script_annotations_util.py index 1191d4626..b26126b8b 100644 --- a/tests/test_unit/test_script_annotations_util.py +++ b/tests/test_unit/test_script_annotations_util.py @@ -16,6 +16,11 @@ from hera.workflows.models import ValueFrom from hera.workflows.parameter import Parameter +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + @pytest.mark.parametrize( "destination,expected_path", @@ -174,3 +179,74 @@ class MyInput(Input): a_dict_str=json.dumps({"key": "value"}), a_list_str=json.dumps([1, 2, 3]), ) + + +def test_map_runner_input_annotated_parameter(): + """Test annotated Parameter.""" + + class Foo(Input): + foo: Annotated[str, Parameter(name="bar")] + + kwargs = {"foo": "hello"} + assert map_runner_input(Foo, kwargs) == Foo(foo="hello") + kwargs = {"bar": "there"} + assert map_runner_input(Foo, kwargs) == Foo(foo="there") + + +def test_map_runner_input_output_parameter_disallowed(): + """Test annotated output Parameter is not allowed.""" + + class Foo(Input): + foo: Annotated[str, Parameter(name="bar", output=True)] + + with pytest.raises(AssertionError): + kwargs = {"foo": "hello"} + map_runner_input(Foo, kwargs) + + +def test_map_runner_input_annotated_artifact(tmp_path): + """Test annotated Artifact.""" + + foo_path = tmp_path / "foo" + foo_path.write_text("hello there") + + class Foo(Input): + foo: Annotated[str, Artifact(name="bar", path=str(foo_path), loader=ArtifactLoader.file)] + + assert map_runner_input(Foo, {}) == Foo(foo="hello there") + + +def test_map_runner_input_annotated_inheritance(): + """Test model inheritance with Annotated fields.""" + + class Foo(Input): + foo: Annotated[str, Parameter(name="foo")] + + class FooBar(Foo): + bar: Annotated[str, Parameter(name="bar")] + + kwargs = {"foo": "hello", "bar": "there"} + assert map_runner_input(FooBar, kwargs) == FooBar(**kwargs) + + +def test_map_runner_input_annotated_inheritance_override(): + """Test model inheritance with Annotated fields.""" + + class Foo(Input): + foo: Annotated[str, Parameter(name="foo")] + + class FooBar(Foo): + foo: Annotated[str, Parameter(name="bar")] + + kwargs = {"bar": "hello"} + assert map_runner_input(FooBar, kwargs) == FooBar(foo="hello") + + +def test_map_runner_input_annotated_parameter_noname(): + """Test Annotated Parameter with no name.""" + + class Foo(Input): + foo: Annotated[str, Parameter(description="a parameter")] + + kwargs = {"foo": "hello"} + assert map_runner_input(Foo, kwargs) == Foo(foo="hello")