Skip to content

Commit

Permalink
Allow Artifact/Parameter in any position in Annotated metadata (#1168)
Browse files Browse the repository at this point in the history
**Pull Request Checklist**
- [x] ~~Fixes #<!--issue number goes here-->~~ this PR follows up on
#1160
- [x] Tests added
- [x] Documentation/examples added : this is refactoring. no need
- [x] [Good commit messages](https://cbea.ms/git-commit/) and/or PR
title

**Description of PR**

Currently, we are checking annotated, parameter/artifact, or other types
in different ways.
In this PR, I added `hera._utils.type_util` module and unified them into
one place.
Followings can be expected:
- If there's another metadata along with Parameter or Artifact, they
should be ignored. (e.g. `Annotated[type, metadata1, metadata2,
Parameter(...)]`, this can be happened when `Annotated`s are nested)
- `Optional` parameter handling should be more precise which is added in
#1160.

Suggestions on function names/module names are welcome.

---------

Signed-off-by: Ukjae Jeong <jeongukjae@gmail.com>
Signed-off-by: Ukjae Jeong <JeongUkJae@gmail.com>
Co-authored-by: Elliot Gunton <elliotgunton@gmail.com>
Co-authored-by: Sambhav Kothari <skothari44@bloomberg.net>
Co-authored-by: Alice <Alice.Purcell.39@gmail.com>
  • Loading branch information
4 people authored Sep 2, 2024
1 parent f84bfc8 commit 777fd32
Show file tree
Hide file tree
Showing 14 changed files with 547 additions and 334 deletions.
106 changes: 106 additions & 0 deletions src/hera/shared/_type_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Module that handles types and annotations."""

import sys
from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload

if sys.version_info >= (3, 9):
from typing import Annotated, get_args, get_origin
else:
# Python 3.8 has get_origin() and get_args() but those implementations aren't
# Annotated-aware.
from typing_extensions import Annotated, get_args, get_origin

if sys.version_info >= (3, 10):
from types import UnionType
else:
UnionType = Union

from hera.workflows.artifact import Artifact
from hera.workflows.parameter import Parameter


def is_annotated(annotation: Any):
"""Check annotation has Annotated type or not."""
return get_origin(annotation) is Annotated


def unwrap_annotation(annotation: Any) -> Any:
"""If the given annotation is of type Annotated, return the underlying type, otherwise return the annotation."""
if is_annotated(annotation):
return get_args(annotation)[0]
return annotation


T = TypeVar("T")
V = TypeVar("V")


@overload
def get_annotated_metadata(annotation: Any, type_: Type[T]) -> List[T]: ...


@overload
def get_annotated_metadata(annotation: Any, type_: Tuple[Type[T], Type[V]]) -> List[Union[T, V]]: ...


def get_annotated_metadata(annotation, type_):
"""If given annotation has metadata typed type_, return the metadata.
Prefer get_workflow_annotation if you want to call this with Artifact or Parameter.
"""
if not is_annotated(annotation):
return []

found = []
args = get_args(annotation)
for arg in args[1:]:
if isinstance(type_, Iterable):
if any(isinstance(arg, t) for t in type_):
found.append(arg)
else:
if isinstance(arg, type_):
found.append(arg)
return found


def get_workflow_annotation(annotation: Any) -> Optional[Union[Artifact, Parameter]]:
"""If given annotation has Artifact or Parameter metadata, return it.
Note that this function will raise the error when multiple Artifact or Parameter metadata are given.
"""
metadata = get_annotated_metadata(annotation, (Artifact, Parameter))
if not metadata:
return None
if len(metadata) > 1:
raise ValueError("Annotation metadata cannot contain more than one Artifact/Parameter.")
return metadata[0]


def get_unsubscripted_type(t: Any) -> Any:
"""Return the origin of t, if subscripted, or t itself.
This can be helpful if you want to use t with isinstance, issubclass, etc.,
"""
if origin_type := get_origin(t):
return origin_type
return t


def origin_type_issubclass(cls: Any, type_: type) -> bool:
"""Return True if cls can be considered as a subclass of type_."""
unwrapped_type = unwrap_annotation(cls)
origin_type = get_unsubscripted_type(unwrapped_type)
if origin_type is Union or origin_type is UnionType:
return any(origin_type_issubclass(arg, type_) for arg in get_args(cls))
return issubclass(origin_type, type_)


def is_subscripted(t: Any) -> bool:
"""Check if given type is subscripted, i.e. a typing object of the form X[Y, Z, ...].
>>> is_subscripted(list[str])
True
>>> is_subscripted(str)
False
"""
return get_origin(t) is not None
50 changes: 15 additions & 35 deletions src/hera/workflows/_meta_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Type, TypeVar, Union, cast

if sys.version_info >= (3, 9):
from typing import Annotated, get_args, get_origin
else:
# Python 3.8 has get_origin() and get_args() but those implementations aren't
# Annotated-aware.
from typing_extensions import Annotated, get_args, get_origin

if sys.version_info >= (3, 10):
from inspect import get_annotations
from types import NoneType
Expand All @@ -31,6 +24,7 @@
from hera.shared import BaseMixin, global_config
from hera.shared._global_config import _DECORATOR_SYNTAX_FLAG, _flag_enabled
from hera.shared._pydantic import BaseModel, get_fields, root_validator
from hera.shared._type_util import get_annotated_metadata
from hera.workflows._context import _context
from hera.workflows.exceptions import InvalidTemplateCall
from hera.workflows.io.v1 import (
Expand All @@ -56,6 +50,7 @@
Output as OutputV2,
)


if TYPE_CHECKING:
from hera.workflows._mixins import TemplateMixin
from hera.workflows.steps import Step
Expand Down Expand Up @@ -155,7 +150,7 @@ class ModelMapperMixin(BaseMixin):

class ModelMapper:
def __init__(self, model_path: str, hera_builder: Optional[Callable] = None):
self.model_path = None
self.model_path = []
self.builder = hera_builder

if not model_path:
Expand All @@ -181,18 +176,18 @@ def build_model(
assert isinstance(hera_obj, ModelMapperMixin)

for attr, annotation in hera_class._get_all_annotations().items():
if get_origin(annotation) is Annotated and isinstance(
get_args(annotation)[1], ModelMapperMixin.ModelMapper
):
mapper = get_args(annotation)[1]
if mappers := get_annotated_metadata(annotation, ModelMapperMixin.ModelMapper):
if len(mappers) != 1:
raise ValueError("Expected only one ModelMapper")

# Value comes from builder function if it exists on hera_obj, otherwise directly from the attr
value = (
getattr(hera_obj, mapper.builder.__name__)()
if mapper.builder is not None
getattr(hera_obj, mappers[0].builder.__name__)()
if mappers[0].builder is not None
else getattr(hera_obj, attr)
)
if value is not None:
_set_model_attr(model, mapper.model_path, value)
_set_model_attr(model, mappers[0].model_path, value)

return model

Expand All @@ -207,12 +202,11 @@ def _from_model(cls, model: BaseModel) -> ModelMapperMixin:
hera_obj = cls()

for attr, annotation in cls._get_all_annotations().items():
if get_origin(annotation) is Annotated and isinstance(
get_args(annotation)[1], ModelMapperMixin.ModelMapper
):
mapper = get_args(annotation)[1]
if mapper.model_path:
value = _get_model_attr(model, mapper.model_path)
if mappers := get_annotated_metadata(annotation, ModelMapperMixin.ModelMapper):
if len(mappers) != 1:
raise ValueError("Expected only one model mapper")
if mappers[0].model_path:
value = _get_model_attr(model, mappers[0].model_path)
if value is not None:
setattr(hera_obj, attr, value)

Expand Down Expand Up @@ -497,20 +491,6 @@ def __init__(self, subnode_type: str, output_class: Type[Union[OutputV1, OutputV
self.output_class = output_class


def _get_underlying_type(annotation: Type):
real_type = annotation
if get_origin(annotation) is Annotated:
real_type = get_args(annotation)[0]

if get_origin(real_type) is Union:
args = get_args(real_type)
if len(args) == 2 and any([arg is NoneType for arg in args]):
# This was an "Optional" type, get the real type
real_type = next(iter([arg for arg in args if arg is not NoneType]))

return real_type


class TemplateDecoratorFuncsMixin(ContextMixin):
from hera.workflows.container import Container
from hera.workflows.dag import DAG
Expand Down
30 changes: 11 additions & 19 deletions src/hera/workflows/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import sys
from typing import (
Any,
Callable,
Expand All @@ -16,15 +15,9 @@
cast,
)

if sys.version_info >= (3, 9):
from typing import Annotated, get_args, get_origin
else:
# Python 3.8 has get_origin() and get_args() but those implementations aren't
# Annotated-aware.
from typing_extensions import Annotated, get_args, get_origin

from hera.shared import BaseMixin, global_config
from hera.shared._pydantic import PrivateAttr, get_field_annotations, get_fields, root_validator, validator
from hera.shared._type_util import get_workflow_annotation
from hera.shared.serialization import serialize
from hera.workflows._context import SubNodeMixin, _context
from hera.workflows._meta_mixins import CallableTemplateMixin, HeraBuildObj, HookMixin
Expand Down Expand Up @@ -745,18 +738,17 @@ def __getattribute__(self, name: str) -> Any:
result_templated_str = f"{{{{{subnode_type}.{subnode_name}.outputs.result}}}}"
return result_templated_str

if get_origin(annotations[name]) is Annotated:
annotation = get_args(annotations[name])[1]

if not isinstance(annotation, (Parameter, Artifact)):
return f"{{{{{subnode_type}.{subnode_name}.outputs.parameters.{name}}}}}"
if param_or_artifact := get_workflow_annotation(annotations[name]):
if isinstance(param_or_artifact, Parameter):
return (
"{{" + f"{subnode_type}.{subnode_name}.outputs.parameters.{param_or_artifact.name}" + "}}"
)
else:
return (
"{{" + f"{subnode_type}.{subnode_name}.outputs.artifacts.{param_or_artifact.name}" + "}}"
)

if isinstance(annotation, Parameter):
return f"{{{{{subnode_type}.{subnode_name}.outputs.parameters.{annotation.name}}}}}"
elif isinstance(annotation, Artifact):
return f"{{{{{subnode_type}.{subnode_name}.outputs.artifacts.{annotation.name}}}}}"
else:
return f"{{{{{subnode_type}.{subnode_name}.outputs.parameters.{name}}}}}"
return "{{" + f"{subnode_type}.{subnode_name}.outputs.parameters.{name}" + "}}"

return super().__getattribute__(name)

Expand Down
Loading

0 comments on commit 777fd32

Please sign in to comment.