Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deserialization for str unions #1239

Merged
merged 7 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/hera/shared/_type_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,24 @@
return t


def origin_type_issubclass(cls: Any, type_: type) -> bool:
"""Return True if cls can be considered as a subclass of type_."""
unwrapped_type = unwrap_annotation(cls)
def origin_type_issubtype(annotation: Any, type_: Union[type, Tuple[type, ...]]) -> bool:
"""Return True if annotation is a subtype of type_."""
alicederyn marked this conversation as resolved.
Show resolved Hide resolved
unwrapped_type = unwrap_annotation(annotation)

Check warning on line 97 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L97

Added line #L97 was not covered by tests
origin_type = get_unsubscripted_type(unwrapped_type)
if origin_type is Union or origin_type is UnionType:
return any(origin_type_issubclass(arg, type_) for arg in get_args(cls))
return all(origin_type_issubtype(arg, type_) for arg in get_args(annotation))
return issubclass(origin_type, type_)


def origin_type_issupertype(annotation: Any, type_: type) -> bool:
"""Return True if annotation is a supertype of type_."""
unwrapped_type = unwrap_annotation(annotation)
origin_type = get_unsubscripted_type(unwrapped_type)

Check warning on line 107 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L106-L107

Added lines #L106 - L107 were not covered by tests
if origin_type is Union or origin_type is UnionType:
return any(origin_type_issupertype(arg, type_) for arg in get_args(annotation))
return issubclass(type_, origin_type)

Check warning on line 110 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L110

Added line #L110 was not covered by tests


