Skip to content

Commit

Permalink
refactor: streamline initialization of application_generate_entity an…
Browse files Browse the repository at this point in the history
…d task_state in task pipeline classes

Signed-off-by: -LAN- <laipz8200@outlook.com>
  • Loading branch information
laipz8200 committed Jan 3, 2025
1 parent 3d150c3 commit 2dc4bd4
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 108 deletions.
90 changes: 42 additions & 48 deletions api/core/app/apps/advanced_chat/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
from models.enums import CreatedByRole
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowRunStatus,
)

Expand All @@ -79,12 +78,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""

_task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
_conversation_name_generate_thread: Optional[Thread] = None

def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
Expand All @@ -96,10 +89,8 @@ def __init__(
stream: bool,
dialogue_count: int,
) -> None:
super().__init__(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
stream=stream,
BasedGenerateTaskPipeline.__init__(
self, application_generate_entity=application_generate_entity, queue_manager=queue_manager, stream=stream
)

if isinstance(user, EndUser):
Expand All @@ -112,33 +103,36 @@ def __init__(
self._created_by_role = CreatedByRole.ACCOUNT
else:
raise NotImplementedError(f"User type not supported: {type(user)}")
WorkflowCycleManage.__init__(
self,
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
)

self._task_state = WorkflowTaskState()
MessageCycleManage.__init__(
self, application_generate_entity=application_generate_entity, task_state=self._task_state
)

self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict

self._conversation_id = conversation.id
self._conversation_mode = conversation.mode

self._message_id = message.id
self._message_created_at = int(message.created_at.timestamp())

self._workflow_system_variables = {
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
}

self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {}

self._conversation_name_generate_thread = None
self._conversation_name_generate_thread: Thread | None = None
self._recorded_files: list[Mapping[str, Any]] = []
self._workflow_run_id = ""
self._workflow_run_id: str = ""

def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Expand Down Expand Up @@ -275,7 +269,7 @@ def _process_stream_response(
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
err = self._handle_error(event=event, session=session, message_id=self._message_id)
session.commit()
yield self._error_to_stream_response(err)
Expand All @@ -284,7 +278,7 @@ def _process_stream_response(
# override graph runtime state
graph_runtime_state = event.graph_runtime_state

with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
# init workflow run
workflow_run = self._handle_workflow_run_start(
session=session,
Expand All @@ -310,7 +304,7 @@ def _process_stream_response(
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")

with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event
Expand All @@ -329,7 +323,7 @@ def _process_stream_response(
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")

with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event
Expand All @@ -350,7 +344,7 @@ def _process_stream_response(
if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))

with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)

node_finish_resp = self._workflow_node_finish_to_stream_response(
Expand All @@ -364,7 +358,7 @@ def _process_stream_response(
if node_finish_resp:
yield node_finish_resp
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event)

node_finish_resp = self._workflow_node_finish_to_stream_response(
Expand All @@ -381,7 +375,7 @@ def _process_stream_response(
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")

with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
session=session,
Expand All @@ -395,7 +389,7 @@ def _process_stream_response(
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")

with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
session=session,
Expand All @@ -409,7 +403,7 @@ def _process_stream_response(
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")

with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_start_resp = self._workflow_iteration_start_to_stream_response(
session=session,
Expand All @@ -423,7 +417,7 @@ def _process_stream_response(
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")

with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_next_resp = self._workflow_iteration_next_to_stream_response(
session=session,
Expand All @@ -437,7 +431,7 @@ def _process_stream_response(
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")

with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
session=session,
Expand All @@ -454,7 +448,7 @@ def _process_stream_response(
if not graph_runtime_state:
raise ValueError("workflow run not initialized.")

with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_success(
session=session,
workflow_run_id=self._workflow_run_id,
Expand All @@ -479,7 +473,7 @@ def _process_stream_response(
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")

with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_partial_success(
session=session,
workflow_run_id=self._workflow_run_id,
Expand All @@ -504,7 +498,7 @@ def _process_stream_response(
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")

with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_failed(
session=session,
workflow_run_id=self._workflow_run_id,
Expand All @@ -529,7 +523,7 @@ def _process_stream_response(
break
elif isinstance(event, QueueStopEvent):
if self._workflow_run_id and graph_runtime_state:
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_failed(
session=session,
workflow_run_id=self._workflow_run_id,
Expand Down Expand Up @@ -557,7 +551,7 @@ def _process_stream_response(
elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)

with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session)
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
Expand All @@ -566,7 +560,7 @@ def _process_stream_response(
elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event)

with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session)
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
Expand Down Expand Up @@ -603,7 +597,7 @@ def _process_stream_response(
yield self._message_replace_to_stream_response(answer=output_moderation_answer)

# Save message
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
session.commit()

Expand Down
Loading

0 comments on commit 2dc4bd4

Please sign in to comment.