diff --git a/src/hera/shared/_type_util.py b/src/hera/shared/_type_util.py new file mode 100644 index 000000000..cabd78cd6 --- /dev/null +++ b/src/hera/shared/_type_util.py @@ -0,0 +1,106 @@ +"""Module that handles types and annotations.""" + +import sys +from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload + +if sys.version_info >= (3, 9): + from typing import Annotated, get_args, get_origin +else: + # Python 3.8 has get_origin() and get_args() but those implementations aren't + # Annotated-aware. + from typing_extensions import Annotated, get_args, get_origin + +if sys.version_info >= (3, 10): + from types import UnionType +else: + UnionType = Union + +from hera.workflows.artifact import Artifact +from hera.workflows.parameter import Parameter + + +def is_annotated(annotation: Any): + """Check annotation has Annotated type or not.""" + return get_origin(annotation) is Annotated + + +def unwrap_annotation(annotation: Any) -> Any: + """If the given annotation is of type Annotated, return the underlying type, otherwise return the annotation.""" + if is_annotated(annotation): + return get_args(annotation)[0] + return annotation + + +T = TypeVar("T") +V = TypeVar("V") + + +@overload +def get_annotated_metadata(annotation: Any, type_: Type[T]) -> List[T]: ... + + +@overload +def get_annotated_metadata(annotation: Any, type_: Tuple[Type[T], Type[V]]) -> List[Union[T, V]]: ... + + +def get_annotated_metadata(annotation, type_): + """If given annotation has metadata typed type_, return the metadata. + + Prefer get_workflow_annotation if you want to call this with Artifact or Parameter. + """ + if not is_annotated(annotation): + return [] + + found = [] + args = get_args(annotation) + for arg in args[1:]: + if isinstance(type_, Iterable): + if any(isinstance(arg, t) for t in type_): + found.append(arg) + else: + if isinstance(arg, type_): + found.append(arg) + return found + + +def get_workflow_annotation(annotation: Any) -> Optional[Union[Artifact, Parameter]]: + """If given annotation has Artifact or Parameter metadata, return it. + + Note that this function will raise the error when multiple Artifact or Parameter metadata are given. + """ + metadata = get_annotated_metadata(annotation, (Artifact, Parameter)) + if not metadata: + return None + if len(metadata) > 1: + raise ValueError("Annotation metadata cannot contain more than one Artifact/Parameter.") + return metadata[0] + + +def get_unsubscripted_type(t: Any) -> Any: + """Return the origin of t, if subscripted, or t itself. + + This can be helpful if you want to use t with isinstance, issubclass, etc., + """ + if origin_type := get_origin(t): + return origin_type + return t + + +def origin_type_issubclass(cls: Any, type_: type) -> bool: + """Return True if cls can be considered as a subclass of type_.""" + unwrapped_type = unwrap_annotation(cls) + origin_type = get_unsubscripted_type(unwrapped_type) + if origin_type is Union or origin_type is UnionType: + return any(origin_type_issubclass(arg, type_) for arg in get_args(cls)) + return issubclass(origin_type, type_) + + +def is_subscripted(t: Any) -> bool: + """Check if given type is subscripted, i.e. a typing object of the form X[Y, Z, ...]. + + >>> is_subscripted(list[str]) + True + >>> is_subscripted(str) + False + """ + return get_origin(t) is not None diff --git a/src/hera/workflows/_meta_mixins.py b/src/hera/workflows/_meta_mixins.py index 302c5a884..915fbc8a0 100644 --- a/src/hera/workflows/_meta_mixins.py +++ b/src/hera/workflows/_meta_mixins.py @@ -10,13 +10,6 @@ from types import ModuleType from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Type, TypeVar, Union, cast -if sys.version_info >= (3, 9): - from typing import Annotated, get_args, get_origin -else: - # Python 3.8 has get_origin() and get_args() but those implementations aren't - # Annotated-aware. - from typing_extensions import Annotated, get_args, get_origin - if sys.version_info >= (3, 10): from inspect import get_annotations from types import NoneType @@ -31,6 +24,7 @@ from hera.shared import BaseMixin, global_config from hera.shared._global_config import _DECORATOR_SYNTAX_FLAG, _flag_enabled from hera.shared._pydantic import BaseModel, get_fields, root_validator +from hera.shared._type_util import get_annotated_metadata from hera.workflows._context import _context from hera.workflows.exceptions import InvalidTemplateCall from hera.workflows.io.v1 import ( @@ -56,6 +50,7 @@ Output as OutputV2, ) + if TYPE_CHECKING: from hera.workflows._mixins import TemplateMixin from hera.workflows.steps import Step @@ -155,7 +150,7 @@ class ModelMapperMixin(BaseMixin): class ModelMapper: def __init__(self, model_path: str, hera_builder: Optional[Callable] = None): - self.model_path = None + self.model_path = [] self.builder = hera_builder if not model_path: @@ -181,18 +176,18 @@ def build_model( assert isinstance(hera_obj, ModelMapperMixin) for attr, annotation in hera_class._get_all_annotations().items(): - if get_origin(annotation) is Annotated and isinstance( - get_args(annotation)[1], ModelMapperMixin.ModelMapper - ): - mapper = get_args(annotation)[1] + if mappers := get_annotated_metadata(annotation, ModelMapperMixin.ModelMapper): + if len(mappers) != 1: + raise ValueError("Expected only one ModelMapper") + # Value comes from builder function if it exists on hera_obj, otherwise directly from the attr value = ( - getattr(hera_obj, mapper.builder.__name__)() - if mapper.builder is not None + getattr(hera_obj, mappers[0].builder.__name__)() + if mappers[0].builder is not None else getattr(hera_obj, attr) ) if value is not None: - _set_model_attr(model, mapper.model_path, value) + _set_model_attr(model, mappers[0].model_path, value) return model @@ -207,12 +202,11 @@ def _from_model(cls, model: BaseModel) -> ModelMapperMixin: hera_obj = cls() for attr, annotation in cls._get_all_annotations().items(): - if get_origin(annotation) is Annotated and isinstance( - get_args(annotation)[1], ModelMapperMixin.ModelMapper - ): - mapper = get_args(annotation)[1] - if mapper.model_path: - value = _get_model_attr(model, mapper.model_path) + if mappers := get_annotated_metadata(annotation, ModelMapperMixin.ModelMapper): + if len(mappers) != 1: + raise ValueError("Expected only one model mapper") + if mappers[0].model_path: + value = _get_model_attr(model, mappers[0].model_path) if value is not None: setattr(hera_obj, attr, value) @@ -497,20 +491,6 @@ def __init__(self, subnode_type: str, output_class: Type[Union[OutputV1, OutputV self.output_class = output_class -def _get_underlying_type(annotation: Type): - real_type = annotation - if get_origin(annotation) is Annotated: - real_type = get_args(annotation)[0] - - if get_origin(real_type) is Union: - args = get_args(real_type) - if len(args) == 2 and any([arg is NoneType for arg in args]): - # This was an "Optional" type, get the real type - real_type = next(iter([arg for arg in args if arg is not NoneType])) - - return real_type - - class TemplateDecoratorFuncsMixin(ContextMixin): from hera.workflows.container import Container from hera.workflows.dag import DAG diff --git a/src/hera/workflows/_mixins.py b/src/hera/workflows/_mixins.py index b1e4887e8..341466b17 100644 --- a/src/hera/workflows/_mixins.py +++ b/src/hera/workflows/_mixins.py @@ -2,7 +2,6 @@ from __future__ import annotations -import sys from typing import ( Any, Callable, @@ -16,15 +15,9 @@ cast, ) -if sys.version_info >= (3, 9): - from typing import Annotated, get_args, get_origin -else: - # Python 3.8 has get_origin() and get_args() but those implementations aren't - # Annotated-aware. - from typing_extensions import Annotated, get_args, get_origin - from hera.shared import BaseMixin, global_config from hera.shared._pydantic import PrivateAttr, get_field_annotations, get_fields, root_validator, validator +from hera.shared._type_util import get_workflow_annotation from hera.shared.serialization import serialize from hera.workflows._context import SubNodeMixin, _context from hera.workflows._meta_mixins import CallableTemplateMixin, HeraBuildObj, HookMixin @@ -745,18 +738,17 @@ def __getattribute__(self, name: str) -> Any: result_templated_str = f"{{{{{subnode_type}.{subnode_name}.outputs.result}}}}" return result_templated_str - if get_origin(annotations[name]) is Annotated: - annotation = get_args(annotations[name])[1] - - if not isinstance(annotation, (Parameter, Artifact)): - return f"{{{{{subnode_type}.{subnode_name}.outputs.parameters.{name}}}}}" + if param_or_artifact := get_workflow_annotation(annotations[name]): + if isinstance(param_or_artifact, Parameter): + return ( + "{{" + f"{subnode_type}.{subnode_name}.outputs.parameters.{param_or_artifact.name}" + "}}" + ) + else: + return ( + "{{" + f"{subnode_type}.{subnode_name}.outputs.artifacts.{param_or_artifact.name}" + "}}" + ) - if isinstance(annotation, Parameter): - return f"{{{{{subnode_type}.{subnode_name}.outputs.parameters.{annotation.name}}}}}" - elif isinstance(annotation, Artifact): - return f"{{{{{subnode_type}.{subnode_name}.outputs.artifacts.{annotation.name}}}}}" - else: - return f"{{{{{subnode_type}.{subnode_name}.outputs.parameters.{name}}}}}" + return "{{" + f"{subnode_type}.{subnode_name}.outputs.parameters.{name}" + "}}" return super().__getattribute__(name) diff --git a/src/hera/workflows/_runner/script_annotations_util.py b/src/hera/workflows/_runner/script_annotations_util.py index d9cd38c7e..28cf20313 100644 --- a/src/hera/workflows/_runner/script_annotations_util.py +++ b/src/hera/workflows/_runner/script_annotations_util.py @@ -3,18 +3,17 @@ import inspect import json import os -import sys from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast -if sys.version_info >= (3, 9): - from typing import Annotated, get_args, get_origin -else: - # Python 3.8 has get_origin() and get_args() but those implementations aren't - # Annotated-aware. - from typing_extensions import Annotated, get_args, get_origin - from hera.shared._pydantic import BaseModel, get_field_annotations, get_fields +from hera.shared._type_util import ( + get_unsubscripted_type, + get_workflow_annotation, + is_subscripted, + origin_type_issubclass, + unwrap_annotation, +) from hera.shared.serialization import serialize from hera.workflows import Artifact, Parameter from hera.workflows.artifact import ArtifactLoader @@ -135,12 +134,10 @@ def map_runner_input( If the field is annotated, we look for the kwarg with the name from the annotation (Parameter or Artifact). Otherwise, we look for the kwarg with the name of the field. """ - from hera.workflows._runner.util import _get_type - input_model_obj = {} def load_parameter_value(value: str, value_type: type) -> Any: - if issubclass(_get_type(value_type), str): + if origin_type_issubclass(value_type, str): return value try: @@ -156,23 +153,17 @@ def map_field( ) -> Any: annotation = runner_input_annotations.get(field) assert annotation is not None, "RunnerInput fields must be type-annotated" - if get_origin(annotation) is Annotated: - # my_field: Annotated[int, Parameter(...)] - ann_type = get_args(annotation)[0] - param_or_artifact = get_args(annotation)[1] - else: - # my_field: int - ann_type = annotation - param_or_artifact = None - - if isinstance(param_or_artifact, Parameter): - assert not param_or_artifact.output - return load_parameter_value( - _get_annotated_input_param_value(field, param_or_artifact, kwargs), - ann_type, - ) - elif isinstance(param_or_artifact, Artifact): - return get_annotated_artifact_value(param_or_artifact) + ann_type = unwrap_annotation(annotation) + + if param_or_artifact := get_workflow_annotation(annotation): + if isinstance(param_or_artifact, Parameter): + assert not param_or_artifact.output + return load_parameter_value( + _get_annotated_input_param_value(field, param_or_artifact, kwargs), + ann_type, + ) + else: + return get_annotated_artifact_value(param_or_artifact) else: return load_parameter_value(kwargs[field], ann_type) @@ -199,19 +190,15 @@ def _map_argo_inputs_to_function(function: Callable, kwargs: Dict[str, str]) -> mapped_kwargs: Dict[str, Any] = {} for func_param_name, func_param in inspect.signature(function).parameters.items(): - if get_origin(func_param.annotation) is Annotated: - func_param_annotation = get_args(func_param.annotation)[1] - - if isinstance(func_param_annotation, Parameter): - mapped_kwargs[func_param_name] = get_annotated_param_value( - func_param_name, func_param_annotation, kwargs - ) - elif isinstance(func_param_annotation, Artifact): - mapped_kwargs[func_param_name] = get_annotated_artifact_value(func_param_annotation) + if param_or_artifact := get_workflow_annotation(func_param.annotation): + if isinstance(param_or_artifact, Parameter): + mapped_kwargs[func_param_name] = get_annotated_param_value(func_param_name, param_or_artifact, kwargs) else: - mapped_kwargs[func_param_name] = kwargs[func_param_name] - elif get_origin(func_param.annotation) is None and issubclass(func_param.annotation, (InputV1, InputV2)): + mapped_kwargs[func_param_name] = get_annotated_artifact_value(param_or_artifact) + + elif not is_subscripted(func_param.annotation) and issubclass(func_param.annotation, (InputV1, InputV2)): mapped_kwargs[func_param_name] = map_runner_input(func_param.annotation, kwargs) + else: mapped_kwargs[func_param_name] = kwargs[func_param_name] return mapped_kwargs @@ -253,19 +240,12 @@ def _save_annotated_return_outputs( _write_to_path(path, value) else: assert isinstance(dest, tuple) - if get_origin(dest[0]) is None: - # Built-in types return None from get_origin, so we can check isinstance directly - if not isinstance(output_value, dest[0]): - raise ValueError( - f"The type of output `{dest[1].name}`, `{type(output_value)}` does not match the annotated type `{dest[0]}`" - ) - else: - # Here, we know get_origin is not None, but its return type is found to be `Optional[Any]` - origin_type = cast(type, get_origin(dest[0])) - if not isinstance(output_value, origin_type): - raise ValueError( - f"The type of output `{dest[1].name}`, `{type(output_value)}` does not match the annotated type `{dest[0]}`" - ) + + type_ = get_unsubscripted_type(dest[0]) + if not isinstance(output_value, type_): + raise ValueError( + f"The type of output `{dest[1].name}`, `{type(output_value)}` does not match the annotated type `{dest[0]}`" + ) if not dest[1].name: raise ValueError("The name was not provided for one of the outputs.") @@ -304,7 +284,7 @@ def _save_dummy_outputs( can be provided by the user or is set to /tmp/hera-outputs by default """ for dest in output_annotations: - if isinstance(dest, (OutputV1, OutputV2)): + if isinstance(dest, type) and issubclass(dest, (OutputV1, OutputV2)): if os.environ.get("hera__script_pydantic_io", None) is None: raise ValueError("hera__script_pydantic_io environment variable is not set") diff --git a/src/hera/workflows/_runner/util.py b/src/hera/workflows/_runner/util.py index 4849d0b46..5acc32ab3 100644 --- a/src/hera/workflows/_runner/util.py +++ b/src/hera/workflows/_runner/util.py @@ -6,20 +6,17 @@ import inspect import json import os -import sys from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union, cast - -if sys.version_info >= (3, 9): - from typing import Annotated, get_args, get_origin -else: - # Python 3.8 has get_origin() and get_args() but those implementations aren't - # Annotated-aware. - from typing_extensions import Annotated, get_args, get_origin +from typing import Any, Callable, Dict, List, Optional, cast from hera.shared._pydantic import _PYDANTIC_VERSION +from hera.shared._type_util import ( + get_workflow_annotation, + origin_type_issubclass, + unwrap_annotation, +) from hera.shared.serialization import serialize -from hera.workflows import Artifact, Parameter +from hera.workflows import Artifact from hera.workflows._runner.script_annotations_util import ( _map_argo_inputs_to_function, _save_annotated_return_outputs, @@ -110,67 +107,42 @@ def _parse(value: str, key: str, f: Callable) -> Any: return value -def _get_type(type_: type) -> type: - if get_origin(type_) is None: - return type_ - origin_type = cast(type, get_origin(type_)) - if origin_type is Annotated: - return get_args(type_)[0] - return origin_type +def _get_function_param_annotation(key: str, f: Callable) -> Optional[type]: + func_param_annotation = inspect.signature(f).parameters[key].annotation + if func_param_annotation is inspect.Parameter.empty: + return None + return func_param_annotation def _get_unannotated_type(key: str, f: Callable) -> Optional[type]: """Get the type of function param without the 'Annotated' outer type.""" - type_ = inspect.signature(f).parameters[key].annotation - if type_ is inspect.Parameter.empty: + type_ = _get_function_param_annotation(key, f) + if type_ is None: return None - if get_origin(type_) is None: - return type_ - - origin_type = cast(type, get_origin(type_)) - if origin_type is Annotated: - return get_args(type_)[0] - - # Type could be a dict/list with subscript type - return type_ + return unwrap_annotation(type_) def _is_str_kwarg_of(key: str, f: Callable) -> bool: """Check if param `key` of function `f` has a type annotation that can be interpreted as a subclass of str.""" - func_param_annotation = inspect.signature(f).parameters[key].annotation - if func_param_annotation is inspect.Parameter.empty: - return False - - type_ = _get_type(func_param_annotation) - if type_ is Union: - # Checking only Union[X, None] or Union[None, X] for given X which is subclass of str. - # Note that Optional[X] is alias of Union[X, None], so Optional is also handled in here. - args = get_args(func_param_annotation) - return len(args) == 2 and ( - (args[0] is type(None) and issubclass(args[1], str)) - or (issubclass(args[0], str) and args[1] is type(None)) - ) - return issubclass(type_, str) + if func_param_annotation := _get_function_param_annotation(key, f): + return origin_type_issubclass(func_param_annotation, str) + return False def _is_artifact_loaded(key: str, f: Callable) -> bool: """Check if param `key` of function `f` is actually an Artifact that has already been loaded.""" - param = inspect.signature(f).parameters[key] - return ( - get_origin(param.annotation) is Annotated - and isinstance(get_args(param.annotation)[1], Artifact) - and get_args(param.annotation)[1].loader == ArtifactLoader.json.value - ) + if param_annotation := _get_function_param_annotation(key, f): + if (artifact := get_workflow_annotation(param_annotation)) and isinstance(artifact, Artifact): + return artifact.loader == ArtifactLoader.json.value + return False def _is_output_kwarg(key: str, f: Callable) -> bool: """Check if param `key` of function `f` is an output Artifact/Parameter.""" - param = inspect.signature(f).parameters[key] - return ( - get_origin(param.annotation) is Annotated - and isinstance(get_args(param.annotation)[1], (Artifact, Parameter)) - and get_args(param.annotation)[1].output - ) + if param_annotation := _get_function_param_annotation(key, f): + if param_or_artifact := get_workflow_annotation(param_annotation): + return bool(param_or_artifact.output) + return False def _runner(entrypoint: str, kwargs_list: List) -> Any: diff --git a/src/hera/workflows/cron_workflow.py b/src/hera/workflows/cron_workflow.py index 8bf3d873a..ea73c561c 100644 --- a/src/hera/workflows/cron_workflow.py +++ b/src/hera/workflows/cron_workflow.py @@ -9,14 +9,13 @@ from typing import Dict, Optional, Type, Union, cast if sys.version_info >= (3, 9): - from typing import Annotated, get_args, get_origin + from typing import Annotated else: - # Python 3.8 has get_origin() and get_args() but those implementations aren't - # Annotated-aware. - from typing_extensions import Annotated, get_args, get_origin + from typing_extensions import Annotated from hera.exceptions import NotFound from hera.shared._pydantic import BaseModel +from hera.shared._type_util import get_annotated_metadata from hera.workflows._meta_mixins import ( ModelMapperMixin, _get_model_attr, @@ -47,22 +46,25 @@ def build_model( assert isinstance(hera_obj, ModelMapperMixin) for attr, annotation in hera_class._get_all_annotations().items(): - if get_origin(annotation) is Annotated and isinstance( - get_args(annotation)[1], ModelMapperMixin.ModelMapper - ): - mapper = get_args(annotation)[1] - if not isinstance(mapper, _CronWorkflowModelMapper) and mapper.model_path[0] == "spec": + if mappers := get_annotated_metadata(annotation, ModelMapperMixin.ModelMapper): + if len(mappers) != 1: + raise ValueError("Expected only one ModelMapper") + if ( + not isinstance(mappers[0], _CronWorkflowModelMapper) + and mappers[0].model_path + and mappers[0].model_path[0] == "spec" + ): # Skip attributes mapped to spec by parent _WorkflowModelMapper continue # Value comes from builder function if it exists on hera_obj, otherwise directly from the attr value = ( - getattr(hera_obj, mapper.builder.__name__)() - if mapper.builder is not None + getattr(hera_obj, mappers[0].builder.__name__)() + if mappers[0].builder is not None else getattr(hera_obj, attr) ) if value is not None: - _set_model_attr(model, mapper.model_path, value) + _set_model_attr(model, mappers[0].model_path, value) return model @@ -167,23 +169,24 @@ def _from_model(cls, model: BaseModel) -> ModelMapperMixin: hera_cron_workflow = cls(schedule="") for attr, annotation in cls._get_all_annotations().items(): - if get_origin(annotation) is Annotated and isinstance( - get_args(annotation)[1], ModelMapperMixin.ModelMapper - ): - mapper = get_args(annotation)[1] - if mapper.model_path: + if mappers := get_annotated_metadata(annotation, ModelMapperMixin.ModelMapper): + if len(mappers) != 1: + raise ValueError("Expected only one ModelMapper") + + if mappers[0].model_path: value = None if ( - isinstance(mapper, _CronWorkflowModelMapper) - or isinstance(mapper, _WorkflowModelMapper) - and mapper.model_path[0] == "metadata" + isinstance(mappers[0], _CronWorkflowModelMapper) + or isinstance(mappers[0], _WorkflowModelMapper) + and mappers[0].model_path[0] == "metadata" ): - value = _get_model_attr(model, mapper.model_path) - elif isinstance(mapper, _WorkflowModelMapper) and mapper.model_path[0] == "spec": + value = _get_model_attr(model, mappers[0].model_path) + + elif isinstance(mappers[0], _WorkflowModelMapper) and mappers[0].model_path[0] == "spec": # We map "spec.workflow_spec" from the model CronWorkflow to "spec" for Hera's Workflow (used # as the parent class of Hera's CronWorkflow) - value = _get_model_attr(model.spec.workflow_spec, mapper.model_path[1:]) + value = _get_model_attr(model.spec.workflow_spec, mappers[0].model_path[1:]) if value is not None: setattr(hera_cron_workflow, attr, value) diff --git a/src/hera/workflows/io/_io_mixins.py b/src/hera/workflows/io/_io_mixins.py index b4d0df667..912f9f6c4 100644 --- a/src/hera/workflows/io/_io_mixins.py +++ b/src/hera/workflows/io/_io_mixins.py @@ -1,28 +1,14 @@ import sys import warnings -from pathlib import Path from typing import TYPE_CHECKING, List, Optional, Union -if sys.version_info >= (3, 9): - from typing import Annotated -else: - from typing_extensions import Annotated - -if sys.version_info >= (3, 10): - from typing import get_args, get_origin -else: - # Python 3.8 has get_origin() and get_args() but those implementations aren't - # Annotated-aware. Python 3.9's versions don't support ParamSpecArgs and - # ParamSpecKwargs. - from typing_extensions import get_args, get_origin - if sys.version_info >= (3, 11): from typing import Self else: from typing_extensions import Self - from hera.shared._pydantic import _PYDANTIC_VERSION, get_field_annotations, get_fields +from hera.shared._type_util import get_workflow_annotation, is_annotated from hera.shared.serialization import MISSING, serialize from hera.workflows._context import _context from hera.workflows.artifact import Artifact @@ -44,7 +30,6 @@ V2BaseModel = V1BaseModel # type: ignore PydanticUndefined = None # type: ignore[assignment] - if TYPE_CHECKING: # We add BaseModel as a parent class of the mixins only when type checking which allows it # to be used with either a V1 BaseModel or a V2 BaseModel @@ -80,24 +65,23 @@ def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Paramet annotations = get_field_annotations(cls) for field, field_info in get_fields(cls).items(): - if get_origin(annotations[field]) is Annotated: + if (param := get_workflow_annotation(annotations[field])) and isinstance(param, Parameter): # Copy so as to not modify the Input fields themselves - param = get_args(annotations[field])[1].copy() - if isinstance(param, Parameter): - if param.name is None: - param.name = field - if param.default is not None: - warnings.warn( - "Using the default field for Parameters in Annotations is deprecated since v5.16" - "and will be removed in a future minor version, use a Python default value instead" - ) - if object_override: - param.default = serialize(getattr(object_override, field)) - elif field_info.default is not None and field_info.default != PydanticUndefined: # type: ignore - # Serialize the value (usually done in Parameter's validator) - param.default = serialize(field_info.default) # type: ignore - parameters.append(param) - else: + param = param.copy() + if param.name is None: + param.name = field + if param.default is not None: + warnings.warn( + "Using the default field for Parameters in Annotations is deprecated since v5.16" + "and will be removed in a future minor version, use a Python default value instead" + ) + if object_override: + param.default = serialize(getattr(object_override, field)) + elif field_info.default is not None and field_info.default != PydanticUndefined: # type: ignore + # Serialize the value (usually done in Parameter's validator) + param.default = serialize(field_info.default) # type: ignore + parameters.append(param) + elif not is_annotated(annotations[field]): # Create a Parameter from basic type annotations default = getattr(object_override, field) if object_override else field_info.default @@ -115,15 +99,14 @@ def _get_artifacts(cls) -> List[Artifact]: annotations = get_field_annotations(cls) for field in get_fields(cls): - if get_origin(annotations[field]) is Annotated: + if (artifact := get_workflow_annotation(annotations[field])) and isinstance(artifact, Artifact): # Copy so as to not modify the Input fields themselves - artifact = get_args(annotations[field])[1].copy() - if isinstance(artifact, Artifact): - if artifact.name is None: - artifact.name = field - if artifact.path is None: - artifact.path = artifact._get_default_inputs_path() - artifacts.append(artifact) + artifact = artifact.copy() + if artifact.name is None: + artifact.name = field + if artifact.path is None: + artifact.path = artifact._get_default_inputs_path() + artifacts.append(artifact) return artifacts @classmethod @@ -138,15 +121,12 @@ def _get_as_templated_arguments(cls) -> Self: annotations = get_field_annotations(cls) for field in cls_fields: - if get_origin(annotations[field]) is Annotated: - annotation = get_args(annotations[field])[1] - if isinstance(annotation, (Artifact, Parameter)): - name = annotation.name - if isinstance(annotation, Parameter): - object_dict[field] = "{{inputs.parameters." + f"{name}" + "}}" - elif isinstance(annotation, Artifact): - object_dict[field] = "{{inputs.artifacts." + f"{name}" + "}}" - else: + if param_or_artifact := get_workflow_annotation(annotations[field]): + if isinstance(param_or_artifact, Parameter): + object_dict[field] = "{{inputs.parameters." + f"{param_or_artifact.name}" + "}}" + else: + object_dict[field] = "{{inputs.artifacts." + f"{param_or_artifact.name}" + "}}" + elif not is_annotated(annotations[field]): object_dict[field] = "{{inputs.parameters." + f"{field}" + "}}" return cls.construct(None, **object_dict) @@ -166,39 +146,17 @@ def _get_as_arguments(self) -> ModelArguments: # If it is a templated string, it will be unaffected as `"{{mystr}}" == serialize("{{mystr}}")`` templated_value = serialize(self_dict[field]) - if get_origin(annotations[field]) is Annotated: - annotation = get_args(annotations[field])[1] - if isinstance(annotation, Parameter) and annotation.name: - params.append(ModelParameter(name=annotation.name, value=templated_value)) - elif isinstance(annotation, Artifact) and annotation.name: - artifacts.append(ModelArtifact(name=annotation.name, from_=templated_value)) - else: + if (param_or_artifact := get_workflow_annotation(annotations[field])) and param_or_artifact.name: + if isinstance(param_or_artifact, Parameter): + params.append(ModelParameter(name=param_or_artifact.name, value=templated_value)) + else: + artifacts.append(ModelArtifact(name=param_or_artifact.name, from_=templated_value)) + elif not is_annotated(annotations[field]): params.append(ModelParameter(name=field, value=templated_value)) return ModelArguments(parameters=params or None, artifacts=artifacts or None) -def _get_output_path(annotation: Union[Parameter, Artifact]) -> Path: - """Get the path from the OutputMixin attribute's annotation. - - Use the default path with the annotation's name if no path present on the object. - """ - default_path = Path("/tmp/hera-outputs") - if isinstance(annotation, Parameter): - if annotation.value_from and annotation.value_from.path: - return Path(annotation.value_from.path) - - assert annotation.name - return default_path / f"parameters/{annotation.name}" - - if isinstance(annotation, Artifact): - if annotation.path: - return Path(annotation.path) - - assert annotation.name - return default_path / f"artifacts/{annotation.name}" - - class OutputMixin(BaseModel): def __new__(cls, **kwargs): if _context.declaring: @@ -219,7 +177,7 @@ def __init__(self, /, **kwargs): @classmethod def _get_outputs(cls, add_missing_path: bool = False) -> List[Union[Artifact, Parameter]]: - outputs = [] + outputs: List[Union[Artifact, Parameter]] = [] annotations = get_field_annotations(cls) model_fields = get_fields(cls) @@ -227,16 +185,20 @@ def _get_outputs(cls, add_missing_path: bool = False) -> List[Union[Artifact, Pa for field in model_fields: if field in {"exit_code", "result"}: continue - if get_origin(annotations[field]) is Annotated: - annotation = get_args(annotations[field])[1] - if isinstance(annotation, Parameter): - if add_missing_path and (annotation.value_from is None or annotation.value_from.path is None): - annotation.value_from = ValueFrom(path=f"/tmp/hera-outputs/parameters/{annotation.name}") - elif isinstance(annotation, Artifact): - if add_missing_path and annotation.path is None: - annotation.path = f"/tmp/hera-outputs/artifacts/{annotation.name}" - outputs.append(annotation) - else: + if param_or_artifact := get_workflow_annotation(annotations[field]): + if isinstance(param_or_artifact, Parameter): + if add_missing_path and ( + param_or_artifact.value_from is None or param_or_artifact.value_from.path is None + ): + param_or_artifact.value_from = ValueFrom( + path=f"/tmp/hera-outputs/parameters/{param_or_artifact.name}" + ) + outputs.append(param_or_artifact) + else: + if add_missing_path and param_or_artifact.path is None: + param_or_artifact.path = f"/tmp/hera-outputs/artifacts/{param_or_artifact.name}" + outputs.append(param_or_artifact) + elif not is_annotated(annotations[field]): # Create a Parameter from basic type annotations default = model_fields[field].default if default is None or default == PydanticUndefined: @@ -253,9 +215,8 @@ def _get_outputs(cls, add_missing_path: bool = False) -> List[Union[Artifact, Pa def _get_output(cls, field_name: str) -> Union[Artifact, Parameter]: annotations = get_field_annotations(cls) annotation = annotations[field_name] - if get_origin(annotation) is Annotated: - if isinstance(get_args(annotation)[1], (Parameter, Artifact)): - return get_args(annotation)[1] + if output := get_workflow_annotation(annotation): + return output # Create a Parameter from basic type annotations default = get_fields(cls)[field_name].default @@ -282,13 +243,14 @@ def _get_as_invocator_output(self) -> List[Union[Artifact, Parameter]]: templated_value = self_dict[field] # a string such as `"{{tasks.task_a.outputs.parameter.my_param}}"` - if get_origin(annotations[field]) is Annotated: - annotation = get_args(annotations[field])[1] - if isinstance(annotation, Parameter) and annotation.name: - outputs.append(Parameter(name=annotation.name, value_from=ValueFrom(parameter=templated_value))) - elif isinstance(annotation, Artifact) and annotation.name: - outputs.append(Artifact(name=annotation.name, from_=templated_value)) - else: + if (param_or_artifact := get_workflow_annotation(annotations[field])) and param_or_artifact.name: + if isinstance(param_or_artifact, Parameter): + outputs.append( + Parameter(name=param_or_artifact.name, value_from=ValueFrom(parameter=templated_value)) + ) + else: + outputs.append(Artifact(name=param_or_artifact.name, from_=templated_value)) + elif not is_annotated(annotations[field]): outputs.append(Parameter(name=field, value_from=ValueFrom(parameter=templated_value))) return outputs diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index 03ef6280c..9a44db23b 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -29,10 +29,10 @@ overload, ) -if sys.version_info >= (3, 9): - from typing import Annotated +if sys.version_info >= (3, 10): + from types import NoneType else: - from typing_extensions import Annotated + NoneType = type(None) from typing_extensions import ParamSpec, get_args, get_origin @@ -45,6 +45,7 @@ _flag_enabled, ) from hera.shared._pydantic import _PYDANTIC_VERSION, root_validator, validator +from hera.shared._type_util import get_workflow_annotation, is_subscripted, origin_type_issubclass from hera.shared.serialization import serialize from hera.workflows._context import _context from hera.workflows._meta_mixins import CallableTemplateMixin @@ -396,14 +397,15 @@ def append_annotation(annotation: Union[Artifact, Parameter]): parameters.append(annotation) return_annotation = inspect.signature(source).return_annotation - if get_origin(return_annotation) is Annotated: - append_annotation(get_args(return_annotation)[1]) + if param_or_artifact := get_workflow_annotation(return_annotation): + append_annotation(param_or_artifact) elif get_origin(return_annotation) is tuple: for annotation in get_args(return_annotation): if isinstance(annotation, type) and issubclass(annotation, (OutputV1, OutputV2)): raise ValueError("Output cannot be part of a tuple output") - append_annotation(get_args(annotation)[1]) + if param_or_artifact := get_workflow_annotation(annotation): + append_annotation(param_or_artifact) elif return_annotation and issubclass(return_annotation, (OutputV1, OutputV2)): if not _flag_enabled(_SCRIPT_PYDANTIC_IO_FLAG): raise ValueError( @@ -431,14 +433,8 @@ def _get_outputs_from_parameter_annotations( artifacts: List[Artifact] = [] for name, p in inspect.signature(source).parameters.items(): - if get_origin(p.annotation) is not Annotated: - continue - annotation = get_args(p.annotation)[1] - - if not isinstance(annotation, (Artifact, Parameter)): - raise ValueError(f"The output {type(annotation)} cannot be used as an annotation.") - - if not annotation.output: + annotation = get_workflow_annotation(p.annotation) + if not annotation or not annotation.output: continue new_object = annotation.copy() @@ -473,7 +469,7 @@ class will be used as inputs, rather than the class itself. artifacts = [] for func_param in inspect.signature(source).parameters.values(): - if get_origin(func_param.annotation) is None and issubclass(func_param.annotation, (InputV1, InputV2)): + if not is_subscripted(func_param.annotation) and issubclass(func_param.annotation, (InputV1, InputV2)): if not _flag_enabled(_SCRIPT_PYDANTIC_IO_FLAG): raise ValueError( ( @@ -499,37 +495,12 @@ class will be used as inputs, rather than the class itself. artifacts.extend(input_class._get_artifacts()) - elif get_origin(func_param.annotation) is not Annotated or not isinstance( - get_args(func_param.annotation)[1], (Artifact, Parameter) - ): - if ( - func_param.default != inspect.Parameter.empty - and func_param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - default = func_param.default - else: - default = MISSING - - type_ = get_origin(func_param.annotation) - args = get_args(func_param.annotation) - if type_ is Annotated: - type_ = get_origin(args[0]) - args = get_args(args[0]) - - if (type_ is Union and len(args) == 2 and type(None) in args) and ( - default is MISSING or default is not None - ): - raise ValueError(f"Optional parameter '{func_param.name}' must have a default value of None.") - - parameters.append(Parameter(name=func_param.name, default=default)) - else: - annotation = get_args(func_param.annotation)[1] - - if annotation.output: + elif param_or_artifact := get_workflow_annotation(func_param.annotation): + if param_or_artifact.output: continue # Create a new object so we don't modify the Workflow itself - new_object = annotation.copy() + new_object = param_or_artifact.copy() if not new_object.name: new_object.name = func_param.name @@ -557,6 +528,19 @@ class will be used as inputs, rather than the class itself. ) new_object.default = serialize(func_param.default) parameters.append(new_object) + else: + if ( + func_param.default != inspect.Parameter.empty + and func_param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ): + default = func_param.default + else: + default = MISSING + + if origin_type_issubclass(func_param.annotation, NoneType) and (default is MISSING or default is not None): + raise ValueError(f"Optional parameter '{func_param.name}' must have a default value of None.") + + parameters.append(Parameter(name=func_param.name, default=default)) return parameters, artifacts @@ -568,7 +552,7 @@ def _extract_return_annotation_output(source: Callable) -> List: return_annotation = inspect.signature(source).return_annotation origin_type = get_origin(return_annotation) annotation_args = get_args(return_annotation) - if origin_type is Annotated: + if get_workflow_annotation(return_annotation): output.append(annotation_args) elif origin_type is tuple: for annotated_type in annotation_args: @@ -591,11 +575,8 @@ def _extract_all_output_annotations(source: Callable) -> List: output = [] for _, func_param in inspect.signature(source).parameters.items(): - if get_origin(func_param.annotation) is Annotated: - annotation_args = get_args(func_param.annotation) - annotated_type = annotation_args[1] - if isinstance(annotated_type, (Artifact, Parameter)) and annotated_type.output: - output.append(annotation_args) + if (annotated := get_workflow_annotation(func_param.annotation)) and annotated.output: + output.append(annotated) output.extend(_extract_return_annotation_output(source)) diff --git a/tests/script_runner/parameter_inputs.py b/tests/script_runner/parameter_inputs.py index 754afd2a1..980d756a9 100644 --- a/tests/script_runner/parameter_inputs.py +++ b/tests/script_runner/parameter_inputs.py @@ -45,6 +45,14 @@ def annotated_basic_types( return Output(output=[Input(a=a_but_kebab, b=b_but_kebab)]) +@script() +def annotated_basic_types_with_other_metadata( + a_but_kebab: Annotated[int, "Should skip this one", Parameter(name="a-but-kebab")] = 2, + b_but_kebab: Annotated[str, "should", "skip", Parameter(name="b-but-kebab"), "this", "one"] = "foo", +) -> Output: + return Output(output=[Input(a=a_but_kebab, b=b_but_kebab)]) + + @script() def annotated_object(annotated_input_value: Annotated[Input, Parameter(name="input-value")]) -> Output: return Output(output=[annotated_input_value]) diff --git a/tests/script_runner/parameter_with_complex_types.py b/tests/script_runner/parameter_with_complex_types.py index f64b55073..cda68fd7c 100644 --- a/tests/script_runner/parameter_with_complex_types.py +++ b/tests/script_runner/parameter_with_complex_types.py @@ -23,6 +23,10 @@ def optional_str_parameter_using_union(my_string: Union[None, str] = None) -> Un def optional_str_parameter_using_or(my_string: str | None = None) -> str | None: return my_string + @script(constructor="runner") + def optional_str_parameter_using_multiple_or(my_string: str | int | None = None) -> str: + return my_string + @script(constructor="runner") def optional_int_parameter(my_int: Optional[int] = None) -> Optional[int]: diff --git a/tests/script_runner/pydantic_io_v2_invalid.py b/tests/script_runner/pydantic_io_v2_invalid.py new file mode 100644 index 000000000..2c9eec2bd --- /dev/null +++ b/tests/script_runner/pydantic_io_v2_invalid.py @@ -0,0 +1,43 @@ +import sys + +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated + +from tests.helper import ARTIFACT_PATH + +from hera.shared import global_config +from hera.workflows import Artifact, script + +try: + from hera.workflows.io.v2 import Input, Output +except ImportError: + from hera.workflows.io.v1 import Input, Output + +global_config.experimental_features["script_annotations"] = True +global_config.experimental_features["script_pydantic_io"] = True + + +class MultipleAnnotationInput(Input): + str_path_artifact: Annotated[ + str, + Artifact(name="str-path-artifact", path=ARTIFACT_PATH + "/str-path", loader=None), + Artifact(name="str-path-artifact", path=ARTIFACT_PATH + "/path", loader=None), + ] + + +class MultipleAnnotationOutput(Output): + an_artifact: Annotated[str, Artifact(name="artifact-str-output"), Artifact(name="artifact-str-output")] + + +@script(constructor="runner") +def pydantic_input_invalid( + my_input: MultipleAnnotationInput, +) -> str: + return "Should not run" + + +@script(constructor="runner") +def pydantic_output_invalid() -> MultipleAnnotationOutput: + return MultipleAnnotationOutput(an_artifact="test") diff --git a/tests/test_runner.py b/tests/test_runner.py index 91b146670..4cc03ef52 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -167,6 +167,13 @@ def test_runner_parameter_inputs( _PYDANTIC_VERSION, id="str-param-given-int", ), + pytest.param( + "tests.script_runner.parameter_inputs:annotated_basic_types_with_other_metadata", + [{"name": "a-but-kebab", "value": "3"}, {"name": "b-but-kebab", "value": "1"}], + '{"output": [{"a": 3, "b": "1"}]}', + _PYDANTIC_VERSION, + id="str-param-given-int", + ), pytest.param( "tests.script_runner.parameter_inputs:annotated_object", [{"name": "input-value", "value": '{"a": 3, "b": "bar"}'}], @@ -981,6 +988,45 @@ def test_runner_pydantic_output_with_result( assert Path(tmp_path / file["subpath"]).read_text() == file["value"] +@pytest.mark.parametrize("pydantic_mode", [1, _PYDANTIC_VERSION]) +@pytest.mark.parametrize( + "entrypoint,error_type,error_match", + [ + pytest.param( + "tests.script_runner.pydantic_io_v2_invalid:pydantic_input_invalid", + ValueError, + "Annotation metadata cannot contain more than one Artifact/Parameter.", + id="invalid input annotation", + ), + pytest.param( + "tests.script_runner.pydantic_io_v2_invalid:pydantic_output_invalid", + ValueError, + "Annotation metadata cannot contain more than one Artifact/Parameter.", + id="invalid output annotation", + ), + ], +) +def test_runner_pydantic_with_invalid_annotations( + entrypoint, + error_type, + error_match, + pydantic_mode, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +): + # GIVEN + monkeypatch.setenv("hera__pydantic_mode", str(pydantic_mode)) + monkeypatch.setenv("hera__script_annotations", "") + monkeypatch.setenv("hera__script_pydantic_io", "") + + outputs_directory = str(tmp_path / "tmp/hera-outputs") + monkeypatch.setenv("hera__outputs_directory", outputs_directory) + + # WHEN / THEN + with pytest.raises(error_type, match=error_match): + _runner(entrypoint, []) + + @pytest.mark.parametrize( "entrypoint", [ @@ -989,7 +1035,10 @@ def test_runner_pydantic_output_with_result( ] + ( # Union types using OR operator are allowed since python 3.10. - ["tests.script_runner.parameter_with_complex_types:optional_str_parameter_using_or"] + [ + "tests.script_runner.parameter_with_complex_types:optional_str_parameter_using_or", + "tests.script_runner.parameter_with_complex_types:optional_str_parameter_using_multiple_or", + ] if sys.version_info[0] >= 3 and sys.version_info[1] >= 10 else [] ), diff --git a/tests/test_unit/test_script.py b/tests/test_unit/test_script.py index b3b1eb37d..191152d0b 100644 --- a/tests/test_unit/test_script.py +++ b/tests/test_unit/test_script.py @@ -11,7 +11,8 @@ from hera.workflows import Workflow, script from hera.workflows.artifact import Artifact -from hera.workflows.script import _get_inputs_from_callable +from hera.workflows.parameter import Parameter +from hera.workflows.script import _get_inputs_from_callable, _get_outputs_from_return_annotation def test_get_inputs_from_callable_simple_params(): @@ -185,3 +186,21 @@ def unknown_annotations_ignored(my_optional_string: Optional[str] = "123") -> st with pytest.raises(ValueError, match="Optional parameter 'my_optional_string' must have a default value of None."): _get_inputs_from_callable(unknown_annotations_ignored) + + +def test_invalid_script_when_multiple_input_workflow_annotations_are_given(): + @script() + def invalid_script(a_str: Annotated[str, Artifact(name="a_str"), Parameter(name="a_str")] = "123") -> str: + return "Got: {}".format(a_str) + + with pytest.raises(ValueError, match="Annotation metadata cannot contain more than one Artifact/Parameter."): + _get_inputs_from_callable(invalid_script) + + +def test_invalid_script_when_multiple_output_workflow_annotations_are_given(): + @script() + def invalid_script(a_str: str = "123") -> Annotated[str, Artifact(name="a_str"), Artifact(name="b_str")]: + return "Got: {}".format(a_str) + + with pytest.raises(ValueError, match="Annotation metadata cannot contain more than one Artifact/Parameter."): + _get_outputs_from_return_annotation(invalid_script, None) diff --git a/tests/test_unit/test_shared_type_utils.py b/tests/test_unit/test_shared_type_utils.py new file mode 100644 index 000000000..d543997d8 --- /dev/null +++ b/tests/test_unit/test_shared_type_utils.py @@ -0,0 +1,114 @@ +from typing import List, Optional, Union + +import pytest +from annotated_types import Gt + +from hera.shared._type_util import ( + get_annotated_metadata, + get_unsubscripted_type, + get_workflow_annotation, + is_annotated, + origin_type_issubclass, + unwrap_annotation, +) +from hera.workflows import Artifact, Parameter + +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + + +@pytest.mark.parametrize("annotation, expected", [[Annotated[str, "some metadata"], True], [str, False]]) +def test_is_annotated(annotation, expected): + assert is_annotated(annotation) == expected + + +@pytest.mark.parametrize( + "annotation, expected", + [ + [Annotated[str, Parameter(name="a_str")], str], + [Annotated[str, "some metadata"], str], + [str, str], + ], +) +def test_unwrap_annotation(annotation, expected): + assert unwrap_annotation(annotation) == expected + + +@pytest.mark.parametrize( + "annotation, t, expected", + [ + # Not annotated one. + [str, Parameter, []], + [Annotated[str, Parameter(name="a_str")], Parameter, [Parameter(name="a_str")]], + [Annotated[str, "some metadata"], Parameter, []], + # Must support variadic annotated + [Annotated[str, "some metadata", Parameter(name="a_str")], Parameter, [Parameter(name="a_str")]], + # Must support multiple types + [Annotated[str, "some metadata", Parameter(name="a_str")], (Parameter, Artifact), [Parameter(name="a_str")]], + # Must consume in order + [ + Annotated[str, "some metadata", Artifact(name="a_str"), Parameter(name="a_str")], + (Parameter, Artifact), + [Artifact(name="a_str"), Parameter(name="a_str")], + ], + ], +) +def test_get_annotated_metadata(annotation, t, expected): + assert get_annotated_metadata(annotation, t) == expected + + +@pytest.mark.parametrize( + "annotation, expected", + [ + [str, None], + [Annotated[str, Parameter(name="a_str")], Parameter(name="a_str")], + [Annotated[str, Artifact(name="a_str")], Artifact(name="a_str")], + [Annotated[int, Gt(10), Artifact(name="a_int")], Artifact(name="a_int")], + [Annotated[int, Artifact(name="a_int"), Gt(30)], Artifact(name="a_int")], + # this can happen when user uses already annotated types. + [Annotated[Annotated[int, Gt(10)], Artifact(name="a_int")], Artifact(name="a_int")], + ], +) +def test_get_workflow_annotation(annotation, expected): + assert get_workflow_annotation(annotation) == expected + + +@pytest.mark.parametrize( + "annotation", + [ + # Duplicated annotation + Annotated[str, Parameter(name="a_str"), Parameter(name="b_str")], + Annotated[str, Parameter(name="a_str"), Artifact(name="a_str")], + # Nested + Annotated[Annotated[str, Parameter(name="a_str")], Artifact(name="b_str")], + ], +) +def test_get_workflow_annotation_should_raise_error(annotation): + with pytest.raises(ValueError): + get_workflow_annotation(annotation) + + +@pytest.mark.parametrize( + "annotation, expected", + [ + [List[str], list], + [Optional[str], Union], + ], +) +def test_get_unsubscripted_type(annotation, expected): + assert get_unsubscripted_type(annotation) is expected + + +@pytest.mark.parametrize( + "annotation, target, expected", + [ + [List[str], str, False], + [Optional[str], str, True], + [str, str, True], + [Union[int, str], int, True], + ], +) +def test_origin_type_issubclass(annotation, target, expected): + assert origin_type_issubclass(annotation, target) is expected