diff --git a/src/hera/workflows/io/v1.py b/src/hera/workflows/io/v1.py index 9699dbed8..cd75d182f 100644 --- a/src/hera/workflows/io/v1.py +++ b/src/hera/workflows/io/v1.py @@ -3,7 +3,7 @@ from typing import Any, List, Optional, Union from hera.shared._pydantic import BaseModel, get_fields -from hera.shared.serialization import serialize +from hera.shared.serialization import MISSING, serialize from hera.workflows.artifact import Artifact from hera.workflows.parameter import Parameter @@ -33,27 +33,21 @@ def _get_parameters(cls, object_override: "Optional[RunnerInput]" = None) -> Lis annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()} fields = get_fields(cls) - for field in fields: + for field, field_info in fields.items(): if get_origin(annotations[field]) is Annotated: if isinstance(get_args(annotations[field])[1], Parameter): param = get_args(annotations[field])[1] if object_override: param.default = serialize(getattr(object_override, field)) - elif fields[field].default: + elif field_info.default: # Serialize the value (usually done in Parameter's validator) - param.default = serialize(fields[field].default) + param.default = serialize(field_info.default) parameters.append(param) else: # Create a Parameter from basic type annotations - if object_override: - parameters.append( - Parameter( - name=field, - default=serialize(getattr(object_override, field)), - ) - ) - else: - parameters.append(Parameter(name=field, default=fields[field].default)) + default = getattr(object_override, field) if object_override else field_info.default + parameters.append(Parameter(name=field, default=default or MISSING)) + return parameters @classmethod diff --git a/src/hera/workflows/io/v2.py b/src/hera/workflows/io/v2.py index 6512809cc..a8904c2a1 100644 --- a/src/hera/workflows/io/v2.py +++ b/src/hera/workflows/io/v2.py @@ -5,7 +5,8 @@ from collections import ChainMap from typing import Any, List, Optional, Union -from hera.shared.serialization import serialize + +from hera.shared.serialization import MISSING, serialize from hera.workflows.artifact import Artifact from hera.workflows.parameter import Parameter @@ -23,6 +24,7 @@ if find_spec("pydantic.v1"): from pydantic import BaseModel + from pydantic_core import PydanticUndefined class RunnerInput(BaseModel): """Input model usable by the Hera Runner. @@ -38,22 +40,24 @@ def _get_parameters(cls, object_override: "Optional[RunnerInput]" = None) -> Lis parameters = [] annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()} - for field in cls.model_fields: # type: ignore + for field, field_info in cls.model_fields.items(): if get_origin(annotations[field]) is Annotated: if isinstance(get_args(annotations[field])[1], Parameter): param = get_args(annotations[field])[1] if object_override: param.default = serialize(getattr(object_override, field)) - elif cls.model_fields[field].default: # type: ignore + elif field_info.default: # type: ignore # Serialize the value (usually done in Parameter's validator) - param.default = serialize(cls.model_fields[field].default) # type: ignore + param.default = serialize(field_info.default) # type: ignore parameters.append(param) else: # Create a Parameter from basic type annotations - if object_override: - parameters.append(Parameter(name=field, default=serialize(getattr(object_override, field)))) - else: - parameters.append(Parameter(name=field, default=cls.model_fields[field].default)) # type: ignore + default = getattr(object_override, field) if object_override else field_info.default + if default == PydanticUndefined: + default = MISSING + + parameters.append(Parameter(name=field, default=default or MISSING)) + return parameters @classmethod diff --git a/tests/script_annotations/pydantic_io_v1.py b/tests/script_annotations/pydantic_io_v1.py index ca2663a7f..e04fe9965 100644 --- a/tests/script_annotations/pydantic_io_v1.py +++ b/tests/script_annotations/pydantic_io_v1.py @@ -13,6 +13,7 @@ class ParamOnlyInput(RunnerInput): my_int: int = 1 my_annotated_int: Annotated[int, Parameter(name="another-int", description="my desc")] = 42 + no_default_param: int class ParamOnlyOutput(RunnerOutput): @@ -64,14 +65,14 @@ def pydantic_io( @script(constructor="runner") def pydantic_io_with_defaults( - my_input: ParamOnlyInput = ParamOnlyInput(my_int=2, my_annotated_int=24), + my_input: ParamOnlyInput = ParamOnlyInput(my_int=2, my_annotated_int=24, no_default_param=1), ) -> ParamOnlyOutput: pass @script(constructor="runner") def pydantic_io_within_generic( - my_inputs: List[ParamOnlyInput] = [ParamOnlyInput(), ParamOnlyInput(my_int=2)], + my_inputs: List[ParamOnlyInput] = [ParamOnlyInput(no_default_param=1), ParamOnlyInput(my_int=2, no_default_param=2)], ) -> ParamOnlyOutput: pass diff --git a/tests/script_annotations/pydantic_io_v2.py b/tests/script_annotations/pydantic_io_v2.py index fb2643f74..d83e05ae8 100644 --- a/tests/script_annotations/pydantic_io_v2.py +++ b/tests/script_annotations/pydantic_io_v2.py @@ -23,6 +23,7 @@ class ParamOnlyInput(RunnerInput): my_int: int = 1 my_annotated_int: Annotated[int, Parameter(name="another-int", description="my desc")] = 42 + no_default_param: int class ParamOnlyOutput(RunnerOutput): @@ -74,14 +75,14 @@ def pydantic_io( @script(constructor="runner") def pydantic_io_with_defaults( - my_input: ParamOnlyInput = ParamOnlyInput(my_int=2, my_annotated_int=24), + my_input: ParamOnlyInput = ParamOnlyInput(my_int=2, my_annotated_int=24, no_default_param=1), ) -> ParamOnlyOutput: pass @script(constructor="runner") def pydantic_io_within_generic( - my_inputs: List[ParamOnlyInput] = [ParamOnlyInput(), ParamOnlyInput(my_int=2)], + my_inputs: List[ParamOnlyInput] = [ParamOnlyInput(no_default_param=1), ParamOnlyInput(my_int=2,no_default_param=2)], ) -> ParamOnlyOutput: pass diff --git a/tests/test_script_annotations.py b/tests/test_script_annotations.py index 05ab1ec2d..4f330f1f4 100644 --- a/tests/test_script_annotations.py +++ b/tests/test_script_annotations.py @@ -196,6 +196,7 @@ def test_configmap(global_config_fixture): "parameters": [ {"name": "my_int", "default": "1"}, {"name": "another-int", "default": "42", "description": "my desc"}, + {"name": "no_default_param"}, ] }, { @@ -247,6 +248,7 @@ def test_configmap(global_config_fixture): "parameters": [ {"name": "my_int", "default": "2"}, {"name": "another-int", "default": "24", "description": "my desc"}, + {"name": "no_default_param", "default": "1"}, ], }, { @@ -263,7 +265,7 @@ def test_configmap(global_config_fixture): "parameters": [ { "name": "my_inputs", - "default": '[{"my_int": 1, "my_annotated_int": 42}, {"my_int": 2, "my_annotated_int": 42}]', + "default": '[{"my_int": 1, "my_annotated_int": 42, "no_default_param": 1}, {"my_int": 2, "my_annotated_int": 42, "no_default_param": 2}]', }, ], },