diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 4208d4f739bffb..a04b30093b1e8e 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -57,7 +57,7 @@ WorkflowRunStatus, ) -from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError +from .exc import WorkflowRunNotFoundError class WorkflowCycleManage: @@ -308,7 +308,7 @@ def _handle_node_execution_start( session.add(workflow_node_execution) - self._workflow_node_executions[workflow_node_execution.id] = workflow_node_execution + self._workflow_node_executions[event.node_execution_id] = workflow_node_execution return workflow_node_execution def _handle_workflow_node_execution_success( @@ -336,6 +336,7 @@ def _handle_workflow_node_execution_success( workflow_node_execution.finished_at = finished_at workflow_node_execution.elapsed_time = elapsed_time + workflow_node_execution = session.merge(workflow_node_execution) return workflow_node_execution def _handle_workflow_node_execution_failed( @@ -375,6 +376,7 @@ def _handle_workflow_node_execution_failed( workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.execution_metadata = execution_metadata + workflow_node_execution = session.merge(workflow_node_execution) return workflow_node_execution def _handle_workflow_node_execution_retried( @@ -427,7 +429,7 @@ def _handle_workflow_node_execution_retried( session.add(workflow_node_execution) - self._workflow_node_executions[workflow_node_execution.id] = workflow_node_execution + self._workflow_node_executions[event.node_execution_id] = workflow_node_execution return workflow_node_execution ################################################# @@ -837,13 +839,7 @@ def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> Workfl return workflow_run def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: - if node_execution_id in self._workflow_node_executions: - cached_workflow_node_execution = self._workflow_node_executions[node_execution_id] - cached_workflow_node_execution = session.merge(cached_workflow_node_execution) - return cached_workflow_node_execution - stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id) - workflow_node_execution = session.scalar(stmt) - if not workflow_node_execution: - raise WorkflowNodeExecutionNotFoundError(node_execution_id) - self._workflow_node_executions[node_execution_id] = workflow_node_execution - return workflow_node_execution + if node_execution_id not in self._workflow_node_executions: + raise ValueError(f"Workflow node execution not found: {node_execution_id}") + cached_workflow_node_execution = self._workflow_node_executions[node_execution_id] + return cached_workflow_node_execution