From 53c268a34d8ee63641bc0436744ed4a4bda576c2 Mon Sep 17 00:00:00 2001 From: Hash Brown Date: Sun, 22 Sep 2024 03:15:11 +0800 Subject: [PATCH] feat: regenerate in `Chat`, `agent` and `Chatflow` app (#7661) --- api/constants/__init__.py | 1 + api/controllers/console/app/completion.py | 1 + api/controllers/console/app/message.py | 2 - api/controllers/console/app/workflow.py | 2 + api/controllers/console/explore/completion.py | 1 + api/controllers/console/explore/message.py | 2 +- api/controllers/service_api/app/message.py | 1 + api/controllers/web/completion.py | 1 + api/controllers/web/message.py | 3 +- api/core/agent/base_agent_runner.py | 5 +- .../app/apps/advanced_chat/app_generator.py | 1 + api/core/app/apps/agent_chat/app_generator.py | 1 + api/core/app/apps/chat/app_generator.py | 1 + .../app/apps/message_based_app_generator.py | 1 + api/core/app/entities/app_invoke_entities.py | 3 + api/core/memory/token_buffer_memory.py | 21 ++- .../prompt/utils/extract_thread_messages.py | 22 +++ api/fields/conversation_fields.py | 1 + api/fields/message_fields.py | 1 + ...bb251_add_parent_message_id_to_messages.py | 36 +++++ api/models/model.py | 1 + api/services/message_service.py | 4 +- .../prompt/test_extract_thread_messages.py | 91 +++++++++++ .../debug-with-multiple-model/chat-item.tsx | 4 +- .../debug/debug-with-single-model/index.tsx | 26 ++- web/app/components/app/log/list.tsx | 149 ++++++++++-------- .../chat/chat-with-history/chat-wrapper.tsx | 25 ++- .../base/chat/chat-with-history/hooks.tsx | 35 +--- .../base/chat/chat/answer/index.tsx | 3 + .../base/chat/chat/answer/operation.tsx | 13 +- web/app/components/base/chat/chat/context.tsx | 3 + web/app/components/base/chat/chat/hooks.ts | 3 +- web/app/components/base/chat/chat/index.tsx | 5 + web/app/components/base/chat/chat/type.ts | 1 + web/app/components/base/chat/constants.ts | 1 + .../chat/embedded-chatbot/chat-wrapper.tsx | 25 ++- .../base/chat/embedded-chatbot/hooks.tsx | 38 +---- web/app/components/base/chat/types.ts | 4 +- web/app/components/base/chat/utils.ts | 57 ++++++- .../assets/vender/line/general/refresh.svg | 1 + .../src/vender/line/general/Refresh.json | 23 +++ .../icons/src/vender/line/general/Refresh.tsx | 16 ++ .../icons/src/vender/line/general/index.ts | 1 + .../components/base/regenerate-btn/index.tsx | 31 ++++ .../workflow/panel/chat-record/index.tsx | 79 ++++++---- .../panel/debug-and-preview/chat-wrapper.tsx | 26 ++- .../workflow/panel/debug-and-preview/hooks.ts | 2 + web/hooks/use-app-favicon.ts | 6 +- web/i18n/en-US/app-api.ts | 1 + web/i18n/zh-Hans/app-api.ts | 1 + web/models/log.ts | 1 + 51 files changed, 604 insertions(+), 179 deletions(-) create mode 100644 api/core/prompt/utils/extract_thread_messages.py create mode 100644 api/migrations/versions/2024_09_11_1012-d57ba9ebb251_add_parent_message_id_to_messages.py create mode 100644 api/tests/unit_tests/core/prompt/test_extract_thread_messages.py create mode 100644 web/app/components/base/icons/assets/vender/line/general/refresh.svg create mode 100644 web/app/components/base/icons/src/vender/line/general/Refresh.json create mode 100644 web/app/components/base/icons/src/vender/line/general/Refresh.tsx create mode 100644 web/app/components/base/regenerate-btn/index.tsx diff --git a/api/constants/__init__.py b/api/constants/__init__.py index e22c3268ef428b..75eaf81638cdfb 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -1 +1,2 @@ HIDDEN_VALUE = "[__HIDDEN__]" +UUID_NIL = "00000000-0000-0000-0000-000000000000" diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 53de51c24d798a..d3296d3dff44a5 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -109,6 +109,7 @@ def post(self, app_model): parser.add_argument("files", type=list, required=False, location="json") parser.add_argument("model_config", type=dict, required=True, location="json") parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index fe06201982374a..2fba3e0af02be9 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -105,8 +105,6 @@ def get(self, app_model): if rest_count > 0: has_more = True - history_messages = list(reversed(history_messages)) - return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index b488deb89d90c1..0a693b84e27ec2 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -166,6 +166,8 @@ def post(self, app_model: App): parser.add_argument("query", type=str, required=True, location="json", default="") parser.add_argument("files", type=list, location="json") parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") + args = parser.parse_args() try: diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index f4646920982bdb..125bc1af8c41ec 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -100,6 +100,7 @@ def post(self, installed_app): parser.add_argument("query", type=str, required=True, location="json") parser.add_argument("files", type=list, required=False, location="json") parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") args = parser.parse_args() diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 0e0238556cf9aa..3d221ff30a6599 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -51,7 +51,7 @@ def get(self, installed_app): try: return MessageService.pagination_by_first_id( - app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] + app_model, current_user, args["conversation_id"], args["first_id"], args["limit"], "desc" ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index e54e6f4903d574..a70ee89b5e2bf3 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -54,6 +54,7 @@ class MessageListApi(Resource): message_fields = { "id": fields.String, "conversation_id": fields.String, + "parent_message_id": fields.String, "inputs": fields.Raw, "query": fields.String, "answer": fields.String(attribute="re_sign_file_url_answer"), diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 115492b7966c01..45b890dfc4899d 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -96,6 +96,7 @@ def post(self, app_model, end_user): parser.add_argument("files", type=list, required=False, location="json") parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json") args = parser.parse_args() diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 0d4047f4efbaf8..2d2a5866c8038b 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -57,6 +57,7 @@ class MessageListApi(WebApiResource): message_fields = { "id": fields.String, "conversation_id": fields.String, + "parent_message_id": fields.String, "inputs": fields.Raw, "query": fields.String, "answer": fields.String(attribute="re_sign_file_url_answer"), @@ -89,7 +90,7 @@ def get(self, app_model, end_user): try: return MessageService.pagination_by_first_id( - app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] + app_model, end_user, args["conversation_id"], args["first_id"], args["limit"], "desc" ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index d09a9956a4a591..5295f97bdbb86d 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -32,6 +32,7 @@ from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder +from core.prompt.utils.extract_thread_messages import extract_thread_messages from core.tools.entities.tool_entities import ( ToolParameter, ToolRuntimeVariablePool, @@ -441,10 +442,12 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P .filter( Message.conversation_id == self.message.conversation_id, ) - .order_by(Message.created_at.asc()) + .order_by(Message.created_at.desc()) .all() ) + messages = list(reversed(extract_thread_messages(messages))) + for message in messages: if message.id == self.message.id: continue diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 88e1256ed54e4d..445ef6d0ab19f3 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -121,6 +121,7 @@ def generate( inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), query=query, files=file_objs, + parent_message_id=args.get("parent_message_id"), user_id=user.id, stream=stream, invoke_from=invoke_from, diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index abf8a332ab27da..99abccf4f98402 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -127,6 +127,7 @@ def generate( inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), query=query, files=file_objs, + parent_message_id=args.get("parent_message_id"), user_id=user.id, stream=stream, invoke_from=invoke_from, diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 032556ec4c1ce1..9ef1366a0f6689 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -128,6 +128,7 @@ def generate( inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), query=query, files=file_objs, + parent_message_id=args.get("parent_message_id"), user_id=user.id, stream=stream, invoke_from=invoke_from, diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index c4db95cbd0c4a6..65b759acf5376f 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -218,6 +218,7 @@ def _init_generate_records( answer_tokens=0, answer_unit_price=0, answer_price_unit=0, + parent_message_id=getattr(application_generate_entity, "parent_message_id", None), provider_response_latency=0, total_price=0, currency="USD", diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index ab8d4e374e26f8..87ca51ef1b1db1 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -122,6 +122,7 @@ class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): """ conversation_id: Optional[str] = None + parent_message_id: Optional[str] = None class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity): @@ -138,6 +139,7 @@ class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): """ conversation_id: Optional[str] = None + parent_message_id: Optional[str] = None class AdvancedChatAppGenerateEntity(AppGenerateEntity): @@ -149,6 +151,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity): app_config: WorkflowUIBasedAppConfig conversation_id: Optional[str] = None + parent_message_id: Optional[str] = None query: str class SingleIterationRunEntity(BaseModel): diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index d3185c3b11aecb..60b36c50f00d38 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -11,6 +11,7 @@ TextPromptMessageContent, UserPromptMessage, ) +from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db from models.model import AppMode, Conversation, Message, MessageFile from models.workflow import WorkflowRun @@ -33,8 +34,17 @@ def get_history_prompt_messages( # fetch limited messages, and return reversed query = ( - db.session.query(Message.id, Message.query, Message.answer, Message.created_at, Message.workflow_run_id) - .filter(Message.conversation_id == self.conversation.id, Message.answer != "") + db.session.query( + Message.id, + Message.query, + Message.answer, + Message.created_at, + Message.workflow_run_id, + Message.parent_message_id, + ) + .filter( + Message.conversation_id == self.conversation.id, + ) .order_by(Message.created_at.desc()) ) @@ -45,7 +55,12 @@ def get_history_prompt_messages( messages = query.limit(message_limit).all() - messages = list(reversed(messages)) + # instead of all messages from the conversation, we only need to extract messages + # that belong to the thread of last message + thread_messages = extract_thread_messages(messages) + thread_messages.pop(0) + messages = list(reversed(thread_messages)) + message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id) prompt_messages = [] for message in messages: diff --git a/api/core/prompt/utils/extract_thread_messages.py b/api/core/prompt/utils/extract_thread_messages.py new file mode 100644 index 00000000000000..e8b626499fce19 --- /dev/null +++ b/api/core/prompt/utils/extract_thread_messages.py @@ -0,0 +1,22 @@ +from constants import UUID_NIL + + +def extract_thread_messages(messages: list[dict]) -> list[dict]: + thread_messages = [] + next_message = None + + for message in messages: + if not message.parent_message_id: + # If the message is regenerated and does not have a parent message, it is the start of a new thread + thread_messages.append(message) + break + + if not next_message: + thread_messages.append(message) + next_message = message.parent_message_id + else: + if next_message in {message.id, UUID_NIL}: + thread_messages.append(message) + next_message = message.parent_message_id + + return thread_messages diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 9207314fc21328..3dcd88d1de31dd 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -75,6 +75,7 @@ def format(self, value): "metadata": fields.Raw(attribute="message_metadata_dict"), "status": fields.String, "error": fields.String, + "parent_message_id": fields.String, } feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer} diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 3d2df87afb9b19..c938097131f7ad 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -62,6 +62,7 @@ message_fields = { "id": fields.String, "conversation_id": fields.String, + "parent_message_id": fields.String, "inputs": fields.Raw, "query": fields.String, "answer": fields.String(attribute="re_sign_file_url_answer"), diff --git a/api/migrations/versions/2024_09_11_1012-d57ba9ebb251_add_parent_message_id_to_messages.py b/api/migrations/versions/2024_09_11_1012-d57ba9ebb251_add_parent_message_id_to_messages.py new file mode 100644 index 00000000000000..fd957eeafb2b6c --- /dev/null +++ b/api/migrations/versions/2024_09_11_1012-d57ba9ebb251_add_parent_message_id_to_messages.py @@ -0,0 +1,36 @@ +"""add parent_message_id to messages + +Revision ID: d57ba9ebb251 +Revises: 675b5321501b +Create Date: 2024-09-11 10:12:45.826265 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = 'd57ba9ebb251' +down_revision = '675b5321501b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.add_column(sa.Column('parent_message_id', models.types.StringUUID(), nullable=True)) + + # Set parent_message_id for existing messages to uuid_nil() to distinguish them from new messages with actual parent IDs or NULLs + op.execute('UPDATE messages SET parent_message_id = uuid_nil() WHERE parent_message_id IS NULL') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_column('parent_message_id') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index ae0bc3210b6465..53940a5a16cc87 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -710,6 +710,7 @@ class Message(db.Model): answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + parent_message_id = db.Column(StringUUID, nullable=True) provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255), nullable=False) diff --git a/api/services/message_service.py b/api/services/message_service.py index ecb121c36e4cc4..f432a77c80e511 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -34,6 +34,7 @@ def pagination_by_first_id( conversation_id: str, first_id: Optional[str], limit: int, + order: str = "asc", ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) @@ -91,7 +92,8 @@ def pagination_by_first_id( if rest_count > 0: has_more = True - history_messages = list(reversed(history_messages)) + if order == "asc": + history_messages = list(reversed(history_messages)) return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) diff --git a/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py b/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py new file mode 100644 index 00000000000000..ba3c1eb5e032a0 --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py @@ -0,0 +1,91 @@ +from uuid import uuid4 + +from constants import UUID_NIL +from core.prompt.utils.extract_thread_messages import extract_thread_messages + + +class TestMessage: + def __init__(self, id, parent_message_id): + self.id = id + self.parent_message_id = parent_message_id + + def __getitem__(self, item): + return getattr(self, item) + + +def test_extract_thread_messages_single_message(): + messages = [TestMessage(str(uuid4()), UUID_NIL)] + result = extract_thread_messages(messages) + assert len(result) == 1 + assert result[0] == messages[0] + + +def test_extract_thread_messages_linear_thread(): + id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()) + messages = [ + TestMessage(id5, id4), + TestMessage(id4, id3), + TestMessage(id3, id2), + TestMessage(id2, id1), + TestMessage(id1, UUID_NIL), + ] + result = extract_thread_messages(messages) + assert len(result) == 5 + assert [msg["id"] for msg in result] == [id5, id4, id3, id2, id1] + + +def test_extract_thread_messages_branched_thread(): + id1, id2, id3, id4 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()) + messages = [ + TestMessage(id4, id2), + TestMessage(id3, id2), + TestMessage(id2, id1), + TestMessage(id1, UUID_NIL), + ] + result = extract_thread_messages(messages) + assert len(result) == 3 + assert [msg["id"] for msg in result] == [id4, id2, id1] + + +def test_extract_thread_messages_empty_list(): + messages = [] + result = extract_thread_messages(messages) + assert len(result) == 0 + + +def test_extract_thread_messages_partially_loaded(): + id0, id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()) + messages = [ + TestMessage(id3, id2), + TestMessage(id2, id1), + TestMessage(id1, id0), + ] + result = extract_thread_messages(messages) + assert len(result) == 3 + assert [msg["id"] for msg in result] == [id3, id2, id1] + + +def test_extract_thread_messages_legacy_messages(): + id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4()) + messages = [ + TestMessage(id3, UUID_NIL), + TestMessage(id2, UUID_NIL), + TestMessage(id1, UUID_NIL), + ] + result = extract_thread_messages(messages) + assert len(result) == 3 + assert [msg["id"] for msg in result] == [id3, id2, id1] + + +def test_extract_thread_messages_mixed_with_legacy_messages(): + id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()) + messages = [ + TestMessage(id5, id4), + TestMessage(id4, id2), + TestMessage(id3, id2), + TestMessage(id2, UUID_NIL), + TestMessage(id1, UUID_NIL), + ] + result = extract_thread_messages(messages) + assert len(result) == 4 + assert [msg["id"] for msg in result] == [id5, id4, id2, id1] diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx index 80dfb5c534ea5c..1c70f4fe77a9d5 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx @@ -46,6 +46,7 @@ const ChatItem: FC = ({ const config = useConfigFromDebugContext() const { chatList, + chatListRef, isResponding, handleSend, suggestedQuestions, @@ -80,6 +81,7 @@ const ChatItem: FC = ({ query: message, inputs, model_config: configData, + parent_message_id: chatListRef.current.at(-1)?.id || null, } if (visionConfig.enabled && files?.length && supportVision) @@ -93,7 +95,7 @@ const ChatItem: FC = ({ onGetSuggestedQuestions: (responseItemId, getAbortController) => fetchSuggestedQuestions(appId, responseItemId, getAbortController), }, ) - }, [appId, config, handleSend, inputs, modelAndParameter, textGenerationModelList, visionConfig.enabled]) + }, [appId, config, handleSend, inputs, modelAndParameter, textGenerationModelList, visionConfig.enabled, chatListRef]) const { eventEmitter } = useEventEmitterContextContext() eventEmitter?.useSubscription((v: any) => { diff --git a/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx b/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx index d93ad00659e478..80e7c98a8f0c92 100644 --- a/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx +++ b/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx @@ -12,7 +12,7 @@ import { import Chat from '@/app/components/base/chat/chat' import { useChat } from '@/app/components/base/chat/chat/hooks' import { useDebugConfigurationContext } from '@/context/debug-configuration' -import type { OnSend } from '@/app/components/base/chat/types' +import type { ChatItem, OnSend } from '@/app/components/base/chat/types' import { useProviderContext } from '@/context/provider-context' import { fetchConversationMessages, @@ -45,10 +45,12 @@ const DebugWithSingleModel = forwardRef { + const doSend: OnSend = useCallback((message, files, last_answer) => { if (checkCanSend && !checkCanSend()) return const currentProvider = textGenerationModelList.find(item => item.provider === modelConfig.provider) @@ -85,6 +87,7 @@ const DebugWithSingleModel = forwardRef fetchSuggestedQuestions(appId, responseItemId, getAbortController), }, ) - }, [appId, checkCanSend, completionParams, config, handleSend, inputs, modelConfig, textGenerationModelList, visionConfig.enabled]) + }, [chatListRef, appId, checkCanSend, completionParams, config, handleSend, inputs, modelConfig, textGenerationModelList, visionConfig.enabled]) + + const doRegenerate = useCallback((chatItem: ChatItem) => { + const index = chatList.findIndex(item => item.id === chatItem.id) + if (index === -1) + return + + const prevMessages = chatList.slice(0, index) + const question = prevMessages.pop() + const lastAnswer = prevMessages.at(-1) + + if (!question) + return + + handleUpdateChatList(prevMessages) + doSend(question.content, question.message_files, (!lastAnswer || lastAnswer.isOpeningStatement) ? undefined : lastAnswer) + }, [chatList, handleUpdateChatList, doSend]) const allToolIcons = useMemo(() => { const icons: Record = {} @@ -123,6 +142,7 @@ const DebugWithSingleModel = forwardRef} diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index caec10c4f714b0..149e877fa4ac2d 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -16,6 +16,7 @@ import timezone from 'dayjs/plugin/timezone' import { createContext, useContext } from 'use-context-selector' import { useShallow } from 'zustand/react/shallow' import { useTranslation } from 'react-i18next' +import { UUID_NIL } from '../../base/chat/constants' import s from './style.module.css' import VarPanel from './var-panel' import cn from '@/utils/classnames' @@ -81,72 +82,92 @@ const PARAM_MAP = { frequency_penalty: 'Frequency Penalty', } -// Format interface data for easy display -const getFormattedChatList = (messages: ChatMessage[], conversationId: string, timezone: string, format: string) => { - const newChatList: IChatItem[] = [] - messages.forEach((item: ChatMessage) => { - newChatList.push({ - id: `question-${item.id}`, - content: item.inputs.query || item.inputs.default_input || item.query, // text generation: item.inputs.query; chat: item.query - isAnswer: false, - message_files: item.message_files?.filter((file: any) => file.belongs_to === 'user') || [], - }) - newChatList.push({ - id: item.id, - content: item.answer, - agent_thoughts: addFileInfos(item.agent_thoughts ? sortAgentSorts(item.agent_thoughts) : item.agent_thoughts, item.message_files), - feedback: item.feedbacks.find(item => item.from_source === 'user'), // user feedback - adminFeedback: item.feedbacks.find(item => item.from_source === 'admin'), // admin feedback - feedbackDisabled: false, - isAnswer: true, - message_files: item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [], - log: [ - ...item.message, - ...(item.message[item.message.length - 1]?.role !== 'assistant' - ? [ - { - role: 'assistant', - text: item.answer, - files: item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [], - }, - ] - : []), - ], - workflow_run_id: item.workflow_run_id, - conversationId, - input: { - inputs: item.inputs, - query: item.query, - }, - more: { - time: dayjs.unix(item.created_at).tz(timezone).format(format), - tokens: item.answer_tokens + item.message_tokens, - latency: item.provider_response_latency.toFixed(2), - }, - citation: item.metadata?.retriever_resources, - annotation: (() => { - if (item.annotation_hit_history) { - return { - id: item.annotation_hit_history.annotation_id, - authorName: item.annotation_hit_history.annotation_create_account?.name || 'N/A', - created_at: item.annotation_hit_history.created_at, - } +function appendQAToChatList(newChatList: IChatItem[], item: any, conversationId: string, timezone: string, format: string) { + newChatList.push({ + id: item.id, + content: item.answer, + agent_thoughts: addFileInfos(item.agent_thoughts ? sortAgentSorts(item.agent_thoughts) : item.agent_thoughts, item.message_files), + feedback: item.feedbacks.find((item: any) => item.from_source === 'user'), // user feedback + adminFeedback: item.feedbacks.find((item: any) => item.from_source === 'admin'), // admin feedback + feedbackDisabled: false, + isAnswer: true, + message_files: item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [], + log: [ + ...item.message, + ...(item.message[item.message.length - 1]?.role !== 'assistant' + ? [ + { + role: 'assistant', + text: item.answer, + files: item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [], + }, + ] + : []), + ], + workflow_run_id: item.workflow_run_id, + conversationId, + input: { + inputs: item.inputs, + query: item.query, + }, + more: { + time: dayjs.unix(item.created_at).tz(timezone).format(format), + tokens: item.answer_tokens + item.message_tokens, + latency: item.provider_response_latency.toFixed(2), + }, + citation: item.metadata?.retriever_resources, + annotation: (() => { + if (item.annotation_hit_history) { + return { + id: item.annotation_hit_history.annotation_id, + authorName: item.annotation_hit_history.annotation_create_account?.name || 'N/A', + created_at: item.annotation_hit_history.created_at, } + } - if (item.annotation) { - return { - id: item.annotation.id, - authorName: item.annotation.account.name, - logAnnotation: item.annotation, - created_at: 0, - } + if (item.annotation) { + return { + id: item.annotation.id, + authorName: item.annotation.account.name, + logAnnotation: item.annotation, + created_at: 0, } + } - return undefined - })(), - }) + return undefined + })(), + parentMessageId: `question-${item.id}`, }) - return newChatList + newChatList.push({ + id: `question-${item.id}`, + content: item.inputs.query || item.inputs.default_input || item.query, // text generation: item.inputs.query; chat: item.query + isAnswer: false, + message_files: item.message_files?.filter((file: any) => file.belongs_to === 'user') || [], + parentMessageId: item.parent_message_id || undefined, + }) +} + +const getFormattedChatList = (messages: ChatMessage[], conversationId: string, timezone: string, format: string) => { + const newChatList: IChatItem[] = [] + let nextMessageId = null + for (const item of messages) { + if (!item.parent_message_id) { + appendQAToChatList(newChatList, item, conversationId, timezone, format) + break + } + + if (!nextMessageId) { + appendQAToChatList(newChatList, item, conversationId, timezone, format) + nextMessageId = item.parent_message_id + } + else { + if (item.id === nextMessageId || nextMessageId === UUID_NIL) { + appendQAToChatList(newChatList, item, conversationId, timezone, format) + nextMessageId = item.parent_message_id + } + } + } + return newChatList.reverse() } // const displayedParams = CompletionParams.slice(0, -2) @@ -171,6 +192,7 @@ function DetailPanel([]) + const fetchedMessages = useRef([]) const [hasMore, setHasMore] = useState(true) const [varValues, setVarValues] = useState>({}) const fetchData = async () => { @@ -192,7 +214,8 @@ function DetailPanel - : items.length < 8 + : (items.length < 8 && !hasMore) ?
{ }, [appParams, currentConversationItem?.introduction, currentConversationId]) const { chatList, + chatListRef, + handleUpdateChatList, handleSend, handleStop, isResponding, @@ -63,11 +66,12 @@ const ChatWrapper = () => { currentChatInstanceRef.current.handleStop = handleStop }, []) - const doSend: OnSend = useCallback((message, files) => { + const doSend: OnSend = useCallback((message, files, last_answer) => { const data: any = { query: message, inputs: currentConversationId ? currentConversationItem?.inputs : newConversationInputs, conversation_id: currentConversationId, + parent_message_id: last_answer?.id || chatListRef.current.at(-1)?.id || null, } if (appConfig?.file_upload?.image.enabled && files?.length) @@ -83,6 +87,7 @@ const ChatWrapper = () => { }, ) }, [ + chatListRef, appConfig, currentConversationId, currentConversationItem, @@ -92,6 +97,23 @@ const ChatWrapper = () => { isInstalledApp, appId, ]) + + const doRegenerate = useCallback((chatItem: ChatItem) => { + const index = chatList.findIndex(item => item.id === chatItem.id) + if (index === -1) + return + + const prevMessages = chatList.slice(0, index) + const question = prevMessages.pop() + const lastAnswer = prevMessages.at(-1) + + if (!question) + return + + handleUpdateChatList(prevMessages) + doSend(question.content, question.message_files, (!lastAnswer || lastAnswer.isOpeningStatement) ? undefined : lastAnswer) + }, [chatList, handleUpdateChatList, doSend]) + const chatNode = useMemo(() => { if (inputsForms.length) { return ( @@ -148,6 +170,7 @@ const ChatWrapper = () => { chatFooterClassName='pb-4' chatFooterInnerClassName={`mx-auto w-full max-w-full ${isMobile && 'px-4'}`} onSend={doSend} + onRegenerate={doRegenerate} onStopResponding={handleStop} chatNode={chatNode} allToolIcons={appMeta?.tool_icons || {}} diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index 1e05cc39ef5dc6..b9ebc42ec88892 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -12,10 +12,10 @@ import produce from 'immer' import type { Callback, ChatConfig, - ChatItem, Feedback, } from '../types' import { CONVERSATION_ID_INFO } from '../constants' +import { getPrevChatList } from '../utils' import { delConversation, fetchAppInfo, @@ -34,7 +34,6 @@ import type { AppData, ConversationItem, } from '@/models/share' -import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils' import { useToastContext } from '@/app/components/base/toast' import { changeLanguage } from '@/i18n/i18next-config' import { useAppFavicon } from '@/hooks/use-app-favicon' @@ -108,32 +107,12 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { const { data: appConversationData, isLoading: appConversationDataLoading, mutate: mutateAppConversationData } = useSWR(['appConversationData', isInstalledApp, appId, false], () => fetchConversations(isInstalledApp, appId, undefined, false, 100)) const { data: appChatListData, isLoading: appChatListDataLoading } = useSWR(chatShouldReloadKey ? ['appChatList', chatShouldReloadKey, isInstalledApp, appId] : null, () => fetchChatList(chatShouldReloadKey, isInstalledApp, appId)) - const appPrevChatList = useMemo(() => { - const data = appChatListData?.data || [] - const chatList: ChatItem[] = [] - - if (currentConversationId && data.length) { - data.forEach((item: any) => { - chatList.push({ - id: `question-${item.id}`, - content: item.query, - isAnswer: false, - message_files: item.message_files?.filter((file: any) => file.belongs_to === 'user') || [], - }) - chatList.push({ - id: item.id, - content: item.answer, - agent_thoughts: addFileInfos(item.agent_thoughts ? sortAgentSorts(item.agent_thoughts) : item.agent_thoughts, item.message_files), - feedback: item.feedback, - isAnswer: true, - citation: item.retriever_resources, - message_files: item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [], - }) - }) - } - - return chatList - }, [appChatListData, currentConversationId]) + const appPrevChatList = useMemo( + () => (currentConversationId && appChatListData?.data.length) + ? getPrevChatList(appChatListData.data) + : [], + [appChatListData, currentConversationId], + ) const [showNewConversationItemInList, setShowNewConversationItemInList] = useState(false) diff --git a/web/app/components/base/chat/chat/answer/index.tsx b/web/app/components/base/chat/chat/answer/index.tsx index 5fe2a7bad51072..705cd73ddf18d9 100644 --- a/web/app/components/base/chat/chat/answer/index.tsx +++ b/web/app/components/base/chat/chat/answer/index.tsx @@ -35,6 +35,7 @@ type AnswerProps = { chatAnswerContainerInner?: string hideProcessDetail?: boolean appData?: AppData + noChatInput?: boolean } const Answer: FC = ({ item, @@ -48,6 +49,7 @@ const Answer: FC = ({ chatAnswerContainerInner, hideProcessDetail, appData, + noChatInput, }) => { const { t } = useTranslation() const { @@ -110,6 +112,7 @@ const Answer: FC = ({ question={question} index={index} showPromptLog={showPromptLog} + noChatInput={noChatInput} /> ) } diff --git a/web/app/components/base/chat/chat/answer/operation.tsx b/web/app/components/base/chat/chat/answer/operation.tsx index 08267bb09cd5c9..5e5fc3b204641d 100644 --- a/web/app/components/base/chat/chat/answer/operation.tsx +++ b/web/app/components/base/chat/chat/answer/operation.tsx @@ -7,6 +7,7 @@ import { import { useTranslation } from 'react-i18next' import type { ChatItem } from '../../types' import { useChatContext } from '../context' +import RegenerateBtn from '@/app/components/base/regenerate-btn' import cn from '@/utils/classnames' import CopyBtn from '@/app/components/base/copy-btn' import { MessageFast } from '@/app/components/base/icons/src/vender/solid/communication' @@ -28,6 +29,7 @@ type OperationProps = { maxSize: number contentWidth: number hasWorkflowProcess: boolean + noChatInput?: boolean } const Operation: FC = ({ item, @@ -37,6 +39,7 @@ const Operation: FC = ({ maxSize, contentWidth, hasWorkflowProcess, + noChatInput, }) => { const { t } = useTranslation() const { @@ -45,6 +48,7 @@ const Operation: FC = ({ onAnnotationEdited, onAnnotationRemoved, onFeedback, + onRegenerate, } = useChatContext() const [isShowReplyModal, setIsShowReplyModal] = useState(false) const { @@ -159,12 +163,13 @@ const Operation: FC = ({
) } + { + !isOpeningStatement && !noChatInput && onRegenerate?.(item)} /> + } { config?.supportFeedback && !localFeedback?.rating && onFeedback && !isOpeningStatement && ( -
- +
+
handleFeedback('like')} diff --git a/web/app/components/base/chat/chat/context.tsx b/web/app/components/base/chat/chat/context.tsx index ba6f67189e1cc9..c47b7501762e18 100644 --- a/web/app/components/base/chat/chat/context.tsx +++ b/web/app/components/base/chat/chat/context.tsx @@ -12,6 +12,7 @@ export type ChatContextValue = Pick void noChatInput?: boolean onSend?: OnSend + onRegenerate?: OnRegenerate chatContainerClassName?: string chatContainerInnerClassName?: string chatFooterClassName?: string @@ -67,6 +69,7 @@ const Chat: FC = ({ appData, config, onSend, + onRegenerate, chatList, isResponding, noStopResponding, @@ -186,6 +189,7 @@ const Chat: FC = ({ answerIcon={answerIcon} allToolIcons={allToolIcons} onSend={onSend} + onRegenerate={onRegenerate} onAnnotationAdded={onAnnotationAdded} onAnnotationEdited={onAnnotationEdited} onAnnotationRemoved={onAnnotationRemoved} @@ -219,6 +223,7 @@ const Chat: FC = ({ showPromptLog={showPromptLog} chatAnswerContainerInner={chatAnswerContainerInner} hideProcessDetail={hideProcessDetail} + noChatInput={noChatInput} /> ) } diff --git a/web/app/components/base/chat/chat/type.ts b/web/app/components/base/chat/chat/type.ts index b2cb18011c6929..dd26a4179d4fb3 100644 --- a/web/app/components/base/chat/chat/type.ts +++ b/web/app/components/base/chat/chat/type.ts @@ -95,6 +95,7 @@ export type IChatItem = { // for agent log conversationId?: string input?: any + parentMessageId?: string } export type Metadata = { diff --git a/web/app/components/base/chat/constants.ts b/web/app/components/base/chat/constants.ts index 8249be7375d2e3..309f0f04a716b2 100644 --- a/web/app/components/base/chat/constants.ts +++ b/web/app/components/base/chat/constants.ts @@ -1 +1,2 @@ export const CONVERSATION_ID_INFO = 'conversationIdInfo' +export const UUID_NIL = '00000000-0000-0000-0000-000000000000' diff --git a/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx b/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx index 48ee4110584ae0..8cb546fd52a0e8 100644 --- a/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx +++ b/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx @@ -2,6 +2,7 @@ import { useCallback, useEffect, useMemo } from 'react' import Chat from '../chat' import type { ChatConfig, + ChatItem, OnSend, } from '../types' import { useChat } from '../chat/hooks' @@ -45,11 +46,13 @@ const ChatWrapper = () => { } as ChatConfig }, [appParams, currentConversationItem?.introduction, currentConversationId]) const { + chatListRef, chatList, handleSend, handleStop, isResponding, suggestedQuestions, + handleUpdateChatList, } = useChat( appConfig, { @@ -65,11 +68,12 @@ const ChatWrapper = () => { currentChatInstanceRef.current.handleStop = handleStop }, []) - const doSend: OnSend = useCallback((message, files) => { + const doSend: OnSend = useCallback((message, files, last_answer) => { const data: any = { query: message, inputs: currentConversationId ? currentConversationItem?.inputs : newConversationInputs, conversation_id: currentConversationId, + parent_message_id: last_answer?.id || chatListRef.current.at(-1)?.id || null, } if (appConfig?.file_upload?.image.enabled && files?.length) @@ -85,6 +89,7 @@ const ChatWrapper = () => { }, ) }, [ + chatListRef, appConfig, currentConversationId, currentConversationItem, @@ -94,6 +99,23 @@ const ChatWrapper = () => { isInstalledApp, appId, ]) + + const doRegenerate = useCallback((chatItem: ChatItem) => { + const index = chatList.findIndex(item => item.id === chatItem.id) + if (index === -1) + return + + const prevMessages = chatList.slice(0, index) + const question = prevMessages.pop() + const lastAnswer = prevMessages.at(-1) + + if (!question) + return + + handleUpdateChatList(prevMessages) + doSend(question.content, question.message_files, (!lastAnswer || lastAnswer.isOpeningStatement) ? undefined : lastAnswer) + }, [chatList, handleUpdateChatList, doSend]) + const chatNode = useMemo(() => { if (inputsForms.length) { return ( @@ -136,6 +158,7 @@ const ChatWrapper = () => { chatFooterClassName='pb-4' chatFooterInnerClassName={cn('mx-auto w-full max-w-full tablet:px-4', isMobile && 'px-4')} onSend={doSend} + onRegenerate={doRegenerate} onStopResponding={handleStop} chatNode={chatNode} allToolIcons={appMeta?.tool_icons || {}} diff --git a/web/app/components/base/chat/embedded-chatbot/hooks.tsx b/web/app/components/base/chat/embedded-chatbot/hooks.tsx index 39d25f57d194e1..fd89efcbff3ab6 100644 --- a/web/app/components/base/chat/embedded-chatbot/hooks.tsx +++ b/web/app/components/base/chat/embedded-chatbot/hooks.tsx @@ -11,10 +11,10 @@ import { useLocalStorageState } from 'ahooks' import produce from 'immer' import type { ChatConfig, - ChatItem, Feedback, } from '../types' import { CONVERSATION_ID_INFO } from '../constants' +import { getPrevChatList, getProcessedInputsFromUrlParams } from '../utils' import { fetchAppInfo, fetchAppMeta, @@ -28,10 +28,8 @@ import type { // AppData, ConversationItem, } from '@/models/share' -import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils' import { useToastContext } from '@/app/components/base/toast' import { changeLanguage } from '@/i18n/i18next-config' -import { getProcessedInputsFromUrlParams } from '@/app/components/base/chat/utils' export const useEmbeddedChatbot = () => { const isInstalledApp = false @@ -75,32 +73,12 @@ export const useEmbeddedChatbot = () => { const { data: appConversationData, isLoading: appConversationDataLoading, mutate: mutateAppConversationData } = useSWR(['appConversationData', isInstalledApp, appId, false], () => fetchConversations(isInstalledApp, appId, undefined, false, 100)) const { data: appChatListData, isLoading: appChatListDataLoading } = useSWR(chatShouldReloadKey ? ['appChatList', chatShouldReloadKey, isInstalledApp, appId] : null, () => fetchChatList(chatShouldReloadKey, isInstalledApp, appId)) - const appPrevChatList = useMemo(() => { - const data = appChatListData?.data || [] - const chatList: ChatItem[] = [] - - if (currentConversationId && data.length) { - data.forEach((item: any) => { - chatList.push({ - id: `question-${item.id}`, - content: item.query, - isAnswer: false, - message_files: item.message_files?.filter((file: any) => file.belongs_to === 'user') || [], - }) - chatList.push({ - id: item.id, - content: item.answer, - agent_thoughts: addFileInfos(item.agent_thoughts ? sortAgentSorts(item.agent_thoughts) : item.agent_thoughts, item.message_files), - feedback: item.feedback, - isAnswer: true, - citation: item.retriever_resources, - message_files: item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [], - }) - }) - } - - return chatList - }, [appChatListData, currentConversationId]) + const appPrevChatList = useMemo( + () => (currentConversationId && appChatListData?.data.length) + ? getPrevChatList(appChatListData.data) + : [], + [appChatListData, currentConversationId], + ) const [showNewConversationItemInList, setShowNewConversationItemInList] = useState(false) @@ -155,7 +133,7 @@ export const useEmbeddedChatbot = () => { type: 'text-input', } }) - }, [appParams]) + }, [initInputs, appParams]) useEffect(() => { // init inputs from url params diff --git a/web/app/components/base/chat/types.ts b/web/app/components/base/chat/types.ts index 21277fec570994..489dbb44cf67e7 100644 --- a/web/app/components/base/chat/types.ts +++ b/web/app/components/base/chat/types.ts @@ -63,7 +63,9 @@ export type ChatItem = IChatItem & { conversationId?: string } -export type OnSend = (message: string, files?: VisionFile[]) => void +export type OnSend = (message: string, files?: VisionFile[], last_answer?: ChatItem) => void + +export type OnRegenerate = (chatItem: ChatItem) => void export type Callback = { onSuccess: () => void diff --git a/web/app/components/base/chat/utils.ts b/web/app/components/base/chat/utils.ts index 3fe5050cc7b34d..e851c4c4633fb5 100644 --- a/web/app/components/base/chat/utils.ts +++ b/web/app/components/base/chat/utils.ts @@ -1,7 +1,11 @@ +import { addFileInfos, sortAgentSorts } from '../../tools/utils' +import { UUID_NIL } from './constants' +import type { ChatItem } from './types' + async function decodeBase64AndDecompress(base64String: string) { const binaryString = atob(base64String) const compressedUint8Array = Uint8Array.from(binaryString, char => char.charCodeAt(0)) - const decompressedStream = new Response(compressedUint8Array).body.pipeThrough(new DecompressionStream('gzip')) + const decompressedStream = new Response(compressedUint8Array).body?.pipeThrough(new DecompressionStream('gzip')) const decompressedArrayBuffer = await new Response(decompressedStream).arrayBuffer() return new TextDecoder().decode(decompressedArrayBuffer) } @@ -15,6 +19,57 @@ function getProcessedInputsFromUrlParams(): Record { return inputs } +function appendQAToChatList(chatList: ChatItem[], item: any) { + // we append answer first and then question since will reverse the whole chatList later + chatList.push({ + id: item.id, + content: item.answer, + agent_thoughts: addFileInfos(item.agent_thoughts ? sortAgentSorts(item.agent_thoughts) : item.agent_thoughts, item.message_files), + feedback: item.feedback, + isAnswer: true, + citation: item.retriever_resources, + message_files: item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [], + }) + chatList.push({ + id: `question-${item.id}`, + content: item.query, + isAnswer: false, + message_files: item.message_files?.filter((file: any) => file.belongs_to === 'user') || [], + }) +} + +/** + * Computes the latest thread messages from all messages of the conversation. + * Same logic as backend codebase `api/core/prompt/utils/extract_thread_messages.py` + * + * @param fetchedMessages - The history chat list data from the backend, sorted by created_at in descending order. This includes all flattened history messages of the conversation. + * @returns An array of ChatItems representing the latest thread. + */ +function getPrevChatList(fetchedMessages: any[]) { + const ret: ChatItem[] = [] + let nextMessageId = null + + for (const item of fetchedMessages) { + if (!item.parent_message_id) { + appendQAToChatList(ret, item) + break + } + + if (!nextMessageId) { + appendQAToChatList(ret, item) + nextMessageId = item.parent_message_id + } + else { + if (item.id === nextMessageId || nextMessageId === UUID_NIL) { + appendQAToChatList(ret, item) + nextMessageId = item.parent_message_id + } + } + } + return ret.reverse() +} + export { getProcessedInputsFromUrlParams, + getPrevChatList, } diff --git a/web/app/components/base/icons/assets/vender/line/general/refresh.svg b/web/app/components/base/icons/assets/vender/line/general/refresh.svg new file mode 100644 index 00000000000000..05cf98682734ac --- /dev/null +++ b/web/app/components/base/icons/assets/vender/line/general/refresh.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/line/general/Refresh.json b/web/app/components/base/icons/src/vender/line/general/Refresh.json new file mode 100644 index 00000000000000..128dcb7d4d713f --- /dev/null +++ b/web/app/components/base/icons/src/vender/line/general/Refresh.json @@ -0,0 +1,23 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "xmlns": "http://www.w3.org/2000/svg", + "viewBox": "0 0 24 24", + "fill": "currentColor" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M5.46257 4.43262C7.21556 2.91688 9.5007 2 12 2C17.5228 2 22 6.47715 22 12C22 14.1361 21.3302 16.1158 20.1892 17.7406L17 12H20C20 7.58172 16.4183 4 12 4C9.84982 4 7.89777 4.84827 6.46023 6.22842L5.46257 4.43262ZM18.5374 19.5674C16.7844 21.0831 14.4993 22 12 22C6.47715 22 2 17.5228 2 12C2 9.86386 2.66979 7.88416 3.8108 6.25944L7 12H4C4 16.4183 7.58172 20 12 20C14.1502 20 16.1022 19.1517 17.5398 17.7716L18.5374 19.5674Z" + }, + "children": [] + } + ] + }, + "name": "Refresh" +} \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/line/general/Refresh.tsx b/web/app/components/base/icons/src/vender/line/general/Refresh.tsx new file mode 100644 index 00000000000000..96641f1c4243b2 --- /dev/null +++ b/web/app/components/base/icons/src/vender/line/general/Refresh.tsx @@ -0,0 +1,16 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import * as React from 'react' +import data from './Refresh.json' +import IconBase from '@/app/components/base/icons/IconBase' +import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase' + +const Icon = React.forwardRef, Omit>(( + props, + ref, +) => ) + +Icon.displayName = 'Refresh' + +export default Icon diff --git a/web/app/components/base/icons/src/vender/line/general/index.ts b/web/app/components/base/icons/src/vender/line/general/index.ts index c1af2e49948a41..b5c7a7bbc1d9a2 100644 --- a/web/app/components/base/icons/src/vender/line/general/index.ts +++ b/web/app/components/base/icons/src/vender/line/general/index.ts @@ -18,6 +18,7 @@ export { default as Menu01 } from './Menu01' export { default as Pin01 } from './Pin01' export { default as Pin02 } from './Pin02' export { default as Plus02 } from './Plus02' +export { default as Refresh } from './Refresh' export { default as Settings01 } from './Settings01' export { default as Settings04 } from './Settings04' export { default as Target04 } from './Target04' diff --git a/web/app/components/base/regenerate-btn/index.tsx b/web/app/components/base/regenerate-btn/index.tsx new file mode 100644 index 00000000000000..aaf0206df609db --- /dev/null +++ b/web/app/components/base/regenerate-btn/index.tsx @@ -0,0 +1,31 @@ +'use client' +import { t } from 'i18next' +import { Refresh } from '../icons/src/vender/line/general' +import Tooltip from '@/app/components/base/tooltip' + +type Props = { + className?: string + onClick?: () => void +} + +const RegenerateBtn = ({ className, onClick }: Props) => { + return ( +
+ +
onClick?.()} + style={{ + boxShadow: '0px 4px 8px -2px rgba(16, 24, 40, 0.1), 0px 2px 4px -2px rgba(16, 24, 40, 0.06)', + }} + > + +
+
+
+ ) +} + +export default RegenerateBtn diff --git a/web/app/components/workflow/panel/chat-record/index.tsx b/web/app/components/workflow/panel/chat-record/index.tsx index afd20b7358670d..1bcfd6474d7c88 100644 --- a/web/app/components/workflow/panel/chat-record/index.tsx +++ b/web/app/components/workflow/panel/chat-record/index.tsx @@ -2,7 +2,6 @@ import { memo, useCallback, useEffect, - useMemo, useState, } from 'react' import { RiCloseLine } from '@remixicon/react' @@ -17,50 +16,70 @@ import type { ChatItem } from '@/app/components/base/chat/types' import { fetchConversationMessages } from '@/service/debug' import { useStore as useAppStore } from '@/app/components/app/store' import Loading from '@/app/components/base/loading' +import { UUID_NIL } from '@/app/components/base/chat/constants' + +function appendQAToChatList(newChatList: ChatItem[], item: any) { + newChatList.push({ + id: item.id, + content: item.answer, + feedback: item.feedback, + isAnswer: true, + citation: item.metadata?.retriever_resources, + message_files: item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [], + workflow_run_id: item.workflow_run_id, + }) + newChatList.push({ + id: `question-${item.id}`, + content: item.query, + isAnswer: false, + message_files: item.message_files?.filter((file: any) => file.belongs_to === 'user') || [], + }) +} + +function getFormattedChatList(messages: any[]) { + const newChatList: ChatItem[] = [] + let nextMessageId = null + for (const item of messages) { + if (!item.parent_message_id) { + appendQAToChatList(newChatList, item) + break + } + + if (!nextMessageId) { + appendQAToChatList(newChatList, item) + nextMessageId = item.parent_message_id + } + else { + if (item.id === nextMessageId || nextMessageId === UUID_NIL) { + appendQAToChatList(newChatList, item) + nextMessageId = item.parent_message_id + } + } + } + return newChatList.reverse() +} const ChatRecord = () => { const [fetched, setFetched] = useState(false) - const [chatList, setChatList] = useState([]) + const [chatList, setChatList] = useState([]) const appDetail = useAppStore(s => s.appDetail) const workflowStore = useWorkflowStore() const { handleLoadBackupDraft } = useWorkflowRun() const historyWorkflowData = useStore(s => s.historyWorkflowData) const currentConversationID = historyWorkflowData?.conversation_id - const chatMessageList = useMemo(() => { - const res: ChatItem[] = [] - if (chatList.length) { - chatList.forEach((item: any) => { - res.push({ - id: `question-${item.id}`, - content: item.query, - isAnswer: false, - message_files: item.message_files?.filter((file: any) => file.belongs_to === 'user') || [], - }) - res.push({ - id: item.id, - content: item.answer, - feedback: item.feedback, - isAnswer: true, - citation: item.metadata?.retriever_resources, - message_files: item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [], - workflow_run_id: item.workflow_run_id, - }) - }) - } - return res - }, [chatList]) - const handleFetchConversationMessages = useCallback(async () => { if (appDetail && currentConversationID) { try { setFetched(false) const res = await fetchConversationMessages(appDetail.id, currentConversationID) - setFetched(true) - setChatList((res as any).data) + setChatList(getFormattedChatList((res as any).data)) } catch (e) { - + console.error(e) + } + finally { + setFetched(true) } } }, [appDetail, currentConversationID]) @@ -101,7 +120,7 @@ const ChatRecord = () => { config={{ supportCitationHitInfo: true, } as any} - chatList={chatMessageList} + chatList={chatList} chatContainerClassName='px-4' chatContainerInnerClassName='pt-6 w-full max-w-full mx-auto' chatFooterClassName='px-4 rounded-b-2xl' diff --git a/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx b/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx index a7dd607e221ea5..107a5dc698f741 100644 --- a/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx +++ b/web/app/components/workflow/panel/debug-and-preview/chat-wrapper.tsx @@ -18,7 +18,7 @@ import ConversationVariableModal from './conversation-variable-modal' import { useChat } from './hooks' import type { ChatWrapperRefType } from './index' import Chat from '@/app/components/base/chat/chat' -import type { OnSend } from '@/app/components/base/chat/types' +import type { ChatItem, OnSend } from '@/app/components/base/chat/types' import { useFeaturesStore } from '@/app/components/base/features/hooks' import { fetchSuggestedQuestions, @@ -58,6 +58,8 @@ const ChatWrapper = forwardRef(({ showConv const { conversationId, chatList, + chatListRef, + handleUpdateChatList, handleStop, isResponding, suggestedQuestions, @@ -73,19 +75,36 @@ const ChatWrapper = forwardRef(({ showConv taskId => stopChatMessageResponding(appDetail!.id, taskId), ) - const doSend = useCallback((query, files) => { + const doSend = useCallback((query, files, last_answer) => { handleSend( { query, files, inputs: workflowStore.getState().inputs, conversation_id: conversationId, + parent_message_id: last_answer?.id || chatListRef.current.at(-1)?.id || null, }, { onGetSuggestedQuestions: (messageId, getAbortController) => fetchSuggestedQuestions(appDetail!.id, messageId, getAbortController), }, ) - }, [conversationId, handleSend, workflowStore, appDetail]) + }, [chatListRef, conversationId, handleSend, workflowStore, appDetail]) + + const doRegenerate = useCallback((chatItem: ChatItem) => { + const index = chatList.findIndex(item => item.id === chatItem.id) + if (index === -1) + return + + const prevMessages = chatList.slice(0, index) + const question = prevMessages.pop() + const lastAnswer = prevMessages.at(-1) + + if (!question) + return + + handleUpdateChatList(prevMessages) + doSend(question.content, question.message_files, (!lastAnswer || lastAnswer.isOpeningStatement) ? undefined : lastAnswer) + }, [chatList, handleUpdateChatList, doSend]) useImperativeHandle(ref, () => { return { @@ -107,6 +126,7 @@ const ChatWrapper = forwardRef(({ showConv chatFooterClassName='px-4 rounded-bl-2xl' chatFooterInnerClassName='pb-4 w-full max-w-full mx-auto' onSend={doSend} + onRegenerate={doRegenerate} onStopResponding={handleStop} chatNode={( <> diff --git a/web/app/components/workflow/panel/debug-and-preview/hooks.ts b/web/app/components/workflow/panel/debug-and-preview/hooks.ts index 51a018bcb15b43..cad76a4490c85d 100644 --- a/web/app/components/workflow/panel/debug-and-preview/hooks.ts +++ b/web/app/components/workflow/panel/debug-and-preview/hooks.ts @@ -387,6 +387,8 @@ export const useChat = ( return { conversationId: conversationId.current, chatList, + chatListRef, + handleUpdateChatList, handleSend, handleStop, handleRestart, diff --git a/web/hooks/use-app-favicon.ts b/web/hooks/use-app-favicon.ts index 86eadc1b3d0862..1ff743928faaed 100644 --- a/web/hooks/use-app-favicon.ts +++ b/web/hooks/use-app-favicon.ts @@ -5,10 +5,10 @@ import type { AppIconType } from '@/types/app' type UseAppFaviconOptions = { enable?: boolean - icon_type?: AppIconType + icon_type?: AppIconType | null icon?: string - icon_background?: string - icon_url?: string + icon_background?: string | null + icon_url?: string | null } export function useAppFavicon(options: UseAppFaviconOptions) { diff --git a/web/i18n/en-US/app-api.ts b/web/i18n/en-US/app-api.ts index 631faeee9a5e74..355ff306027364 100644 --- a/web/i18n/en-US/app-api.ts +++ b/web/i18n/en-US/app-api.ts @@ -6,6 +6,7 @@ const translation = { ok: 'In Service', copy: 'Copy', copied: 'Copied', + regenerate: 'Regenerate', play: 'Play', pause: 'Pause', playing: 'Playing', diff --git a/web/i18n/zh-Hans/app-api.ts b/web/i18n/zh-Hans/app-api.ts index 6b9048b66efb92..a0defdab625174 100644 --- a/web/i18n/zh-Hans/app-api.ts +++ b/web/i18n/zh-Hans/app-api.ts @@ -6,6 +6,7 @@ const translation = { ok: '运行中', copy: '复制', copied: '已复制', + regenerate: '重新生成', play: '播放', pause: '暂停', playing: '播放中', diff --git a/web/models/log.ts b/web/models/log.ts index 8da1c4cf4e826f..dc557bfe2142ba 100644 --- a/web/models/log.ts +++ b/web/models/log.ts @@ -106,6 +106,7 @@ export type MessageContent = { metadata: Metadata agent_thoughts: any[] // TODO workflow_run_id: string + parent_message_id: string | null } export type CompletionConversationGeneralDetail = {