def is_subscripted(t: Any) -> bool:
"""Check if given type is subscripted, i.e. a typing object of the form X[Y, Z, ...].

Expand Down
10 changes: 8 additions & 2 deletions src/hera/workflows/_runner/script_annotations_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,21 @@
import inspect
import json
import os
import sys
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast

if sys.version_info >= (3, 10):
from types import NoneType
else:
NoneType = type(None)

from hera.shared._pydantic import BaseModel, get_field_annotations, get_fields
from hera.shared._type_util import (
get_unsubscripted_type,
get_workflow_annotation,
is_subscripted,
origin_type_issubclass,
origin_type_issubtype,
unwrap_annotation,
)
from hera.shared.serialization import serialize
Expand Down Expand Up @@ -138,7 +144,7 @@ def map_runner_input(
input_model_obj = {}

def load_parameter_value(value: str, value_type: type) -> Any:
if origin_type_issubclass(value_type, str):
if origin_type_issubtype(value_type, (str, NoneType)):
return value

try:
Expand Down
10 changes: 8 additions & 2 deletions src/hera/workflows/_runner/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@
import inspect
import json
import os
import sys
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, cast

if sys.version_info >= (3, 10):
from types import NoneType
else:
NoneType = type(None)

from hera.shared._pydantic import _PYDANTIC_VERSION
from hera.shared._type_util import (
get_workflow_annotation,
origin_type_issubclass,
origin_type_issubtype,
unwrap_annotation,
)
from hera.shared.serialization import serialize
Expand Down Expand Up @@ -125,7 +131,7 @@
def _is_str_kwarg_of(key: str, f: Callable) -> bool:
"""Check if param `key` of function `f` has a type annotation that can be interpreted as a subclass of str."""
if func_param_annotation := _get_function_param_annotation(key, f):
return origin_type_issubclass(func_param_annotation, str)
return origin_type_issubtype(func_param_annotation, (str, NoneType))

Check warning on line 134 in src/hera/workflows/_runner/util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_runner/util.py#L134

Added line #L134 was not covered by tests
Comment on lines -128 to +134
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Thanks

return False


Expand Down
6 changes: 4 additions & 2 deletions src/hera/workflows/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
_flag_enabled,
)
from hera.shared._pydantic import _PYDANTIC_VERSION, root_validator, validator
from hera.shared._type_util import get_workflow_annotation, is_subscripted, origin_type_issubclass
from hera.shared._type_util import get_workflow_annotation, is_subscripted, origin_type_issupertype
from hera.shared.serialization import serialize
from hera.workflows._context import _context
from hera.workflows._meta_mixins import CallableTemplateMixin
Expand Down Expand Up @@ -540,7 +540,9 @@ class will be used as inputs, rather than the class itself.
else:
default = MISSING

if origin_type_issubclass(func_param.annotation, NoneType) and (default is MISSING or default is not None):
if origin_type_issupertype(func_param.annotation, NoneType) and (
default is MISSING or default is not None
):
raise ValueError(f"Optional parameter '{func_param.name}' must have a default value of None.")

parameters.append(Parameter(name=func_param.name, default=default))
Expand Down
7 changes: 6 additions & 1 deletion tests/script_runner/parameter_inputs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, List
from typing import Any, List, Union

try:
from typing import Annotated
Expand Down Expand Up @@ -76,6 +76,11 @@ def no_type_parameter(my_anything) -> Any:
return my_anything


@script()
def str_or_int_parameter(my_str_or_int: Union[str, int]) -> str:
return f"type given: {type(my_str_or_int).__name__}"


@script()
def str_parameter_expects_jsonstr_dict(my_json_str: str) -> dict:
return json.loads(my_json_str)
Expand Down
18 changes: 18 additions & 0 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,24 @@
{},
id="no-type-dict",
),
*(
[
pytest.param(
"tests.script_runner.parameter_inputs:str_or_int_parameter",
[{"name": "my_str_or_int", "value": "hi there"}],
"type given: str",
id="str-or-int-given-str",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_or_int_parameter",
[{"name": "my_str_or_int", "value": "3"}],
"type given: int",
id="str-or-int-given-int",
),
]
if _PYDANTIC_VERSION > 1
alicederyn marked this conversation as resolved.
Show resolved Hide resolved
alicederyn marked this conversation as resolved.
Show resolved Hide resolved
else []
),
pytest.param(
"tests.script_runner.parameter_inputs:str_parameter_expects_jsonstr_dict",
[{"name": "my_json_str", "value": json.dumps({"my": "dict"})}],
Expand Down
48 changes: 36 additions & 12 deletions tests/test_unit/test_shared_type_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import sys
from typing import List, Optional, Union

if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated
if sys.version_info >= (3, 10):
from types import NoneType
else:
NoneType = type(None)

import pytest
from annotated_types import Gt

Expand All @@ -8,16 +18,12 @@
get_unsubscripted_type,
get_workflow_annotation,
is_annotated,
origin_type_issubclass,
origin_type_issubtype,
origin_type_issupertype,
unwrap_annotation,
)
from hera.workflows import Artifact, Parameter

try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated


@pytest.mark.parametrize("annotation, expected", [[Annotated[str, "some metadata"], True], [str, False]])
def test_is_annotated(annotation, expected):
Expand Down Expand Up @@ -104,11 +110,29 @@ def test_get_unsubscripted_type(annotation, expected):
@pytest.mark.parametrize(
"annotation, target, expected",
[
[List[str], str, False],
[Optional[str], str, True],
[str, str, True],
[Union[int, str], int, True],
pytest.param(List[str], str, False, id="list-str-not-subtype-of-str"),
pytest.param(Optional[str], str, False, id="optional-str-not-subtype-of-str"),
pytest.param(str, str, True, id="str-is-subtype-of-str"),
pytest.param(Union[int, str], int, False, id="union-int-str-not-subtype-of-str"),
pytest.param(Optional[str], (str, NoneType), True, id="optional-str-is-subtype-of-optional-str"),
pytest.param(str, (str, NoneType), True, id="str-is-subtype-of-optional-str"),
pytest.param(Union[int, str], (str, NoneType), False, id="union-int-str-not-subtype-of-optional-str"),
],
)
def test_origin_type_issubtype(annotation, target, expected):
assert origin_type_issubtype(annotation, target) is expected


@pytest.mark.parametrize(
"annotation, target, expected",
[
pytest.param(List[str], str, False, id="list-str-not-supertype-of-str"),
pytest.param(Optional[str], str, True, id="optional-str-is-supertype-of-str"),
pytest.param(str, str, True, id="str-is-supertype-of-str"),
pytest.param(Union[int, str], int, True, id="union-int-str-is-supertype-of-int"),
pytest.param(Optional[str], NoneType, True, id="optional-str-is-supertype-of-nonetype"),
pytest.param(str, NoneType, False, id="str-not-supertype-of-nonetype"),
],
)
def test_origin_type_issubclass(annotation, target, expected):
assert origin_type_issubclass(annotation, target) is expected
def test_origin_type_issupertype(annotation, target, expected):
assert origin_type_issupertype(annotation, target) is expected
Loading