Skip to content

Commit

Permalink
Support various output tuple annotation (#1186)
Browse files Browse the repository at this point in the history
**Pull Request Checklist**
- [x] Fixes #1167 
- [x] Tests added
- [x] Documentation/examples added; bug fix
- [x] [Good commit messages](https://cbea.ms/git-commit/) and/or PR
title

**Description of PR**
Currently, hera doesn't support un-annotated tuple output type, and
there's no guard for partially annotated type.

So updating hera like as following.
- Support un-annotated tuple output.
- Raise the error when the output tuple is partially annotated.

---------

Signed-off-by: Ukjae Jeong <jeongukjae@gmail.com>
Signed-off-by: Ukjae Jeong <JeongUkJae@gmail.com>
Co-authored-by: Elliot Gunton <elliotgunton@gmail.com>
Co-authored-by: Alice <alicederyn@gmail.com>
  • Loading branch information
3 people authored Sep 5, 2024
1 parent 2eec4c2 commit e7f38fa
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 4 deletions.
15 changes: 13 additions & 2 deletions src/hera/workflows/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,19 @@ def _extract_return_annotation_output(source: Callable) -> List:
if get_workflow_annotation(return_annotation):
output.append(annotation_args)
elif origin_type is tuple:
for annotated_type in annotation_args:
output.append(get_args(annotated_type))
workflow_args = [
get_args(annotated_type) for annotated_type in annotation_args if get_workflow_annotation(annotated_type)
]

# If all tuple elements are annotated as Parameter/Artifact
if len(workflow_args) == len(annotation_args):
output.extend(workflow_args)
# Only some tuple elements are annotated as Parameter/Artifact
elif workflow_args:
raise ValueError(
f"Function '{source.__name__}' output has partially annotated tuple return type. "
"Tuple elements must be all Annotated as Parameter/Artifact, or contain no Parameter/Artifact annotations for a raw tuple return type."
)
elif (
origin_type is None
and isinstance(return_annotation, type)
Expand Down
19 changes: 17 additions & 2 deletions tests/script_runner/parameter_with_complex_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import sys
from typing import Optional, Union
from typing import Optional, Tuple, Union

if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated

from hera.shared import global_config
from hera.workflows import script
from hera.workflows import Parameter, script

global_config.experimental_features["script_annotations"] = True

Expand Down Expand Up @@ -36,3 +41,13 @@ def optional_int_parameter(my_int: Optional[int] = None) -> Optional[int]:
@script(constructor="runner")
def union_parameter(my_param: Union[str, int] = None) -> Union[str, int]:
return my_param


@script(constructor="runner")
def fn_with_output_tuple(my_string: str) -> Tuple[str, str]:
return my_string, my_string


@script(constructor="runner")
def fn_with_output_tuple_partially_annotated(my_string: str) -> Tuple[str, Annotated[str, Parameter(name="sample")]]:
return my_string, my_string
22 changes: 22 additions & 0 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,11 @@ def test_script_optional_parameter(
[{"name": "my_param", "value": 123}],
"123",
],
[
"tests.script_runner.parameter_with_complex_types:fn_with_output_tuple",
[{"name": "my_string", "value": "123"}],
'["123", "123"]',
],
],
)
def test_script_with_complex_types(
Expand All @@ -1111,3 +1116,20 @@ def test_script_with_complex_types(

# THEN
assert serialize(output) == expected_output


def test_script_partially_annotated_tuple_should_raise_an_error(monkeypatch: pytest.MonkeyPatch):
# GIVEN
monkeypatch.setenv("hera__script_annotations", "")
entrypoint = "tests.script_runner.parameter_with_complex_types:fn_with_output_tuple_partially_annotated"
kwargs_list = [{"name": "my_string", "value": "123"}]

# WHEN/THEN
with pytest.raises(
ValueError,
match=(
"Function 'fn_with_output_tuple_partially_annotated' output has partially annotated tuple return type. "
"Tuple elements must be all Annotated as Parameter/Artifact, or contain no Parameter/Artifact annotations for a raw tuple return type."
),
):
_runner(entrypoint, kwargs_list)

0 comments on commit e7f38fa

Please sign in to comment.