Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deserialization for str unions #1239

Merged
merged 7 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions src/hera/shared/_type_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,26 @@
return t


def origin_type_issubclass(cls: Any, type_: type) -> bool:
"""Return True if cls can be considered as a subclass of type_."""
unwrapped_type = unwrap_annotation(cls)
def origin_type_issubtype(annotation: Any, type_: Union[type, Tuple[type, ...]]) -> bool:
"""Return True if annotation is a subtype of type_.

type_ may be a tuple of types, in which case return True if annotation is a subtype
of the union of the types in the tuple.
"""
unwrapped_type = unwrap_annotation(annotation)
origin_type = get_unsubscripted_type(unwrapped_type)

Check warning on line 102 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L101-L102

Added lines #L101 - L102 were not covered by tests
if origin_type is Union or origin_type is UnionType:
return all(origin_type_issubtype(arg, type_) for arg in get_args(unwrapped_type))
return isinstance(origin_type, type) and issubclass(origin_type, type_)

Check warning on line 105 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L105

Added line #L105 was not covered by tests


def origin_type_issupertype(annotation: Any, type_: type) -> bool:
"""Return True if annotation is a supertype of type_."""
unwrapped_type = unwrap_annotation(annotation)

Check warning on line 110 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L110

Added line #L110 was not covered by tests
origin_type = get_unsubscripted_type(unwrapped_type)
if origin_type is Union or origin_type is UnionType:
return any(origin_type_issubclass(arg, type_) for arg in get_args(cls))
return issubclass(origin_type, type_)
return any(origin_type_issupertype(arg, type_) for arg in get_args(unwrapped_type))
return isinstance(origin_type, type) and issubclass(type_, origin_type)

Check warning on line 114 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L114

Added line #L114 was not covered by tests


def is_subscripted(t: Any) -> bool:
Expand Down
10 changes: 8 additions & 2 deletions src/hera/workflows/_runner/script_annotations_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,21 @@
import inspect
import json
import os
import sys
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast

if sys.version_info >= (3, 10):
from types import NoneType
else:
NoneType = type(None)

from hera.shared._pydantic import BaseModel, get_field_annotations, get_fields
from hera.shared._type_util import (
get_unsubscripted_type,
get_workflow_annotation,
is_subscripted,
origin_type_issubclass,
origin_type_issubtype,
unwrap_annotation,
)
from hera.shared.serialization import serialize
Expand Down Expand Up @@ -138,7 +144,7 @@ def map_runner_input(
input_model_obj = {}

def load_parameter_value(value: str, value_type: type) -> Any:
if origin_type_issubclass(value_type, str):
if origin_type_issubtype(value_type, (str, NoneType)):
return value

try:
Expand Down
10 changes: 8 additions & 2 deletions src/hera/workflows/_runner/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@
import inspect
import json
import os
import sys
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, cast

if sys.version_info >= (3, 10):
from types import NoneType
else:
NoneType = type(None)

from hera.shared._pydantic import _PYDANTIC_VERSION
from hera.shared._type_util import (
get_workflow_annotation,
origin_type_issubclass,
origin_type_issubtype,
unwrap_annotation,
)
from hera.shared.serialization import serialize
Expand Down Expand Up @@ -125,7 +131,7 @@
def _is_str_kwarg_of(key: str, f: Callable) -> bool:
"""Check if param `key` of function `f` has a type annotation that can be interpreted as a subclass of str."""
if func_param_annotation := _get_function_param_annotation(key, f):
return origin_type_issubclass(func_param_annotation, str)
return origin_type_issubtype(func_param_annotation, (str, NoneType))

Check warning on line 134 in src/hera/workflows/_runner/util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_runner/util.py#L134

Added line #L134 was not covered by tests
Comment on lines -128 to +134
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Thanks

return False


Expand Down
6 changes: 4 additions & 2 deletions src/hera/workflows/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
_flag_enabled,
)
from hera.shared._pydantic import _PYDANTIC_VERSION, root_validator, validator
from hera.shared._type_util import get_workflow_annotation, is_subscripted, origin_type_issubclass
from hera.shared._type_util import get_workflow_annotation, is_subscripted, origin_type_issupertype
from hera.shared.serialization import serialize
from hera.workflows._context import _context
from hera.workflows._meta_mixins import CallableTemplateMixin
Expand Down Expand Up @@ -540,7 +540,9 @@ class will be used as inputs, rather than the class itself.
else:
default = MISSING

