Skip to content

Commit

Permalink
Add str tests, fix for falsey boolean coercion
Browse files Browse the repository at this point in the history
Signed-off-by: Elliot Gunton <egunton@bloomberg.net>
  • Loading branch information
elliotgunton committed Feb 26, 2024
1 parent c6c9d45 commit cb8417c
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 6 deletions.
4 changes: 3 additions & 1 deletion src/hera/workflows/io/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def _get_parameters(cls, object_override: "Optional[RunnerInput]" = None) -> Lis
else:
# Create a Parameter from basic type annotations
default = getattr(object_override, field) if object_override else field_info.default
parameters.append(Parameter(name=field, default=default or MISSING))
if default is None:
default = MISSING
parameters.append(Parameter(name=field, default=default))

return parameters

Expand Down
5 changes: 2 additions & 3 deletions src/hera/workflows/io/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from collections import ChainMap
from typing import Any, List, Optional, Union


from hera.shared.serialization import MISSING, serialize
from hera.workflows.artifact import Artifact
from hera.workflows.parameter import Parameter
Expand Down Expand Up @@ -40,7 +39,7 @@ def _get_parameters(cls, object_override: "Optional[RunnerInput]" = None) -> Lis
parameters = []
annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()}

for field, field_info in cls.model_fields.items():
for field, field_info in cls.model_fields.items(): # type: ignore
if get_origin(annotations[field]) is Annotated:
if isinstance(get_args(annotations[field])[1], Parameter):
param = get_args(annotations[field])[1]
Expand All @@ -56,7 +55,7 @@ def _get_parameters(cls, object_override: "Optional[RunnerInput]" = None) -> Lis
if default == PydanticUndefined:
default = MISSING

parameters.append(Parameter(name=field, default=default or MISSING))
parameters.append(Parameter(name=field, default=default))

return parameters

Expand Down
5 changes: 4 additions & 1 deletion tests/script_annotations/pydantic_io_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def pydantic_io_with_defaults(

@script(constructor="runner")
def pydantic_io_within_generic(
my_inputs: List[ParamOnlyInput] = [ParamOnlyInput(no_default_param=1), ParamOnlyInput(my_int=2, no_default_param=2)],
my_inputs: List[ParamOnlyInput] = [
ParamOnlyInput(no_default_param=1),
ParamOnlyInput(my_int=2, no_default_param=2),
],
) -> ParamOnlyOutput:
pass

Expand Down
41 changes: 41 additions & 0 deletions tests/script_annotations/pydantic_io_v1_strs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from pathlib import Path

from hera.workflows import Parameter, Workflow, script

try:
from hera.workflows.io.v2 import ( # type: ignore
RunnerInput,
RunnerOutput,
)
except ImportError:
from hera.workflows.io.v1 import ( # type: ignore
RunnerInput,
RunnerOutput,
)

try:
from typing import Annotated # type: ignore
except ImportError:
from typing_extensions import Annotated # type: ignore


class ParamOnlyInput(RunnerInput):
my_str: str
my_empty_default_str: str = ""
my_annotated_str: Annotated[str, Parameter(name="alt-name")] = "hello world!"


class ParamOnlyOutput(RunnerOutput):
my_output_str: str = "my-default-str"
another_output: Annotated[Path, Parameter(name="second-output")]


@script(constructor="runner")
def pydantic_io_params(
my_input: ParamOnlyInput,
) -> ParamOnlyOutput:
pass


with Workflow(generate_name="pydantic-io-") as w:
pydantic_io_params()
5 changes: 4 additions & 1 deletion tests/script_annotations/pydantic_io_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def pydantic_io_with_defaults(

@script(constructor="runner")
def pydantic_io_within_generic(
my_inputs: List[ParamOnlyInput] = [ParamOnlyInput(no_default_param=1), ParamOnlyInput(my_int=2,no_default_param=2)],
my_inputs: List[ParamOnlyInput] = [
ParamOnlyInput(no_default_param=1),
ParamOnlyInput(my_int=2, no_default_param=2),
],
) -> ParamOnlyOutput:
pass

Expand Down
31 changes: 31 additions & 0 deletions tests/script_annotations/pydantic_io_v2_strs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pathlib import Path

from hera.workflows import Parameter, Workflow, script
from hera.workflows.io.v1 import RunnerInput, RunnerOutput

try:
from typing import Annotated # type: ignore
except ImportError:
from typing_extensions import Annotated # type: ignore


class ParamOnlyInput(RunnerInput):
my_str: str
my_empty_default_str: str = ""
my_annotated_str: Annotated[str, Parameter(name="alt-name")] = "hello world!"


class ParamOnlyOutput(RunnerOutput):
my_output_str: str = "my-default-str"
another_output: Annotated[Path, Parameter(name="second-output")]


@script(constructor="runner")
def pydantic_io_params(
my_input: ParamOnlyInput,
) -> ParamOnlyOutput:
pass


with Workflow(generate_name="pydantic-io-") as w:
pydantic_io_params()
53 changes: 53 additions & 0 deletions tests/test_script_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,59 @@ def test_script_pydantic_io(pydantic_mode, function_name, expected_input, expect
assert template["outputs"] == expected_output


@pytest.mark.parametrize(
"pydantic_mode",
[
1,
_PYDANTIC_VERSION,
],
)
@pytest.mark.parametrize(
"function_name,expected_input,expected_output",
[
pytest.param(
"pydantic_io_params",
{
"parameters": [
{"name": "my_str"},
{"name": "my_empty_default_str", "default": ""},
{"name": "alt-name", "default": "hello world!"},
]
},
{
"parameters": [
{"name": "my_output_str", "valueFrom": {"path": "/tmp/hera-outputs/parameters/my_output_str"}},
{"name": "second-output", "valueFrom": {"path": "/tmp/hera-outputs/parameters/second-output"}},
],
},
id="param-only-io",
),
],
)
def test_script_pydantic_io_strs(pydantic_mode, function_name, expected_input, expected_output, global_config_fixture):
"""Test that output annotations work correctly by asserting correct inputs and outputs on the built workflow."""
# GIVEN
global_config_fixture.experimental_features["script_annotations"] = True
global_config_fixture.experimental_features["script_pydantic_io"] = True
# Force a reload of the test module, as the runner performs "importlib.import_module", which
# may fetch a cached version
module_name = f"tests.script_annotations.pydantic_io_v{pydantic_mode}_strs"

module = importlib.import_module(module_name)
importlib.reload(module)
workflow = importlib.import_module(module.__name__).w

# WHEN
workflow_dict = workflow.to_dict()
assert workflow == Workflow.from_dict(workflow_dict)
assert workflow == Workflow.from_yaml(workflow.to_yaml())

# THEN
template = next(filter(lambda t: t["name"] == function_name.replace("_", "-"), workflow_dict["spec"]["templates"]))
assert template["inputs"] == expected_input
assert template["outputs"] == expected_output


def test_script_pydantic_invalid_outputs(global_config_fixture):
"""Test that output annotations work correctly by asserting correct inputs and outputs on the built workflow."""
# GIVEN
Expand Down

0 comments on commit cb8417c

Please sign in to comment.