diff --git a/docs/examples/workflows/experimental/pydantic_io_in_dag_context.md b/docs/examples/workflows/experimental/pydantic_io_in_dag_context.md new file mode 100644 index 000000000..d5c58e8de --- /dev/null +++ b/docs/examples/workflows/experimental/pydantic_io_in_dag_context.md @@ -0,0 +1,154 @@ +# Pydantic Io In Dag Context + + + + + + +=== "Hera" + + ```python linenums="1" + import sys + from typing import List + + if sys.version_info >= (3, 9): + from typing import Annotated + else: + from typing_extensions import Annotated + + + from hera.shared import global_config + from hera.workflows import DAG, Parameter, WorkflowTemplate, script + from hera.workflows.io.v1 import Input, Output + + global_config.experimental_features["decorator_syntax"] = True + + + class CutInput(Input): + cut_after: Annotated[int, Parameter(name="cut-after")] + strings: List[str] + + + class CutOutput(Output): + first_strings: Annotated[List[str], Parameter(name="first-strings")] + remainder: List[str] + + + class JoinInput(Input): + strings: List[str] + joiner: str + + + class JoinOutput(Output): + joined_string: Annotated[str, Parameter(name="joined-string")] + + + @script(constructor="runner") + def cut(input: CutInput) -> CutOutput: + return CutOutput( + first_strings=input.strings[: input.cut_after], + remainder=input.strings[input.cut_after :], + exit_code=1 if len(input.strings) <= input.cut_after else 0, + ) + + + @script(constructor="runner") + def join(input: JoinInput) -> JoinOutput: + return JoinOutput(joined_string=input.joiner.join(input.strings)) + + + with WorkflowTemplate(generate_name="pydantic-io-in-steps-context-v1-", entrypoint="d") as w: + with DAG(name="d"): + cut_result = cut(CutInput(strings=["hello", "world", "it's", "hera"], cut_after=1)) + join(JoinInput(strings=cut_result.first_strings, joiner=" ")) + ``` + +=== "YAML" + + ```yaml linenums="1" + apiVersion: argoproj.io/v1alpha1 + kind: WorkflowTemplate + metadata: + generateName: pydantic-io-in-steps-context-v1- + spec: + entrypoint: d + templates: + - dag: + tasks: + - arguments: + parameters: + - name: cut-after + value: '1' + - name: strings + value: '["hello", "world", "it''s", "hera"]' + name: cut + template: cut + - arguments: + parameters: + - name: strings + value: '{{tasks.cut.outputs.parameters.first-strings}}' + - name: joiner + value: ' ' + depends: cut + name: join + template: join + name: d + - inputs: + parameters: + - name: cut-after + - name: strings + name: cut + outputs: + parameters: + - name: first-strings + valueFrom: + path: /tmp/hera-outputs/parameters/first-strings + - name: remainder + valueFrom: + path: /tmp/hera-outputs/parameters/remainder + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.pydantic_io_in_dag_context:cut + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__outputs_directory + value: /tmp/hera-outputs + - name: hera__script_pydantic_io + value: '' + image: python:3.8 + source: '{{inputs.parameters}}' + - inputs: + parameters: + - name: strings + - name: joiner + name: join + outputs: + parameters: + - name: joined-string + valueFrom: + path: /tmp/hera-outputs/parameters/joined-string + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.pydantic_io_in_dag_context:join + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__outputs_directory + value: /tmp/hera-outputs + - name: hera__script_pydantic_io + value: '' + image: python:3.8 + source: '{{inputs.parameters}}' + ``` + diff --git a/docs/examples/workflows/experimental/pydantic_io_in_steps_context.md b/docs/examples/workflows/experimental/pydantic_io_in_steps_context.md new file mode 100644 index 000000000..6bc6d7bbe --- /dev/null +++ b/docs/examples/workflows/experimental/pydantic_io_in_steps_context.md @@ -0,0 +1,152 @@ +# Pydantic Io In Steps Context + + + + + + +=== "Hera" + + ```python linenums="1" + import sys + from typing import List + + if sys.version_info >= (3, 9): + from typing import Annotated + else: + from typing_extensions import Annotated + + + from hera.shared import global_config + from hera.workflows import Parameter, Steps, WorkflowTemplate, script + from hera.workflows.io.v1 import Input, Output + + global_config.experimental_features["decorator_syntax"] = True + + + class CutInput(Input): + cut_after: Annotated[int, Parameter(name="cut-after")] + strings: List[str] + + + class CutOutput(Output): + first_strings: Annotated[List[str], Parameter(name="first-strings")] + remainder: List[str] + + + class JoinInput(Input): + strings: List[str] + joiner: str + + + class JoinOutput(Output): + joined_string: Annotated[str, Parameter(name="joined-string")] + + + @script(constructor="runner") + def cut(input: CutInput) -> CutOutput: + return CutOutput( + first_strings=input.strings[: input.cut_after], + remainder=input.strings[input.cut_after :], + exit_code=1 if len(input.strings) <= input.cut_after else 0, + ) + + + @script(constructor="runner") + def join(input: JoinInput) -> JoinOutput: + return JoinOutput(joined_string=input.joiner.join(input.strings)) + + + with WorkflowTemplate(generate_name="pydantic-io-in-steps-context-v1-", entrypoint="d") as w: + with Steps(name="d"): + cut_result = cut(CutInput(strings=["hello", "world", "it's", "hera"], cut_after=1)) + join(JoinInput(strings=cut_result.first_strings, joiner=" ")) + ``` + +=== "YAML" + + ```yaml linenums="1" + apiVersion: argoproj.io/v1alpha1 + kind: WorkflowTemplate + metadata: + generateName: pydantic-io-in-steps-context-v1- + spec: + entrypoint: d + templates: + - name: d + steps: + - - arguments: + parameters: + - name: cut-after + value: '1' + - name: strings + value: '["hello", "world", "it''s", "hera"]' + name: cut + template: cut + - - arguments: + parameters: + - name: strings + value: '{{steps.cut.outputs.parameters.first-strings}}' + - name: joiner + value: ' ' + name: join + template: join + - inputs: + parameters: + - name: cut-after + - name: strings + name: cut + outputs: + parameters: + - name: first-strings + valueFrom: + path: /tmp/hera-outputs/parameters/first-strings + - name: remainder + valueFrom: + path: /tmp/hera-outputs/parameters/remainder + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.pydantic_io_in_steps_context:cut + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__outputs_directory + value: /tmp/hera-outputs + - name: hera__script_pydantic_io + value: '' + image: python:3.8 + source: '{{inputs.parameters}}' + - inputs: + parameters: + - name: strings + - name: joiner + name: join + outputs: + parameters: + - name: joined-string + valueFrom: + path: /tmp/hera-outputs/parameters/joined-string + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.pydantic_io_in_steps_context:join + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__outputs_directory + value: /tmp/hera-outputs + - name: hera__script_pydantic_io + value: '' + image: python:3.8 + source: '{{inputs.parameters}}' + ``` + diff --git a/docs/user-guides/decorators.md b/docs/user-guides/decorators.md index 1526261a0..8e8559be6 100644 --- a/docs/user-guides/decorators.md +++ b/docs/user-guides/decorators.md @@ -170,3 +170,40 @@ We can simply call the script templates, passing the input objects in. For more complex examples, including use of a dag, see [the "experimental" examples](../examples/workflows/experimental/new_dag_decorator_params.md). + +## Incremental workflow migration + +If you have a larger workflow you want to migrate to decorator syntax, you can enable a hybrid mode where Pydantic types can be passed to functions in a Steps/DAG context block, intermixed with calls that pass dictionaries. This will allow you to make smaller changes, and verify that the generated YAML remains the same. For example: + +```py +from hera.shared import global_config +from hera.workflows import Input, Output, Steps, Workflow, script + +global_config.experimental_features["context_manager_pydantic_io"] = True + +class MyInput(Input): + value: int + +class MyOutput(Output): + value: int + +# Function migrated to Pydantic I/O +@script() +def double(input: MyInput) -> MyOutput: + return MyOutput(value = input.value * 2) + +# Not yet migrated to Pydantic I/O +@script() +def print_value(value: int) -> None: + print("Value was", value) + +# Not yet migrated to decorator syntax +with Workflow(name="my-template") as w: + with Steps(name="steps"): + # Can now pass Pydantic types to/from functions + first_step = double(Input(value=5)) + # Results can be passed into non-migrated functions + print_value(arguments={"value": first_step.value}) +``` + +This feature is turned on by a different experimental flag, as we recommend only using this as a temporary stop-gap during a migration. Once you have fully migrated, you can disable the flag again to verify you are no longer using hybrid mode. diff --git a/examples/workflows/experimental/incremental-workflow-migration.yaml b/examples/workflows/experimental/incremental-workflow-migration.yaml new file mode 100644 index 000000000..9a5aa16bd --- /dev/null +++ b/examples/workflows/experimental/incremental-workflow-migration.yaml @@ -0,0 +1,53 @@ +apiVersion: argoproj.io/v1alpha1 +kind: Workflow +metadata: + name: my-template +spec: + templates: + - name: steps + steps: + - - name: double + template: double + - - arguments: + parameters: + - name: value + value: '{{steps.double.outputs.parameters.value}}' + name: print-value + template: print-value + - inputs: + parameters: + - name: value + name: double + outputs: + parameters: + - name: value + script: + command: + - python + image: python:3.9 + source: |- + import os + import sys + sys.path.append(os.getcwd()) + import json + try: value = json.loads(r'''{{inputs.parameters.value}}''') + except: value = r'''{{inputs.parameters.value}}''' + + return MyOutput(value=input.value * 2) + - inputs: + parameters: + - name: value + name: print-value + script: + command: + - python + image: python:3.9 + source: |- + import os + import sys + sys.path.append(os.getcwd()) + import json + try: value = json.loads(r'''{{inputs.parameters.value}}''') + except: value = r'''{{inputs.parameters.value}}''' + + print('Value was', value) diff --git a/examples/workflows/experimental/incremental_workflow_migration.py b/examples/workflows/experimental/incremental_workflow_migration.py new file mode 100644 index 000000000..a82359239 --- /dev/null +++ b/examples/workflows/experimental/incremental_workflow_migration.py @@ -0,0 +1,33 @@ +from hera.shared import global_config +from hera.workflows import Input, Output, Steps, Workflow, script + +global_config.experimental_features["context_manager_pydantic_io"] = True + + +class MyInput(Input): + value: int + + +class MyOutput(Output): + value: int + + +# Function migrated to Pydantic I/O +@script() +def double(input: MyInput) -> MyOutput: + return MyOutput(value=input.value * 2) + + +# Not yet migrated to Pydantic I/O +@script() +def print_value(value: int) -> None: + print("Value was", value) + + +# Not yet migrated to decorator syntax +with Workflow(name="my-template") as w: + with Steps(name="steps"): + # Can now pass Pydantic types to/from functions + first_step = double(Input(value=5)) + # Results can be passed into non-migrated functions + print_value(arguments={"value": first_step.value}) diff --git a/examples/workflows/experimental/pydantic-io-in-dag-context.yaml b/examples/workflows/experimental/pydantic-io-in-dag-context.yaml new file mode 100644 index 000000000..6f2788fc4 --- /dev/null +++ b/examples/workflows/experimental/pydantic-io-in-dag-context.yaml @@ -0,0 +1,84 @@ +apiVersion: argoproj.io/v1alpha1 +kind: WorkflowTemplate +metadata: + generateName: pydantic-io-in-steps-context-v1- +spec: + entrypoint: d + templates: + - dag: + tasks: + - arguments: + parameters: + - name: cut-after + value: '1' + - name: strings + value: '["hello", "world", "it''s", "hera"]' + name: cut + template: cut + - arguments: + parameters: + - name: strings + value: '{{tasks.cut.outputs.parameters.first-strings}}' + - name: joiner + value: ' ' + depends: cut + name: join + template: join + name: d + - inputs: + parameters: + - name: cut-after + - name: strings + name: cut + outputs: + parameters: + - name: first-strings + valueFrom: + path: /tmp/hera-outputs/parameters/first-strings + - name: remainder + valueFrom: + path: /tmp/hera-outputs/parameters/remainder + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.pydantic_io_in_dag_context:cut + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__outputs_directory + value: /tmp/hera-outputs + - name: hera__script_pydantic_io + value: '' + image: python:3.9 + source: '{{inputs.parameters}}' + - inputs: + parameters: + - name: strings + - name: joiner + name: join + outputs: + parameters: + - name: joined-string + valueFrom: + path: /tmp/hera-outputs/parameters/joined-string + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.pydantic_io_in_dag_context:join + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__outputs_directory + value: /tmp/hera-outputs + - name: hera__script_pydantic_io + value: '' + image: python:3.9 + source: '{{inputs.parameters}}' diff --git a/examples/workflows/experimental/pydantic-io-in-steps-context.yaml b/examples/workflows/experimental/pydantic-io-in-steps-context.yaml new file mode 100644 index 000000000..72c4300f5 --- /dev/null +++ b/examples/workflows/experimental/pydantic-io-in-steps-context.yaml @@ -0,0 +1,82 @@ +apiVersion: argoproj.io/v1alpha1 +kind: WorkflowTemplate +metadata: + generateName: pydantic-io-in-steps-context-v1- +spec: + entrypoint: d + templates: + - name: d + steps: + - - arguments: + parameters: + - name: cut-after + value: '1' + - name: strings + value: '["hello", "world", "it''s", "hera"]' + name: cut + template: cut + - - arguments: + parameters: + - name: strings + value: '{{steps.cut.outputs.parameters.first-strings}}' + - name: joiner + value: ' ' + name: join + template: join + - inputs: + parameters: + - name: cut-after + - name: strings + name: cut + outputs: + parameters: + - name: first-strings + valueFrom: + path: /tmp/hera-outputs/parameters/first-strings + - name: remainder + valueFrom: + path: /tmp/hera-outputs/parameters/remainder + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.pydantic_io_in_steps_context:cut + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__outputs_directory + value: /tmp/hera-outputs + - name: hera__script_pydantic_io + value: '' + image: python:3.9 + source: '{{inputs.parameters}}' + - inputs: + parameters: + - name: strings + - name: joiner + name: join + outputs: + parameters: + - name: joined-string + valueFrom: + path: /tmp/hera-outputs/parameters/joined-string + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.experimental.pydantic_io_in_steps_context:join + command: + - python + env: + - name: hera__script_annotations + value: '' + - name: hera__outputs_directory + value: /tmp/hera-outputs + - name: hera__script_pydantic_io + value: '' + image: python:3.9 + source: '{{inputs.parameters}}' diff --git a/examples/workflows/experimental/pydantic_io_in_dag_context.py b/examples/workflows/experimental/pydantic_io_in_dag_context.py new file mode 100644 index 000000000..76863f283 --- /dev/null +++ b/examples/workflows/experimental/pydantic_io_in_dag_context.py @@ -0,0 +1,53 @@ +import sys +from typing import List + +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated + + +from hera.shared import global_config +from hera.workflows import DAG, Parameter, WorkflowTemplate, script +from hera.workflows.io.v1 import Input, Output + +global_config.experimental_features["context_manager_pydantic_io"] = True + + +class CutInput(Input): + cut_after: Annotated[int, Parameter(name="cut-after")] + strings: List[str] + + +class CutOutput(Output): + first_strings: Annotated[List[str], Parameter(name="first-strings")] + remainder: List[str] + + +class JoinInput(Input): + strings: List[str] + joiner: str + + +class JoinOutput(Output): + joined_string: Annotated[str, Parameter(name="joined-string")] + + +@script(constructor="runner") +def cut(input: CutInput) -> CutOutput: + return CutOutput( + first_strings=input.strings[: input.cut_after], + remainder=input.strings[input.cut_after :], + exit_code=1 if len(input.strings) <= input.cut_after else 0, + ) + + +@script(constructor="runner") +def join(input: JoinInput) -> JoinOutput: + return JoinOutput(joined_string=input.joiner.join(input.strings)) + + +with WorkflowTemplate(generate_name="pydantic-io-in-steps-context-v1-", entrypoint="d") as w: + with DAG(name="d"): + cut_result = cut(CutInput(strings=["hello", "world", "it's", "hera"], cut_after=1)) + join(JoinInput(strings=cut_result.first_strings, joiner=" ")) diff --git a/examples/workflows/experimental/pydantic_io_in_steps_context.py b/examples/workflows/experimental/pydantic_io_in_steps_context.py new file mode 100644 index 000000000..c4169dc21 --- /dev/null +++ b/examples/workflows/experimental/pydantic_io_in_steps_context.py @@ -0,0 +1,53 @@ +import sys +from typing import List + +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated + + +from hera.shared import global_config +from hera.workflows import Parameter, Steps, WorkflowTemplate, script +from hera.workflows.io.v1 import Input, Output + +global_config.experimental_features["context_manager_pydantic_io"] = True + + +class CutInput(Input): + cut_after: Annotated[int, Parameter(name="cut-after")] + strings: List[str] + + +class CutOutput(Output): + first_strings: Annotated[List[str], Parameter(name="first-strings")] + remainder: List[str] + + +class JoinInput(Input): + strings: List[str] + joiner: str + + +class JoinOutput(Output): + joined_string: Annotated[str, Parameter(name="joined-string")] + + +@script(constructor="runner") +def cut(input: CutInput) -> CutOutput: + return CutOutput( + first_strings=input.strings[: input.cut_after], + remainder=input.strings[input.cut_after :], + exit_code=1 if len(input.strings) <= input.cut_after else 0, + ) + + +@script(constructor="runner") +def join(input: JoinInput) -> JoinOutput: + return JoinOutput(joined_string=input.joiner.join(input.strings)) + + +with WorkflowTemplate(generate_name="pydantic-io-in-steps-context-v1-", entrypoint="d") as w: + with Steps(name="d"): + cut_result = cut(CutInput(strings=["hello", "world", "it's", "hera"], cut_after=1)) + join(JoinInput(strings=cut_result.first_strings, joiner=" ")) diff --git a/src/hera/shared/_global_config.py b/src/hera/shared/_global_config.py index a5c42353c..62ff0706e 100644 --- a/src/hera/shared/_global_config.py +++ b/src/hera/shared/_global_config.py @@ -15,9 +15,9 @@ TypeTBase = Type[TBase] Hook = Callable[[TBase], TBase] -"""`Hook` is a callable that takes a Hera objects and returns the same, optionally mutated, object. +"""`Hook` is a callable that takes a Hera objects and returns the same, optionally mutated, object. -This can be a Workflow, a Script, a Container, etc - any Hera object. +This can be a Workflow, a Script, a Container, etc - any Hera object. """ _HookMap = Dict[TypeTBase, List[Hook]] @@ -202,14 +202,15 @@ def _set_defaults(cls, values): _SCRIPT_ANNOTATIONS_FLAG = "script_annotations" _SCRIPT_PYDANTIC_IO_FLAG = "script_pydantic_io" _DECORATOR_SYNTAX_FLAG = "decorator_syntax" +_CONTEXT_MANAGER_PYDANTIC_IO_FLAG = "context_manager_pydantic_io" _SUPPRESS_PARAMETER_DEFAULT_ERROR_FLAG = "suppress_parameter_default_error" # A dictionary where each key is a flag that has a list of flags which supersede it, hence # the given flag key can also be switched on by any of the flags in the list. Using simple flat lists # for now, otherwise with many superseding flags we may want to have a recursive structure. _SUPERSEDING_FLAGS: Dict[str, List] = { - _SCRIPT_ANNOTATIONS_FLAG: [_SCRIPT_PYDANTIC_IO_FLAG, _DECORATOR_SYNTAX_FLAG], - _SCRIPT_PYDANTIC_IO_FLAG: [_DECORATOR_SYNTAX_FLAG], + _SCRIPT_ANNOTATIONS_FLAG: [_SCRIPT_PYDANTIC_IO_FLAG, _DECORATOR_SYNTAX_FLAG, _CONTEXT_MANAGER_PYDANTIC_IO_FLAG], + _SCRIPT_PYDANTIC_IO_FLAG: [_DECORATOR_SYNTAX_FLAG, _CONTEXT_MANAGER_PYDANTIC_IO_FLAG], _DECORATOR_SYNTAX_FLAG: [], } diff --git a/src/hera/workflows/_meta_mixins.py b/src/hera/workflows/_meta_mixins.py index d32dd55b6..fcc8716d2 100644 --- a/src/hera/workflows/_meta_mixins.py +++ b/src/hera/workflows/_meta_mixins.py @@ -347,12 +347,12 @@ def __call__(self, *args, **kwargs) -> Union[None, Step, Task]: from hera.workflows.dag import DAG from hera.workflows.script import Script - from hera.workflows.steps import Parallel, Step, Steps - from hera.workflows.task import Task + from hera.workflows.steps import Parallel, Steps from hera.workflows.workflow import Workflow if _context.pieces: - if isinstance(_context.pieces[-1], Workflow): + current_context = _context.pieces[-1] + if isinstance(current_context, Workflow): # Notes on callable templates under a Workflow: # * If the user calls a script directly under a Workflow (outside of a Steps/DAG) then we add the script # template to the workflow and return None. @@ -369,11 +369,8 @@ def __call__(self, *args, **kwargs) -> Union[None, Step, Task]: raise InvalidTemplateCall( f"Callable Template '{self.name}' is not callable under a Workflow" # type: ignore ) - if isinstance(_context.pieces[-1], (Steps, Parallel)): - return Step(template=self, **kwargs) - - if isinstance(_context.pieces[-1], DAG): - return Task(template=self, **kwargs) + if isinstance(current_context, (Steps, Parallel, DAG)): + return current_context._create_leaf_node(template=self, **kwargs) raise InvalidTemplateCall( f"Callable Template '{self.name}' is not under a Workflow, Steps, Parallel, or DAG context" # type: ignore @@ -527,8 +524,7 @@ def _create_subnode( ) -> Union[Step, Task]: from hera.workflows.cluster_workflow_template import ClusterWorkflowTemplate from hera.workflows.dag import DAG - from hera.workflows.steps import Parallel, Step, Steps - from hera.workflows.task import Task + from hera.workflows.steps import Parallel, Steps from hera.workflows.workflow_template import WorkflowTemplate subnode_args = None @@ -543,13 +539,10 @@ def _create_subnode( signature = inspect.signature(func) output_class = signature.return_annotation - subnode: Union[Step, Task] - assert _context.pieces template_ref = None - _context.declaring = False - if _context.pieces[0] != self and isinstance(self, WorkflowTemplate): + if _context.pieces[0] is not self and isinstance(self, WorkflowTemplate): # Using None for cluster_scope means it won't appear in the YAML spec (saving some bytes), # as cluster_scope=False is the default value template_ref = TemplateRef( @@ -560,27 +553,18 @@ def _create_subnode( # Set template to None as it cannot be set alongside template_ref template = None # type: ignore - if isinstance(_context.pieces[-1], (Steps, Parallel)): - subnode = Step( - name=subnode_name, - template=template, - template_ref=template_ref, - arguments=subnode_args, - **kwargs, - ) - elif isinstance(_context.pieces[-1], DAG): - subnode = Task( - name=subnode_name, - template=template, - template_ref=template_ref, - arguments=subnode_args, - depends=" && ".join(sorted(_context.pieces[-1]._current_task_depends)) or None, - **kwargs, - ) - _context.pieces[-1]._current_task_depends.clear() + current_context = _context.pieces[-1] + if not isinstance(current_context, (Steps, Parallel, DAG)): + raise InvalidTemplateCall("Not under a Steps, Parallel, or DAG context") + subnode = current_context._create_leaf_node( + name=subnode_name, + template=template, + template_ref=template_ref, + arguments=subnode_args, + **kwargs, + ) subnode._build_obj = HeraBuildObj(subnode._subtype, output_class) - _context.declaring = True return subnode @_add_type_hints(Script) # type: ignore diff --git a/src/hera/workflows/_mixins.py b/src/hera/workflows/_mixins.py index 9acf61062..1bf9f1d23 100644 --- a/src/hera/workflows/_mixins.py +++ b/src/hera/workflows/_mixins.py @@ -715,16 +715,18 @@ class TemplateInvocatorSubNodeMixin(BaseMixin): _build_obj: Optional[HeraBuildObj] = PrivateAttr(None) def __getattribute__(self, name: str) -> Any: - if _context.declaring: + try: # Use object's __getattribute__ to avoid infinite recursion build_obj = object.__getattribute__(self, "_build_obj") - assert build_obj # Assertions to fix type checking + except AttributeError: + build_obj = None + if build_obj and _context.active: fields = get_fields(build_obj.output_class) annotations = get_field_annotations(build_obj.output_class) if name in fields: # If the attribute name is in the build_obj's output class fields, then - # as we are in a declaring context, the access is for a Task/Step output + # as we are in an active context, the access is for a Task/Step output subnode_name = object.__getattribute__(self, "name") subnode_type = object.__getattribute__(self, "_subtype") diff --git a/src/hera/workflows/dag.py b/src/hera/workflows/dag.py index b4443b381..43dcbaf96 100644 --- a/src/hera/workflows/dag.py +++ b/src/hera/workflows/dag.py @@ -47,6 +47,11 @@ class DAG( _node_names = PrivateAttr(default_factory=set) _current_task_depends: Set[str] = PrivateAttr(set()) + def _create_leaf_node(self, **kwargs) -> Task: + if self._current_task_depends and "depends" not in kwargs: + kwargs["depends"] = " && ".join(sorted(self._current_task_depends)) + return Task(**kwargs) + def _add_sub(self, node: Any): if not isinstance(node, Task): raise InvalidType(type(node)) @@ -54,6 +59,7 @@ def _add_sub(self, node: Any): raise NodeNameConflict(f"Found multiple Task nodes with name: {node.name}") self._node_names.add(node.name) self.tasks.append(node) + self._current_task_depends.clear() def _build_template(self) -> _ModelTemplate: """Builds the auto-generated `Template` representation of the `DAG`.""" diff --git a/src/hera/workflows/io/_io_mixins.py b/src/hera/workflows/io/_io_mixins.py index 3ef106122..484514435 100644 --- a/src/hera/workflows/io/_io_mixins.py +++ b/src/hera/workflows/io/_io_mixins.py @@ -1,5 +1,6 @@ import sys import warnings +from contextlib import contextmanager from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Type, Union if sys.version_info >= (3, 11): @@ -59,21 +60,29 @@ def _construct_io_from_fields(cls: Type[BaseModel]) -> Iterator[Tuple[str, Field yield field, field_info, Parameter(name=field) +@contextmanager +def no_active_context() -> Iterator[None]: + pieces = _context.pieces + _context.pieces = [] + try: + yield + finally: + _context.pieces = pieces + + class InputMixin(BaseModel): def __new__(cls, **kwargs): - if _context.declaring: + if _context.active: # Intercept the declaration to avoid validation on the templated strings - # We must then turn off declaring mode to be able to "construct" an instance + # We must then disable the active context to be able to "construct" an instance # of the InputMixin subclass. - _context.declaring = False - instance = cls.construct(**kwargs) - _context.declaring = True - return instance + with no_active_context(): + return cls.construct(**kwargs) else: return super(InputMixin, cls).__new__(cls) def __init__(self, /, **kwargs): - if _context.declaring: + if _context.active: # Return in order to skip validation of `construct`ed instance return @@ -157,17 +166,15 @@ def _get_as_arguments(self) -> ModelArguments: class OutputMixin(BaseModel): def __new__(cls, **kwargs): - if _context.declaring: + if _context.active: # Intercept the declaration to avoid validation on the templated strings - _context.declaring = False - instance = cls.construct(**kwargs) - _context.declaring = True - return instance + with no_active_context(): + return cls.construct(**kwargs) else: return super(OutputMixin, cls).__new__(cls) def __init__(self, /, **kwargs): - if _context.declaring: + if _context.active: # Return in order to skip validation of `construct`ed instance return diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index dce234510..122842330 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -41,6 +41,7 @@ from hera.expr import g from hera.shared import BaseMixin, global_config from hera.shared._global_config import ( + _CONTEXT_MANAGER_PYDANTIC_IO_FLAG, _SCRIPT_ANNOTATIONS_FLAG, _SCRIPT_PYDANTIC_IO_FLAG, _SUPPRESS_PARAMETER_DEFAULT_ERROR_FLAG, @@ -50,7 +51,7 @@ from hera.shared._type_util import get_workflow_annotation, is_subscripted, origin_type_issubclass from hera.shared.serialization import serialize from hera.workflows._context import _context -from hera.workflows._meta_mixins import CallableTemplateMixin +from hera.workflows._meta_mixins import CallableTemplateMixin, HeraBuildObj from hera.workflows._mixins import ( ArgumentsT, ContainerMixin, @@ -379,6 +380,22 @@ def _get_parameters_from_callable(source: Callable) -> List[Parameter]: return parameters +def _enable_experimental_feature_msg(flag: str) -> str: + return ( + "Please turn on experimental features by setting " + f'`hera.shared.global_config.experimental_features["{flag}"] = True`.' + " Note that experimental features are unstable and subject to breaking changes." + ) + + +def _assert_pydantic_io_enabled(annotation: str) -> None: + if not _flag_enabled(_SCRIPT_PYDANTIC_IO_FLAG): + raise ValueError( + f"Unable to instantiate {annotation} since it is an experimental feature. " + + _enable_experimental_feature_msg(_SCRIPT_PYDANTIC_IO_FLAG) + ) + + def _get_outputs_from_return_annotation( source: Callable, outputs_directory: Optional[str], @@ -408,16 +425,7 @@ def append_annotation(annotation: Union[Artifact, Parameter]): if param_or_artifact := get_workflow_annotation(annotation): append_annotation(param_or_artifact) elif isinstance(return_annotation, type) and issubclass(return_annotation, (OutputV1, OutputV2)): - if not _flag_enabled(_SCRIPT_PYDANTIC_IO_FLAG): - raise ValueError( - ( - "Unable to instantiate {} since it is an experimental feature." - " Please turn on experimental features by setting " - '`hera.shared.global_config.experimental_features["{}"] = True`.' - " Note that experimental features are unstable and subject to breaking changes." - ).format(return_annotation, _SCRIPT_PYDANTIC_IO_FLAG) - ) - + _assert_pydantic_io_enabled(return_annotation) output_class = return_annotation for output in output_class._get_outputs(): append_annotation(output) @@ -471,15 +479,7 @@ class will be used as inputs, rather than the class itself. for func_param in inspect.signature(source).parameters.values(): if not is_subscripted(func_param.annotation) and issubclass(func_param.annotation, (InputV1, InputV2)): - if not _flag_enabled(_SCRIPT_PYDANTIC_IO_FLAG): - raise ValueError( - ( - "Unable to instantiate {} since it is an experimental feature." - " Please turn on experimental features by setting " - '`hera.shared.global_config.experimental_features["{}"] = True`.' - " Note that experimental features are unstable and subject to breaking changes." - ).format(func_param.annotation, _SCRIPT_PYDANTIC_IO_FLAG) - ) + _assert_pydantic_io_enabled(func_param.annotation) if len(inspect.signature(source).parameters) != 1: raise SyntaxError("Only one function parameter can be specified when using an Input.") @@ -774,6 +774,32 @@ def script_wrapper(func: Callable[FuncIns, FuncR]) -> Callable: def task_wrapper(*args, **kwargs) -> Union[FuncR, Step, Task, None]: """Invokes a `Script` object's `__call__` method using the given SubNode (Step or Task) args/kwargs.""" if _context.active: + if len(args) == 1 and isinstance(args[0], (InputV1, InputV2)): + if not _flag_enabled(_CONTEXT_MANAGER_PYDANTIC_IO_FLAG): + raise SyntaxError( + "Cannot pass a Pydantic type inside a context. " + + _enable_experimental_feature_msg(_CONTEXT_MANAGER_PYDANTIC_IO_FLAG) + ) + arguments = args[0]._get_as_arguments() + arguments_list = [ + *(arguments.artifacts or []), + *(arguments.parameters or []), + ] + + subnode = s.__call__(arguments=arguments_list, **kwargs) + if not subnode: + raise SyntaxError("Cannot use Pydantic I/O outside of a DAG, Steps or Parallel context") + + output_class = inspect.signature(func).return_annotation + if not output_class or output_class is NoneType: + return None + + if not issubclass(output_class, (OutputV1, OutputV2)): + raise SyntaxError("Cannot use Pydantic input type without a Pydantic output type") + + _assert_pydantic_io_enabled(output_class) + subnode._build_obj = HeraBuildObj(subnode._subtype, output_class) + return subnode return s.__call__(*args, **kwargs) return func(*args, **kwargs) diff --git a/src/hera/workflows/steps.py b/src/hera/workflows/steps.py index bfe57c645..6557ec0d1 100644 --- a/src/hera/workflows/steps.py +++ b/src/hera/workflows/steps.py @@ -110,6 +110,9 @@ class Parallel( _node_names = PrivateAttr(default_factory=set) + def _create_leaf_node(self, **kwargs) -> Step: + return Step(**kwargs) + def _add_sub(self, node: Any): if not isinstance(node, Step): raise InvalidType(type(node)) @@ -180,6 +183,9 @@ def _build_steps(self) -> Optional[List[ParallelSteps]]: return steps or None + def _create_leaf_node(self, **kwargs) -> Step: + return Step(**kwargs) + def _add_sub(self, node: Any): if not isinstance(node, (Step, Parallel)): raise InvalidType(type(node)) diff --git a/tests/test_pydantic_io_workflow_syntax.py b/tests/test_pydantic_io_workflow_syntax.py new file mode 100644 index 000000000..a4fd46d58 --- /dev/null +++ b/tests/test_pydantic_io_workflow_syntax.py @@ -0,0 +1,62 @@ +import pytest + +from hera.shared._global_config import _CONTEXT_MANAGER_PYDANTIC_IO_FLAG +from hera.workflows import Input, Output, Steps, Workflow, script + + +class IntInput(Input): + field: int + + +class IntOutput(Output): + field: int + + +@pytest.fixture(autouse=True) +def enable_pydantic_io(global_config_fixture): + global_config_fixture.experimental_features[_CONTEXT_MANAGER_PYDANTIC_IO_FLAG] = True + + +def test_output_field_contains_argo_template(global_config_fixture): + @script() + def triple(input: IntInput) -> IntOutput: + return IntOutput(field=input.field * 3) + + with Workflow(name="foo"): + with Steps(name="bar"): + result = triple(IntInput(field=5)).field + + assert result == "{{steps.triple.outputs.parameters.field}}" + + +def test_script_can_return_none(): + @script() + def print_field(input: IntInput) -> None: + print(input.field) + + with Workflow(name="foo"): + with Steps(name="bar"): + result = print_field(IntInput(field=5)) + + assert result is None + + +def test_invalid_pydantic_io_outside_of_context(): + @script() + def triple(input: IntInput) -> IntOutput: + return IntOutput(field=input.field * 3) + + with Workflow(name="foo"): + with pytest.raises(SyntaxError, match="Cannot use Pydantic I/O outside of a .* context"): + triple(IntInput(field=5)) + + +def test_invalid_non_pydantic_return_type(): + @script() + def triple(input: IntInput) -> int: + return input.field * 3 + + with Workflow(name="foo"): + with Steps(name="bar"): + with pytest.raises(SyntaxError, match="Cannot use Pydantic input type without a Pydantic output type"): + triple(IntInput(field=5))