diff --git a/src/hera/workflows/_meta_mixins.py b/src/hera/workflows/_meta_mixins.py index 0da462f07..b575f9923 100644 --- a/src/hera/workflows/_meta_mixins.py +++ b/src/hera/workflows/_meta_mixins.py @@ -346,12 +346,12 @@ def __call__(self, *args, **kwargs) -> Union[None, Step, Task]: from hera.workflows.dag import DAG from hera.workflows.script import Script - from hera.workflows.steps import Parallel, Step, Steps - from hera.workflows.task import Task + from hera.workflows.steps import Parallel, Steps from hera.workflows.workflow import Workflow if _context.pieces: - if isinstance(_context.pieces[-1], Workflow): + current_context = _context.pieces[-1] + if isinstance(current_context, Workflow): # Notes on callable templates under a Workflow: # * If the user calls a script directly under a Workflow (outside of a Steps/DAG) then we add the script # template to the workflow and return None. @@ -368,16 +368,8 @@ def __call__(self, *args, **kwargs) -> Union[None, Step, Task]: raise InvalidTemplateCall( f"Callable Template '{self.name}' is not callable under a Workflow" # type: ignore ) - if isinstance(_context.pieces[-1], (Steps, Parallel)): - return Step(template=self, **kwargs) - - if isinstance(_context.pieces[-1], DAG): - # Add dependencies based on context if not explicitly provided - current_task_depends = _context.pieces[-1]._current_task_depends - if current_task_depends and "depends" not in kwargs: - kwargs["depends"] = " && ".join(sorted(current_task_depends)) - - return Task(template=self, **kwargs) + if isinstance(current_context, (Steps, Parallel, DAG)): + return current_context._create_leaf_node(template=self, **kwargs) raise InvalidTemplateCall( f"Callable Template '{self.name}' is not under a Workflow, Steps, Parallel, or DAG context" # type: ignore @@ -528,8 +520,7 @@ def _create_subnode( ) -> Union[Step, Task]: from hera.workflows.cluster_workflow_template import ClusterWorkflowTemplate from hera.workflows.dag import DAG - from hera.workflows.steps import Parallel, Step, Steps - from hera.workflows.task import Task + from hera.workflows.steps import Parallel, Steps from hera.workflows.workflow_template import WorkflowTemplate subnode_args = None @@ -539,8 +530,6 @@ def _create_subnode( signature = inspect.signature(func) output_class = signature.return_annotation - subnode: Union[Step, Task] - assert _context.pieces template_ref = None @@ -556,25 +545,16 @@ def _create_subnode( template = None # type: ignore current_context = _context.pieces[-1] - if isinstance(current_context, (Steps, Parallel)): - subnode = Step( - name=subnode_name, - template=template, - template_ref=template_ref, - arguments=subnode_args, - **kwargs, - ) - elif isinstance(current_context, DAG): - if current_context._current_task_depends and "depends" not in kwargs: - kwargs["depends"] = " && ".join(sorted(current_context._current_task_depends)) - subnode = Task( - name=subnode_name, - template=template, - template_ref=template_ref, - arguments=subnode_args, - **kwargs, - ) - + if not isinstance(current_context, (Steps, Parallel, DAG)): + raise InvalidTemplateCall("Not under a Steps, Parallel, or DAG context") + + subnode = current_context._create_leaf_node( + name=subnode_name, + template=template, + template_ref=template_ref, + arguments=subnode_args, + **kwargs, + ) subnode._build_obj = HeraBuildObj(subnode._subtype, output_class) return subnode diff --git a/src/hera/workflows/dag.py b/src/hera/workflows/dag.py index 2443c06fb..43dcbaf96 100644 --- a/src/hera/workflows/dag.py +++ b/src/hera/workflows/dag.py @@ -47,6 +47,11 @@ class DAG( _node_names = PrivateAttr(default_factory=set) _current_task_depends: Set[str] = PrivateAttr(set()) + def _create_leaf_node(self, **kwargs) -> Task: + if self._current_task_depends and "depends" not in kwargs: + kwargs["depends"] = " && ".join(sorted(self._current_task_depends)) + return Task(**kwargs) + def _add_sub(self, node: Any): if not isinstance(node, Task): raise InvalidType(type(node)) diff --git a/src/hera/workflows/steps.py b/src/hera/workflows/steps.py index bfe57c645..6557ec0d1 100644 --- a/src/hera/workflows/steps.py +++ b/src/hera/workflows/steps.py @@ -110,6 +110,9 @@ class Parallel( _node_names = PrivateAttr(default_factory=set) + def _create_leaf_node(self, **kwargs) -> Step: + return Step(**kwargs) + def _add_sub(self, node: Any): if not isinstance(node, Step): raise InvalidType(type(node)) @@ -180,6 +183,9 @@ def _build_steps(self) -> Optional[List[ParallelSteps]]: return steps or None + def _create_leaf_node(self, **kwargs) -> Step: + return Step(**kwargs) + def _add_sub(self, node: Any): if not isinstance(node, (Step, Parallel)): raise InvalidType(type(node))