Skip to content

Commit

Permalink
Update script to handle Optional and Union input parameters (#1160)
Browse files Browse the repository at this point in the history
**Pull Request Checklist**
- [x] Fixes #1012 
- [x] Tests added
- [x] Documentation/examples added : I think no needs for document?
- [x] [Good commit messages](https://cbea.ms/git-commit/) and/or PR
title

**Description of PR**
Currently, `script` decorator cannot handle `Union` or `Optional` input
parameters because of the error in `issubclass`, so updated it to be
able to handle it as expected.

---

original PR: #1147

---------

Signed-off-by: Ukjae Jeong <jeongukjae@gmail.com>
Signed-off-by: Ukjae Jeong <JeongUkJae@gmail.com>
Co-authored-by: Elliot Gunton <elliotgunton@gmail.com>
  • Loading branch information
jeongukjae and elliotgunton authored Aug 20, 2024
1 parent 3470e8e commit 523a9b9
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,4 @@ stop-argo: ## Stop the argo server
.PHONY: test-on-cluster
test-on-cluster: ## Run workflow tests (requires local argo cluster)
@(kubectl -n argo port-forward deployment/argo-server 2746:2746 &)
@poetry run python -m pytest tests/test_submission.py -m on_cluster
@poetry run python -m pytest tests/submissions -m on_cluster
12 changes: 10 additions & 2 deletions src/hera/workflows/_runner/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, cast
from typing import Any, Callable, Dict, List, Optional, Union, cast

from hera.shared._pydantic import _PYDANTIC_VERSION
from hera.shared.serialization import serialize
Expand Down Expand Up @@ -133,12 +133,20 @@ def _get_unannotated_type(key: str, f: Callable) -> Optional[type]:


def _is_str_kwarg_of(key: str, f: Callable) -> bool:
"""Check if param `key` of function `f` has a type annotation of a subclass of str."""
"""Check if param `key` of function `f` has a type annotation that can be interpreted as a subclass of str."""
func_param_annotation = inspect.signature(f).parameters[key].annotation
if func_param_annotation is inspect.Parameter.empty:
return False

type_ = _get_type(func_param_annotation)
if type_ is Union:
# Checking only Union[X, None] or Union[None, X] for given X which is subclass of str.
# Note that Optional[X] is alias of Union[X, None], so Optional is also handled in here.
args = get_args(func_param_annotation)
return len(args) == 2 and (
(args[0] is type(None) and issubclass(args[1], str))
or (issubclass(args[0], str) and args[1] is type(None))
)
return issubclass(type_, str)


Expand Down
11 changes: 11 additions & 0 deletions src/hera/workflows/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,17 @@ class will be used as inputs, rather than the class itself.
else:
default = MISSING

type_ = get_origin(func_param.annotation)
args = get_args(func_param.annotation)
if type_ is Annotated:
type_ = get_origin(args[0])
args = get_args(args[0])

if (type_ is Union and len(args) == 2 and type(None) in args) and (
default is MISSING or default is not None
):
raise ValueError(f"Optional parameter '{func_param.name}' must have a default value of None.")

parameters.append(Parameter(name=func_param.name, default=default))
else:
annotation = get_args(func_param.annotation)[1]
Expand Down
34 changes: 34 additions & 0 deletions tests/script_runner/parameter_with_complex_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import sys
from typing import Optional, Union

from hera.shared import global_config
from hera.workflows import script

global_config.experimental_features["script_annotations"] = True


@script(constructor="runner")
def optional_str_parameter(my_string: Optional[str] = None) -> Optional[str]:
return my_string


@script(constructor="runner")
def optional_str_parameter_using_union(my_string: Union[None, str] = None) -> Union[None, str]:
return my_string


if sys.version_info[0] >= 3 and sys.version_info[1] >= 10:
# Union types using OR operator are allowed since python 3.10.
@script(constructor="runner")
def optional_str_parameter_using_or(my_string: str | None = None) -> str | None:
return my_string


@script(constructor="runner")
def optional_int_parameter(my_int: Optional[int] = None) -> Optional[int]:
return my_int


@script(constructor="runner")
def union_parameter(my_param: Union[str, int] = None) -> Union[str, int]:
return my_param
Empty file added tests/submissions/__init__.py
Empty file.
File renamed without changes.
51 changes: 51 additions & 0 deletions tests/submissions/test_optional_input_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Optional

import pytest

from hera.workflows import Parameter, Steps, Workflow, WorkflowsService, script
from hera.workflows.models import (
NodeStatus,
Parameter as ModelParameter,
)


@script(outputs=Parameter(name="message-out", value_from={"path": "/tmp/message-out"}))
def print_msg(message: Optional[str] = None):
with open("/tmp/message-out", "w") as f:
f.write("Got: {}".format(message))


def get_workflow() -> Workflow:
with Workflow(
generate_name="optional-param-",
entrypoint="steps",
namespace="argo",
workflows_service=WorkflowsService(
host="https://localhost:2746",
namespace="argo",
verify_ssl=False,
),
) as w:
with Steps(name="steps"):
print_msg(name="step-1", arguments={"message": "Hello world!"})
print_msg(name="step-2", arguments={})
print_msg(name="step-3")

return w


@pytest.mark.on_cluster
def test_create_workflow_with_optional_input_parameter():
model_workflow = get_workflow().create(wait=True)
assert model_workflow.status and model_workflow.status.phase == "Succeeded"

step_and_expected_output = {
"step-1": "Got: Hello world!",
"step-2": "Got: None",
"step-3": "Got: None",
}

for step, expected_output in step_and_expected_output.items():
node: NodeStatus = next(filter(lambda n: n.display_name == step, model_workflow.status.nodes.values()))
message_out: ModelParameter = next(filter(lambda n: n.name == "message-out", node.outputs.parameters))
assert message_out.value == expected_output
84 changes: 84 additions & 0 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import importlib
import json
import sys
from pathlib import Path
from typing import Any, Dict, List, Literal
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -978,3 +979,86 @@ def test_runner_pydantic_output_with_result(
for file in expected_files:
assert Path(tmp_path / file["subpath"]).is_file()
assert Path(tmp_path / file["subpath"]).read_text() == file["value"]


@pytest.mark.parametrize(
"entrypoint",
[
"tests.script_runner.parameter_with_complex_types:optional_str_parameter",
"tests.script_runner.parameter_with_complex_types:optional_str_parameter_using_union",
]
+ (
# Union types using OR operator are allowed since python 3.10.
["tests.script_runner.parameter_with_complex_types:optional_str_parameter_using_or"]
if sys.version_info[0] >= 3 and sys.version_info[1] >= 10
else []
),
)
@pytest.mark.parametrize(
"kwargs_list,expected_output",
[
pytest.param(
[{"name": "my_string", "value": "a string"}],
"a string",
),
pytest.param(
[{"name": "my_string", "value": None}],
"null",
),
],
)
def test_script_optional_parameter(
monkeypatch: pytest.MonkeyPatch,
entrypoint,
kwargs_list,
expected_output,
):
# GIVEN
monkeypatch.setenv("hera__script_annotations", "")

# WHEN
output = _runner(entrypoint, kwargs_list)

# THEN
assert serialize(output) == expected_output


@pytest.mark.parametrize(
"entrypoint,kwargs_list,expected_output",
[
[
"tests.script_runner.parameter_with_complex_types:optional_int_parameter",
[{"name": "my_int", "value": 123}],
"123",
],
[
"tests.script_runner.parameter_with_complex_types:optional_int_parameter",
[{"name": "my_int", "value": None}],
"null",
],
[
"tests.script_runner.parameter_with_complex_types:union_parameter",
[{"name": "my_param", "value": "a string"}],
"a string",
],
[
"tests.script_runner.parameter_with_complex_types:union_parameter",
[{"name": "my_param", "value": 123}],
"123",
],
],
)
def test_script_with_complex_types(
monkeypatch: pytest.MonkeyPatch,
entrypoint,
kwargs_list,
expected_output,
):
# GIVEN
monkeypatch.setenv("hera__script_annotations", "")

# WHEN
output = _runner(entrypoint, kwargs_list)

# THEN
assert serialize(output) == expected_output
68 changes: 68 additions & 0 deletions tests/test_unit/test_script.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from typing import Optional, Union

import pytest

try:
from typing import Annotated # type: ignore
except ImportError:
Expand Down Expand Up @@ -117,3 +121,67 @@ def unknown_annotations_ignored(my_string: Annotated[str, "some metadata"]) -> s

assert parameter.name == "my_string"
assert parameter.default is None


def test_script_optional_parameter():
# GIVEN
@script()
def unknown_annotations_ignored(my_optional_string: Optional[str] = None) -> str:
return "Got: {}".format(my_optional_string)

# WHEN
params, artifacts = _get_inputs_from_callable(unknown_annotations_ignored)

# THEN
assert artifacts == []
assert isinstance(params, list)
assert len(params) == 1
parameter = params[0]

assert parameter.name == "my_optional_string"
assert parameter.default == "null"


def test_invalid_script_when_optional_parameter_does_not_have_default_value():
@script()
def unknown_annotations_ignored(my_optional_string: Optional[str]) -> str:
return "Got: {}".format(my_optional_string)

with pytest.raises(ValueError, match="Optional parameter 'my_optional_string' must have a default value of None."):
_get_inputs_from_callable(unknown_annotations_ignored)


def test_invalid_script_when_optional_parameter_does_not_have_default_value_2():
@script()
def unknown_annotations_ignored(my_optional_string: Annotated[Optional[str], "123"]) -> str:
return "Got: {}".format(my_optional_string)

with pytest.raises(ValueError, match="Optional parameter 'my_optional_string' must have a default value of None."):
_get_inputs_from_callable(unknown_annotations_ignored)


def test_invalid_script_when_optional_parameter_does_not_have_default_value_3():
@script()
def unknown_annotations_ignored(my_optional_string: Union[str, None]) -> str:
return "Got: {}".format(my_optional_string)

with pytest.raises(ValueError, match="Optional parameter 'my_optional_string' must have a default value of None."):
_get_inputs_from_callable(unknown_annotations_ignored)


def test_invalid_script_when_optional_parameter_does_not_have_default_value_4():
@script()
def unknown_annotations_ignored(my_optional_string: Union[None, str]) -> str:
return "Got: {}".format(my_optional_string)

with pytest.raises(ValueError, match="Optional parameter 'my_optional_string' must have a default value of None."):
_get_inputs_from_callable(unknown_annotations_ignored)


def test_invalid_script_when_optional_parameter_does_not_have_default_value_5():
@script()
def unknown_annotations_ignored(my_optional_string: Optional[str] = "123") -> str:
return "Got: {}".format(my_optional_string)

with pytest.raises(ValueError, match="Optional parameter 'my_optional_string' must have a default value of None."):
_get_inputs_from_callable(unknown_annotations_ignored)

0 comments on commit 523a9b9

Please sign in to comment.