Skip to content

Commit

Permalink
Support inherited fields for RunnerInput, other niceties (#1093)
Browse files Browse the repository at this point in the history
## Pull Request Checklist
- [ ] Fixes #<!--issue number goes here-->
- [x] Tests added
- [ ] Documentation/examples added
- [x] [Good commit messages](https://cbea.ms/git-commit/) and/or PR
title

## Summary

Currently, users cannot leverage inheritance when defining `RunnerInput`
models that utilize `Annotated` fields. For example:

```
class ResourceIndex(RunnerInput):
    resource_name: Annotated[str, Parameter(name="resource-name")]
    resource_namespace: Annotated[str, Parameter(name="resource-namespace")]

class GetStatusInput(ResourceIndex):
    wait: Annotated[bool, Parameter(name="wait")]
    timeout_s: Annotated[int, Parameter(name="timeout")]
```

This would fail at runtime with a KeyError at [this
line](https://github.com/argoproj-labs/hera/blob/main/src/hera/workflows/_runner/script_annotations_util.py#L152),
because apparently `cls.__annotations__` are not inherited by child
classes.

This PR enables the above inheritance structure by utilizing the
existing
[get_field_annotations](https://github.com/argoproj-labs/hera/blob/main/src/hera/shared/_pydantic.py#L69)
instead of `my_runner_input.__annotations__`.

#### Parameter Default Names

This PR also changes logic so that `Input` field `Parameters` adopt the
name of their field if no name is provided. For example, in

```
class MyInput(Input):
    my_int: Annotated[int, Parameter(description="this is my int")]
```

The parameter is converted to `Parameter(name="my_int",
description="this is my int")`. This avoids an error from
[here](https://github.com/argoproj-labs/hera/blob/main/src/hera/workflows/parameter.py#L30)
which is thrown when trying to generate a workflow yaml.

---------

Signed-off-by: Mitchell Douglass <mdouglass15@bloomberg.net>
Co-authored-by: Sambhav Kothari <skothari44@bloomberg.net>
  • Loading branch information
mitrydoug and sambhav authored Jun 12, 2024
1 parent cc222ad commit 5631c6a
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 16 deletions.
37 changes: 23 additions & 14 deletions src/hera/workflows/_runner/script_annotations_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast

from hera.shared._pydantic import BaseModel, get_fields
from hera.shared._pydantic import BaseModel, get_field_annotations, get_fields
from hera.shared.serialization import serialize
from hera.workflows import Artifact, Parameter
from hera.workflows.artifact import ArtifactLoader
Expand Down Expand Up @@ -145,24 +145,33 @@ def load_parameter_value(value: str, value_type: type) -> Any:
except json.JSONDecodeError:
return value

runner_input_annotations = get_field_annotations(runner_input_class)

def map_field(
field: str,
kwargs: Dict[str, str],
) -> Any:
annotation = runner_input_class.__annotations__[field]
annotation = runner_input_annotations.get(field)
assert annotation is not None, "RunnerInput fields must be type-annotated"
if get_origin(annotation) is Annotated:
meta_annotation = get_args(annotation)[1]
if isinstance(meta_annotation, Parameter):
assert not meta_annotation.output
return load_parameter_value(
_get_annotated_input_param_value(field, meta_annotation, kwargs),
get_args(annotation)[0],
)

if isinstance(meta_annotation, Artifact):
return get_annotated_artifact_value(meta_annotation)

return load_parameter_value(kwargs[field], annotation)
# my_field: Annotated[int, Parameter(...)]
ann_type = get_args(annotation)[0]
param_or_artifact = get_args(annotation)[1]
else:
# my_field: int
ann_type = annotation
param_or_artifact = None

if isinstance(param_or_artifact, Parameter):
assert not param_or_artifact.output
return load_parameter_value(
_get_annotated_input_param_value(field, param_or_artifact, kwargs),
ann_type,
)
elif isinstance(param_or_artifact, Artifact):
return get_annotated_artifact_value(param_or_artifact)
else:
return load_parameter_value(kwargs[field], ann_type)

for field in get_fields(runner_input_class):
input_model_obj[field] = map_field(field, kwargs)
Expand Down
10 changes: 8 additions & 2 deletions src/hera/workflows/io/_io_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,11 @@ def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Paramet

for field, field_info in get_fields(cls).items():
if get_origin(annotations[field]) is Annotated:
param = get_args(annotations[field])[1]
# Copy so as to not modify the Input fields themselves
param = get_args(annotations[field])[1].copy()
if isinstance(param, Parameter):
if param.name is None:
param.name = field
if object_override:
param.default = serialize(getattr(object_override, field))
elif field_info.default is not None and field_info.default != PydanticUndefined: # type: ignore
Expand All @@ -91,8 +94,11 @@ def _get_artifacts(cls) -> List[Artifact]:

for field in get_fields(cls):
if get_origin(annotations[field]) is Annotated:
artifact = get_args(annotations[field])[1]
# Copy so as to not modify the Input fields themselves
artifact = get_args(annotations[field])[1].copy()
if isinstance(artifact, Artifact):
if artifact.name is None:
artifact.name = field
if artifact.path is None:
artifact.path = artifact._get_default_inputs_path()
artifacts.append(artifact)
Expand Down
21 changes: 21 additions & 0 deletions tests/test_unit/test_io_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from hera.workflows.io import Input
from hera.workflows.parameter import Parameter

try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated


def test_input_mixin_get_parameters():
class Foo(Input):
foo: Annotated[int, Parameter(name="foo")]

assert Foo._get_parameters() == [Parameter(name="foo")]


def test_input_mixin_get_parameters_default_name():
class Foo(Input):
foo: Annotated[int, Parameter(description="a foo")]

assert Foo._get_parameters() == [Parameter(name="foo", description="a foo")]
76 changes: 76 additions & 0 deletions tests/test_unit/test_script_annotations_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
from hera.workflows.models import ValueFrom
from hera.workflows.parameter import Parameter

try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated


@pytest.mark.parametrize(
"destination,expected_path",
Expand Down Expand Up @@ -174,3 +179,74 @@ class MyInput(Input):
a_dict_str=json.dumps({"key": "value"}),
a_list_str=json.dumps([1, 2, 3]),
)


def test_map_runner_input_annotated_parameter():
"""Test annotated Parameter."""

class Foo(Input):
foo: Annotated[str, Parameter(name="bar")]

kwargs = {"foo": "hello"}
assert map_runner_input(Foo, kwargs) == Foo(foo="hello")
kwargs = {"bar": "there"}
assert map_runner_input(Foo, kwargs) == Foo(foo="there")


def test_map_runner_input_output_parameter_disallowed():
"""Test annotated output Parameter is not allowed."""

class Foo(Input):
foo: Annotated[str, Parameter(name="bar", output=True)]

with pytest.raises(AssertionError):
kwargs = {"foo": "hello"}
map_runner_input(Foo, kwargs)


def test_map_runner_input_annotated_artifact(tmp_path):
"""Test annotated Artifact."""

foo_path = tmp_path / "foo"
foo_path.write_text("hello there")

class Foo(Input):
foo: Annotated[str, Artifact(name="bar", path=str(foo_path), loader=ArtifactLoader.file)]

assert map_runner_input(Foo, {}) == Foo(foo="hello there")


def test_map_runner_input_annotated_inheritance():
"""Test model inheritance with Annotated fields."""

class Foo(Input):
foo: Annotated[str, Parameter(name="foo")]

class FooBar(Foo):
bar: Annotated[str, Parameter(name="bar")]

kwargs = {"foo": "hello", "bar": "there"}
assert map_runner_input(FooBar, kwargs) == FooBar(**kwargs)


def test_map_runner_input_annotated_inheritance_override():
"""Test model inheritance with Annotated fields."""

class Foo(Input):
foo: Annotated[str, Parameter(name="foo")]

class FooBar(Foo):
foo: Annotated[str, Parameter(name="bar")]

kwargs = {"bar": "hello"}
assert map_runner_input(FooBar, kwargs) == FooBar(foo="hello")


def test_map_runner_input_annotated_parameter_noname():
"""Test Annotated Parameter with no name."""

class Foo(Input):
foo: Annotated[str, Parameter(description="a parameter")]

kwargs = {"foo": "hello"}
assert map_runner_input(Foo, kwargs) == Foo(foo="hello")

0 comments on commit 5631c6a

Please sign in to comment.