Skip to content

Commit

Permalink
Move Step/Task creation to Steps/Parallel/DAG
Browse files Browse the repository at this point in the history
Move the logic that creates the right leaf node for Steps, Parallel and
DAG to a _create_leaf_node method on those types. DAG now specifies how
to default the depends parameter to Task based on its
_current_task_depends field. This simplifies the duplicated logic in
_meta_mixins to a simple isinstance check for any of those three types,
followed by a _create_leaf_node call.

Signed-off-by: Alice Purcell <alicederyn@gmail.com>
  • Loading branch information
alicederyn committed Sep 5, 2024
1 parent 20dae4c commit ce12edf
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 36 deletions.
52 changes: 16 additions & 36 deletions src/hera/workflows/_meta_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,12 @@ def __call__(self, *args, **kwargs) -> Union[None, Step, Task]:

from hera.workflows.dag import DAG
from hera.workflows.script import Script
from hera.workflows.steps import Parallel, Step, Steps
from hera.workflows.task import Task
from hera.workflows.steps import Parallel, Steps
from hera.workflows.workflow import Workflow

if _context.pieces:
if isinstance(_context.pieces[-1], Workflow):
current_context = _context.pieces[-1]
if isinstance(current_context, Workflow):
# Notes on callable templates under a Workflow:
# * If the user calls a script directly under a Workflow (outside of a Steps/DAG) then we add the script
# template to the workflow and return None.
Expand All @@ -368,16 +368,8 @@ def __call__(self, *args, **kwargs) -> Union[None, Step, Task]:
raise InvalidTemplateCall(
f"Callable Template '{self.name}' is not callable under a Workflow" # type: ignore
)
if isinstance(_context.pieces[-1], (Steps, Parallel)):
return Step(template=self, **kwargs)

if isinstance(_context.pieces[-1], DAG):
# Add dependencies based on context if not explicitly provided
current_task_depends = _context.pieces[-1]._current_task_depends
if current_task_depends and "depends" not in kwargs:
kwargs["depends"] = " && ".join(sorted(current_task_depends))

return Task(template=self, **kwargs)
if isinstance(current_context, (Steps, Parallel, DAG)):
return current_context._create_leaf_node(template=self, **kwargs)

raise InvalidTemplateCall(
f"Callable Template '{self.name}' is not under a Workflow, Steps, Parallel, or DAG context" # type: ignore
Expand Down Expand Up @@ -528,8 +520,7 @@ def _create_subnode(
) -> Union[Step, Task]:
from hera.workflows.cluster_workflow_template import ClusterWorkflowTemplate
from hera.workflows.dag import DAG
from hera.workflows.steps import Parallel, Step, Steps
from hera.workflows.task import Task
from hera.workflows.steps import Parallel, Steps
from hera.workflows.workflow_template import WorkflowTemplate

subnode_args = None
Expand All @@ -539,8 +530,6 @@ def _create_subnode(
signature = inspect.signature(func)
output_class = signature.return_annotation

subnode: Union[Step, Task]

assert _context.pieces

template_ref = None
Expand All @@ -556,25 +545,16 @@ def _create_subnode(
template = None # type: ignore

current_context = _context.pieces[-1]
if isinstance(current_context, (Steps, Parallel)):
subnode = Step(
name=subnode_name,
template=template,
template_ref=template_ref,
arguments=subnode_args,
**kwargs,
)
elif isinstance(current_context, DAG):
if current_context._current_task_depends and "depends" not in kwargs:
kwargs["depends"] = " && ".join(sorted(current_context._current_task_depends))
subnode = Task(
name=subnode_name,
template=template,
template_ref=template_ref,
arguments=subnode_args,
**kwargs,
)

if not isinstance(current_context, (Steps, Parallel, DAG)):
raise InvalidTemplateCall("Not under a Steps, Parallel, or DAG context")

subnode = current_context._create_leaf_node(
name=subnode_name,
template=template,
template_ref=template_ref,
arguments=subnode_args,
**kwargs,
)
subnode._build_obj = HeraBuildObj(subnode._subtype, output_class)
return subnode

Expand Down
5 changes: 5 additions & 0 deletions src/hera/workflows/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ class DAG(
_node_names = PrivateAttr(default_factory=set)
_current_task_depends: Set[str] = PrivateAttr(set())

def _create_leaf_node(self, **kwargs) -> Task:
if self._current_task_depends and "depends" not in kwargs:
kwargs["depends"] = " && ".join(sorted(self._current_task_depends))
return Task(**kwargs)

def _add_sub(self, node: Any):
if not isinstance(node, Task):
raise InvalidType(type(node))
Expand Down
6 changes: 6 additions & 0 deletions src/hera/workflows/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ class Parallel(

_node_names = PrivateAttr(default_factory=set)

def _create_leaf_node(self, **kwargs) -> Step:
return Step(**kwargs)

def _add_sub(self, node: Any):
if not isinstance(node, Step):
raise InvalidType(type(node))
Expand Down Expand Up @@ -180,6 +183,9 @@ def _build_steps(self) -> Optional[List[ParallelSteps]]:

return steps or None

def _create_leaf_node(self, **kwargs) -> Step:
return Step(**kwargs)

def _add_sub(self, node: Any):
if not isinstance(node, (Step, Parallel)):
raise InvalidType(type(node))
Expand Down

0 comments on commit ce12edf

Please sign in to comment.