Skip to content

Commit

Permalink
Fix RunnerInputs - no longer require default value
Browse files Browse the repository at this point in the history
Signed-off-by: Elliot Gunton <egunton@bloomberg.net>
  • Loading branch information
elliotgunton committed Feb 26, 2024
1 parent 338fe6e commit c6c9d45
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 26 deletions.
20 changes: 7 additions & 13 deletions src/hera/workflows/io/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, List, Optional, Union

from hera.shared._pydantic import BaseModel, get_fields
from hera.shared.serialization import serialize
from hera.shared.serialization import MISSING, serialize
from hera.workflows.artifact import Artifact
from hera.workflows.parameter import Parameter

Expand Down Expand Up @@ -33,27 +33,21 @@ def _get_parameters(cls, object_override: "Optional[RunnerInput]" = None) -> Lis
annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()}

fields = get_fields(cls)
for field in fields:
for field, field_info in fields.items():
if get_origin(annotations[field]) is Annotated:
if isinstance(get_args(annotations[field])[1], Parameter):
param = get_args(annotations[field])[1]
if object_override:
param.default = serialize(getattr(object_override, field))
elif fields[field].default:
elif field_info.default:
# Serialize the value (usually done in Parameter's validator)
param.default = serialize(fields[field].default)
param.default = serialize(field_info.default)
parameters.append(param)
else:
# Create a Parameter from basic type annotations
if object_override:
parameters.append(
Parameter(
name=field,
default=serialize(getattr(object_override, field)),
)
)
else:
parameters.append(Parameter(name=field, default=fields[field].default))
default = getattr(object_override, field) if object_override else field_info.default
parameters.append(Parameter(name=field, default=default or MISSING))

return parameters

@classmethod
Expand Down
20 changes: 12 additions & 8 deletions src/hera/workflows/io/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from collections import ChainMap
from typing import Any, List, Optional, Union

from hera.shared.serialization import serialize

from hera.shared.serialization import MISSING, serialize
from hera.workflows.artifact import Artifact
from hera.workflows.parameter import Parameter

Expand All @@ -23,6 +24,7 @@

if find_spec("pydantic.v1"):
from pydantic import BaseModel
from pydantic_core import PydanticUndefined

class RunnerInput(BaseModel):
"""Input model usable by the Hera Runner.
Expand All @@ -38,22 +40,24 @@ def _get_parameters(cls, object_override: "Optional[RunnerInput]" = None) -> Lis
parameters = []
annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()}

for field in cls.model_fields: # type: ignore
for field, field_info in cls.model_fields.items():
if get_origin(annotations[field]) is Annotated:
if isinstance(get_args(annotations[field])[1], Parameter):
param = get_args(annotations[field])[1]
if object_override:
param.default = serialize(getattr(object_override, field))
elif cls.model_fields[field].default: # type: ignore
elif field_info.default: # type: ignore
# Serialize the value (usually done in Parameter's validator)
param.default = serialize(cls.model_fields[field].default) # type: ignore
param.default = serialize(field_info.default) # type: ignore
parameters.append(param)
else:
# Create a Parameter from basic type annotations
if object_override:
parameters.append(Parameter(name=field, default=serialize(getattr(object_override, field))))
else:
parameters.append(Parameter(name=field, default=cls.model_fields[field].default)) # type: ignore
default = getattr(object_override, field) if object_override else field_info.default
if default == PydanticUndefined:
default = MISSING

parameters.append(Parameter(name=field, default=default or MISSING))

return parameters

@classmethod
Expand Down
5 changes: 3 additions & 2 deletions tests/script_annotations/pydantic_io_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
class ParamOnlyInput(RunnerInput):
my_int: int = 1
my_annotated_int: Annotated[int, Parameter(name="another-int", description="my desc")] = 42
no_default_param: int


class ParamOnlyOutput(RunnerOutput):
Expand Down Expand Up @@ -64,14 +65,14 @@ def pydantic_io(

@script(constructor="runner")
def pydantic_io_with_defaults(
my_input: ParamOnlyInput = ParamOnlyInput(my_int=2, my_annotated_int=24),
my_input: ParamOnlyInput = ParamOnlyInput(my_int=2, my_annotated_int=24, no_default_param=1),
) -> ParamOnlyOutput:
pass


@script(constructor="runner")
def pydantic_io_within_generic(
my_inputs: List[ParamOnlyInput] = [ParamOnlyInput(), ParamOnlyInput(my_int=2)],
my_inputs: List[ParamOnlyInput] = [ParamOnlyInput(no_default_param=1), ParamOnlyInput(my_int=2, no_default_param=2)],
) -> ParamOnlyOutput:
pass

Expand Down
5 changes: 3 additions & 2 deletions tests/script_annotations/pydantic_io_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
class ParamOnlyInput(RunnerInput):
my_int: int = 1
my_annotated_int: Annotated[int, Parameter(name="another-int", description="my desc")] = 42
no_default_param: int


class ParamOnlyOutput(RunnerOutput):
Expand Down Expand Up @@ -74,14 +75,14 @@ def pydantic_io(

@script(constructor="runner")
def pydantic_io_with_defaults(
my_input: ParamOnlyInput = ParamOnlyInput(my_int=2, my_annotated_int=24),
my_input: ParamOnlyInput = ParamOnlyInput(my_int=2, my_annotated_int=24, no_default_param=1),
) -> ParamOnlyOutput:
pass


@script(constructor="runner")
def pydantic_io_within_generic(
my_inputs: List[ParamOnlyInput] = [ParamOnlyInput(), ParamOnlyInput(my_int=2)],
my_inputs: List[ParamOnlyInput] = [ParamOnlyInput(no_default_param=1), ParamOnlyInput(my_int=2,no_default_param=2)],
) -> ParamOnlyOutput:
pass

Expand Down
4 changes: 3 additions & 1 deletion tests/test_script_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def test_configmap(global_config_fixture):
"parameters": [
{"name": "my_int", "default": "1"},
{"name": "another-int", "default": "42", "description": "my desc"},
{"name": "no_default_param"},
]
},
{
Expand Down Expand Up @@ -247,6 +248,7 @@ def test_configmap(global_config_fixture):
"parameters": [
{"name": "my_int", "default": "2"},
{"name": "another-int", "default": "24", "description": "my desc"},
{"name": "no_default_param", "default": "1"},
],
},
{
Expand All @@ -263,7 +265,7 @@ def test_configmap(global_config_fixture):
"parameters": [
{
"name": "my_inputs",
"default": '[{"my_int": 1, "my_annotated_int": 42}, {"my_int": 2, "my_annotated_int": 42}]',
"default": '[{"my_int": 1, "my_annotated_int": 42, "no_default_param": 1}, {"my_int": 2, "my_annotated_int": 42, "no_default_param": 2}]',
},
],
},
Expand Down

0 comments on commit c6c9d45

Please sign in to comment.