Skip to content

Commit

Permalink
Implement get_parameter API on IO mixin and add tests (#876)
Browse files Browse the repository at this point in the history
**Pull Request Checklist**
- [x] Fixes #816
- [x] Tests added
- [x] Documentation/examples added
- [x] [Good commit messages](https://cbea.ms/git-commit/) and/or PR
title

Implements the `get_parameter` API on the `IOMixin`. This allows clients
that need parameters in `arguments` fields to get the parameter object
from DAGs, Steps, etc. This helps avoid the need to write
`{{inputs.parameters.whatever}}` explicitly. `get_parameter` _assumes_
that clients want to use a parameter as input, so the value field is set
accordingly

Signed-off-by: Flaviu Vadan <flaviuvadan@gmail.com>
  • Loading branch information
flaviuvadan authored Nov 23, 2023
1 parent c8980bd commit 614d364
Show file tree
Hide file tree
Showing 11 changed files with 277 additions and 17 deletions.
101 changes: 101 additions & 0 deletions docs/examples/workflows/callable_dag_with_param_get.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Callable Dag With Param Get






=== "Hera"

```python linenums="1"
from typing_extensions import Annotated

from hera.workflows import DAG, Parameter, Workflow, script


@script(constructor="runner")
def hello_with_output(name: str) -> Annotated[str, Parameter(name="output-message")]:
return "Hello, {name}!".format(name=name)


with Workflow(
generate_name="callable-dag-",
entrypoint="calling-dag",
) as w:
with DAG(
name="my-dag-with-outputs",
inputs=Parameter(name="my-dag-input"),
outputs=Parameter(
name="my-dag-output",
value_from={"parameter": "{{hello.outputs.parameters.output-message}}"},
),
) as my_dag:
# Here, get_parameter searches through the *inputs* of my_dag
hello_with_output(name="hello", arguments={"name": f"hello {my_dag.get_parameter('my-dag-input')}"})

with DAG(name="calling-dag") as d:
t1 = my_dag(name="call-1", arguments={"my-dag-input": "call-1"})
# Here, t1 is a Task from the called dag, so get_parameter is called on the Task to get the output parameter! 🚀
t2 = my_dag(name="call-2", arguments=t1.get_parameter("my-dag-output").with_name("my-dag-input"))
t1 >> t2
```

=== "YAML"

```yaml linenums="1"
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: callable-dag-
spec:
entrypoint: calling-dag
templates:
- dag:
tasks:
- arguments:
parameters:
- name: name
value: hello {{inputs.parameters.my-dag-input}}
name: hello
template: hello-with-output
inputs:
parameters:
- name: my-dag-input
name: my-dag-with-outputs
outputs:
parameters:
- name: my-dag-output
valueFrom:
parameter: '{{hello.outputs.parameters.output-message}}'
- inputs:
parameters:
- name: name
name: hello-with-output
script:
args:
- -m
- hera.workflows.runner
- -e
- examples.workflows.callable_dag_with_param_get:hello_with_output
command:
- python
image: python:3.8
source: '{{inputs.parameters}}'
- dag:
tasks:
- arguments:
parameters:
- name: my-dag-input
value: call-1
name: call-1
template: my-dag-with-outputs
- arguments:
parameters:
- name: my-dag-input
value: '{{tasks.call-1.outputs.parameters.my-dag-output}}'
depends: call-1
name: call-2
template: my-dag-with-outputs
name: calling-dag
```

Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ The upstream example can be [found here](https://github.com/argoproj/argo-workfl
"exit_code": f"{g.item.exit_code:$}",
"message": f"{g.item.message:$}",
},
with_param="{{inputs.parameters.step_params}}",
with_param=s.get_parameter("step_params"),
)
```

Expand Down
8 changes: 4 additions & 4 deletions docs/examples/workflows/upstream/parallelism_nested.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ The upstream example can be [found here](https://github.com/argoproj/argo-workfl
one_job(
name="seq-step",
arguments=[
Parameter(name="parallel-id", value="{{inputs.parameters.parallel-id}}"),
seq_worker.get_parameter("parallel-id"),
Parameter(name="seq-id", value="{{item}}"),
],
with_param="{{inputs.parameters.seq-list}}",
with_param=seq_worker.get_parameter("seq-list"),
)

with Steps(
Expand All @@ -49,10 +49,10 @@ The upstream example can be [found here](https://github.com/argoproj/argo-workfl
name="parallel-worker",
template=seq_worker,
arguments=[
Parameter(name="seq-list", value="{{inputs.parameters.seq-list}}"),
seq_worker.get_parameter("seq-list"),
Parameter(name="parallel-id", value="{{item}}"),
],
with_param="{{inputs.parameters.parallel-list}}",
with_param=parallel_worker.get_parameter("parallel-list"),
)
```

Expand Down
54 changes: 54 additions & 0 deletions examples/workflows/callable-dag-with-param-get.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: callable-dag-
spec:
entrypoint: calling-dag
templates:
- dag:
tasks:
- arguments:
parameters:
- name: name
value: hello {{inputs.parameters.my-dag-input}}
name: hello
template: hello-with-output
inputs:
parameters:
- name: my-dag-input
name: my-dag-with-outputs
outputs:
parameters:
- name: my-dag-output
valueFrom:
parameter: '{{hello.outputs.parameters.output-message}}'
- inputs:
parameters:
- name: name
name: hello-with-output
script:
args:
- -m
- hera.workflows.runner
- -e
- examples.workflows.callable_dag_with_param_get:hello_with_output
command:
- python
image: python:3.8
source: '{{inputs.parameters}}'
- dag:
tasks:
- arguments:
parameters:
- name: my-dag-input
value: call-1
name: call-1
template: my-dag-with-outputs
- arguments:
parameters:
- name: my-dag-input
value: '{{tasks.call-1.outputs.parameters.my-dag-output}}'
depends: call-1
name: call-2
template: my-dag-with-outputs
name: calling-dag
30 changes: 30 additions & 0 deletions examples/workflows/callable_dag_with_param_get.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing_extensions import Annotated

from hera.workflows import DAG, Parameter, Workflow, script


@script(constructor="runner")
def hello_with_output(name: str) -> Annotated[str, Parameter(name="output-message")]:
return "Hello, {name}!".format(name=name)


with Workflow(
generate_name="callable-dag-",
entrypoint="calling-dag",
) as w:
with DAG(
name="my-dag-with-outputs",
inputs=Parameter(name="my-dag-input"),
outputs=Parameter(
name="my-dag-output",
value_from={"parameter": "{{hello.outputs.parameters.output-message}}"},
),
) as my_dag:
# Here, get_parameter searches through the *inputs* of my_dag
hello_with_output(name="hello", arguments={"name": f"hello {my_dag.get_parameter('my-dag-input')}"})

with DAG(name="calling-dag") as d:
t1 = my_dag(name="call-1", arguments={"my-dag-input": "call-1"})
# Here, t1 is a Task from the called dag, so get_parameter is called on the Task to get the output parameter! 🚀
t2 = my_dag(name="call-2", arguments=t1.get_parameter("my-dag-output").with_name("my-dag-input"))
t1 >> t2
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@
"exit_code": f"{g.item.exit_code:$}",
"message": f"{g.item.message:$}",
},
with_param="{{inputs.parameters.step_params}}",
with_param=s.get_parameter("step_params"),
)
8 changes: 4 additions & 4 deletions examples/workflows/upstream/parallelism_nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
one_job(
name="seq-step",
arguments=[
Parameter(name="parallel-id", value="{{inputs.parameters.parallel-id}}"),
seq_worker.get_parameter("parallel-id"),
Parameter(name="seq-id", value="{{item}}"),
],
with_param="{{inputs.parameters.seq-list}}",
with_param=seq_worker.get_parameter("seq-list"),
)