if origin_type_issubclass(func_param.annotation, NoneType) and (default is MISSING or default is not None):
if origin_type_issupertype(func_param.annotation, NoneType) and (
default is MISSING or default is not None
):
raise ValueError(f"Optional parameter '{func_param.name}' must have a default value of None.")

parameters.append(Parameter(name=func_param.name, default=default))
Expand Down
7 changes: 6 additions & 1 deletion tests/script_runner/parameter_inputs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, List
from typing import Any, List, Union

try:
from typing import Annotated
Expand Down Expand Up @@ -76,6 +76,11 @@ def no_type_parameter(my_anything) -> Any:
return my_anything


@script()
def str_or_int_parameter(my_str_or_int: Union[int, str]) -> str:
return f"type given: {type(my_str_or_int).__name__}"


@script()
def str_parameter_expects_jsonstr_dict(my_json_str: str) -> dict:
return json.loads(my_json_str)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@
{},
id="no-type-dict",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_or_int_parameter",
[{"name": "my_str_or_int", "value": "hi there"}],
"type given: str",
id="str-or-int-given-str",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_or_int_parameter",
[{"name": "my_str_or_int", "value": "3"}],
"type given: int",
id="str-or-int-given-int",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_parameter_expects_jsonstr_dict",
[{"name": "my_json_str", "value": json.dumps({"my": "dict"})}],
Expand Down
54 changes: 41 additions & 13 deletions tests/test_unit/test_shared_type_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from typing import List, Optional, Union
import sys
from typing import List, NoReturn, Optional, Union

if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated
if sys.version_info >= (3, 10):
from types import NoneType
else:
NoneType = type(None)

import pytest
from annotated_types import Gt
Expand All @@ -8,16 +18,12 @@
get_unsubscripted_type,
get_workflow_annotation,
is_annotated,
origin_type_issubclass,
origin_type_issubtype,
origin_type_issupertype,
unwrap_annotation,
)
from hera.workflows import Artifact, Parameter

try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated


@pytest.mark.parametrize("annotation, expected", [[Annotated[str, "some metadata"], True], [str, False]])
def test_is_annotated(annotation, expected):
Expand Down Expand Up @@ -104,11 +110,33 @@ def test_get_unsubscripted_type(annotation, expected):
@pytest.mark.parametrize(
"annotation, target, expected",
[
[List[str], str, False],
[Optional[str], str, True],
[str, str, True],
[Union[int, str], int, True],
pytest.param(List[str], str, False, id="list-str-not-subtype-of-str"),
pytest.param(NoReturn, str, False, id="special-form-does-not-raise-error"),
pytest.param(Optional[str], str, False, id="optional-str-not-subtype-of-str"),
pytest.param(str, str, True, id="str-is-subtype-of-str"),
pytest.param(Union[int, str], int, False, id="union-int-str-not-subtype-of-str"),
pytest.param(Optional[str], (str, NoneType), True, id="optional-str-is-subtype-of-optional-str"),
pytest.param(Annotated[Optional[str], "foo"], (str, NoneType), True, id="annotated-optional"),
pytest.param(str, (str, NoneType), True, id="str-is-subtype-of-optional-str"),
pytest.param(Union[int, str], (str, NoneType), False, id="union-int-str-not-subtype-of-optional-str"),
],
)
def test_origin_type_issubtype(annotation, target, expected):
assert origin_type_issubtype(annotation, target) is expected


@pytest.mark.parametrize(
"annotation, target, expected",
[
pytest.param(List[str], str, False, id="list-str-not-supertype-of-str"),
pytest.param(NoReturn, str, False, id="special-form-does-not-raise-error"),
pytest.param(Optional[str], str, True, id="optional-str-is-supertype-of-str"),
pytest.param(str, str, True, id="str-is-supertype-of-str"),
pytest.param(Union[int, str], int, True, id="union-int-str-is-supertype-of-int"),
pytest.param(Optional[str], NoneType, True, id="optional-str-is-supertype-of-nonetype"),
pytest.param(Annotated[Optional[str], "foo"], NoneType, True, id="annotated-optional"),
pytest.param(str, NoneType, False, id="str-not-supertype-of-nonetype"),
],
)
def test_origin_type_issubclass(annotation, target, expected):
assert origin_type_issubclass(annotation, target) is expected
def test_origin_type_issupertype(annotation, target, expected):
assert origin_type_issupertype(annotation, target) is expected
Loading