From 2dc4bd4b2767a54d9c390fc43738f2178db3ca1a Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 3 Jan 2025 14:29:47 +0800 Subject: [PATCH 1/5] refactor: streamline initialization of application_generate_entity and task_state in task pipeline classes Signed-off-by: -LAN- --- .../advanced_chat/generate_task_pipeline.py | 90 +++++++++---------- .../apps/workflow/generate_task_pipeline.py | 68 +++++++------- .../based_generate_task_pipeline.py | 4 - .../app/task_pipeline/message_cycle_manage.py | 17 +++- .../task_pipeline/workflow_cycle_manage.py | 46 ++++++---- 5 files changed, 117 insertions(+), 108 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index c6c4923ee684f2..bf3ea7d992deaa 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -67,7 +67,6 @@ from models.enums import CreatedByRole from models.workflow import ( Workflow, - WorkflowNodeExecution, WorkflowRunStatus, ) @@ -79,12 +78,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _task_state: WorkflowTaskState - _application_generate_entity: AdvancedChatAppGenerateEntity - _workflow_system_variables: dict[SystemVariableKey, Any] - _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] - _conversation_name_generate_thread: Optional[Thread] = None - def __init__( self, application_generate_entity: AdvancedChatAppGenerateEntity, @@ -96,10 +89,8 @@ def __init__( stream: bool, dialogue_count: int, ) -> None: - super().__init__( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - stream=stream, + BasedGenerateTaskPipeline.__init__( + self, application_generate_entity=application_generate_entity, queue_manager=queue_manager, stream=stream ) if isinstance(user, EndUser): @@ -112,33 +103,36 @@ def __init__( self._created_by_role = CreatedByRole.ACCOUNT else: raise NotImplementedError(f"User type not supported: {type(user)}") + WorkflowCycleManage.__init__( + self, + application_generate_entity=application_generate_entity, + workflow_system_variables={ + SystemVariableKey.QUERY: message.query, + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.CONVERSATION_ID: conversation.id, + SystemVariableKey.USER_ID: user_session_id, + SystemVariableKey.DIALOGUE_COUNT: dialogue_count, + SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, + SystemVariableKey.WORKFLOW_ID: workflow.id, + SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, + }, + ) + + self._task_state = WorkflowTaskState() + MessageCycleManage.__init__( + self, application_generate_entity=application_generate_entity, task_state=self._task_state + ) + self._application_generate_entity = application_generate_entity self._workflow_id = workflow.id self._workflow_features_dict = workflow.features_dict - self._conversation_id = conversation.id self._conversation_mode = conversation.mode - self._message_id = message.id self._message_created_at = int(message.created_at.timestamp()) - - self._workflow_system_variables = { - SystemVariableKey.QUERY: message.query, - SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.CONVERSATION_ID: conversation.id, - SystemVariableKey.USER_ID: user_session_id, - SystemVariableKey.DIALOGUE_COUNT: dialogue_count, - SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, - } - - self._task_state = WorkflowTaskState() - self._wip_workflow_node_executions = {} - - self._conversation_name_generate_thread = None + self._conversation_name_generate_thread: Thread | None = None self._recorded_files: list[Mapping[str, Any]] = [] - self._workflow_run_id = "" + self._workflow_run_id: str = "" def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ @@ -275,7 +269,7 @@ def _process_stream_response( if isinstance(event, QueuePingEvent): yield self._ping_stream_response() elif isinstance(event, QueueErrorEvent): - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: err = self._handle_error(event=event, session=session, message_id=self._message_id) session.commit() yield self._error_to_stream_response(err) @@ -284,7 +278,7 @@ def _process_stream_response( # override graph runtime state graph_runtime_state = event.graph_runtime_state - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: # init workflow run workflow_run = self._handle_workflow_run_start( session=session, @@ -310,7 +304,7 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_node_execution = self._handle_workflow_node_execution_retried( session=session, workflow_run=workflow_run, event=event @@ -329,7 +323,7 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_node_execution = self._handle_node_execution_start( session=session, workflow_run=workflow_run, event=event @@ -350,7 +344,7 @@ def _process_stream_response( if event.node_type in [NodeType.ANSWER, NodeType.END]: self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) node_finish_resp = self._workflow_node_finish_to_stream_response( @@ -364,7 +358,7 @@ def _process_stream_response( if node_finish_resp: yield node_finish_resp elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event) node_finish_resp = self._workflow_node_finish_to_stream_response( @@ -381,7 +375,7 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( session=session, @@ -395,7 +389,7 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( session=session, @@ -409,7 +403,7 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) iter_start_resp = self._workflow_iteration_start_to_stream_response( session=session, @@ -423,7 +417,7 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) iter_next_resp = self._workflow_iteration_next_to_stream_response( session=session, @@ -437,7 +431,7 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) iter_finish_resp = self._workflow_iteration_completed_to_stream_response( session=session, @@ -454,7 +448,7 @@ def _process_stream_response( if not graph_runtime_state: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._handle_workflow_run_success( session=session, workflow_run_id=self._workflow_run_id, @@ -479,7 +473,7 @@ def _process_stream_response( if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._handle_workflow_run_partial_success( session=session, workflow_run_id=self._workflow_run_id, @@ -504,7 +498,7 @@ def _process_stream_response( if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._handle_workflow_run_failed( session=session, workflow_run_id=self._workflow_run_id, @@ -529,7 +523,7 @@ def _process_stream_response( break elif isinstance(event, QueueStopEvent): if self._workflow_run_id and graph_runtime_state: - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._handle_workflow_run_failed( session=session, workflow_run_id=self._workflow_run_id, @@ -557,7 +551,7 @@ def _process_stream_response( elif isinstance(event, QueueRetrieverResourcesEvent): self._handle_retriever_resources(event) - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: message = self._get_message(session=session) message.message_metadata = ( json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None @@ -566,7 +560,7 @@ def _process_stream_response( elif isinstance(event, QueueAnnotationReplyEvent): self._handle_annotation_reply(event) - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: message = self._get_message(session=session) message.message_metadata = ( json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None @@ -603,7 +597,7 @@ def _process_stream_response( yield self._message_replace_to_stream_response(answer=output_moderation_answer) # Save message - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: self._save_message(session=session, graph_runtime_state=graph_runtime_state) session.commit() diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index c447f9c2fc1515..a49b101f0643d6 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -1,7 +1,7 @@ import logging import time from collections.abc import Generator -from typing import Any, Optional, Union +from typing import Optional, Union from sqlalchemy.orm import Session @@ -58,7 +58,6 @@ Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, - WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus, ) @@ -71,11 +70,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _task_state: WorkflowTaskState - _application_generate_entity: WorkflowAppGenerateEntity - _workflow_system_variables: dict[SystemVariableKey, Any] - _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] - def __init__( self, application_generate_entity: WorkflowAppGenerateEntity, @@ -84,12 +78,6 @@ def __init__( user: Union[Account, EndUser], stream: bool, ) -> None: - super().__init__( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - stream=stream, - ) - if isinstance(user, EndUser): self._user_id = user.id user_session_id = user.session_id @@ -101,17 +89,27 @@ def __init__( else: raise ValueError(f"Invalid user type: {type(user)}") + BasedGenerateTaskPipeline.__init__( + self, + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) + WorkflowCycleManage.__init__( + self, + application_generate_entity=application_generate_entity, + workflow_system_variables={ + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.USER_ID: user_session_id, + SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, + SystemVariableKey.WORKFLOW_ID: workflow.id, + SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, + }, + ) + + self._application_generate_entity = application_generate_entity self._workflow_id = workflow.id self._workflow_features_dict = workflow.features_dict - - self._workflow_system_variables = { - SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.USER_ID: user_session_id, - SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, - } - self._task_state = WorkflowTaskState() self._wip_workflow_node_executions = {} self._workflow_run_id = "" @@ -250,7 +248,7 @@ def _process_stream_response( # override graph runtime state graph_runtime_state = event.graph_runtime_state - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: # init workflow run workflow_run = self._handle_workflow_run_start( session=session, @@ -271,7 +269,7 @@ def _process_stream_response( ): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_node_execution = self._handle_workflow_node_execution_retried( session=session, workflow_run=workflow_run, event=event @@ -290,7 +288,7 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_node_execution = self._handle_node_execution_start( session=session, workflow_run=workflow_run, event=event @@ -306,7 +304,7 @@ def _process_stream_response( if node_start_response: yield node_start_response elif isinstance(event, QueueNodeSucceededEvent): - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) node_success_response = self._workflow_node_finish_to_stream_response( session=session, @@ -319,7 +317,7 @@ def _process_stream_response( if node_success_response: yield node_success_response elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_node_execution = self._handle_workflow_node_execution_failed( session=session, event=event, @@ -339,7 +337,7 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( session=session, @@ -354,7 +352,7 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( session=session, @@ -369,7 +367,7 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) iter_start_resp = self._workflow_iteration_start_to_stream_response( session=session, @@ -384,7 +382,7 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) iter_next_resp = self._workflow_iteration_next_to_stream_response( session=session, @@ -399,7 +397,7 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) iter_finish_resp = self._workflow_iteration_completed_to_stream_response( session=session, @@ -416,7 +414,7 @@ def _process_stream_response( if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._handle_workflow_run_success( session=session, workflow_run_id=self._workflow_run_id, @@ -445,7 +443,7 @@ def _process_stream_response( if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._handle_workflow_run_partial_success( session=session, workflow_run_id=self._workflow_run_id, @@ -473,7 +471,7 @@ def _process_stream_response( if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: workflow_run = self._handle_workflow_run_failed( session=session, workflow_run_id=self._workflow_run_id, diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index e363a7f64244d3..358cd2fc60b834 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -15,7 +15,6 @@ from core.app.entities.task_entities import ( ErrorStreamResponse, PingStreamResponse, - TaskState, ) from core.errors.error import QuotaExceededError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError @@ -30,9 +29,6 @@ class BasedGenerateTaskPipeline: BasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _task_state: TaskState - _application_generate_entity: AppGenerateEntity - def __init__( self, application_generate_entity: AppGenerateEntity, diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 15f2c25c66a3d2..6a4ab259ba4a1a 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -31,10 +31,19 @@ class MessageCycleManage: - _application_generate_entity: Union[ - ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity - ] - _task_state: Union[EasyUITaskState, WorkflowTaskState] + def __init__( + self, + *, + application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity, + ], + task_state: Union[EasyUITaskState, WorkflowTaskState], + ) -> None: + self._application_generate_entity = application_generate_entity + self._task_state = task_state def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: """ diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 885b341196d04b..d255f56f9adff7 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -34,7 +34,6 @@ ParallelBranchStartStreamResponse, WorkflowFinishStreamResponse, WorkflowStartStreamResponse, - WorkflowTaskState, ) from core.file import FILE_MODEL_IDENTITY, File from core.model_runtime.utils.encoders import jsonable_encoder @@ -62,9 +61,16 @@ class WorkflowCycleManage: - _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] - _task_state: WorkflowTaskState - _workflow_system_variables: dict[SystemVariableKey, Any] + def __init__( + self, + *, + application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], + workflow_system_variables: dict[SystemVariableKey, Any], + ) -> None: + self._workflow_run: WorkflowRun | None = None + self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {} + self._application_generate_entity = application_generate_entity + self._workflow_system_variables = workflow_system_variables def _handle_workflow_run_start( self, @@ -240,7 +246,7 @@ def _handle_workflow_run_failed( workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.exceptions_count = exceptions_count - stmt = select(WorkflowNodeExecution).where( + stmt = select(WorkflowNodeExecution.id).where( WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, WorkflowNodeExecution.app_id == workflow_run.app_id, WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, @@ -248,16 +254,18 @@ def _handle_workflow_run_failed( WorkflowNodeExecution.workflow_run_id == workflow_run.id, WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, ) - - running_workflow_node_executions = session.scalars(stmt).all() + ids = session.scalars(stmt).all() + # Use self._get_workflow_node_execution here to make sure the cache is updated + running_workflow_node_executions = [ + self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids + ] for workflow_node_execution in running_workflow_node_executions: + now = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = error - workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) - workflow_node_execution.elapsed_time = ( - workflow_node_execution.finished_at - workflow_node_execution.created_at - ).total_seconds() + workflow_node_execution.finished_at = now + workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds() if trace_manager: trace_manager.add_trace_task( @@ -812,22 +820,26 @@ def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any return None def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: - """ - Refetch workflow run - :param workflow_run_id: workflow run id - :return: - """ + if self._workflow_run and self._workflow_run.id == workflow_run_id: + workflow_run = self._workflow_run + workflow_run = session.merge(workflow_run) + return workflow_run stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) workflow_run = session.scalar(stmt) if not workflow_run: raise WorkflowRunNotFoundError(workflow_run_id) + self._workflow_run = workflow_run return workflow_run def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: + if node_execution_id in self._workflow_node_executions: + workflow_node_execution = self._workflow_node_executions[node_execution_id] + workflow_node_execution = session.merge(workflow_node_execution) + return workflow_node_execution stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id) workflow_node_execution = session.scalar(stmt) if not workflow_node_execution: raise WorkflowNodeExecutionNotFoundError(node_execution_id) - + self._workflow_node_executions[node_execution_id] = workflow_node_execution return workflow_node_execution From db1b05a6abe1a82994963151c8841aebf01a147c Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 3 Jan 2025 15:06:47 +0800 Subject: [PATCH 2/5] refactor: improve variable naming for clarity in task pipeline classes Signed-off-by: -LAN- --- .../advanced_chat/generate_task_pipeline.py | 176 +++++++++++------- .../apps/workflow/generate_task_pipeline.py | 120 +++++++----- .../based_generate_task_pipeline.py | 7 - .../task_pipeline/workflow_cycle_manage.py | 12 +- 4 files changed, 182 insertions(+), 133 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index bf3ea7d992deaa..6aad805034ba9c 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -73,7 +73,7 @@ logger = logging.getLogger(__name__) -class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage): +class AdvancedChatAppGenerateTaskPipeline: """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ @@ -89,8 +89,10 @@ def __init__( stream: bool, dialogue_count: int, ) -> None: - BasedGenerateTaskPipeline.__init__( - self, application_generate_entity=application_generate_entity, queue_manager=queue_manager, stream=stream + self._base_task_pipeline = BasedGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, ) if isinstance(user, EndUser): @@ -103,8 +105,8 @@ def __init__( self._created_by_role = CreatedByRole.ACCOUNT else: raise NotImplementedError(f"User type not supported: {type(user)}") - WorkflowCycleManage.__init__( - self, + + self._workflow_cycle_manager = WorkflowCycleManage( application_generate_entity=application_generate_entity, workflow_system_variables={ SystemVariableKey.QUERY: message.query, @@ -119,8 +121,8 @@ def __init__( ) self._task_state = WorkflowTaskState() - MessageCycleManage.__init__( - self, application_generate_entity=application_generate_entity, task_state=self._task_state + self._message_cycle_manager = MessageCycleManage( + application_generate_entity=application_generate_entity, task_state=self._task_state ) self._application_generate_entity = application_generate_entity @@ -140,13 +142,13 @@ def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStrea :return: """ # start generate conversation name thread - self._conversation_name_generate_thread = self._generate_conversation_name( + self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name( conversation_id=self._conversation_id, query=self._application_generate_entity.query ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) - if self._stream: + if self._base_task_pipeline._stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) @@ -263,16 +265,18 @@ def _process_stream_response( # init fake graph runtime state graph_runtime_state: Optional[GraphRuntimeState] = None - for queue_message in self._queue_manager.listen(): + for queue_message in self._base_task_pipeline._queue_manager.listen(): event = queue_message.event if isinstance(event, QueuePingEvent): - yield self._ping_stream_response() + yield self._base_task_pipeline._ping_stream_response() elif isinstance(event, QueueErrorEvent): with Session(db.engine, expire_on_commit=False) as session: - err = self._handle_error(event=event, session=session, message_id=self._message_id) + err = self._base_task_pipeline._handle_error( + event=event, session=session, message_id=self._message_id + ) session.commit() - yield self._error_to_stream_response(err) + yield self._base_task_pipeline._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): # override graph runtime state @@ -280,7 +284,7 @@ def _process_stream_response( with Session(db.engine, expire_on_commit=False) as session: # init workflow run - workflow_run = self._handle_workflow_run_start( + workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( session=session, workflow_id=self._workflow_id, user_id=self._user_id, @@ -291,7 +295,7 @@ def _process_stream_response( if not message: raise ValueError(f"Message not found: {self._message_id}") message.workflow_run_id = workflow_run.id - workflow_start_resp = self._workflow_start_to_stream_response( + workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) session.commit() @@ -305,11 +309,13 @@ def _process_stream_response( raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - workflow_node_execution = self._handle_workflow_node_execution_retried( + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( session=session, workflow_run=workflow_run, event=event ) - node_retry_resp = self._workflow_node_retry_to_stream_response( + node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( session=session, event=event, task_id=self._application_generate_entity.task_id, @@ -324,12 +330,14 @@ def _process_stream_response( raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - workflow_node_execution = self._handle_node_execution_start( + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( session=session, workflow_run=workflow_run, event=event ) - node_start_resp = self._workflow_node_start_to_stream_response( + node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response( session=session, event=event, task_id=self._application_generate_entity.task_id, @@ -342,12 +350,16 @@ def _process_stream_response( elif isinstance(event, QueueNodeSucceededEvent): # Record files if it's an answer node or end node if event.node_type in [NodeType.ANSWER, NodeType.END]: - self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) + self._recorded_files.extend( + self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {}) + ) with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( + session=session, event=event + ) - node_finish_resp = self._workflow_node_finish_to_stream_response( + node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( session=session, event=event, task_id=self._application_generate_entity.task_id, @@ -359,9 +371,11 @@ def _process_stream_response( yield node_finish_resp elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event) + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( + session=session, event=event + ) - node_finish_resp = self._workflow_node_finish_to_stream_response( + node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( session=session, event=event, task_id=self._application_generate_entity.task_id, @@ -376,12 +390,16 @@ def _process_stream_response( raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + parallel_start_resp = ( + self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) ) yield parallel_start_resp @@ -390,12 +408,16 @@ def _process_stream_response( raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + parallel_finish_resp = ( + self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) ) yield parallel_finish_resp @@ -404,8 +426,10 @@ def _process_stream_response( raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - iter_start_resp = self._workflow_iteration_start_to_stream_response( + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, @@ -418,8 +442,10 @@ def _process_stream_response( raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - iter_next_resp = self._workflow_iteration_next_to_stream_response( + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, @@ -432,8 +458,10 @@ def _process_stream_response( raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - iter_finish_resp = self._workflow_iteration_completed_to_stream_response( + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, @@ -449,7 +477,7 @@ def _process_stream_response( raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._handle_workflow_run_success( + workflow_run = self._workflow_cycle_manager._handle_workflow_run_success( session=session, workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, @@ -460,13 +488,15 @@ def _process_stream_response( trace_manager=trace_manager, ) - workflow_finish_resp = self._workflow_finish_to_stream_response( + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) session.commit() yield workflow_finish_resp - self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + self._base_task_pipeline._queue_manager.publish( + QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE + ) elif isinstance(event, QueueWorkflowPartialSuccessEvent): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") @@ -474,7 +504,7 @@ def _process_stream_response( raise ValueError("graph runtime state not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._handle_workflow_run_partial_success( + workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success( session=session, workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, @@ -485,13 +515,15 @@ def _process_stream_response( conversation_id=None, trace_manager=trace_manager, ) - workflow_finish_resp = self._workflow_finish_to_stream_response( + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) session.commit() yield workflow_finish_resp - self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + self._base_task_pipeline._queue_manager.publish( + QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE + ) elif isinstance(event, QueueWorkflowFailedEvent): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") @@ -499,7 +531,7 @@ def _process_stream_response( raise ValueError("graph runtime state not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._handle_workflow_run_failed( + workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( session=session, workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, @@ -511,20 +543,22 @@ def _process_stream_response( trace_manager=trace_manager, exceptions_count=event.exceptions_count, ) - workflow_finish_resp = self._workflow_finish_to_stream_response( + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) - err = self._handle_error(event=err_event, session=session, message_id=self._message_id) + err = self._base_task_pipeline._handle_error( + event=err_event, session=session, message_id=self._message_id + ) session.commit() yield workflow_finish_resp - yield self._error_to_stream_response(err) + yield self._base_task_pipeline._error_to_stream_response(err) break elif isinstance(event, QueueStopEvent): if self._workflow_run_id and graph_runtime_state: with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._handle_workflow_run_failed( + workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( session=session, workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, @@ -535,7 +569,7 @@ def _process_stream_response( conversation_id=self._conversation_id, trace_manager=trace_manager, ) - workflow_finish_resp = self._workflow_finish_to_stream_response( + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, @@ -549,7 +583,7 @@ def _process_stream_response( yield self._message_end_to_stream_response() break elif isinstance(event, QueueRetrieverResourcesEvent): - self._handle_retriever_resources(event) + self._message_cycle_manager._handle_retriever_resources(event) with Session(db.engine, expire_on_commit=False) as session: message = self._get_message(session=session) @@ -558,7 +592,7 @@ def _process_stream_response( ) session.commit() elif isinstance(event, QueueAnnotationReplyEvent): - self._handle_annotation_reply(event) + self._message_cycle_manager._handle_annotation_reply(event) with Session(db.engine, expire_on_commit=False) as session: message = self._get_message(session=session) @@ -581,20 +615,24 @@ def _process_stream_response( tts_publisher.publish(queue_message) self._task_state.answer += delta_text - yield self._message_to_stream_response( + yield self._message_cycle_manager._message_to_stream_response( answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector ) elif isinstance(event, QueueMessageReplaceEvent): # published by moderation - yield self._message_replace_to_stream_response(answer=event.text) + yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueueAdvancedChatMessageEndEvent): if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) + output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished( + self._task_state.answer + ) if output_moderation_answer: self._task_state.answer = output_moderation_answer - yield self._message_replace_to_stream_response(answer=output_moderation_answer) + yield self._message_cycle_manager._message_replace_to_stream_response( + answer=output_moderation_answer + ) # Save message with Session(db.engine, expire_on_commit=False) as session: @@ -615,7 +653,7 @@ def _process_stream_response( def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: message = self._get_message(session=session) message.answer = self._task_state.answer - message.provider_response_latency = time.perf_counter() - self._start_at + message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at message.message_metadata = ( json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None ) @@ -679,20 +717,20 @@ def _handle_output_moderation_chunk(self, text: str) -> bool: :param text: text :return: True if output moderation should direct output, otherwise False """ - if self._output_moderation_handler: - if self._output_moderation_handler.should_direct_output(): + if self._base_task_pipeline._output_moderation_handler: + if self._base_task_pipeline._output_moderation_handler.should_direct_output(): # stop subscribe new token when output moderation should direct output - self._task_state.answer = self._output_moderation_handler.get_final_output() - self._queue_manager.publish( + self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output() + self._base_task_pipeline._queue_manager.publish( QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE ) - self._queue_manager.publish( + self._base_task_pipeline._queue_manager.publish( QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE ) return True else: - self._output_moderation_handler.append_new_token(text) + self._base_task_pipeline._output_moderation_handler.append_new_token(text) return False diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index a49b101f0643d6..f89f456916e72e 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -65,7 +65,7 @@ logger = logging.getLogger(__name__) -class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage): +class WorkflowAppGenerateTaskPipeline: """ WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ @@ -78,6 +78,12 @@ def __init__( user: Union[Account, EndUser], stream: bool, ) -> None: + self._base_task_pipeline = BasedGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) + if isinstance(user, EndUser): self._user_id = user.id user_session_id = user.session_id @@ -89,14 +95,7 @@ def __init__( else: raise ValueError(f"Invalid user type: {type(user)}") - BasedGenerateTaskPipeline.__init__( - self, - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - stream=stream, - ) - WorkflowCycleManage.__init__( - self, + self._workflow_cycle_manager = WorkflowCycleManage( application_generate_entity=application_generate_entity, workflow_system_variables={ SystemVariableKey.FILES: application_generate_entity.files, @@ -111,7 +110,6 @@ def __init__( self._workflow_id = workflow.id self._workflow_features_dict = workflow.features_dict self._task_state = WorkflowTaskState() - self._wip_workflow_node_executions = {} self._workflow_run_id = "" def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: @@ -120,7 +118,7 @@ def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStr :return: """ generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) - if self._stream: + if self._base_task_pipeline._stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) @@ -235,14 +233,14 @@ def _process_stream_response( """ graph_runtime_state = None - for queue_message in self._queue_manager.listen(): + for queue_message in self._base_task_pipeline._queue_manager.listen(): event = queue_message.event if isinstance(event, QueuePingEvent): - yield self._ping_stream_response() + yield self._base_task_pipeline._ping_stream_response() elif isinstance(event, QueueErrorEvent): - err = self._handle_error(event=event) - yield self._error_to_stream_response(err) + err = self._base_task_pipeline._handle_error(event=event) + yield self._base_task_pipeline._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): # override graph runtime state @@ -250,14 +248,14 @@ def _process_stream_response( with Session(db.engine, expire_on_commit=False) as session: # init workflow run - workflow_run = self._handle_workflow_run_start( + workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( session=session, workflow_id=self._workflow_id, user_id=self._user_id, created_by_role=self._created_by_role, ) self._workflow_run_id = workflow_run.id - start_resp = self._workflow_start_to_stream_response( + start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) session.commit() @@ -270,11 +268,13 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - workflow_node_execution = self._handle_workflow_node_execution_retried( + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( session=session, workflow_run=workflow_run, event=event ) - response = self._workflow_node_retry_to_stream_response( + response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( session=session, event=event, task_id=self._application_generate_entity.task_id, @@ -289,11 +289,13 @@ def _process_stream_response( raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - workflow_node_execution = self._handle_node_execution_start( + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( session=session, workflow_run=workflow_run, event=event ) - node_start_response = self._workflow_node_start_to_stream_response( + node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response( session=session, event=event, task_id=self._application_generate_entity.task_id, @@ -305,8 +307,10 @@ def _process_stream_response( yield node_start_response elif isinstance(event, QueueNodeSucceededEvent): with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) - node_success_response = self._workflow_node_finish_to_stream_response( + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( + session=session, event=event + ) + node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( session=session, event=event, task_id=self._application_generate_entity.task_id, @@ -318,11 +322,11 @@ def _process_stream_response( yield node_success_response elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._handle_workflow_node_execution_failed( + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( session=session, event=event, ) - node_failed_response = self._workflow_node_finish_to_stream_response( + node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( session=session, event=event, task_id=self._application_generate_entity.task_id, @@ -338,12 +342,16 @@ def _process_stream_response( raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + parallel_start_resp = ( + self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) ) yield parallel_start_resp @@ -353,12 +361,16 @@ def _process_stream_response( raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + parallel_finish_resp = ( + self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) ) yield parallel_finish_resp @@ -368,8 +380,10 @@ def _process_stream_response( raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - iter_start_resp = self._workflow_iteration_start_to_stream_response( + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, @@ -383,8 +397,10 @@ def _process_stream_response( raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - iter_next_resp = self._workflow_iteration_next_to_stream_response( + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, @@ -398,8 +414,10 @@ def _process_stream_response( raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - iter_finish_resp = self._workflow_iteration_completed_to_stream_response( + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, @@ -415,7 +433,7 @@ def _process_stream_response( raise ValueError("graph runtime state not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._handle_workflow_run_success( + workflow_run = self._workflow_cycle_manager._handle_workflow_run_success( session=session, workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, @@ -429,7 +447,7 @@ def _process_stream_response( # save workflow app log self._save_workflow_app_log(session=session, workflow_run=workflow_run) - workflow_finish_resp = self._workflow_finish_to_stream_response( + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, @@ -444,7 +462,7 @@ def _process_stream_response( raise ValueError("graph runtime state not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._handle_workflow_run_partial_success( + workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success( session=session, workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, @@ -459,7 +477,7 @@ def _process_stream_response( # save workflow app log self._save_workflow_app_log(session=session, workflow_run=workflow_run) - workflow_finish_resp = self._workflow_finish_to_stream_response( + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) session.commit() @@ -472,7 +490,7 @@ def _process_stream_response( raise ValueError("graph runtime state not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._handle_workflow_run_failed( + workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( session=session, workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, @@ -490,7 +508,7 @@ def _process_stream_response( # save workflow app log self._save_workflow_app_log(session=session, workflow_run=workflow_run) - workflow_finish_resp = self._workflow_finish_to_stream_response( + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) session.commit() diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 358cd2fc60b834..a2e06d4e1ff492 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -35,13 +35,6 @@ def __init__( queue_manager: AppQueueManager, stream: bool, ) -> None: - """ - Initialize GenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param user: user - :param stream: stream - """ self._application_generate_entity = application_generate_entity self._queue_manager = queue_manager self._start_at = time.perf_counter() diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index d255f56f9adff7..ed24b418bd46cb 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -821,9 +821,9 @@ def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: if self._workflow_run and self._workflow_run.id == workflow_run_id: - workflow_run = self._workflow_run - workflow_run = session.merge(workflow_run) - return workflow_run + cached_workflow_run = self._workflow_run + cached_workflow_run = session.merge(cached_workflow_run) + return cached_workflow_run stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) workflow_run = session.scalar(stmt) if not workflow_run: @@ -834,9 +834,9 @@ def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> Workfl def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: if node_execution_id in self._workflow_node_executions: - workflow_node_execution = self._workflow_node_executions[node_execution_id] - workflow_node_execution = session.merge(workflow_node_execution) - return workflow_node_execution + cached_workflow_node_execution = self._workflow_node_executions[node_execution_id] + cached_workflow_node_execution = session.merge(cached_workflow_node_execution) + return cached_workflow_node_execution stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id) workflow_node_execution = session.scalar(stmt) if not workflow_node_execution: From d3161fffe53d675b02dbf8552b8995e926ef31ff Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 3 Jan 2025 16:04:53 +0800 Subject: [PATCH 3/5] fix: ensure workflow node executions are tracked after addition to session Signed-off-by: -LAN- --- api/core/app/task_pipeline/workflow_cycle_manage.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index ed24b418bd46cb..4208d4f739bffb 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -307,6 +307,8 @@ def _handle_node_execution_start( workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) session.add(workflow_node_execution) + + self._workflow_node_executions[workflow_node_execution.id] = workflow_node_execution return workflow_node_execution def _handle_workflow_node_execution_success( @@ -424,6 +426,8 @@ def _handle_workflow_node_execution_retried( workflow_node_execution.index = event.node_run_index session.add(workflow_node_execution) + + self._workflow_node_executions[workflow_node_execution.id] = workflow_node_execution return workflow_node_execution ################################################# From f459c025e23d0ba491f5839d9ad073f5d5916f55 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 3 Jan 2025 16:44:06 +0800 Subject: [PATCH 4/5] fix: update workflow node execution retrieval logic to improve error handling Signed-off-by: -LAN- --- .../task_pipeline/workflow_cycle_manage.py | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 4208d4f739bffb..a04b30093b1e8e 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -57,7 +57,7 @@ WorkflowRunStatus, ) -from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError +from .exc import WorkflowRunNotFoundError class WorkflowCycleManage: @@ -308,7 +308,7 @@ def _handle_node_execution_start( session.add(workflow_node_execution) - self._workflow_node_executions[workflow_node_execution.id] = workflow_node_execution + self._workflow_node_executions[event.node_execution_id] = workflow_node_execution return workflow_node_execution def _handle_workflow_node_execution_success( @@ -336,6 +336,7 @@ def _handle_workflow_node_execution_success( workflow_node_execution.finished_at = finished_at workflow_node_execution.elapsed_time = elapsed_time + workflow_node_execution = session.merge(workflow_node_execution) return workflow_node_execution def _handle_workflow_node_execution_failed( @@ -375,6 +376,7 @@ def _handle_workflow_node_execution_failed( workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.execution_metadata = execution_metadata + workflow_node_execution = session.merge(workflow_node_execution) return workflow_node_execution def _handle_workflow_node_execution_retried( @@ -427,7 +429,7 @@ def _handle_workflow_node_execution_retried( session.add(workflow_node_execution) - self._workflow_node_executions[workflow_node_execution.id] = workflow_node_execution + self._workflow_node_executions[event.node_execution_id] = workflow_node_execution return workflow_node_execution ################################################# @@ -837,13 +839,7 @@ def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> Workfl return workflow_run def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: - if node_execution_id in self._workflow_node_executions: - cached_workflow_node_execution = self._workflow_node_executions[node_execution_id] - cached_workflow_node_execution = session.merge(cached_workflow_node_execution) - return cached_workflow_node_execution - stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id) - workflow_node_execution = session.scalar(stmt) - if not workflow_node_execution: - raise WorkflowNodeExecutionNotFoundError(node_execution_id) - self._workflow_node_executions[node_execution_id] = workflow_node_execution - return workflow_node_execution + if node_execution_id not in self._workflow_node_executions: + raise ValueError(f"Workflow node execution not found: {node_execution_id}") + cached_workflow_node_execution = self._workflow_node_executions[node_execution_id] + return cached_workflow_node_execution From fb2019767b39c6d6b7440b8cf4ea135a01053025 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 3 Jan 2025 16:55:17 +0800 Subject: [PATCH 5/5] fix: correct selection of node_execution_id in workflow cycle management Signed-off-by: -LAN- --- api/core/app/task_pipeline/workflow_cycle_manage.py | 4 ++-- api/core/app/task_pipeline/workflow_cycle_state_manager.py | 0 2 files changed, 2 insertions(+), 2 deletions(-) delete mode 100644 api/core/app/task_pipeline/workflow_cycle_state_manager.py diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index a04b30093b1e8e..dcc364d22766e6 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -246,7 +246,7 @@ def _handle_workflow_run_failed( workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.exceptions_count = exceptions_count - stmt = select(WorkflowNodeExecution.id).where( + stmt = select(WorkflowNodeExecution.node_execution_id).where( WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, WorkflowNodeExecution.app_id == workflow_run.app_id, WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, @@ -257,7 +257,7 @@ def _handle_workflow_run_failed( ids = session.scalars(stmt).all() # Use self._get_workflow_node_execution here to make sure the cache is updated running_workflow_node_executions = [ - self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids + self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id ] for workflow_node_execution in running_workflow_node_executions: diff --git a/api/core/app/task_pipeline/workflow_cycle_state_manager.py b/api/core/app/task_pipeline/workflow_cycle_state_manager.py deleted file mode 100644 index e69de29bb2d1d6..00000000000000