with Steps(
Expand All @@ -36,8 +36,8 @@
name="parallel-worker",
template=seq_worker,
arguments=[
Parameter(name="seq-list", value="{{inputs.parameters.seq-list}}"),
seq_worker.get_parameter("seq-list"),
Parameter(name="parallel-id", value="{{item}}"),
],
with_param="{{inputs.parameters.parallel-list}}",
with_param=parallel_worker.get_parameter("parallel-list"),
)
26 changes: 26 additions & 0 deletions src/hera/workflows/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,32 @@ class IOMixin(BaseMixin):
inputs: InputsT = None
outputs: OutputsT = None

def get_parameter(self, name: str) -> Parameter:
"""Finds and returns the parameter with the supplied name.
Note that this method will raise an error if the parameter is not found.
Args:
name: name of the input parameter to find and return.
Returns:
Parameter: the parameter with the supplied name.
Raises:
KeyError: if the parameter is not found.
"""
inputs = self._build_inputs()
if inputs is None:
raise KeyError(f"No inputs set. Parameter {name} not found.")
if inputs.parameters is None:
raise KeyError(f"No parameters set. Parameter {name} not found.")
for p in inputs.parameters:
if p.name == name:
param = Parameter.from_model(p)
param.value = f"{{{{inputs.parameters.{param.name}}}}}"
return param
raise KeyError(f"Parameter {name} not found.")

