From e22dd8d579d199518d9c05037b53563f7c5eee35 Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Thu, 5 Sep 2024 12:35:27 +0100 Subject: [PATCH] Add tests for the new Pydantic I/O syntax Signed-off-by: Alice Purcell --- tests/test_pydantic_io_syntax.py | 62 ++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 tests/test_pydantic_io_syntax.py diff --git a/tests/test_pydantic_io_syntax.py b/tests/test_pydantic_io_syntax.py new file mode 100644 index 000000000..cf9e16735 --- /dev/null +++ b/tests/test_pydantic_io_syntax.py @@ -0,0 +1,62 @@ +import pytest + +from hera.shared._global_config import _SCRIPT_PYDANTIC_IO_FLAG +from hera.workflows import Input, Output, Steps, Workflow, script + + +class IntInput(Input): + field: int + + +class IntOutput(Output): + field: int + + +@pytest.fixture(autouse=True) +def enable_pydantic_io(global_config_fixture): + global_config_fixture.experimental_features[_SCRIPT_PYDANTIC_IO_FLAG] = True + + +def test_output_field_contains_argo_template(global_config_fixture): + @script() + def triple(input: IntInput) -> IntOutput: + return IntOutput(field=input.field * 3) + + with Workflow(name="foo"): + with Steps(name="bar"): + result = triple(IntInput(field=5)).field + + assert result == "{{steps.triple.outputs.parameters.field}}" + + +def test_script_can_return_none(): + @script() + def print_field(input: IntInput) -> None: + print(input.field) + + with Workflow(name="foo"): + with Steps(name="bar"): + result = print_field(IntInput(field=5)) + + assert result is None + + +def test_invalid_pydantic_io_outside_of_context(): + @script() + def triple(input: IntInput) -> IntOutput: + return IntOutput(field=input.field * 3) + + with Workflow(name="foo"): + with pytest.raises(SyntaxError, match="Cannot use Pydantic I/O outside of a .* context"): + triple(IntInput(field=5)) + + +def test_invalid_non_pydantic_return_type(): + @script() + def triple(input: IntInput) -> int: + return input.field * 3 + + with Workflow(name="foo"): + with Steps(name="bar"): + with pytest.raises(SyntaxError, match="Cannot use Pydantic input type without a Pydantic output type"): + triple(IntInput(field=5))