diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index c6c4923ee684f2..6aad805034ba9c 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -67,24 +67,17 @@ from models.enums import CreatedByRole from models.workflow import ( Workflow, - WorkflowNodeExecution, WorkflowRunStatus, ) logger = logging.getLogger(__name__) -class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage): +class AdvancedChatAppGenerateTaskPipeline: """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _task_state: WorkflowTaskState - _application_generate_entity: AdvancedChatAppGenerateEntity - _workflow_system_variables: dict[SystemVariableKey, Any] - _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] - _conversation_name_generate_thread: Optional[Thread] = None - def __init__( self, application_generate_entity: AdvancedChatAppGenerateEntity, @@ -96,7 +89,7 @@ def __init__( stream: bool, dialogue_count: int, ) -> None: - super().__init__( + self._base_task_pipeline = BasedGenerateTaskPipeline( application_generate_entity=application_generate_entity, queue_manager=queue_manager, stream=stream, @@ -113,32 +106,35 @@ def __init__( else: raise NotImplementedError(f"User type not supported: {type(user)}") + self._workflow_cycle_manager = WorkflowCycleManage( + application_generate_entity=application_generate_entity, + workflow_system_variables={ + SystemVariableKey.QUERY: message.query, + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.CONVERSATION_ID: conversation.id, + SystemVariableKey.USER_ID: user_session_id, + SystemVariableKey.DIALOGUE_COUNT: dialogue_count, + SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, + SystemVariableKey.WORKFLOW_ID: workflow.id, + SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, + }, + ) + + self._task_state = WorkflowTaskState() + self._message_cycle_manager = MessageCycleManage( + application_generate_entity=application_generate_entity, task_state=self._task_state + ) + + self._application_generate_entity = application_generate_entity self._workflow_id = workflow.id self._workflow_features_dict = workflow.features_dict - self._conversation_id = conversation.id self._conversation_mode = conversation.mode - self._message_id = message.id self._message_created_at = int(message.created_at.timestamp()) - - self._workflow_system_variables = { - SystemVariableKey.QUERY: message.query, - SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.CONVERSATION_ID: conversation.id, - SystemVariableKey.USER_ID: user_session_id, - SystemVariableKey.DIALOGUE_COUNT: dialogue_count, - SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, - } - - self._task_state = WorkflowTaskState() - self._wip_workflow_node_executions = {} - - self._conversation_name_generate_thread = None + self._conversation_name_generate_thread: Thread | None = None self._recorded_files: list[Mapping[str, Any]] = [] - self._workflow_run_id = "" + self._workflow_run_id: str = "" def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ @@ -146,13 +142,13 @@ def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStrea :return: """ # start generate conversation name thread - self._conversation_name_generate_thread = self._generate_conversation_name( + self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name( conversation_id=self._conversation_id, query=self._application_generate_entity.query ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) - if self._stream: + if self._base_task_pipeline._stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) @@ -269,24 +265,26 @@ def _process_stream_response( # init fake graph runtime state graph_runtime_state: Optional[GraphRuntimeState] = None - for queue_message in self._queue_manager.listen(): + for queue_message in self._base_task_pipeline._queue_manager.listen(): event = queue_message.event if isinstance(event, QueuePingEvent): - yield self._ping_stream_response() + yield self._base_task_pipeline._ping_stream_response() elif isinstance(event, QueueErrorEvent): - with Session(db.engine) as session: - err = self._handle_error(event=event, session=session, message_id=self._message_id) + with Session(db.engine, expire_on_commit=False) as session: + err = self._base_task_pipeline._handle_error( + event=event, session=session, message_id=self._message_id + ) session.commit() - yield self._error_to_stream_response(err) + yield self._base_task_pipeline._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): # override graph runtime state graph_runtime_state = event.graph_runtime_state - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: # init workflow run - workflow_run = self._handle_workflow_run_start( + workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( session=session, workflow_id=self._workflow_id, user_id=self._user_id, @@ -297,7 +295,7 @@ def _process_stream_response( if not message: raise ValueError(f"Message not found: {self._message_id}") message.workflow_run_id = workflow_run.id - workflow_start_resp = self._workflow_start_to_stream_response( + workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) session.commit() @@ -310,12 +308,14 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - workflow_node_execution = self._handle_workflow_node_execution_retried( + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( session=session, workflow_run=workflow_run, event=event ) - node_retry_resp = self._workflow_node_retry_to_stream_response( + node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( session=session, event=event, task_id=self._application_generate_entity.task_id, @@ -329,13 +329,15 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - workflow_node_execution = self._handle_node_execution_start( + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( session=session, workflow_run=workflow_run, event=event ) - node_start_resp = self._workflow_node_start_to_stream_response( + node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response( session=session, event=event, task_id=self._application_generate_entity.task_id, @@ -348,12 +350,16 @@ def _process_stream_response( elif isinstance(event, QueueNodeSucceededEvent): # Record files if it's an answer node or end node if event.node_type in [NodeType.ANSWER, NodeType.END]: - self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) + self._recorded_files.extend( + self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {}) + ) - with Session(db.engine) as session: - workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) + with Session(db.engine, expire_on_commit=False) as session: + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( + session=session, event=event + ) - node_finish_resp = self._workflow_node_finish_to_stream_response( + node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( session=session, event=event, task_id=self._application_generate_entity.task_id, @@ -364,10 +370,12 @@ def _process_stream_response( if node_finish_resp: yield node_finish_resp elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): - with Session(db.engine) as session: - workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event) + with Session(db.engine, expire_on_commit=False) as session: + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( + session=session, event=event + ) - node_finish_resp = self._workflow_node_finish_to_stream_response( + node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( session=session, event=event, task_id=self._application_generate_entity.task_id, @@ -381,13 +389,17 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + parallel_start_resp = ( + self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) ) yield parallel_start_resp @@ -395,13 +407,17 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + parallel_finish_resp = ( + self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) ) yield parallel_finish_resp @@ -409,9 +425,11 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - iter_start_resp = self._workflow_iteration_start_to_stream_response( + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, @@ -423,9 +441,11 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - iter_next_resp = self._workflow_iteration_next_to_stream_response( + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, @@ -437,9 +457,11 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - iter_finish_resp = self._workflow_iteration_completed_to_stream_response( + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, @@ -454,8 +476,8 @@ def _process_stream_response( if not graph_runtime_state: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: - workflow_run = self._handle_workflow_run_success( + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_success( session=session, workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, @@ -466,21 +488,23 @@ def _process_stream_response( trace_manager=trace_manager, ) - workflow_finish_resp = self._workflow_finish_to_stream_response( + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) session.commit() yield workflow_finish_resp - self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + self._base_task_pipeline._queue_manager.publish( + QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE + ) elif isinstance(event, QueueWorkflowPartialSuccessEvent): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - with Session(db.engine) as session: - workflow_run = self._handle_workflow_run_partial_success( + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success( session=session, workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, @@ -491,21 +515,23 @@ def _process_stream_response( conversation_id=None, trace_manager=trace_manager, ) - workflow_finish_resp = self._workflow_finish_to_stream_response( + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) session.commit() yield workflow_finish_resp - self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + self._base_task_pipeline._queue_manager.publish( + QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE + ) elif isinstance(event, QueueWorkflowFailedEvent): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - with Session(db.engine) as session: - workflow_run = self._handle_workflow_run_failed( + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( session=session, workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, @@ -517,20 +543,22 @@ def _process_stream_response( trace_manager=trace_manager, exceptions_count=event.exceptions_count, ) - workflow_finish_resp = self._workflow_finish_to_stream_response( + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) - err = self._handle_error(event=err_event, session=session, message_id=self._message_id) + err = self._base_task_pipeline._handle_error( + event=err_event, session=session, message_id=self._message_id + ) session.commit() yield workflow_finish_resp - yield self._error_to_stream_response(err) + yield self._base_task_pipeline._error_to_stream_response(err) break elif isinstance(event, QueueStopEvent): if self._workflow_run_id and graph_runtime_state: - with Session(db.engine) as session: - workflow_run = self._handle_workflow_run_failed( + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( session=session, workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, @@ -541,7 +569,7 @@ def _process_stream_response( conversation_id=self._conversation_id, trace_manager=trace_manager, ) - workflow_finish_resp = self._workflow_finish_to_stream_response( + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, @@ -555,18 +583,18 @@ def _process_stream_response( yield self._message_end_to_stream_response() break elif isinstance(event, QueueRetrieverResourcesEvent): - self._handle_retriever_resources(event) + self._message_cycle_manager._handle_retriever_resources(event) - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: message = self._get_message(session=session) message.message_metadata = ( json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None ) session.commit() elif isinstance(event, QueueAnnotationReplyEvent): - self._handle_annotation_reply(event) + self._message_cycle_manager._handle_annotation_reply(event) - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: message = self._get_message(session=session) message.message_metadata = ( json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None @@ -587,23 +615,27 @@ def _process_stream_response( tts_publisher.publish(queue_message) self._task_state.answer += delta_text - yield self._message_to_stream_response( + yield self._message_cycle_manager._message_to_stream_response( answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector ) elif isinstance(event, QueueMessageReplaceEvent): # published by moderation - yield self._message_replace_to_stream_response(answer=event.text) + yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueueAdvancedChatMessageEndEvent): if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) + output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished( + self._task_state.answer + ) if output_moderation_answer: self._task_state.answer = output_moderation_answer - yield self._message_replace_to_stream_response(answer=output_moderation_answer) + yield self._message_cycle_manager._message_replace_to_stream_response( + answer=output_moderation_answer + ) # Save message - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: self._save_message(session=session, graph_runtime_state=graph_runtime_state) session.commit() @@ -621,7 +653,7 @@ def _process_stream_response( def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: message = self._get_message(session=session) message.answer = self._task_state.answer - message.provider_response_latency = time.perf_counter() - self._start_at + message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at message.message_metadata = ( json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None ) @@ -685,20 +717,20 @@ def _handle_output_moderation_chunk(self, text: str) -> bool: :param text: text :return: True if output moderation should direct output, otherwise False """ - if self._output_moderation_handler: - if self._output_moderation_handler.should_direct_output(): + if self._base_task_pipeline._output_moderation_handler: + if self._base_task_pipeline._output_moderation_handler.should_direct_output(): # stop subscribe new token when output moderation should direct output - self._task_state.answer = self._output_moderation_handler.get_final_output() - self._queue_manager.publish( + self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output() + self._base_task_pipeline._queue_manager.publish( QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE ) - self._queue_manager.publish( + self._base_task_pipeline._queue_manager.publish( QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE ) return True else: - self._output_moderation_handler.append_new_token(text) + self._base_task_pipeline._output_moderation_handler.append_new_token(text) return False diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index c447f9c2fc1515..f89f456916e72e 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -1,7 +1,7 @@ import logging import time from collections.abc import Generator -from typing import Any, Optional, Union +from typing import Optional, Union from sqlalchemy.orm import Session @@ -58,7 +58,6 @@ Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, - WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus, ) @@ -66,16 +65,11 @@ logger = logging.getLogger(__name__) -class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage): +class WorkflowAppGenerateTaskPipeline: """ WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _task_state: WorkflowTaskState - _application_generate_entity: WorkflowAppGenerateEntity - _workflow_system_variables: dict[SystemVariableKey, Any] - _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] - def __init__( self, application_generate_entity: WorkflowAppGenerateEntity, @@ -84,7 +78,7 @@ def __init__( user: Union[Account, EndUser], stream: bool, ) -> None: - super().__init__( + self._base_task_pipeline = BasedGenerateTaskPipeline( application_generate_entity=application_generate_entity, queue_manager=queue_manager, stream=stream, @@ -101,19 +95,21 @@ def __init__( else: raise ValueError(f"Invalid user type: {type(user)}") + self._workflow_cycle_manager = WorkflowCycleManage( + application_generate_entity=application_generate_entity, + workflow_system_variables={ + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.USER_ID: user_session_id, + SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, + SystemVariableKey.WORKFLOW_ID: workflow.id, + SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, + }, + ) + + self._application_generate_entity = application_generate_entity self._workflow_id = workflow.id self._workflow_features_dict = workflow.features_dict - - self._workflow_system_variables = { - SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.USER_ID: user_session_id, - SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, - } - self._task_state = WorkflowTaskState() - self._wip_workflow_node_executions = {} self._workflow_run_id = "" def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: @@ -122,7 +118,7 @@ def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStr :return: """ generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) - if self._stream: + if self._base_task_pipeline._stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) @@ -237,29 +233,29 @@ def _process_stream_response( """ graph_runtime_state = None - for queue_message in self._queue_manager.listen(): + for queue_message in self._base_task_pipeline._queue_manager.listen(): event = queue_message.event if isinstance(event, QueuePingEvent): - yield self._ping_stream_response() + yield self._base_task_pipeline._ping_stream_response() elif isinstance(event, QueueErrorEvent): - err = self._handle_error(event=event) - yield self._error_to_stream_response(err) + err = self._base_task_pipeline._handle_error(event=event) + yield self._base_task_pipeline._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): # override graph runtime state graph_runtime_state = event.graph_runtime_state - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: # init workflow run - workflow_run = self._handle_workflow_run_start( + workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( session=session, workflow_id=self._workflow_id, user_id=self._user_id, created_by_role=self._created_by_role, ) self._workflow_run_id = workflow_run.id - start_resp = self._workflow_start_to_stream_response( + start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) session.commit() @@ -271,12 +267,14 @@ def _process_stream_response( ): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - workflow_node_execution = self._handle_workflow_node_execution_retried( + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( session=session, workflow_run=workflow_run, event=event ) - response = self._workflow_node_retry_to_stream_response( + response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( session=session, event=event, task_id=self._application_generate_entity.task_id, @@ -290,12 +288,14 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - workflow_node_execution = self._handle_node_execution_start( + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( session=session, workflow_run=workflow_run, event=event ) - node_start_response = self._workflow_node_start_to_stream_response( + node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response( session=session, event=event, task_id=self._application_generate_entity.task_id, @@ -306,9 +306,11 @@ def _process_stream_response( if node_start_response: yield node_start_response elif isinstance(event, QueueNodeSucceededEvent): - with Session(db.engine) as session: - workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) - node_success_response = self._workflow_node_finish_to_stream_response( + with Session(db.engine, expire_on_commit=False) as session: + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( + session=session, event=event + ) + node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( session=session, event=event, task_id=self._application_generate_entity.task_id, @@ -319,12 +321,12 @@ def _process_stream_response( if node_success_response: yield node_success_response elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): - with Session(db.engine) as session: - workflow_node_execution = self._handle_workflow_node_execution_failed( + with Session(db.engine, expire_on_commit=False) as session: + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( session=session, event=event, ) - node_failed_response = self._workflow_node_finish_to_stream_response( + node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( session=session, event=event, task_id=self._application_generate_entity.task_id, @@ -339,13 +341,17 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + parallel_start_resp = ( + self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) ) yield parallel_start_resp @@ -354,13 +360,17 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + parallel_finish_resp = ( + self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) ) yield parallel_finish_resp @@ -369,9 +379,11 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - iter_start_resp = self._workflow_iteration_start_to_stream_response( + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, @@ -384,9 +396,11 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - iter_next_resp = self._workflow_iteration_next_to_stream_response( + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, @@ -399,9 +413,11 @@ def _process_stream_response( if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine) as session: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) - iter_finish_resp = self._workflow_iteration_completed_to_stream_response( + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, @@ -416,8 +432,8 @@ def _process_stream_response( if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - with Session(db.engine) as session: - workflow_run = self._handle_workflow_run_success( + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_success( session=session, workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, @@ -431,7 +447,7 @@ def _process_stream_response( # save workflow app log self._save_workflow_app_log(session=session, workflow_run=workflow_run) - workflow_finish_resp = self._workflow_finish_to_stream_response( + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, @@ -445,8 +461,8 @@ def _process_stream_response( if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - with Session(db.engine) as session: - workflow_run = self._handle_workflow_run_partial_success( + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success( session=session, workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, @@ -461,7 +477,7 @@ def _process_stream_response( # save workflow app log self._save_workflow_app_log(session=session, workflow_run=workflow_run) - workflow_finish_resp = self._workflow_finish_to_stream_response( + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) session.commit() @@ -473,8 +489,8 @@ def _process_stream_response( if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - with Session(db.engine) as session: - workflow_run = self._handle_workflow_run_failed( + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( session=session, workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, @@ -492,7 +508,7 @@ def _process_stream_response( # save workflow app log self._save_workflow_app_log(session=session, workflow_run=workflow_run) - workflow_finish_resp = self._workflow_finish_to_stream_response( + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) session.commit() diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index e363a7f64244d3..a2e06d4e1ff492 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -15,7 +15,6 @@ from core.app.entities.task_entities import ( ErrorStreamResponse, PingStreamResponse, - TaskState, ) from core.errors.error import QuotaExceededError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError @@ -30,22 +29,12 @@ class BasedGenerateTaskPipeline: BasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _task_state: TaskState - _application_generate_entity: AppGenerateEntity - def __init__( self, application_generate_entity: AppGenerateEntity, queue_manager: AppQueueManager, stream: bool, ) -> None: - """ - Initialize GenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param user: user - :param stream: stream - """ self._application_generate_entity = application_generate_entity self._queue_manager = queue_manager self._start_at = time.perf_counter() diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 15f2c25c66a3d2..6a4ab259ba4a1a 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -31,10 +31,19 @@ class MessageCycleManage: - _application_generate_entity: Union[ - ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity - ] - _task_state: Union[EasyUITaskState, WorkflowTaskState] + def __init__( + self, + *, + application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity, + ], + task_state: Union[EasyUITaskState, WorkflowTaskState], + ) -> None: + self._application_generate_entity = application_generate_entity + self._task_state = task_state def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: """ diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 885b341196d04b..dcc364d22766e6 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -34,7 +34,6 @@ ParallelBranchStartStreamResponse, WorkflowFinishStreamResponse, WorkflowStartStreamResponse, - WorkflowTaskState, ) from core.file import FILE_MODEL_IDENTITY, File from core.model_runtime.utils.encoders import jsonable_encoder @@ -58,13 +57,20 @@ WorkflowRunStatus, ) -from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError +from .exc import WorkflowRunNotFoundError class WorkflowCycleManage: - _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] - _task_state: WorkflowTaskState - _workflow_system_variables: dict[SystemVariableKey, Any] + def __init__( + self, + *, + application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], + workflow_system_variables: dict[SystemVariableKey, Any], + ) -> None: + self._workflow_run: WorkflowRun | None = None + self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {} + self._application_generate_entity = application_generate_entity + self._workflow_system_variables = workflow_system_variables def _handle_workflow_run_start( self, @@ -240,7 +246,7 @@ def _handle_workflow_run_failed( workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.exceptions_count = exceptions_count - stmt = select(WorkflowNodeExecution).where( + stmt = select(WorkflowNodeExecution.node_execution_id).where( WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, WorkflowNodeExecution.app_id == workflow_run.app_id, WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, @@ -248,16 +254,18 @@ def _handle_workflow_run_failed( WorkflowNodeExecution.workflow_run_id == workflow_run.id, WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, ) - - running_workflow_node_executions = session.scalars(stmt).all() + ids = session.scalars(stmt).all() + # Use self._get_workflow_node_execution here to make sure the cache is updated + running_workflow_node_executions = [ + self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id + ] for workflow_node_execution in running_workflow_node_executions: + now = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = error - workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) - workflow_node_execution.elapsed_time = ( - workflow_node_execution.finished_at - workflow_node_execution.created_at - ).total_seconds() + workflow_node_execution.finished_at = now + workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds() if trace_manager: trace_manager.add_trace_task( @@ -299,6 +307,8 @@ def _handle_node_execution_start( workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) session.add(workflow_node_execution) + + self._workflow_node_executions[event.node_execution_id] = workflow_node_execution return workflow_node_execution def _handle_workflow_node_execution_success( @@ -326,6 +336,7 @@ def _handle_workflow_node_execution_success( workflow_node_execution.finished_at = finished_at workflow_node_execution.elapsed_time = elapsed_time + workflow_node_execution = session.merge(workflow_node_execution) return workflow_node_execution def _handle_workflow_node_execution_failed( @@ -365,6 +376,7 @@ def _handle_workflow_node_execution_failed( workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.execution_metadata = execution_metadata + workflow_node_execution = session.merge(workflow_node_execution) return workflow_node_execution def _handle_workflow_node_execution_retried( @@ -416,6 +428,8 @@ def _handle_workflow_node_execution_retried( workflow_node_execution.index = event.node_run_index session.add(workflow_node_execution) + + self._workflow_node_executions[event.node_execution_id] = workflow_node_execution return workflow_node_execution ################################################# @@ -812,22 +826,20 @@ def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any return None def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: - """ - Refetch workflow run - :param workflow_run_id: workflow run id - :return: - """ + if self._workflow_run and self._workflow_run.id == workflow_run_id: + cached_workflow_run = self._workflow_run + cached_workflow_run = session.merge(cached_workflow_run) + return cached_workflow_run stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) workflow_run = session.scalar(stmt) if not workflow_run: raise WorkflowRunNotFoundError(workflow_run_id) + self._workflow_run = workflow_run return workflow_run def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: - stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id) - workflow_node_execution = session.scalar(stmt) - if not workflow_node_execution: - raise WorkflowNodeExecutionNotFoundError(node_execution_id) - - return workflow_node_execution + if node_execution_id not in self._workflow_node_executions: + raise ValueError(f"Workflow node execution not found: {node_execution_id}") + cached_workflow_node_execution = self._workflow_node_executions[node_execution_id] + return cached_workflow_node_execution diff --git a/api/core/app/task_pipeline/workflow_cycle_state_manager.py b/api/core/app/task_pipeline/workflow_cycle_state_manager.py deleted file mode 100644 index e69de29bb2d1d6..00000000000000