def _build_inputs(self) -> Optional[ModelInputs]:
"""Processes the `inputs` field and returns a generated `ModelInputs`."""
if self.inputs is None:
Expand Down
1 change: 1 addition & 0 deletions src/hera/workflows/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class DAG(
>>> @script()
>>> def foo() -> None:
>>> print(42)
>>>
>>> with DAG(...) as dag:
>>> foo()
"""
Expand Down
15 changes: 10 additions & 5 deletions src/hera/workflows/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ def _check_values(cls, values):

return values

@classmethod
def _get_input_attributes(cls):
"""Return the attributes used for input parameter annotations."""
return ["enum", "description", "default", "name", "value", "value_from"]

def __str__(self):
"""Represent the parameter as a string by pointing to its value.
Expand All @@ -61,6 +66,11 @@ def __str__(self):
raise ValueError("Cannot represent `Parameter` as string as `value` is not set")
return self.value

@classmethod
def from_model(cls, model: _ModelParameter) -> Parameter:
"""Creates a `Parameter` from a `Parameter` model."""
return cls(**model.dict())

def with_name(self, name: str) -> Parameter:
"""Returns a copy of the parameter with the name set to the value."""
p = self.copy(deep=True)
Expand Down Expand Up @@ -108,10 +118,5 @@ def as_output(self) -> _ModelParameter:
value_from=self.value_from,
)

@classmethod
def _get_input_attributes(cls):
"""Return the attributes used for input parameter annotations."""
return ["enum", "description", "default", "name", "value", "value_from"]


__all__ = ["Parameter"]
47 changes: 45 additions & 2 deletions tests/test_unit/test_mixins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from hera.workflows._mixins import ContainerMixin
from hera.workflows.models import ImagePullPolicy
import pytest

from hera.workflows import Parameter
from hera.workflows._mixins import ContainerMixin, IOMixin
from hera.workflows.models import (
ImagePullPolicy,
Inputs as ModelInputs,
)


class TestContainerMixin:
Expand All @@ -10,3 +16,40 @@ def test_build_image_pull_policy(self) -> None:
== ImagePullPolicy.always
)
assert ContainerMixin()._build_image_pull_policy() is None


class TestIOMixin:
@pytest.fixture(autouse=True)
def setup(self):
self.io_mixin = IOMixin()

def test_get_parameter_success(self):
self.io_mixin.inputs = ModelInputs(parameters=[Parameter(name="test", value="value")])
param = self.io_mixin.get_parameter("test")
assert param.name == "test"
assert param.value == "{{inputs.parameters.test}}"

def test_get_parameter_no_inputs(self):
with pytest.raises(KeyError):
self.io_mixin.get_parameter("test")

def test_get_parameter_no_parameters(self):
self.io_mixin.inputs = ModelInputs()
with pytest.raises(KeyError):
self.io_mixin.get_parameter("test")

def test_get_parameter_not_found(self):
self.io_mixin.inputs = ModelInputs(parameters=[Parameter(name="test", value="value")])
with pytest.raises(KeyError):
self.io_mixin.get_parameter("not_exist")

def test_build_inputs_none(self):
assert self.io_mixin._build_inputs() is None

def test_build_inputs_from_model_inputs(self):
model_inputs = ModelInputs(parameters=[Parameter(name="test", value="value")])
self.io_mixin.inputs = model_inputs
assert self.io_mixin._build_inputs() == model_inputs

def test_build_outputs_none(self):
assert self.io_mixin._build_outputs() is None

0 comments on commit 614d364

Please sign in to comment.