From 55b8740a3514e245a4bfde126fbbc6e54740f876 Mon Sep 17 00:00:00 2001 From: Weves Date: Wed, 2 Oct 2024 17:50:54 -0700 Subject: [PATCH] Fix Fix Refactor more more fix refactor Fix circular imports Refactor Move tests around --- backend/danswer/chat/models.py | 2 +- backend/danswer/chat/process_message.py | 86 ++- backend/danswer/llm/answering/answer.py | 588 ++++-------------- .../llm/answering/llm_response_handler.py | 83 +++ .../danswer/llm/answering/prompts/build.py | 72 +-- .../llm/answering/prompts/citations_prompt.py | 14 +- .../llm/answering/prompts/quotes_prompt.py | 25 +- .../danswer/llm/answering/prune_and_merge.py | 2 +- .../stream_processing/citation_processing.py | 232 +++---- .../citation_response_handler.py | 61 ++ .../answering/tool/tool_response_handler.py | 205 ++++++ backend/danswer/llm/utils.py | 22 + .../one_shot_answer/answer_question.py | 37 +- .../danswer/server/features/persona/models.py | 2 +- backend/danswer/server/features/tool/api.py | 14 +- backend/danswer/tools/base_tool.py | 59 ++ backend/danswer/tools/built_in_tools.py | 10 +- .../custom/custom_tool_prompt_builder.py | 21 - backend/danswer/tools/tool.py | 27 +- .../custom/base_tool_types.py | 0 .../custom/custom_tool.py | 42 +- .../custom/custom_tool_prompts.py | 0 .../custom/openapi_parsing.py | 0 .../images/image_generation_tool.py | 38 +- .../images/prompt.py | 0 .../internet_search/internet_search_tool.py | 48 +- .../internet_search/models.py | 0 .../search/search_tool.py | 51 +- .../search/search_utils.py | 0 .../search_like_tool_utils.py | 71 +++ backend/danswer/tools/tool_runner.py | 2 +- .../ee/danswer/server/query_and_chat/utils.py | 2 +- .../tests/dev_apis/test_simple_chat_api.py | 3 + .../unit/danswer/llm/answering/conftest.py | 113 ++++ .../test_citation_processing.py | 18 +- .../unit/danswer/llm/answering/test_answer.py | 421 +++++++++++++ .../danswer/llm/answering/test_skip_gen_ai.py | 35 +- .../danswer/tools/custom/test_custom_tools.py | 24 +- 38 files changed, 1633 insertions(+), 797 deletions(-) create mode 100644 backend/danswer/llm/answering/llm_response_handler.py create mode 100644 backend/danswer/llm/answering/stream_processing/citation_response_handler.py create mode 100644 backend/danswer/llm/answering/tool/tool_response_handler.py create mode 100644 backend/danswer/tools/base_tool.py delete mode 100644 backend/danswer/tools/custom/custom_tool_prompt_builder.py rename backend/danswer/tools/{ => tool_implementations}/custom/base_tool_types.py (100%) rename backend/danswer/tools/{ => tool_implementations}/custom/custom_tool.py (88%) rename backend/danswer/tools/{ => tool_implementations}/custom/custom_tool_prompts.py (100%) rename backend/danswer/tools/{ => tool_implementations}/custom/openapi_parsing.py (100%) rename backend/danswer/tools/{ => tool_implementations}/images/image_generation_tool.py (86%) rename backend/danswer/tools/{ => tool_implementations}/images/prompt.py (100%) rename backend/danswer/tools/{ => tool_implementations}/internet_search/internet_search_tool.py (81%) rename backend/danswer/tools/{ => tool_implementations}/internet_search/models.py (100%) rename backend/danswer/tools/{ => tool_implementations}/search/search_tool.py (87%) rename backend/danswer/tools/{ => tool_implementations}/search/search_utils.py (100%) create mode 100644 backend/danswer/tools/tool_implementations/search_like_tool_utils.py create mode 100644 backend/tests/unit/danswer/llm/answering/conftest.py create mode 100644 backend/tests/unit/danswer/llm/answering/test_answer.py diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index 97d5b9e7275..d5925fc2ed9 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -10,7 +10,7 @@ from danswer.search.enums import SearchType from danswer.search.models import RetrievalDocs from danswer.search.models import SearchResponse -from danswer.tools.custom.base_tool_types import ToolResultType +from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType class LlmDoc(BaseModel): diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index f58a34c3243..4ff30dd3c04 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -77,31 +77,49 @@ from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.server.utils import get_json_line from danswer.tools.built_in_tools import get_built_in_tool_by_id -from danswer.tools.custom.custom_tool import ( +from danswer.tools.force import ForceUseTool +from danswer.tools.models import DynamicSchemaInfo +from danswer.tools.models import ToolResponse +from danswer.tools.tool import Tool +from danswer.tools.tool_implementations.custom.custom_tool import ( build_custom_tools_from_openapi_schema_and_headers, ) -from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID -from danswer.tools.custom.custom_tool import CustomToolCallSummary -from danswer.tools.force import ForceUseTool -from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID -from danswer.tools.images.image_generation_tool import ImageGenerationResponse -from danswer.tools.images.image_generation_tool import ImageGenerationTool -from danswer.tools.internet_search.internet_search_tool import ( +from danswer.tools.tool_implementations.custom.custom_tool import ( + CUSTOM_TOOL_RESPONSE_ID, +) +from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary +from danswer.tools.tool_implementations.images.image_generation_tool import ( + IMAGE_GENERATION_RESPONSE_ID, +) +from danswer.tools.tool_implementations.images.image_generation_tool import ( + ImageGenerationResponse, +) +from danswer.tools.tool_implementations.images.image_generation_tool import ( + ImageGenerationTool, +) +from danswer.tools.tool_implementations.internet_search.internet_search_tool import ( INTERNET_SEARCH_RESPONSE_ID, ) -from danswer.tools.internet_search.internet_search_tool import ( +from danswer.tools.tool_implementations.internet_search.internet_search_tool import ( internet_search_response_to_search_docs, ) -from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse -from danswer.tools.internet_search.internet_search_tool import InternetSearchTool -from danswer.tools.models import DynamicSchemaInfo -from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID -from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID -from danswer.tools.search.search_tool import SearchResponseSummary -from danswer.tools.search.search_tool import SearchTool -from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID -from danswer.tools.tool import Tool -from danswer.tools.tool import ToolResponse +from danswer.tools.tool_implementations.internet_search.internet_search_tool import ( + InternetSearchResponse, +) +from danswer.tools.tool_implementations.internet_search.internet_search_tool import ( + InternetSearchTool, +) +from danswer.tools.tool_implementations.search.search_tool import ( + FINAL_CONTEXT_DOCUMENTS_ID, +) +from danswer.tools.tool_implementations.search.search_tool import ( + SEARCH_RESPONSE_SUMMARY_ID, +) +from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary +from danswer.tools.tool_implementations.search.search_tool import SearchTool +from danswer.tools.tool_implementations.search.search_tool import ( + SECTION_RELEVANCE_LIST_ID, +) from danswer.tools.tool_runner import ToolCallFinalResult from danswer.tools.utils import compute_all_tool_tokens from danswer.tools.utils import explicit_tool_calling_supported @@ -532,6 +550,13 @@ def stream_chat_message_objects( if not persona else PromptConfig.from_model(persona.prompts[0]) ) + answer_style_config = AnswerStyleConfig( + citation_config=CitationConfig( + all_docs_useful=selected_db_search_docs is not None + ), + document_pruning_config=document_pruning_config, + structured_response_format=new_msg_req.structured_response_format, + ) # find out what tools to use search_tool: SearchTool | None = None @@ -550,13 +575,16 @@ def stream_chat_message_objects( llm=llm, fast_llm=fast_llm, pruning_config=document_pruning_config, + answer_style_config=answer_style_config, selected_sections=selected_sections, chunks_above=new_msg_req.chunks_above, chunks_below=new_msg_req.chunks_below, full_doc=new_msg_req.full_doc, - evaluation_type=LLMEvaluationType.BASIC - if persona.llm_relevance_filter - else LLMEvaluationType.SKIP, + evaluation_type=( + LLMEvaluationType.BASIC + if persona.llm_relevance_filter + else LLMEvaluationType.SKIP + ), ) tool_dict[db_tool_model.id] = [search_tool] elif tool_cls.__name__ == ImageGenerationTool.__name__: @@ -626,7 +654,11 @@ def stream_chat_message_objects( "Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!" ) tool_dict[db_tool_model.id] = [ - InternetSearchTool(api_key=bing_api_key) + InternetSearchTool( + api_key=bing_api_key, + answer_style_config=answer_style_config, + prompt_config=prompt_config, + ) ] continue @@ -667,13 +699,7 @@ def stream_chat_message_objects( is_connected=is_connected, question=final_msg.message, latest_query_files=latest_query_files, - answer_style_config=AnswerStyleConfig( - citation_config=CitationConfig( - all_docs_useful=selected_db_search_docs is not None - ), - document_pruning_config=document_pruning_config, - structured_response_format=new_msg_req.structured_response_format, - ), + answer_style_config=answer_style_config, prompt_config=prompt_config, llm=( llm diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index d2aeb1b14c4..0aea52c303b 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -1,72 +1,38 @@ -import itertools from collections.abc import Callable from collections.abc import Iterator -from typing import Any -from typing import cast from uuid import uuid4 from langchain.schema.messages import BaseMessage from langchain_core.messages import AIMessageChunk -from langchain_core.messages import HumanMessage +from langchain_core.messages import ToolCall -from danswer.chat.chat_utils import llm_doc_from_inference_section from danswer.chat.models import AnswerQuestionPossibleReturn from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece -from danswer.chat.models import LlmDoc -from danswer.chat.models import StreamStopInfo -from danswer.chat.models import StreamStopReason -from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE from danswer.file_store.utils import InMemoryChatFile +from danswer.llm.answering.llm_response_handler import LLMCall +from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager from danswer.llm.answering.models import AnswerStyleConfig from danswer.llm.answering.models import PreviousMessage from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.models import StreamProcessor from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.answering.prompts.build import default_build_system_message from danswer.llm.answering.prompts.build import default_build_user_message -from danswer.llm.answering.prompts.citations_prompt import ( - build_citations_system_message, +from danswer.llm.answering.stream_processing.citation_response_handler import ( + CitationResponseHandler, ) -from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message -from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message -from danswer.llm.answering.stream_processing.citation_processing import ( - build_citation_processor, +from danswer.llm.answering.stream_processing.citation_response_handler import ( + DummyAnswerResponseHandler, ) -from danswer.llm.answering.stream_processing.quotes_processing import ( - build_quotes_processor, -) -from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping from danswer.llm.answering.stream_processing.utils import map_document_id_order +from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler from danswer.llm.interfaces import LLM -from danswer.llm.interfaces import ToolChoiceOptions from danswer.natural_language_processing.utils import get_tokenizer -from danswer.tools.custom.custom_tool_prompt_builder import ( - build_user_message_for_custom_tool_for_non_tool_calling_llm, -) -from danswer.tools.force import filter_tools_for_force_tool_use from danswer.tools.force import ForceUseTool -from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID -from danswer.tools.images.image_generation_tool import ImageGenerationResponse -from danswer.tools.images.image_generation_tool import ImageGenerationTool -from danswer.tools.images.prompt import build_image_generation_user_prompt -from danswer.tools.internet_search.internet_search_tool import InternetSearchTool -from danswer.tools.message import build_tool_message -from danswer.tools.message import ToolCallSummary -from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID -from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID -from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID -from danswer.tools.search.search_tool import SearchResponseSummary -from danswer.tools.search.search_tool import SearchTool +from danswer.tools.models import ToolResponse from danswer.tools.tool import Tool -from danswer.tools.tool import ToolResponse -from danswer.tools.tool_runner import ( - check_which_tools_should_run_for_non_tool_calling_llm, -) -from danswer.tools.tool_runner import ToolCallFinalResult +from danswer.tools.tool_implementations.search.search_tool import SearchTool from danswer.tools.tool_runner import ToolCallKickoff -from danswer.tools.tool_runner import ToolRunner -from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm from danswer.tools.utils import explicit_tool_calling_supported from danswer.utils.logger import setup_logger @@ -74,29 +40,9 @@ logger = setup_logger() -def _get_answer_stream_processor( - context_docs: list[LlmDoc], - doc_id_to_rank_map: DocumentIdOrderMapping, - answer_style_configs: AnswerStyleConfig, -) -> StreamProcessor: - if answer_style_configs.citation_config: - return build_citation_processor( - context_docs=context_docs, doc_id_to_rank_map=doc_id_to_rank_map - ) - if answer_style_configs.quotes_config: - return build_quotes_processor( - context_docs=context_docs, is_json_prompt=not (QA_PROMPT_OVERRIDE == "weak") - ) - - raise RuntimeError("Not implemented yet") - - AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse] -logger = setup_logger() - - class Answer: def __init__( self, @@ -136,8 +82,6 @@ def __init__( self.tools = tools or [] self.force_use_tool = force_use_tool - self.skip_explicit_tool_calling = skip_explicit_tool_calling - self.message_history = message_history or [] # used for QA flow where we only want to send a single message self.single_message_history = single_message_history @@ -162,335 +106,132 @@ def __init__( self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation self._is_cancelled = False - def _update_prompt_builder_for_search_tool( - self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc] - ) -> None: - if self.answer_style_config.citation_config: - prompt_builder.update_system_prompt( - build_citations_system_message(self.prompt_config) - ) - prompt_builder.update_user_prompt( - build_citations_user_message( - question=self.question, - prompt_config=self.prompt_config, - context_docs=final_context_documents, - files=self.latest_query_files, - all_doc_useful=( - self.answer_style_config.citation_config.all_docs_useful - if self.answer_style_config.citation_config - else False - ), - history_message=self.single_message_history or "", - ) - ) - elif self.answer_style_config.quotes_config: - prompt_builder.update_user_prompt( - build_quotes_user_message( - question=self.question, - context_docs=final_context_documents, - history_str=self.single_message_history or "", - prompt=self.prompt_config, - ) + self.using_tool_calling_llm = ( + explicit_tool_calling_supported( + self.llm.config.model_provider, self.llm.config.model_name ) + and not skip_explicit_tool_calling + ) - def _raw_output_for_explicit_tool_calling_llms( - self, - ) -> Iterator[ - str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult - ]: - prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) - - tool_call_chunk: AIMessageChunk | None = None - if self.force_use_tool.force_use and self.force_use_tool.args is not None: - # if we are forcing a tool WITH args specified, we don't need to check which tools to run - # / need to generate the args - tool_call_chunk = AIMessageChunk( - content="", - ) - tool_call_chunk.tool_calls = [ - { - "name": self.force_use_tool.tool_name, - "args": self.force_use_tool.args, - "id": str(uuid4()), - } - ] - else: - # if tool calling is supported, first try the raw message - # to see if we don't need to use any tools - prompt_builder.update_system_prompt( - default_build_system_message(self.prompt_config) - ) - prompt_builder.update_user_prompt( - default_build_user_message( - self.question, self.prompt_config, self.latest_query_files - ) - ) - prompt = prompt_builder.build() - final_tool_definitions = [ - tool.tool_definition() - for tool in filter_tools_for_force_tool_use( - self.tools, self.force_use_tool - ) - ] - - for message in self.llm.stream( - prompt=prompt, - tools=final_tool_definitions if final_tool_definitions else None, - tool_choice="required" if self.force_use_tool.force_use else None, - structured_response_format=self.answer_style_config.structured_response_format, - ): - if isinstance(message, AIMessageChunk) and ( - message.tool_call_chunks or message.tool_calls - ): - if tool_call_chunk is None: - tool_call_chunk = message - else: - tool_call_chunk += message # type: ignore - else: - if message.content: - if self.is_cancelled: - return - yield cast(str, message.content) - if ( - message.additional_kwargs.get("usage_metadata", {}).get("stop") - == "length" - ): - yield StreamStopInfo( - stop_reason=StreamStopReason.CONTEXT_LENGTH - ) - - if not tool_call_chunk: - return # no tool call needed - - # if we have a tool call, we need to call the tool - tool_call_requests = tool_call_chunk.tool_calls - for tool_call_request in tool_call_requests: - known_tools_by_name = [ - tool for tool in self.tools if tool.name == tool_call_request["name"] - ] - - if not known_tools_by_name: - logger.error( - "Tool call requested with unknown name field. \n" - f"self.tools: {self.tools}" - f"tool_call_request: {tool_call_request}" - ) - if self.tools: - tool = self.tools[0] - else: - continue - else: - tool = known_tools_by_name[0] - tool_args = ( - self.force_use_tool.args - if self.force_use_tool.tool_name == tool.name - and self.force_use_tool.args - else tool_call_request["args"] - ) + def _get_tools_list(self) -> list[Tool]: + if not self.force_use_tool.force_use: + return self.tools - tool_runner = ToolRunner(tool, tool_args) - yield tool_runner.kickoff() - yield from tool_runner.tool_responses() + tool = next( + (t for t in self.tools if t.name == self.force_use_tool.tool_name), None + ) + if tool is None: + raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found") - tool_call_summary = ToolCallSummary( - tool_call_request=tool_call_chunk, - tool_call_result=build_tool_message( - tool_call_request, tool_runner.tool_message_content() - ), + logger.info( + f"Forcefully using tool='{tool.name}'" + + ( + f" with args='{self.force_use_tool.args}'" + if self.force_use_tool.args is not None + else "" ) + ) + return [tool] - if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}: - self._update_prompt_builder_for_search_tool(prompt_builder, []) - elif tool.name == ImageGenerationTool._NAME: - img_urls = [ - img_generation_result["url"] - for img_generation_result in tool_runner.tool_final_result().tool_result - ] - prompt_builder.update_user_prompt( - build_image_generation_user_prompt( - query=self.question, img_urls=img_urls - ) - ) - yield tool_runner.tool_final_result() - if not self.skip_gen_ai_answer_generation: - prompt = prompt_builder.build(tool_call_summary=tool_call_summary) - - yield from self._process_llm_stream( - prompt=prompt, - # as of now, we don't support multiple tool calls in sequence, which is why - # we don't need to pass this in here - # tools=[tool.tool_definition() for tool in self.tools], - ) + def _handle_specified_tool_call( + self, llm_calls: list[LLMCall], tool: Tool, tool_args: dict + ) -> AnswerStream: + current_llm_call = llm_calls[-1] - return + # make a dummy tool handler + tool_handler = ToolResponseHandler([tool]) - # This method processes the LLM stream and yields the content or stop information - def _process_llm_stream( - self, - prompt: Any, - tools: list[dict] | None = None, - tool_choice: ToolChoiceOptions | None = None, - ) -> Iterator[str | StreamStopInfo]: - for message in self.llm.stream( - prompt=prompt, - tools=tools, - tool_choice=tool_choice, - structured_response_format=self.answer_style_config.structured_response_format, - ): - if isinstance(message, AIMessageChunk): - if message.content: - if self.is_cancelled: - return StreamStopInfo(stop_reason=StreamStopReason.CANCELLED) - yield cast(str, message.content) - - if ( - message.additional_kwargs.get("usage_metadata", {}).get("stop") - == "length" - ): - yield StreamStopInfo(stop_reason=StreamStopReason.CONTEXT_LENGTH) - - def _raw_output_for_non_explicit_tool_calling_llms( - self, - ) -> Iterator[ - str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult - ]: - prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) - chosen_tool_and_args: tuple[Tool, dict] | None = None - - if self.force_use_tool.force_use: - # if we are forcing a tool, we don't need to check which tools to run - tool = next( - iter( - [ - tool - for tool in self.tools - if tool.name == self.force_use_tool.tool_name - ] - ), - None, - ) - if not tool: - raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found") + dummy_tool_call_chunk = AIMessageChunk(content="") + dummy_tool_call_chunk.tool_calls = [ + ToolCall(name=tool.name, args=tool_args, id=str(uuid4())) + ] - tool_args = ( - self.force_use_tool.args - if self.force_use_tool.args is not None - else tool.get_args_for_non_tool_calling_llm( - query=self.question, - history=self.message_history, - llm=self.llm, - force_run=True, - ) - ) - - if tool_args is None: - raise RuntimeError(f"Tool '{tool.name}' did not return args") + response_handler_manager = LLMResponseHandlerManager( + tool_handler, DummyAnswerResponseHandler(), self.is_cancelled + ) + yield from response_handler_manager.handle_llm_response( + iter([dummy_tool_call_chunk]) + ) - chosen_tool_and_args = (tool, tool_args) + new_llm_call = response_handler_manager.next_llm_call(current_llm_call) + if new_llm_call: + yield from self._get_response(llm_calls + [new_llm_call]) else: - tool_options = check_which_tools_should_run_for_non_tool_calling_llm( - tools=self.tools, - query=self.question, - history=self.message_history, - llm=self.llm, - ) + raise RuntimeError("Tool call handler did not return a new LLM call") - available_tools_and_args = [ - (self.tools[ind], args) - for ind, args in enumerate(tool_options) - if args is not None - ] + def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream: + current_llm_call = llm_calls[-1] - logger.info( - f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}" + # handle the case where no decision has to be made; we simply run the tool + if ( + current_llm_call.force_use_tool.force_use + and current_llm_call.force_use_tool.args is not None + ): + tool_name, tool_args = ( + current_llm_call.force_use_tool.tool_name, + current_llm_call.force_use_tool.args, ) - - chosen_tool_and_args = ( - select_single_tool_for_non_tool_calling_llm( - tools_and_args=available_tools_and_args, - history=self.message_history, - query=self.question, - llm=self.llm, - ) - if available_tools_and_args - else None + tool = next( + (t for t in current_llm_call.tools if t.name == tool_name), None ) + if not tool: + raise RuntimeError(f"Tool '{tool_name}' not found") - logger.notice(f"Chosen tool: {chosen_tool_and_args}") + yield from self._handle_specified_tool_call(llm_calls, tool, tool_args) + return - if not chosen_tool_and_args: - if self.skip_gen_ai_answer_generation: - raise ValueError( - "skip_gen_ai_answer_generation is True, but no tool was chosen; no answer will be generated" - ) - prompt_builder.update_system_prompt( - default_build_system_message(self.prompt_config) - ) - prompt_builder.update_user_prompt( - default_build_user_message( - self.question, self.prompt_config, self.latest_query_files + # special pre-logic for non-tool calling LLM case + if not self.using_tool_calling_llm and current_llm_call.tools: + chosen_tool_and_args = ( + ToolResponseHandler.get_tool_call_for_non_tool_calling_llm( + current_llm_call, self.llm ) ) - prompt = prompt_builder.build() - yield from self._process_llm_stream( - prompt=prompt, - tools=None, - ) + if chosen_tool_and_args: + tool, tool_args = chosen_tool_and_args + yield from self._handle_specified_tool_call(llm_calls, tool, tool_args) + return + + # if we're skipping gen ai answer generation, we should break + # out unless we're forcing a tool call. If we don't, we might generate an + # answer, which is a no-no! + if ( + self.skip_gen_ai_answer_generation + and not current_llm_call.force_use_tool.force_use + ): return - tool, tool_args = chosen_tool_and_args - tool_runner = ToolRunner(tool, tool_args) - yield tool_runner.kickoff() + # set up "handlers" to listen to the LLM response stream and + # feed back the processed results + handle tool call requests + # + figure out what the next LLM call should be + tool_call_handler = ToolResponseHandler(current_llm_call.tools) - if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}: - final_context_documents = None - for response in tool_runner.tool_responses(): - if response.id == FINAL_CONTEXT_DOCUMENTS_ID: - final_context_documents = cast(list[LlmDoc], response.response) - yield response - - if final_context_documents is None: - raise RuntimeError( - f"{tool.name} did not return final context documents" - ) + search_result = SearchTool.get_search_result(current_llm_call) or [] + citation_response_handler = CitationResponseHandler( + context_docs=search_result, + doc_id_to_rank_map=map_document_id_order(search_result), + ) - self._update_prompt_builder_for_search_tool( - prompt_builder, final_context_documents - ) - elif tool.name == ImageGenerationTool._NAME: - img_urls = [] - for response in tool_runner.tool_responses(): - if response.id == IMAGE_GENERATION_RESPONSE_ID: - img_generation_response = cast( - list[ImageGenerationResponse], response.response - ) - img_urls = [img.url for img in img_generation_response] - - yield response - - prompt_builder.update_user_prompt( - build_image_generation_user_prompt( - query=self.question, - img_urls=img_urls, - ) - ) - else: - prompt_builder.update_user_prompt( - HumanMessage( - content=build_user_message_for_custom_tool_for_non_tool_calling_llm( - self.question, - tool.name, - *tool_runner.tool_responses(), - ) - ) - ) - final = tool_runner.tool_final_result() + response_handler_manager = LLMResponseHandlerManager( + tool_call_handler, citation_response_handler, self.is_cancelled + ) - yield final - if not self.skip_gen_ai_answer_generation: - prompt = prompt_builder.build() + # DEBUG: good breakpoint + stream = self.llm.stream( + prompt=current_llm_call.prompt_builder.build(), + tools=[tool.tool_definition() for tool in current_llm_call.tools] or None, + tool_choice=( + "required" + if current_llm_call.tools and current_llm_call.force_use_tool.force_use + else None + ), + structured_response_format=self.answer_style_config.structured_response_format, + ) + yield from response_handler_manager.handle_llm_response(stream) - yield from self._process_llm_stream(prompt=prompt, tools=None) + new_llm_call = response_handler_manager.next_llm_call(current_llm_call) + if new_llm_call: + yield from self._get_response(llm_calls + [new_llm_call]) @property def processed_streamed_output(self) -> AnswerStream: @@ -498,94 +239,30 @@ def processed_streamed_output(self) -> AnswerStream: yield from self._processed_stream return - output_generator = ( - self._raw_output_for_explicit_tool_calling_llms() - if explicit_tool_calling_supported( - self.llm.config.model_provider, self.llm.config.model_name - ) - and not self.skip_explicit_tool_calling - else self._raw_output_for_non_explicit_tool_calling_llms() + prompt_builder = AnswerPromptBuilder( + user_message=default_build_user_message( + user_query=self.question, + prompt_config=self.prompt_config, + files=self.latest_query_files, + ), + message_history=self.message_history, + llm_config=self.llm.config, + single_message_history=self.single_message_history, + ) + prompt_builder.update_system_prompt( + default_build_system_message(self.prompt_config) + ) + llm_call = LLMCall( + prompt_builder=prompt_builder, + tools=self._get_tools_list(), + force_use_tool=self.force_use_tool, + files=self.latest_query_files, + tool_call_info=[], + using_tool_calling_llm=self.using_tool_calling_llm, ) - - def _process_stream( - stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo], - ) -> AnswerStream: - message = None - - # special things we need to keep track of for the SearchTool - # raw results that will be displayed to the user - search_results: list[LlmDoc] | None = None - # processed docs to feed into the LLM - final_context_docs: list[LlmDoc] | None = None - - for message in stream: - if isinstance(message, ToolCallKickoff) or isinstance( - message, ToolCallFinalResult - ): - yield message - elif isinstance(message, ToolResponse): - if message.id == SEARCH_RESPONSE_SUMMARY_ID: - # We don't need to run section merging in this flow, this variable is only used - # below to specify the ordering of the documents for the purpose of matching - # citations to the right search documents. The deduplication logic is more lightweight - # there and we don't need to do it twice - search_results = [ - llm_doc_from_inference_section(section) - for section in cast( - SearchResponseSummary, message.response - ).top_sections - ] - elif message.id == FINAL_CONTEXT_DOCUMENTS_ID: - final_context_docs = cast(list[LlmDoc], message.response) - yield message - - elif ( - message.id == SEARCH_DOC_CONTENT_ID - and not self._return_contexts - ): - continue - - yield message - else: - # assumes all tool responses will come first, then the final answer - break - - if not self.skip_gen_ai_answer_generation: - process_answer_stream_fn = _get_answer_stream_processor( - context_docs=final_context_docs or [], - # if doc selection is enabled, then search_results will be None, - # so we need to use the final_context_docs - doc_id_to_rank_map=map_document_id_order( - search_results or final_context_docs or [] - ), - answer_style_configs=self.answer_style_config, - ) - - stream_stop_info = None - - def _stream() -> Iterator[str]: - nonlocal stream_stop_info - for item in itertools.chain([message], stream): - if isinstance(item, StreamStopInfo): - stream_stop_info = item - return - - # this should never happen, but we're seeing weird behavior here so handling for now - if not isinstance(item, str): - logger.error( - f"Received non-string item in answer stream: {item}. Skipping." - ) - continue - - yield item - - yield from process_answer_stream_fn(_stream()) - - if stream_stop_info: - yield stream_stop_info processed_stream = [] - for processed_packet in _process_stream(output_generator): + for processed_packet in self._get_response([llm_call]): processed_stream.append(processed_packet) yield processed_packet @@ -609,7 +286,6 @@ def citations(self) -> list[CitationInfo]: return citations - @property def is_cancelled(self) -> bool: if self._is_cancelled: return True diff --git a/backend/danswer/llm/answering/llm_response_handler.py b/backend/danswer/llm/answering/llm_response_handler.py new file mode 100644 index 00000000000..6578e808952 --- /dev/null +++ b/backend/danswer/llm/answering/llm_response_handler.py @@ -0,0 +1,83 @@ +from collections.abc import Callable +from collections.abc import Generator +from collections.abc import Iterator +from typing import TYPE_CHECKING + +from langchain_core.messages import BaseMessage +from pydantic.v1 import BaseModel as BaseModel__v1 + +from danswer.chat.models import CitationInfo +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import StreamStopInfo +from danswer.chat.models import StreamStopReason +from danswer.file_store.models import InMemoryChatFile +from danswer.llm.answering.prompts.build import AnswerPromptBuilder +from danswer.tools.force import ForceUseTool +from danswer.tools.models import ToolCallFinalResult +from danswer.tools.models import ToolCallKickoff +from danswer.tools.models import ToolResponse +from danswer.tools.tool import Tool + + +if TYPE_CHECKING: + from danswer.llm.answering.stream_processing.citation_response_handler import ( + AnswerResponseHandler, + ) + from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler + + +ResponsePart = ( + DanswerAnswerPiece + | CitationInfo + | ToolCallKickoff + | ToolResponse + | ToolCallFinalResult + | StreamStopInfo +) + + +class LLMCall(BaseModel__v1): + prompt_builder: AnswerPromptBuilder + tools: list[Tool] + force_use_tool: ForceUseTool + files: list[InMemoryChatFile] + tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult] + using_tool_calling_llm: bool + + class Config: + arbitrary_types_allowed = True + + +class LLMResponseHandlerManager: + def __init__( + self, + tool_handler: "ToolResponseHandler", + answer_handler: "AnswerResponseHandler", + is_cancelled: Callable[[], bool], + ): + self.tool_handler = tool_handler + self.answer_handler = answer_handler + self.is_cancelled = is_cancelled + + def handle_llm_response( + self, + stream: Iterator[BaseMessage], + ) -> Generator[ResponsePart, None, None]: + all_messages: list[BaseMessage] = [] + for message in stream: + # tool handler doesn't do anything until the full message is received + # NOTE: still need to run list() to get this to run + list(self.tool_handler.handle_response_part(message, all_messages)) + yield from self.answer_handler.handle_response_part(message, all_messages) + all_messages.append(message) + + if self.is_cancelled(): + yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED) + return + + # potentially give back all info on the selected tool call + its result + yield from self.tool_handler.handle_response_part(None, all_messages) + yield from self.answer_handler.handle_response_part(None, all_messages) + + def next_llm_call(self, llm_call: LLMCall) -> LLMCall | None: + return self.tool_handler.next_llm_call(llm_call) diff --git a/backend/danswer/llm/answering/prompts/build.py b/backend/danswer/llm/answering/prompts/build.py index f53d4481f6e..b5b774f522d 100644 --- a/backend/danswer/llm/answering/prompts/build.py +++ b/backend/danswer/llm/answering/prompts/build.py @@ -12,12 +12,12 @@ from danswer.llm.interfaces import LLMConfig from danswer.llm.utils import build_content_with_imgs from danswer.llm.utils import check_message_tokens +from danswer.llm.utils import message_to_prompt_and_imgs from danswer.llm.utils import translate_history_to_basemessages from danswer.natural_language_processing.utils import get_tokenizer from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT from danswer.prompts.prompt_utils import add_date_time_to_prompt from danswer.prompts.prompt_utils import drop_messages_history_overflow -from danswer.tools.message import ToolCallSummary def default_build_system_message( @@ -54,18 +54,14 @@ def default_build_user_message( class AnswerPromptBuilder: def __init__( - self, message_history: list[PreviousMessage], llm_config: LLMConfig + self, + user_message: HumanMessage, + message_history: list[PreviousMessage], + llm_config: LLMConfig, + single_message_history: str | None = None, ) -> None: self.max_tokens = compute_max_llm_input_tokens(llm_config) - ( - self.message_history, - self.history_token_cnts, - ) = translate_history_to_basemessages(message_history) - - self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None - self.user_message_and_token_cnt: tuple[HumanMessage, int] | None = None - llm_tokenizer = get_tokenizer( provider_type=llm_config.model_provider, model_name=llm_config.model_name, @@ -74,6 +70,24 @@ def __init__( Callable[[str], list[int]], llm_tokenizer.encode ) + self.raw_message_history = message_history + ( + self.message_history, + self.history_token_cnts, + ) = translate_history_to_basemessages(message_history) + + # for cases where like the QA flow where we want to condense the chat history + # into a single message rather than a sequence of User / Assistant messages + self.single_message_history = single_message_history + + self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None + self.user_message_and_token_cnt = ( + user_message, + check_message_tokens(user_message, self.llm_tokenizer_encode_func), + ) + + self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = [] + def update_system_prompt(self, system_message: SystemMessage | None) -> None: if not system_message: self.system_message_and_token_cnt = None @@ -85,18 +99,21 @@ def update_system_prompt(self, system_message: SystemMessage | None) -> None: ) def update_user_prompt(self, user_message: HumanMessage) -> None: - if not user_message: - self.user_message_and_token_cnt = None - return - self.user_message_and_token_cnt = ( user_message, check_message_tokens(user_message, self.llm_tokenizer_encode_func), ) - def build( - self, tool_call_summary: ToolCallSummary | None = None - ) -> list[BaseMessage]: + def append_message(self, message: BaseMessage) -> None: + """Append a new message to the message history.""" + token_count = check_message_tokens(message, self.llm_tokenizer_encode_func) + self.new_messages_and_token_cnts.append((message, token_count)) + + def get_user_message_content(self) -> str: + query, _ = message_to_prompt_and_imgs(self.user_message_and_token_cnt[0]) + return query + + def build(self) -> list[BaseMessage]: if not self.user_message_and_token_cnt: raise ValueError("User message must be set before building prompt") @@ -113,25 +130,8 @@ def build( final_messages_with_tokens.append(self.user_message_and_token_cnt) - if tool_call_summary: - final_messages_with_tokens.append( - ( - tool_call_summary.tool_call_request, - check_message_tokens( - tool_call_summary.tool_call_request, - self.llm_tokenizer_encode_func, - ), - ) - ) - final_messages_with_tokens.append( - ( - tool_call_summary.tool_call_result, - check_message_tokens( - tool_call_summary.tool_call_result, - self.llm_tokenizer_encode_func, - ), - ) - ) + if self.new_messages_and_token_cnts: + final_messages_with_tokens.extend(self.new_messages_and_token_cnts) return drop_messages_history_overflow( final_messages_with_tokens, self.max_tokens diff --git a/backend/danswer/llm/answering/prompts/citations_prompt.py b/backend/danswer/llm/answering/prompts/citations_prompt.py index a2248da0585..b7ca7797e88 100644 --- a/backend/danswer/llm/answering/prompts/citations_prompt.py +++ b/backend/danswer/llm/answering/prompts/citations_prompt.py @@ -6,7 +6,6 @@ from danswer.db.models import Persona from danswer.db.persona import get_default_prompt__read_only from danswer.db.search_settings import get_multilingual_expansion -from danswer.file_store.utils import InMemoryChatFile from danswer.llm.answering.models import PromptConfig from danswer.llm.factory import get_llms_for_persona from danswer.llm.factory import get_main_llm_from_tuple @@ -14,6 +13,7 @@ from danswer.llm.utils import build_content_with_imgs from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import get_max_input_tokens +from danswer.llm.utils import message_to_prompt_and_imgs from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT @@ -132,10 +132,9 @@ def build_citations_system_message( def build_citations_user_message( - question: str, + message: HumanMessage, prompt_config: PromptConfig, context_docs: list[LlmDoc] | list[InferenceChunk], - files: list[InMemoryChatFile], all_doc_useful: bool, history_message: str = "", ) -> HumanMessage: @@ -149,6 +148,7 @@ def build_citations_user_message( if history_message else "" ) + query, img_urls = message_to_prompt_and_imgs(message) if context_docs: context_docs_str = build_complete_context_str(context_docs) @@ -158,20 +158,22 @@ def build_citations_user_message( optional_ignore_statement=optional_ignore, context_docs_str=context_docs_str, task_prompt=task_prompt_with_reminder, - user_query=question, + user_query=query, history_block=history_block, ) else: # if no context docs provided, assume we're in the tool calling flow user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format( task_prompt=task_prompt_with_reminder, - user_query=question, + user_query=query, history_block=history_block, ) user_prompt = user_prompt.strip() user_msg = HumanMessage( - content=build_content_with_imgs(user_prompt, files) if files else user_prompt + content=build_content_with_imgs(user_prompt, img_urls=img_urls) + if img_urls + else user_prompt ) return user_msg diff --git a/backend/danswer/llm/answering/prompts/quotes_prompt.py b/backend/danswer/llm/answering/prompts/quotes_prompt.py index 07abc4356b6..42f736b627d 100644 --- a/backend/danswer/llm/answering/prompts/quotes_prompt.py +++ b/backend/danswer/llm/answering/prompts/quotes_prompt.py @@ -5,6 +5,7 @@ from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE from danswer.db.search_settings import get_multilingual_expansion from danswer.llm.answering.models import PromptConfig +from danswer.llm.utils import message_to_prompt_and_imgs from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK from danswer.prompts.direct_qa_prompts import JSON_PROMPT @@ -75,7 +76,7 @@ def _build_strong_llm_quotes_prompt( def build_quotes_user_message( - question: str, + message: HumanMessage, context_docs: list[LlmDoc] | list[InferenceChunk], history_str: str, prompt: PromptConfig, @@ -86,28 +87,10 @@ def build_quotes_user_message( else _build_strong_llm_quotes_prompt ) - return prompt_builder( - question=question, - context_docs=context_docs, - history_str=history_str, - prompt=prompt, - ) - - -def build_quotes_prompt( - question: str, - context_docs: list[LlmDoc] | list[InferenceChunk], - history_str: str, - prompt: PromptConfig, -) -> HumanMessage: - prompt_builder = ( - _build_weak_llm_quotes_prompt - if QA_PROMPT_OVERRIDE == "weak" - else _build_strong_llm_quotes_prompt - ) + query, _ = message_to_prompt_and_imgs(message) return prompt_builder( - question=question, + question=query, context_docs=context_docs, history_str=history_str, prompt=prompt, diff --git a/backend/danswer/llm/answering/prune_and_merge.py b/backend/danswer/llm/answering/prune_and_merge.py index 0193de1f2aa..690a5d2280d 100644 --- a/backend/danswer/llm/answering/prune_and_merge.py +++ b/backend/danswer/llm/answering/prune_and_merge.py @@ -19,7 +19,7 @@ from danswer.prompts.prompt_utils import build_doc_context_str from danswer.search.models import InferenceChunk from danswer.search.models import InferenceSection -from danswer.tools.search.search_utils import section_to_dict +from danswer.tools.tool_implementations.search.search_utils import section_to_dict from danswer.utils.logger import setup_logger diff --git a/backend/danswer/llm/answering/stream_processing/citation_processing.py b/backend/danswer/llm/answering/stream_processing/citation_processing.py index f1e5489550d..950ad207878 100644 --- a/backend/danswer/llm/answering/stream_processing/citation_processing.py +++ b/backend/danswer/llm/answering/stream_processing/citation_processing.py @@ -1,12 +1,10 @@ import re -from collections.abc import Iterator +from collections.abc import Generator -from danswer.chat.models import AnswerQuestionStreamReturn from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc from danswer.configs.chat_configs import STOP_STREAM_PAT -from danswer.llm.answering.models import StreamProcessor from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping from danswer.prompts.constants import TRIPLE_BACKTICK from danswer.utils.logger import setup_logger @@ -19,128 +17,104 @@ def in_code_block(llm_text: str) -> bool: return count % 2 != 0 -def extract_citations_from_stream( - tokens: Iterator[str], - context_docs: list[LlmDoc], - doc_id_to_rank_map: DocumentIdOrderMapping, - stop_stream: str | None = STOP_STREAM_PAT, -) -> Iterator[DanswerAnswerPiece | CitationInfo]: - """ - Key aspects: - - 1. Stream Processing: - - Processes tokens one by one, allowing for real-time handling of large texts. - - 2. Citation Detection: - - Uses regex to find citations in the format [number]. - - Example: [1], [2], etc. - - 3. Citation Mapping: - - Maps detected citation numbers to actual document ranks using doc_id_to_rank_map. - - Example: [1] might become [3] if doc_id_to_rank_map maps it to 3. - - 4. Citation Formatting: - - Replaces citations with properly formatted versions. - - Adds links if available: [[1]](https://example.com) - - Handles cases where links are not available: [[1]]() - - 5. Duplicate Handling: - - Skips consecutive citations of the same document to avoid redundancy. - - 6. Output Generation: - - Yields DanswerAnswerPiece objects for regular text. - - Yields CitationInfo objects for each unique citation encountered. - - 7. Context Awareness: - - Uses context_docs to access document information for citations. - - This function effectively processes a stream of text, identifies and reformats citations, - and provides both the processed text and citation information as output. - """ - order_mapping = doc_id_to_rank_map.order_mapping - llm_out = "" - max_citation_num = len(context_docs) - citation_order = [] - curr_segment = "" - cited_inds = set() - hold = "" - - raw_out = "" - current_citations: list[int] = [] - past_cite_count = 0 - for raw_token in tokens: - raw_out += raw_token - if stop_stream: - next_hold = hold + raw_token - if stop_stream in next_hold: - break - if next_hold == stop_stream[: len(next_hold)]: - hold = next_hold - continue +class CitationProcessor: + def __init__( + self, + context_docs: list[LlmDoc], + doc_id_to_rank_map: DocumentIdOrderMapping, + stop_stream: str | None = STOP_STREAM_PAT, + ): + self.context_docs = context_docs + self.doc_id_to_rank_map = doc_id_to_rank_map + self.stop_stream = stop_stream + self.order_mapping = doc_id_to_rank_map.order_mapping + self.llm_out = "" + self.max_citation_num = len(context_docs) + self.citation_order: list[int] = [] + self.curr_segment = "" + self.cited_inds: set[int] = set() + self.hold = "" + self.current_citations: list[int] = [] + self.past_cite_count = 0 + + def process_token( + self, token: str | None + ) -> Generator[DanswerAnswerPiece | CitationInfo, None, None]: + # None -> end of stream + if token is None: + yield DanswerAnswerPiece(answer_piece=self.curr_segment) + return + + if self.stop_stream: + next_hold = self.hold + token + if self.stop_stream in next_hold: + return + if next_hold == self.stop_stream[: len(next_hold)]: + self.hold = next_hold + return token = next_hold - hold = "" - else: - token = raw_token + self.hold = "" - curr_segment += token - llm_out += token + self.curr_segment += token + self.llm_out += token # Handle code blocks without language tags - if "`" in curr_segment: - if curr_segment.endswith("`"): - continue - elif "```" in curr_segment: - piece_that_comes_after = curr_segment.split("```")[1][0] - if piece_that_comes_after == "\n" and in_code_block(llm_out): - curr_segment = curr_segment.replace("```", "```plaintext") + if "`" in self.curr_segment: + if self.curr_segment.endswith("`"): + return + elif "```" in self.curr_segment: + piece_that_comes_after = self.curr_segment.split("```")[1][0] + if piece_that_comes_after == "\n" and in_code_block(self.llm_out): + self.curr_segment = self.curr_segment.replace("```", "```plaintext") citation_pattern = r"\[(\d+)\]" - - citations_found = list(re.finditer(citation_pattern, curr_segment)) + citations_found = list(re.finditer(citation_pattern, self.curr_segment)) possible_citation_pattern = r"(\[\d*$)" # [1, [, etc - possible_citation_found = re.search(possible_citation_pattern, curr_segment) + possible_citation_found = re.search( + possible_citation_pattern, self.curr_segment + ) - # `past_cite_count`: number of characters since past citation - # 5 to ensure a citation hasn't occured - if len(citations_found) == 0 and len(llm_out) - past_cite_count > 5: - current_citations = [] + if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5: + self.current_citations = [] - if citations_found and not in_code_block(llm_out): + result = "" # Initialize result here + if citations_found and not in_code_block(self.llm_out): last_citation_end = 0 length_to_add = 0 while len(citations_found) > 0: citation = citations_found.pop(0) numerical_value = int(citation.group(1)) - if 1 <= numerical_value <= max_citation_num: - context_llm_doc = context_docs[numerical_value - 1] - real_citation_num = order_mapping[context_llm_doc.document_id] + if 1 <= numerical_value <= self.max_citation_num: + context_llm_doc = self.context_docs[numerical_value - 1] + real_citation_num = self.order_mapping[context_llm_doc.document_id] - if real_citation_num not in citation_order: - citation_order.append(real_citation_num) + if real_citation_num not in self.citation_order: + self.citation_order.append(real_citation_num) - target_citation_num = citation_order.index(real_citation_num) + 1 + target_citation_num = ( + self.citation_order.index(real_citation_num) + 1 + ) # Skip consecutive citations of the same work - if target_citation_num in current_citations: + if target_citation_num in self.current_citations: start, end = citation.span() real_start = length_to_add + start diff = end - start - curr_segment = ( - curr_segment[: length_to_add + start] - + curr_segment[real_start + diff :] + self.curr_segment = ( + self.curr_segment[: length_to_add + start] + + self.curr_segment[real_start + diff :] ) length_to_add -= diff continue # Handle edge case where LLM outputs citation itself - # by allowing it to generate citations on its own. - if curr_segment.startswith("[["): - match = re.match(r"\[\[(\d+)\]\]", curr_segment) + if self.curr_segment.startswith("[["): + match = re.match(r"\[\[(\d+)\]\]", self.curr_segment) if match: try: doc_id = int(match.group(1)) - context_llm_doc = context_docs[doc_id - 1] + context_llm_doc = self.context_docs[doc_id - 1] yield CitationInfo( citation_num=target_citation_num, document_id=context_llm_doc.document_id, @@ -150,75 +124,57 @@ def extract_citations_from_stream( f"Manual LLM citation didn't properly cite documents {e}" ) else: - # Will continue attempt on next loops logger.warning( "Manual LLM citation wasn't able to close brackets" ) - continue link = context_llm_doc.link # Replace the citation in the current segment start, end = citation.span() - curr_segment = ( - curr_segment[: start + length_to_add] + self.curr_segment = ( + self.curr_segment[: start + length_to_add] + f"[{target_citation_num}]" - + curr_segment[end + length_to_add :] + + self.curr_segment[end + length_to_add :] ) - past_cite_count = len(llm_out) - current_citations.append(target_citation_num) + self.past_cite_count = len(self.llm_out) + self.current_citations.append(target_citation_num) - if target_citation_num not in cited_inds: - cited_inds.add(target_citation_num) + if target_citation_num not in self.cited_inds: + self.cited_inds.add(target_citation_num) yield CitationInfo( citation_num=target_citation_num, document_id=context_llm_doc.document_id, ) if link: - prev_length = len(curr_segment) - curr_segment = ( - curr_segment[: start + length_to_add] + prev_length = len(self.curr_segment) + self.curr_segment = ( + self.curr_segment[: start + length_to_add] + f"[[{target_citation_num}]]({link})" - + curr_segment[end + length_to_add :] + + self.curr_segment[end + length_to_add :] ) - length_to_add += len(curr_segment) - prev_length - + length_to_add += len(self.curr_segment) - prev_length else: - prev_length = len(curr_segment) - curr_segment = ( - curr_segment[: start + length_to_add] + prev_length = len(self.curr_segment) + self.curr_segment = ( + self.curr_segment[: start + length_to_add] + f"[[{target_citation_num}]]()" - + curr_segment[end + length_to_add :] + + self.curr_segment[end + length_to_add :] ) - length_to_add += len(curr_segment) - prev_length + length_to_add += len(self.curr_segment) - prev_length last_citation_end = end + length_to_add if last_citation_end > 0: - yield DanswerAnswerPiece(answer_piece=curr_segment[:last_citation_end]) - curr_segment = curr_segment[last_citation_end:] - if possible_citation_found: - continue - yield DanswerAnswerPiece(answer_piece=curr_segment) - curr_segment = "" - - if curr_segment: - yield DanswerAnswerPiece(answer_piece=curr_segment) - - -def build_citation_processor( - context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping -) -> StreamProcessor: - def stream_processor( - tokens: Iterator[str], - ) -> AnswerQuestionStreamReturn: - yield from extract_citations_from_stream( - tokens=tokens, - context_docs=context_docs, - doc_id_to_rank_map=doc_id_to_rank_map, - ) + result += self.curr_segment[:last_citation_end] + self.curr_segment = self.curr_segment[last_citation_end:] + + if not possible_citation_found: + result += self.curr_segment + self.curr_segment = "" - return stream_processor + if result: + yield DanswerAnswerPiece(answer_piece=result) diff --git a/backend/danswer/llm/answering/stream_processing/citation_response_handler.py b/backend/danswer/llm/answering/stream_processing/citation_response_handler.py new file mode 100644 index 00000000000..07a342fbd7e --- /dev/null +++ b/backend/danswer/llm/answering/stream_processing/citation_response_handler.py @@ -0,0 +1,61 @@ +import abc +from collections.abc import Generator + +from langchain_core.messages import BaseMessage + +from danswer.chat.models import CitationInfo +from danswer.chat.models import LlmDoc +from danswer.llm.answering.llm_response_handler import ResponsePart +from danswer.llm.answering.stream_processing.citation_processing import ( + CitationProcessor, +) +from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping + + +class AnswerResponseHandler(abc.ABC): + @abc.abstractmethod + def handle_response_part( + self, + response_item: BaseMessage | None, + previous_response_items: list[BaseMessage], + ) -> Generator[ResponsePart, None, None]: + raise NotImplementedError + + +class DummyAnswerResponseHandler(AnswerResponseHandler): + def handle_response_part( + self, + response_item: BaseMessage | None, + previous_response_items: list[BaseMessage], + ) -> Generator[ResponsePart, None, None]: + # This is a dummy handler that returns nothing + yield from [] + + +class CitationResponseHandler(AnswerResponseHandler): + def __init__( + self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping + ): + self.context_docs = context_docs + self.doc_id_to_rank_map = doc_id_to_rank_map + self.citation_processor = CitationProcessor( + context_docs=self.context_docs, + doc_id_to_rank_map=self.doc_id_to_rank_map, + ) + self.processed_text = "" + self.citations: list[CitationInfo] = [] + + def handle_response_part( + self, + response_item: BaseMessage | None, + previous_response_items: list[BaseMessage], + ) -> Generator[ResponsePart, None, None]: + if response_item is None: + return + + content = ( + response_item.content if isinstance(response_item.content, str) else "" + ) + + # Process the new content through the citation processor + yield from self.citation_processor.process_token(content) diff --git a/backend/danswer/llm/answering/tool/tool_response_handler.py b/backend/danswer/llm/answering/tool/tool_response_handler.py new file mode 100644 index 00000000000..6c4fec77941 --- /dev/null +++ b/backend/danswer/llm/answering/tool/tool_response_handler.py @@ -0,0 +1,205 @@ +from collections.abc import Generator + +from langchain_core.messages import AIMessageChunk +from langchain_core.messages import BaseMessage +from langchain_core.messages import ToolCall + +from danswer.llm.answering.llm_response_handler import LLMCall +from danswer.llm.answering.llm_response_handler import ResponsePart +from danswer.llm.interfaces import LLM +from danswer.tools.force import ForceUseTool +from danswer.tools.message import build_tool_message +from danswer.tools.message import ToolCallSummary +from danswer.tools.models import ToolCallFinalResult +from danswer.tools.models import ToolCallKickoff +from danswer.tools.models import ToolResponse +from danswer.tools.tool import Tool +from danswer.tools.tool_runner import ( + check_which_tools_should_run_for_non_tool_calling_llm, +) +from danswer.tools.tool_runner import ToolRunner +from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + + +class ToolResponseHandler: + def __init__(self, tools: list[Tool]): + self.tools = tools + + self.tool_call_chunk: AIMessageChunk | None = None + self.tool_call_requests: list[ToolCall] = [] + + self.tool_runner: ToolRunner | None = None + self.tool_call_summary: ToolCallSummary | None = None + + self.tool_kickoff: ToolCallKickoff | None = None + self.tool_responses: list[ToolResponse] = [] + self.tool_final_result: ToolCallFinalResult | None = None + + @classmethod + def get_tool_call_for_non_tool_calling_llm( + cls, llm_call: LLMCall, llm: LLM + ) -> tuple[Tool, dict] | None: + if llm_call.force_use_tool.force_use: + # if we are forcing a tool, we don't need to check which tools to run + tool = next( + ( + t + for t in llm_call.tools + if t.name == llm_call.force_use_tool.tool_name + ), + None, + ) + if not tool: + raise RuntimeError( + f"Tool '{llm_call.force_use_tool.tool_name}' not found" + ) + + tool_args = ( + llm_call.force_use_tool.args + if llm_call.force_use_tool.args is not None + else tool.get_args_for_non_tool_calling_llm( + query=llm_call.prompt_builder.get_user_message_content(), + history=llm_call.prompt_builder.raw_message_history, + llm=llm, + force_run=True, + ) + ) + + if tool_args is None: + raise RuntimeError(f"Tool '{tool.name}' did not return args") + + return (tool, tool_args) + else: + tool_options = check_which_tools_should_run_for_non_tool_calling_llm( + tools=llm_call.tools, + query=llm_call.prompt_builder.get_user_message_content(), + history=llm_call.prompt_builder.raw_message_history, + llm=llm, + ) + + available_tools_and_args = [ + (llm_call.tools[ind], args) + for ind, args in enumerate(tool_options) + if args is not None + ] + + logger.info( + f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}" + ) + + chosen_tool_and_args = ( + select_single_tool_for_non_tool_calling_llm( + tools_and_args=available_tools_and_args, + history=llm_call.prompt_builder.raw_message_history, + query=llm_call.prompt_builder.get_user_message_content(), + llm=llm, + ) + if available_tools_and_args + else None + ) + + logger.notice(f"Chosen tool: {chosen_tool_and_args}") + return chosen_tool_and_args + + def _handle_tool_call(self) -> Generator[ResponsePart, None, None]: + if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls: + return + + self.tool_call_requests = self.tool_call_chunk.tool_calls + + selected_tool: Tool | None = None + selected_tool_call_request: ToolCall | None = None + for tool_call_request in self.tool_call_requests: + known_tools_by_name = [ + tool for tool in self.tools if tool.name == tool_call_request["name"] + ] + + if not known_tools_by_name: + logger.error( + "Tool call requested with unknown name field. \n" + f"self.tools: {self.tools}" + f"tool_call_request: {tool_call_request}" + ) + continue + else: + selected_tool = known_tools_by_name[0] + selected_tool_call_request = tool_call_request + + if selected_tool and selected_tool_call_request: + break + + if not selected_tool or not selected_tool_call_request: + return + + self.tool_runner = ToolRunner(selected_tool, selected_tool_call_request["args"]) + self.tool_call_summary = ToolCallSummary( + tool_call_request=self.tool_call_chunk, + tool_call_result=build_tool_message( + tool_call_request, self.tool_runner.tool_message_content() + ), + ) + + self.tool_kickoff = self.tool_runner.kickoff() + yield self.tool_kickoff + + for response in self.tool_runner.tool_responses(): + self.tool_responses.append(response) + yield response + + self.tool_final_result = self.tool_runner.tool_final_result() + yield self.tool_final_result + + def handle_response_part( + self, + response_item: BaseMessage | None, + previous_response_items: list[BaseMessage], + ) -> Generator[ResponsePart, None, None]: + if response_item is None: + yield from self._handle_tool_call() + + if isinstance(response_item, AIMessageChunk) and ( + response_item.tool_call_chunks or response_item.tool_calls + ): + if self.tool_call_chunk is None: + self.tool_call_chunk = response_item + else: + self.tool_call_chunk += response_item # type: ignore + + return + + def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None: + if ( + self.tool_runner is None + or self.tool_call_summary is None + or self.tool_kickoff is None + or self.tool_final_result is None + ): + return None + + tool_runner = self.tool_runner + new_prompt_builder = tool_runner.tool.build_next_prompt( + prompt_builder=current_llm_call.prompt_builder, + tool_call_summary=self.tool_call_summary, + tool_responses=self.tool_responses, + using_tool_calling_llm=current_llm_call.using_tool_calling_llm, + ) + return LLMCall( + prompt_builder=new_prompt_builder, + tools=[], # for now, only allow one tool call per response + force_use_tool=ForceUseTool( + force_use=False, + tool_name="", + args=None, + ), + files=current_llm_call.files, + using_tool_calling_llm=current_llm_call.using_tool_calling_llm, + tool_call_info=[ + self.tool_kickoff, + *self.tool_responses, + self.tool_final_result, + ], + ) diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index bad18214b95..af480f83955 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -203,6 +203,28 @@ def build_content_with_imgs( ) +def message_to_prompt_and_imgs(message: BaseMessage) -> tuple[str, list[str]]: + if isinstance(message.content, str): + return message.content, [] + + imgs = [] + texts = [] + for part in message.content: + if isinstance(part, dict): + if part.get("type") == "image_url": + img_url = part.get("image_url", {}).get("url") + if img_url: + imgs.append(img_url) + elif part.get("type") == "text": + text = part.get("text") + if text: + texts.append(text) + else: + texts.append(part) + + return "".join(texts), imgs + + def dict_based_prompt_to_langchain_prompt( messages: list[dict[str, str]] ) -> list[BaseMessage]: diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 1bfac570aee..9ece5f4bba2 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -52,12 +52,16 @@ from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.utils import get_json_line from danswer.tools.force import ForceUseTool -from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID -from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID -from danswer.tools.search.search_tool import SearchResponseSummary -from danswer.tools.search.search_tool import SearchTool -from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID -from danswer.tools.tool import ToolResponse +from danswer.tools.models import ToolResponse +from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID +from danswer.tools.tool_implementations.search.search_tool import ( + SEARCH_RESPONSE_SUMMARY_ID, +) +from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary +from danswer.tools.tool_implementations.search.search_tool import SearchTool +from danswer.tools.tool_implementations.search.search_tool import ( + SECTION_RELEVANCE_LIST_ID, +) from danswer.tools.tool_runner import ToolCallKickoff from danswer.utils.logger import setup_logger from danswer.utils.timing import log_generator_function_time @@ -202,30 +206,33 @@ def stream_answer_objects( max_tokens=max_document_tokens, ) + answer_config = AnswerStyleConfig( + citation_config=CitationConfig() if use_citations else None, + quotes_config=QuotesConfig() if not use_citations else None, + document_pruning_config=document_pruning_config, + ) + search_tool = SearchTool( db_session=db_session, user=user, - evaluation_type=LLMEvaluationType.SKIP - if DISABLE_LLM_DOC_RELEVANCE - else query_req.evaluation_type, + evaluation_type=( + LLMEvaluationType.SKIP + if DISABLE_LLM_DOC_RELEVANCE + else query_req.evaluation_type + ), persona=persona, retrieval_options=query_req.retrieval_options, prompt_config=prompt_config, llm=llm, fast_llm=fast_llm, pruning_config=document_pruning_config, + answer_style_config=answer_config, bypass_acl=bypass_acl, chunks_above=query_req.chunks_above, chunks_below=query_req.chunks_below, full_doc=query_req.full_doc, ) - answer_config = AnswerStyleConfig( - citation_config=CitationConfig() if use_citations else None, - quotes_config=QuotesConfig() if not use_citations else None, - document_pruning_config=document_pruning_config, - ) - answer = Answer( question=query_msg.message, answer_style_config=answer_config, diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py index 0e69777a02c..5fa99952b1f 100644 --- a/backend/danswer/server/features/persona/models.py +++ b/backend/danswer/server/features/persona/models.py @@ -9,7 +9,7 @@ from danswer.search.enums import RecencyBiasSetting from danswer.server.features.document_set.models import DocumentSet from danswer.server.features.prompt.models import PromptSnapshot -from danswer.server.features.tool.api import ToolSnapshot +from danswer.server.features.tool.models import ToolSnapshot from danswer.server.models import MinimalUserSnapshot from danswer.utils.logger import setup_logger diff --git a/backend/danswer/server/features/tool/api.py b/backend/danswer/server/features/tool/api.py index 7e15c048826..48f857780ba 100644 --- a/backend/danswer/server/features/tool/api.py +++ b/backend/danswer/server/features/tool/api.py @@ -18,10 +18,16 @@ from danswer.server.features.tool.models import CustomToolCreate from danswer.server.features.tool.models import CustomToolUpdate from danswer.server.features.tool.models import ToolSnapshot -from danswer.tools.custom.openapi_parsing import MethodSpec -from danswer.tools.custom.openapi_parsing import openapi_to_method_specs -from danswer.tools.custom.openapi_parsing import validate_openapi_schema -from danswer.tools.images.image_generation_tool import ImageGenerationTool +from danswer.tools.tool_implementations.custom.openapi_parsing import MethodSpec +from danswer.tools.tool_implementations.custom.openapi_parsing import ( + openapi_to_method_specs, +) +from danswer.tools.tool_implementations.custom.openapi_parsing import ( + validate_openapi_schema, +) +from danswer.tools.tool_implementations.images.image_generation_tool import ( + ImageGenerationTool, +) from danswer.tools.utils import is_image_generation_available router = APIRouter(prefix="/tool") diff --git a/backend/danswer/tools/base_tool.py b/backend/danswer/tools/base_tool.py new file mode 100644 index 00000000000..73902504462 --- /dev/null +++ b/backend/danswer/tools/base_tool.py @@ -0,0 +1,59 @@ +from typing import cast +from typing import TYPE_CHECKING + +from langchain_core.messages import HumanMessage + +from danswer.llm.utils import message_to_prompt_and_imgs +from danswer.tools.tool import Tool + +if TYPE_CHECKING: + from danswer.llm.answering.prompts.build import AnswerPromptBuilder + from danswer.tools.tool_implementations.custom.custom_tool import ( + CustomToolCallSummary, + ) + from danswer.tools.message import ToolCallSummary + from danswer.tools.models import ToolResponse + + +def build_user_message_for_non_tool_calling_llm( + message: HumanMessage, + tool_name: str, + *args: "ToolResponse", +) -> str: + query, _ = message_to_prompt_and_imgs(message) + + tool_run_summary = cast("CustomToolCallSummary", args[0].response).tool_result + return f""" +Here's the result from the {tool_name} tool: + +{tool_run_summary} + +Now respond to the following: + +{query} +""".strip() + + +class BaseTool(Tool): + def build_next_prompt( + self, + prompt_builder: "AnswerPromptBuilder", + tool_call_summary: "ToolCallSummary", + tool_responses: list["ToolResponse"], + using_tool_calling_llm: bool, + ) -> "AnswerPromptBuilder": + if using_tool_calling_llm: + prompt_builder.append_message(tool_call_summary.tool_call_request) + prompt_builder.append_message(tool_call_summary.tool_call_result) + else: + prompt_builder.update_user_prompt( + HumanMessage( + content=build_user_message_for_non_tool_calling_llm( + prompt_builder.user_message_and_token_cnt[0], + self.name, + *tool_responses, + ) + ) + ) + + return prompt_builder diff --git a/backend/danswer/tools/built_in_tools.py b/backend/danswer/tools/built_in_tools.py index 99b2ae3bbb6..fb64381f1d0 100644 --- a/backend/danswer/tools/built_in_tools.py +++ b/backend/danswer/tools/built_in_tools.py @@ -9,9 +9,13 @@ from danswer.db.models import Persona from danswer.db.models import Tool as ToolDBModel -from danswer.tools.images.image_generation_tool import ImageGenerationTool -from danswer.tools.internet_search.internet_search_tool import InternetSearchTool -from danswer.tools.search.search_tool import SearchTool +from danswer.tools.tool_implementations.images.image_generation_tool import ( + ImageGenerationTool, +) +from danswer.tools.tool_implementations.internet_search.internet_search_tool import ( + InternetSearchTool, +) +from danswer.tools.tool_implementations.search.search_tool import SearchTool from danswer.tools.tool import Tool from danswer.utils.logger import setup_logger diff --git a/backend/danswer/tools/custom/custom_tool_prompt_builder.py b/backend/danswer/tools/custom/custom_tool_prompt_builder.py deleted file mode 100644 index 8016363acc9..00000000000 --- a/backend/danswer/tools/custom/custom_tool_prompt_builder.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import cast - -from danswer.tools.custom.custom_tool import CustomToolCallSummary -from danswer.tools.models import ToolResponse - - -def build_user_message_for_custom_tool_for_non_tool_calling_llm( - query: str, - tool_name: str, - *args: ToolResponse, -) -> str: - tool_run_summary = cast(CustomToolCallSummary, args[0].response).tool_result - return f""" -Here's the result from the {tool_name} tool: - -{tool_run_summary} - -Now respond to the following: - -{query} -""".strip() diff --git a/backend/danswer/tools/tool.py b/backend/danswer/tools/tool.py index 29e5311fc15..1b1c43ab8da 100644 --- a/backend/danswer/tools/tool.py +++ b/backend/danswer/tools/tool.py @@ -1,11 +1,17 @@ import abc from collections.abc import Generator from typing import Any +from typing import TYPE_CHECKING from danswer.key_value_store.interface import JSON_ro from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM -from danswer.tools.models import ToolResponse + + +if TYPE_CHECKING: + from danswer.llm.answering.prompts.build import AnswerPromptBuilder + from danswer.tools.message import ToolCallSummary + from danswer.tools.models import ToolResponse class Tool(abc.ABC): @@ -32,7 +38,7 @@ def tool_definition(self) -> dict: @abc.abstractmethod def build_tool_message_content( - self, *args: ToolResponse + self, *args: "ToolResponse" ) -> str | list[str | dict[str, Any]]: raise NotImplementedError @@ -51,13 +57,26 @@ def get_args_for_non_tool_calling_llm( """Actual execution of the tool""" @abc.abstractmethod - def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: + def run(self, **kwargs: Any) -> Generator["ToolResponse", None, None]: raise NotImplementedError @abc.abstractmethod - def final_result(self, *args: ToolResponse) -> JSON_ro: + def final_result(self, *args: "ToolResponse") -> JSON_ro: """ This is the "final summary" result of the tool. It is the result that will be stored in the database. """ raise NotImplementedError + + """Some tools may want to modify the prompt based on the tool call summary and tool responses. + Default behavior is to continue with just the raw tool call request/result passed to the LLM.""" + + @abc.abstractmethod + def build_next_prompt( + self, + prompt_builder: "AnswerPromptBuilder", + tool_call_summary: "ToolCallSummary", + tool_responses: list["ToolResponse"], + using_tool_calling_llm: bool, + ) -> "AnswerPromptBuilder": + raise NotImplementedError diff --git a/backend/danswer/tools/custom/base_tool_types.py b/backend/danswer/tools/tool_implementations/custom/base_tool_types.py similarity index 100% rename from backend/danswer/tools/custom/base_tool_types.py rename to backend/danswer/tools/tool_implementations/custom/base_tool_types.py diff --git a/backend/danswer/tools/custom/custom_tool.py b/backend/danswer/tools/tool_implementations/custom/custom_tool.py similarity index 88% rename from backend/danswer/tools/custom/custom_tool.py rename to backend/danswer/tools/tool_implementations/custom/custom_tool.py index ee431af70e1..a1fb4bb699e 100644 --- a/backend/danswer/tools/custom/custom_tool.py +++ b/backend/danswer/tools/tool_implementations/custom/custom_tool.py @@ -11,24 +11,34 @@ from danswer.key_value_store.interface import JSON_ro from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM -from danswer.tools.custom.base_tool_types import ToolResultType -from danswer.tools.custom.custom_tool_prompts import ( - SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT, -) -from danswer.tools.custom.custom_tool_prompts import SHOULD_USE_CUSTOM_TOOL_USER_PROMPT -from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_SYSTEM_PROMPT -from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_USER_PROMPT -from danswer.tools.custom.custom_tool_prompts import USE_TOOL -from danswer.tools.custom.openapi_parsing import MethodSpec -from danswer.tools.custom.openapi_parsing import openapi_to_method_specs -from danswer.tools.custom.openapi_parsing import openapi_to_url -from danswer.tools.custom.openapi_parsing import REQUEST_BODY -from danswer.tools.custom.openapi_parsing import validate_openapi_schema +from danswer.tools.base_tool import BaseTool from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER from danswer.tools.models import DynamicSchemaInfo from danswer.tools.models import MESSAGE_ID_PLACEHOLDER -from danswer.tools.tool import Tool -from danswer.tools.tool import ToolResponse +from danswer.tools.models import ToolResponse +from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType +from danswer.tools.tool_implementations.custom.custom_tool_prompts import ( + SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT, +) +from danswer.tools.tool_implementations.custom.custom_tool_prompts import ( + SHOULD_USE_CUSTOM_TOOL_USER_PROMPT, +) +from danswer.tools.tool_implementations.custom.custom_tool_prompts import ( + TOOL_ARG_SYSTEM_PROMPT, +) +from danswer.tools.tool_implementations.custom.custom_tool_prompts import ( + TOOL_ARG_USER_PROMPT, +) +from danswer.tools.tool_implementations.custom.custom_tool_prompts import USE_TOOL +from danswer.tools.tool_implementations.custom.openapi_parsing import MethodSpec +from danswer.tools.tool_implementations.custom.openapi_parsing import ( + openapi_to_method_specs, +) +from danswer.tools.tool_implementations.custom.openapi_parsing import openapi_to_url +from danswer.tools.tool_implementations.custom.openapi_parsing import REQUEST_BODY +from danswer.tools.tool_implementations.custom.openapi_parsing import ( + validate_openapi_schema, +) from danswer.utils.headers import header_list_to_header_dict from danswer.utils.headers import HeaderItemDict from danswer.utils.logger import setup_logger @@ -43,7 +53,7 @@ class CustomToolCallSummary(BaseModel): tool_result: ToolResultType -class CustomTool(Tool): +class CustomTool(BaseTool): def __init__( self, method_spec: MethodSpec, diff --git a/backend/danswer/tools/custom/custom_tool_prompts.py b/backend/danswer/tools/tool_implementations/custom/custom_tool_prompts.py similarity index 100% rename from backend/danswer/tools/custom/custom_tool_prompts.py rename to backend/danswer/tools/tool_implementations/custom/custom_tool_prompts.py diff --git a/backend/danswer/tools/custom/openapi_parsing.py b/backend/danswer/tools/tool_implementations/custom/openapi_parsing.py similarity index 100% rename from backend/danswer/tools/custom/openapi_parsing.py rename to backend/danswer/tools/tool_implementations/custom/openapi_parsing.py diff --git a/backend/danswer/tools/images/image_generation_tool.py b/backend/danswer/tools/tool_implementations/images/image_generation_tool.py similarity index 86% rename from backend/danswer/tools/images/image_generation_tool.py rename to backend/danswer/tools/tool_implementations/images/image_generation_tool.py index 3584d50f77e..6fb06fb534a 100644 --- a/backend/danswer/tools/images/image_generation_tool.py +++ b/backend/danswer/tools/tool_implementations/images/image_generation_tool.py @@ -11,12 +11,17 @@ from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.key_value_store.interface import JSON_ro from danswer.llm.answering.models import PreviousMessage +from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.interfaces import LLM from danswer.llm.utils import build_content_with_imgs from danswer.llm.utils import message_to_string from danswer.prompts.constants import GENERAL_SEP_PAT +from danswer.tools.message import ToolCallSummary +from danswer.tools.models import ToolResponse from danswer.tools.tool import Tool -from danswer.tools.tool import ToolResponse +from danswer.tools.tool_implementations.images.prompt import ( + build_image_generation_user_prompt, +) from danswer.utils.headers import build_llm_extra_headers from danswer.utils.logger import setup_logger from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel @@ -258,3 +263,34 @@ def final_result(self, *args: ToolResponse) -> JSON_ro: image_generation_response.model_dump() for image_generation_response in image_generation_responses ] + + def build_next_prompt( + self, + prompt_builder: AnswerPromptBuilder, + tool_call_summary: ToolCallSummary, + tool_responses: list[ToolResponse], + using_tool_calling_llm: bool, + ) -> AnswerPromptBuilder: + img_generation_response = cast( + list[ImageGenerationResponse] | None, + next( + ( + response.response + for response in tool_responses + if response.id == IMAGE_GENERATION_RESPONSE_ID + ), + None, + ), + ) + if img_generation_response is None: + raise ValueError("No image generation response found") + + img_urls = [img.url for img in img_generation_response] + prompt_builder.update_user_prompt( + build_image_generation_user_prompt( + query=prompt_builder.get_user_message_content(), + img_urls=img_urls, + ) + ) + + return prompt_builder diff --git a/backend/danswer/tools/images/prompt.py b/backend/danswer/tools/tool_implementations/images/prompt.py similarity index 100% rename from backend/danswer/tools/images/prompt.py rename to backend/danswer/tools/tool_implementations/images/prompt.py diff --git a/backend/danswer/tools/internet_search/internet_search_tool.py b/backend/danswer/tools/tool_implementations/internet_search/internet_search_tool.py similarity index 81% rename from backend/danswer/tools/internet_search/internet_search_tool.py rename to backend/danswer/tools/tool_implementations/internet_search/internet_search_tool.py index 70b4483b996..12142bc4852 100644 --- a/backend/danswer/tools/internet_search/internet_search_tool.py +++ b/backend/danswer/tools/tool_implementations/internet_search/internet_search_tool.py @@ -11,18 +11,31 @@ from danswer.configs.constants import DocumentSource from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.key_value_store.interface import JSON_ro +from danswer.llm.answering.models import AnswerStyleConfig from danswer.llm.answering.models import PreviousMessage +from danswer.llm.answering.models import PromptConfig +from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.interfaces import LLM from danswer.llm.utils import message_to_string from danswer.prompts.chat_prompts import INTERNET_SEARCH_QUERY_REPHRASE from danswer.prompts.constants import GENERAL_SEP_PAT from danswer.search.models import SearchDoc from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase -from danswer.tools.internet_search.models import InternetSearchResponse -from danswer.tools.internet_search.models import InternetSearchResult -from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID +from danswer.tools.message import ToolCallSummary +from danswer.tools.models import ToolResponse from danswer.tools.tool import Tool -from danswer.tools.tool import ToolResponse +from danswer.tools.tool_implementations.internet_search.models import ( + InternetSearchResponse, +) +from danswer.tools.tool_implementations.internet_search.models import ( + InternetSearchResult, +) +from danswer.tools.tool_implementations.search_like_tool_utils import ( + build_next_prompt_for_search_like_tool, +) +from danswer.tools.tool_implementations.search_like_tool_utils import ( + FINAL_CONTEXT_DOCUMENTS_ID, +) from danswer.utils.logger import setup_logger logger = setup_logger() @@ -97,8 +110,17 @@ class InternetSearchTool(Tool): _DISPLAY_NAME = "[Beta] Internet Search Tool" _DESCRIPTION = "Perform an internet search for up-to-date information." - def __init__(self, api_key: str, num_results: int = 10) -> None: + def __init__( + self, + api_key: str, + answer_style_config: AnswerStyleConfig, + prompt_config: PromptConfig, + num_results: int = 10, + ) -> None: self.api_key = api_key + self.answer_style_config = answer_style_config + self.prompt_config = prompt_config + self.host = "https://api.bing.microsoft.com/v7.0" self.headers = { "Ocp-Apim-Subscription-Key": api_key, @@ -231,3 +253,19 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: def final_result(self, *args: ToolResponse) -> JSON_ro: search_response = cast(InternetSearchResponse, args[0].response) return search_response.model_dump() + + def build_next_prompt( + self, + prompt_builder: AnswerPromptBuilder, + tool_call_summary: ToolCallSummary, + tool_responses: list[ToolResponse], + using_tool_calling_llm: bool, + ) -> AnswerPromptBuilder: + return build_next_prompt_for_search_like_tool( + prompt_builder=prompt_builder, + tool_call_summary=tool_call_summary, + tool_responses=tool_responses, + using_tool_calling_llm=using_tool_calling_llm, + answer_style_config=self.answer_style_config, + prompt_config=self.prompt_config, + ) diff --git a/backend/danswer/tools/internet_search/models.py b/backend/danswer/tools/tool_implementations/internet_search/models.py similarity index 100% rename from backend/danswer/tools/internet_search/models.py rename to backend/danswer/tools/tool_implementations/internet_search/models.py diff --git a/backend/danswer/tools/search/search_tool.py b/backend/danswer/tools/tool_implementations/search/search_tool.py similarity index 87% rename from backend/danswer/tools/search/search_tool.py rename to backend/danswer/tools/tool_implementations/search/search_tool.py index 96ab7b843f6..6eda3013ab3 100644 --- a/backend/danswer/tools/search/search_tool.py +++ b/backend/danswer/tools/tool_implementations/search/search_tool.py @@ -17,10 +17,13 @@ from danswer.db.models import Persona from danswer.db.models import User from danswer.key_value_store.interface import JSON_ro +from danswer.llm.answering.llm_response_handler import LLMCall +from danswer.llm.answering.models import AnswerStyleConfig from danswer.llm.answering.models import ContextualPruningConfig from danswer.llm.answering.models import DocumentPruningConfig from danswer.llm.answering.models import PreviousMessage from danswer.llm.answering.models import PromptConfig +from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens from danswer.llm.answering.prune_and_merge import prune_and_merge_sections from danswer.llm.answering.prune_and_merge import prune_sections @@ -35,9 +38,16 @@ from danswer.search.pipeline import SearchPipeline from danswer.secondary_llm_flows.choose_search import check_if_need_search from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase -from danswer.tools.search.search_utils import llm_doc_to_dict +from danswer.tools.message import ToolCallSummary +from danswer.tools.models import ToolResponse from danswer.tools.tool import Tool -from danswer.tools.tool import ToolResponse +from danswer.tools.tool_implementations.search.search_utils import llm_doc_to_dict +from danswer.tools.tool_implementations.search_like_tool_utils import ( + build_next_prompt_for_search_like_tool, +) +from danswer.tools.tool_implementations.search_like_tool_utils import ( + FINAL_CONTEXT_DOCUMENTS_ID, +) from danswer.utils.logger import setup_logger logger = setup_logger() @@ -45,7 +55,6 @@ SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary" SEARCH_DOC_CONTENT_ID = "search_doc_content" SECTION_RELEVANCE_LIST_ID = "section_relevance_list" -FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents" SEARCH_EVALUATION_ID = "llm_doc_eval" @@ -85,6 +94,7 @@ def __init__( llm: LLM, fast_llm: LLM, pruning_config: DocumentPruningConfig, + answer_style_config: AnswerStyleConfig, evaluation_type: LLMEvaluationType, # if specified, will not actually run a search and will instead return these # sections. Used when the user selects specific docs to talk to @@ -136,6 +146,7 @@ def __init__( num_chunk_multiple = self.chunks_above + self.chunks_below + 1 + self.answer_style_config = answer_style_config self.contextual_pruning_config = ( ContextualPruningConfig.from_doc_pruning_config( num_chunk_multiple=num_chunk_multiple, doc_pruning_config=pruning_config @@ -353,4 +364,36 @@ def final_result(self, *args: ToolResponse) -> JSON_ro: # NOTE: need to do this json.loads(doc.json()) stuff because there are some # subfields that are not serializable by default (datetime) # this forces pydantic to make them JSON serializable for us - return [json.loads(doc.json()) for doc in final_docs] + return [json.loads(doc.model_dump_json()) for doc in final_docs] + + def build_next_prompt( + self, + prompt_builder: AnswerPromptBuilder, + tool_call_summary: ToolCallSummary, + tool_responses: list[ToolResponse], + using_tool_calling_llm: bool, + ) -> AnswerPromptBuilder: + return build_next_prompt_for_search_like_tool( + prompt_builder=prompt_builder, + tool_call_summary=tool_call_summary, + tool_responses=tool_responses, + using_tool_calling_llm=using_tool_calling_llm, + answer_style_config=self.answer_style_config, + prompt_config=self.prompt_config, + ) + + """Other utility functions""" + + @classmethod + def get_search_result(cls, llm_call: LLMCall) -> list[LlmDoc] | None: + if not llm_call.tool_call_info: + return None + + for yield_item in llm_call.tool_call_info: + if ( + isinstance(yield_item, ToolResponse) + and yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID + ): + return cast(list[LlmDoc], yield_item.response) + + return None diff --git a/backend/danswer/tools/search/search_utils.py b/backend/danswer/tools/tool_implementations/search/search_utils.py similarity index 100% rename from backend/danswer/tools/search/search_utils.py rename to backend/danswer/tools/tool_implementations/search/search_utils.py diff --git a/backend/danswer/tools/tool_implementations/search_like_tool_utils.py b/backend/danswer/tools/tool_implementations/search_like_tool_utils.py new file mode 100644 index 00000000000..6701f1602ea --- /dev/null +++ b/backend/danswer/tools/tool_implementations/search_like_tool_utils.py @@ -0,0 +1,71 @@ +from typing import cast + +from danswer.chat.models import LlmDoc +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import PromptConfig +from danswer.llm.answering.prompts.build import AnswerPromptBuilder +from danswer.llm.answering.prompts.citations_prompt import ( + build_citations_system_message, +) +from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message +from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message +from danswer.tools.message import ToolCallSummary +from danswer.tools.models import ToolResponse + + +FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents" + + +def build_next_prompt_for_search_like_tool( + prompt_builder: AnswerPromptBuilder, + tool_call_summary: ToolCallSummary, + tool_responses: list[ToolResponse], + using_tool_calling_llm: bool, + answer_style_config: AnswerStyleConfig, + prompt_config: PromptConfig, +) -> AnswerPromptBuilder: + if not using_tool_calling_llm: + final_context_docs_response = next( + response + for response in tool_responses + if response.id == FINAL_CONTEXT_DOCUMENTS_ID + ) + final_context_documents = cast( + list[LlmDoc], final_context_docs_response.response + ) + else: + # if using tool calling llm, then the final context documents are the tool responses + final_context_documents = [] + + if answer_style_config.citation_config: + prompt_builder.update_system_prompt( + build_citations_system_message(prompt_config) + ) + prompt_builder.update_user_prompt( + build_citations_user_message( + message=prompt_builder.user_message_and_token_cnt[0], + prompt_config=prompt_config, + context_docs=final_context_documents, + all_doc_useful=( + answer_style_config.citation_config.all_docs_useful + if answer_style_config.citation_config + else False + ), + history_message=prompt_builder.single_message_history or "", + ) + ) + elif answer_style_config.quotes_config: + prompt_builder.update_user_prompt( + build_quotes_user_message( + message=prompt_builder.user_message_and_token_cnt[0], + context_docs=final_context_documents, + history_str=prompt_builder.single_message_history or "", + prompt=prompt_config, + ) + ) + + if using_tool_calling_llm: + prompt_builder.append_message(tool_call_summary.tool_call_request) + prompt_builder.append_message(tool_call_summary.tool_call_result) + + return prompt_builder diff --git a/backend/danswer/tools/tool_runner.py b/backend/danswer/tools/tool_runner.py index 58b94bdb0c8..fb3eb8b9932 100644 --- a/backend/danswer/tools/tool_runner.py +++ b/backend/danswer/tools/tool_runner.py @@ -6,8 +6,8 @@ from danswer.llm.interfaces import LLM from danswer.tools.models import ToolCallFinalResult from danswer.tools.models import ToolCallKickoff +from danswer.tools.models import ToolResponse from danswer.tools.tool import Tool -from danswer.tools.tool import ToolResponse from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel diff --git a/backend/ee/danswer/server/query_and_chat/utils.py b/backend/ee/danswer/server/query_and_chat/utils.py index a2f7253517a..be5507b01c2 100644 --- a/backend/ee/danswer/server/query_and_chat/utils.py +++ b/backend/ee/danswer/server/query_and_chat/utils.py @@ -12,7 +12,7 @@ from danswer.db.models import User from danswer.db.persona import get_prompts_by_ids from danswer.one_shot_answer.models import PersonaConfig -from danswer.tools.custom.custom_tool import ( +from danswer.tools.tool_implementations.custom.custom_tool import ( build_custom_tools_from_openapi_schema_and_headers, ) diff --git a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py index 10d1950ae03..0ed40c758d0 100644 --- a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py +++ b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py @@ -142,6 +142,9 @@ def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) -> assert response.status_code == 200 response_json = response.json() + # make sure there is an answer + assert response_json["answer"] + # since we only gave it one search doc, all responses should only contain that doc assert response_json["final_context_doc_indices"] == [0] assert response_json["llm_selected_doc_indices"] == [0] diff --git a/backend/tests/unit/danswer/llm/answering/conftest.py b/backend/tests/unit/danswer/llm/answering/conftest.py new file mode 100644 index 00000000000..a0077b53917 --- /dev/null +++ b/backend/tests/unit/danswer/llm/answering/conftest.py @@ -0,0 +1,113 @@ +import json +from datetime import datetime +from unittest.mock import MagicMock + +import pytest +from langchain_core.messages import SystemMessage + +from danswer.chat.models import LlmDoc +from danswer.configs.constants import DocumentSource +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import CitationConfig +from danswer.llm.answering.models import PromptConfig +from danswer.llm.answering.prompts.build import AnswerPromptBuilder +from danswer.llm.interfaces import LLMConfig +from danswer.tools.models import ToolResponse +from danswer.tools.tool_implementations.search.search_tool import SearchTool +from danswer.tools.tool_implementations.search_like_tool_utils import ( + FINAL_CONTEXT_DOCUMENTS_ID, +) + +QUERY = "Test question" +DEFAULT_SEARCH_ARGS = {"query": "search"} + + +@pytest.fixture +def answer_style_config() -> AnswerStyleConfig: + return AnswerStyleConfig(citation_config=CitationConfig()) + + +@pytest.fixture +def prompt_config() -> PromptConfig: + return PromptConfig( + system_prompt="System prompt", + task_prompt="Task prompt", + datetime_aware=False, + include_citations=True, + ) + + +@pytest.fixture +def mock_llm() -> MagicMock: + mock_llm_obj = MagicMock() + mock_llm_obj.config = LLMConfig( + model_provider="openai", + model_name="gpt-4o", + temperature=0.0, + api_key=None, + api_base=None, + api_version=None, + ) + return mock_llm_obj + + +@pytest.fixture +def mock_search_results() -> list[LlmDoc]: + return [ + LlmDoc( + content="Search result 1", + source_type=DocumentSource.WEB, + metadata={"id": "doc1"}, + document_id="doc1", + blurb="Blurb 1", + semantic_identifier="Semantic ID 1", + updated_at=datetime(2023, 1, 1), + link="https://example.com/doc1", + source_links={0: "https://example.com/doc1"}, + ), + LlmDoc( + content="Search result 2", + source_type=DocumentSource.WEB, + metadata={"id": "doc2"}, + document_id="doc2", + blurb="Blurb 2", + semantic_identifier="Semantic ID 2", + updated_at=datetime(2023, 1, 2), + link="https://example.com/doc2", + source_links={0: "https://example.com/doc2"}, + ), + ] + + +@pytest.fixture +def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock: + mock_tool = MagicMock(spec=SearchTool) + mock_tool.name = "search" + mock_tool.build_tool_message_content.return_value = "search_response" + mock_tool.get_args_for_non_tool_calling_llm.return_value = DEFAULT_SEARCH_ARGS + mock_tool.final_result.return_value = [ + json.loads(doc.model_dump_json()) for doc in mock_search_results + ] + mock_tool.run.return_value = [ + ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results) + ] + mock_tool.tool_definition.return_value = { + "type": "function", + "function": { + "name": "search", + "description": "Search for information", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "The search query"}, + }, + "required": ["query"], + }, + }, + } + mock_post_search_tool_prompt_builder = MagicMock(spec=AnswerPromptBuilder) + mock_post_search_tool_prompt_builder.build.return_value = [ + SystemMessage(content="Updated system prompt"), + ] + mock_tool.build_next_prompt.return_value = mock_post_search_tool_prompt_builder + return mock_tool diff --git a/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py b/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py index 12e3254d6d6..e6a5fe1f027 100644 --- a/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py +++ b/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py @@ -7,7 +7,7 @@ from danswer.chat.models import LlmDoc from danswer.configs.constants import DocumentSource from danswer.llm.answering.stream_processing.citation_processing import ( - extract_citations_from_stream, + CitationProcessor, ) from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping @@ -70,14 +70,16 @@ def process_text( ) -> tuple[str, list[CitationInfo]]: mock_docs, mock_doc_id_to_rank_map = mock_data mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map) - result = list( - extract_citations_from_stream( - tokens=iter(tokens), - context_docs=mock_docs, - doc_id_to_rank_map=mapping, - stop_stream=None, - ) + processor = CitationProcessor( + context_docs=mock_docs, + doc_id_to_rank_map=mapping, + stop_stream=None, ) + result: list[DanswerAnswerPiece | CitationInfo] = [] + for token in tokens: + result.extend(processor.process_token(token)) + result.extend(processor.process_token(None)) + final_answer_text = "" citations = [] for piece in result: diff --git a/backend/tests/unit/danswer/llm/answering/test_answer.py b/backend/tests/unit/danswer/llm/answering/test_answer.py new file mode 100644 index 00000000000..f38f705441c --- /dev/null +++ b/backend/tests/unit/danswer/llm/answering/test_answer.py @@ -0,0 +1,421 @@ +import json +from typing import cast +from unittest.mock import MagicMock +from unittest.mock import Mock + +import pytest +from langchain_core.messages import AIMessageChunk +from langchain_core.messages import BaseMessage +from langchain_core.messages import HumanMessage +from langchain_core.messages import SystemMessage +from langchain_core.messages import ToolCall +from langchain_core.messages import ToolCallChunk + +from danswer.chat.models import CitationInfo +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import LlmDoc +from danswer.chat.models import StreamStopInfo +from danswer.chat.models import StreamStopReason +from danswer.llm.answering.answer import Answer +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import CitationConfig +from danswer.llm.answering.models import PromptConfig +from danswer.llm.interfaces import LLM +from danswer.tools.force import ForceUseTool +from danswer.tools.models import ToolCallFinalResult +from danswer.tools.models import ToolCallKickoff +from danswer.tools.models import ToolResponse +from tests.unit.danswer.llm.answering.conftest import DEFAULT_SEARCH_ARGS +from tests.unit.danswer.llm.answering.conftest import QUERY + + +@pytest.fixture +def answer_instance( + mock_llm: LLM, answer_style_config: AnswerStyleConfig, prompt_config: PromptConfig +) -> Answer: + return Answer( + question=QUERY, + answer_style_config=answer_style_config, + llm=mock_llm, + prompt_config=prompt_config, + force_use_tool=ForceUseTool(force_use=False, tool_name="", args=None), + ) + + +def test_basic_answer(answer_instance: Answer) -> None: + mock_llm = cast(Mock, answer_instance.llm) + mock_llm.stream.return_value = [ + AIMessageChunk(content="This is a "), + AIMessageChunk(content="mock answer."), + ] + + output = list(answer_instance.processed_streamed_output) + assert len(output) == 2 + assert isinstance(output[0], DanswerAnswerPiece) + assert isinstance(output[1], DanswerAnswerPiece) + + full_answer = "".join( + piece.answer_piece + for piece in output + if isinstance(piece, DanswerAnswerPiece) and piece.answer_piece is not None + ) + assert full_answer == "This is a mock answer." + + assert answer_instance.llm_answer == "This is a mock answer." + assert answer_instance.citations == [] + + assert mock_llm.stream.call_count == 1 + mock_llm.stream.assert_called_once_with( + prompt=[ + SystemMessage(content="System prompt"), + HumanMessage(content="Task prompt\n\nQUERY:\nTest question"), + ], + tools=None, + tool_choice=None, + ) + + +@pytest.mark.parametrize( + "force_use_tool, expected_tool_args", + [ + ( + ForceUseTool(force_use=False, tool_name="", args=None), + DEFAULT_SEARCH_ARGS, + ), + ( + ForceUseTool( + force_use=True, tool_name="search", args={"query": "forced search"} + ), + {"query": "forced search"}, + ), + ], +) +def test_answer_with_search_call( + answer_instance: Answer, + mock_search_results: list[LlmDoc], + mock_search_tool: MagicMock, + force_use_tool: ForceUseTool, + expected_tool_args: dict, +) -> None: + answer_instance.tools = [mock_search_tool] + answer_instance.force_use_tool = force_use_tool + + # Set up the LLM mock to return search results and then an answer + mock_llm = cast(Mock, answer_instance.llm) + + stream_side_effect: list[list[BaseMessage]] = [] + + if not force_use_tool.force_use: + tool_call_chunk = AIMessageChunk(content="") + tool_call_chunk.tool_calls = [ + ToolCall( + id="search", + name="search", + args=expected_tool_args, + ) + ] + tool_call_chunk.tool_call_chunks = [ + ToolCallChunk( + id="search", + name="search", + args=json.dumps(expected_tool_args), + index=0, + ) + ] + stream_side_effect.append([tool_call_chunk]) + + stream_side_effect.append( + [ + AIMessageChunk(content="Based on the search results, "), + AIMessageChunk(content="the answer is abc[1]. "), + AIMessageChunk(content="This is some other stuff."), + ], + ) + mock_llm.stream.side_effect = stream_side_effect + + # Process the output + output = list(answer_instance.processed_streamed_output) + print(output) + + # Updated assertions + assert len(output) == 7 + assert output[0] == ToolCallKickoff( + tool_name="search", tool_args=expected_tool_args + ) + assert output[1] == ToolResponse( + id="final_context_documents", + response=mock_search_results, + ) + assert output[2] == ToolCallFinalResult( + tool_name="search", + tool_args=expected_tool_args, + tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results], + ) + assert output[3] == DanswerAnswerPiece(answer_piece="Based on the search results, ") + expected_citation = CitationInfo(citation_num=1, document_id="doc1") + assert output[4] == expected_citation + assert output[5] == DanswerAnswerPiece( + answer_piece="the answer is abc[[1]](https://example.com/doc1). " + ) + assert output[6] == DanswerAnswerPiece(answer_piece="This is some other stuff.") + + expected_answer = ( + "Based on the search results, " + "the answer is abc[[1]](https://example.com/doc1). " + "This is some other stuff." + ) + full_answer = "".join( + piece.answer_piece + for piece in output + if isinstance(piece, DanswerAnswerPiece) and piece.answer_piece is not None + ) + assert full_answer == expected_answer + + assert answer_instance.llm_answer == expected_answer + assert len(answer_instance.citations) == 1 + assert answer_instance.citations[0] == expected_citation + + # Verify LLM calls + if not force_use_tool.force_use: + assert mock_llm.stream.call_count == 2 + first_call, second_call = mock_llm.stream.call_args_list + + # First call should include the search tool definition + assert len(first_call.kwargs["tools"]) == 1 + assert ( + first_call.kwargs["tools"][0] + == mock_search_tool.tool_definition.return_value + ) + + # Second call should not include tools (as we're just generating the final answer) + assert "tools" not in second_call.kwargs or not second_call.kwargs["tools"] + # Second call should use the returned prompt from build_next_prompt + assert ( + second_call.kwargs["prompt"] + == mock_search_tool.build_next_prompt.return_value.build.return_value + ) + + # Verify that tool_definition was called on the mock_search_tool + mock_search_tool.tool_definition.assert_called_once() + else: + assert mock_llm.stream.call_count == 1 + + call = mock_llm.stream.call_args_list[0] + assert ( + call.kwargs["prompt"] + == mock_search_tool.build_next_prompt.return_value.build.return_value + ) + + +def test_answer_with_search_no_tool_calling( + answer_instance: Answer, + mock_search_results: list[LlmDoc], + mock_search_tool: MagicMock, +) -> None: + answer_instance.tools = [mock_search_tool] + + # Set up the LLM mock to return an answer + mock_llm = cast(Mock, answer_instance.llm) + mock_llm.stream.return_value = [ + AIMessageChunk(content="Based on the search results, "), + AIMessageChunk(content="the answer is abc[1]. "), + AIMessageChunk(content="This is some other stuff."), + ] + + # Force non-tool calling behavior + answer_instance.using_tool_calling_llm = False + + # Process the output + output = list(answer_instance.processed_streamed_output) + + # Assertions + assert len(output) == 7 + assert output[0] == ToolCallKickoff( + tool_name="search", tool_args=DEFAULT_SEARCH_ARGS + ) + assert output[1] == ToolResponse( + id="final_context_documents", + response=mock_search_results, + ) + assert output[2] == ToolCallFinalResult( + tool_name="search", + tool_args=DEFAULT_SEARCH_ARGS, + tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results], + ) + assert output[3] == DanswerAnswerPiece(answer_piece="Based on the search results, ") + expected_citation = CitationInfo(citation_num=1, document_id="doc1") + assert output[4] == expected_citation + assert output[5] == DanswerAnswerPiece( + answer_piece="the answer is abc[[1]](https://example.com/doc1). " + ) + assert output[6] == DanswerAnswerPiece(answer_piece="This is some other stuff.") + + expected_answer = ( + "Based on the search results, " + "the answer is abc[[1]](https://example.com/doc1). " + "This is some other stuff." + ) + assert answer_instance.llm_answer == expected_answer + assert len(answer_instance.citations) == 1 + assert answer_instance.citations[0] == expected_citation + + # Verify LLM calls + assert mock_llm.stream.call_count == 1 + call_args = mock_llm.stream.call_args + + # Verify that no tools were passed to the LLM + assert "tools" not in call_args.kwargs or not call_args.kwargs["tools"] + + # Verify that the prompt was built correctly + assert ( + call_args.kwargs["prompt"] + == mock_search_tool.build_next_prompt.return_value.build.return_value + ) + + # Verify that get_args_for_non_tool_calling_llm was called on the mock_search_tool + mock_search_tool.get_args_for_non_tool_calling_llm.assert_called_once_with( + f"Task prompt\n\nQUERY:\n{QUERY}", [], answer_instance.llm + ) + + # Verify that the search tool's run method was called + mock_search_tool.run.assert_called_once() + + +def test_answer_with_search_call_quotes_enabled( + answer_instance: Answer, + mock_search_results: list[LlmDoc], + mock_search_tool: MagicMock, +) -> None: + answer_instance.tools = [mock_search_tool] + answer_instance.force_use_tool = ForceUseTool( + force_use=False, tool_name="", args=None + ) + answer_instance.answer_style_config.citation_config = CitationConfig( + use_quotes=True + ) + + # Set up the LLM mock to return search results and then an answer + mock_llm = cast(Mock, answer_instance.llm) + + tool_call_chunk = AIMessageChunk(content="") + tool_call_chunk.tool_calls = [ + ToolCall( + id="search", + name="search", + args=DEFAULT_SEARCH_ARGS, + ) + ] + tool_call_chunk.tool_call_chunks = [ + ToolCallChunk( + id="search", + name="search", + args=json.dumps(DEFAULT_SEARCH_ARGS), + index=0, + ) + ] + + mock_llm.stream.side_effect = [ + [tool_call_chunk], + [ + AIMessageChunk(content="Answer"), + ], + ] + + # Process the output + output = list(answer_instance.processed_streamed_output) + + # Assertions + assert len(output) == 7 + assert output[0] == ToolCallKickoff( + tool_name="search", tool_args=DEFAULT_SEARCH_ARGS + ) + assert output[1] == ToolResponse( + id="final_context_documents", + response=mock_search_results, + ) + assert output[2] == ToolCallFinalResult( + tool_name="search", + tool_args=DEFAULT_SEARCH_ARGS, + tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results], + ) + assert output[3] == DanswerAnswerPiece(answer_piece="Based on the search results, ") + expected_citation = CitationInfo(citation_num=1, document_id="doc1") + assert output[4] == expected_citation + assert output[5] == DanswerAnswerPiece( + answer_piece='the answer is "abc"[[1]](https://example.com/doc1). ' + ) + assert output[6] == DanswerAnswerPiece(answer_piece="This is some other stuff.") + + expected_answer = ( + "Based on the search results, " + 'the answer is "abc"[[1]](https://example.com/doc1). ' + "This is some other stuff." + ) + full_answer = "".join( + piece.answer_piece + for piece in output + if isinstance(piece, DanswerAnswerPiece) and piece.answer_piece is not None + ) + assert full_answer == expected_answer + + assert answer_instance.llm_answer == expected_answer + assert len(answer_instance.citations) == 1 + assert answer_instance.citations[0] == expected_citation + + # Verify LLM calls + assert mock_llm.stream.call_count == 2 + first_call, second_call = mock_llm.stream.call_args_list + + # First call should include the search tool definition + assert len(first_call.kwargs["tools"]) == 1 + assert ( + first_call.kwargs["tools"][0] == mock_search_tool.tool_definition.return_value + ) + + # Second call should not include tools (as we're just generating the final answer) + assert "tools" not in second_call.kwargs or not second_call.kwargs["tools"] + # Second call should use the returned prompt from build_next_prompt + assert ( + second_call.kwargs["prompt"] + == mock_search_tool.build_next_prompt.return_value.build.return_value + ) + + # Verify that tool_definition was called on the mock_search_tool + mock_search_tool.tool_definition.assert_called_once() + + +def test_is_cancelled(answer_instance: Answer) -> None: + # Set up the LLM mock to return multiple chunks + mock_llm = Mock() + answer_instance.llm = mock_llm + mock_llm.stream.return_value = [ + AIMessageChunk(content="This is the "), + AIMessageChunk(content="first part."), + AIMessageChunk(content="This should not be seen."), + ] + + # Create a mutable object to control is_connected behavior + connection_status = {"connected": True} + answer_instance.is_connected = lambda: connection_status["connected"] + + # Process the output + output = [] + for i, chunk in enumerate(answer_instance.processed_streamed_output): + output.append(chunk) + # Simulate disconnection after the second chunk + if i == 1: + connection_status["connected"] = False + + assert len(output) == 3 + assert output[0] == DanswerAnswerPiece(answer_piece="This is the ") + assert output[1] == DanswerAnswerPiece(answer_piece="first part.") + assert output[2] == StreamStopInfo(stop_reason=StreamStopReason.CANCELLED) + + # Verify that the stream was cancelled + assert answer_instance.is_cancelled() is True + + # Verify that the final answer only contains the streamed parts + assert answer_instance.llm_answer == "This is the first part." + + # Verify LLM calls + mock_llm.stream.assert_called_once() diff --git a/backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py b/backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py index 998b2932cbb..7bd4a498bd7 100644 --- a/backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py +++ b/backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py @@ -6,8 +6,11 @@ from pytest_mock import MockerFixture from danswer.llm.answering.answer import Answer +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import PromptConfig from danswer.one_shot_answer.answer_question import AnswerObjectIterator from danswer.tools.force import ForceUseTool +from danswer.tools.tool_implementations.search.search_tool import SearchTool from tests.regression.answer_quality.run_qa import _process_and_write_query_results @@ -24,39 +27,43 @@ }, ], ) -def test_skip_gen_ai_answer_generation_flag(config: dict[str, Any]) -> None: - search_tool = Mock() - search_tool.name = "search" - search_tool.run = Mock() - search_tool.run.return_value = [Mock()] +def test_skip_gen_ai_answer_generation_flag( + config: dict[str, Any], + mock_search_tool: SearchTool, + answer_style_config: AnswerStyleConfig, + prompt_config: PromptConfig, +) -> None: + question = config["question"] + skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"] + mock_llm = Mock() mock_llm.config = Mock() mock_llm.config.model_name = "gpt-4o-mini" mock_llm.stream = Mock() mock_llm.stream.return_value = [Mock()] answer = Answer( - question=config["question"], - answer_style_config=Mock(), - prompt_config=Mock(), + question=question, + answer_style_config=answer_style_config, + prompt_config=prompt_config, llm=mock_llm, single_message_history="history", - tools=[search_tool], + tools=[mock_search_tool], force_use_tool=( ForceUseTool( - tool_name=search_tool.name, - args={"query": config["question"]}, + tool_name=mock_search_tool.name, + args={"query": question}, force_use=True, ) ), skip_explicit_tool_calling=True, return_contexts=True, - skip_gen_ai_answer_generation=config["skip_gen_ai_answer_generation"], + skip_gen_ai_answer_generation=skip_gen_ai_answer_generation, ) count = 0 for _ in cast(AnswerObjectIterator, answer.processed_streamed_output): count += 1 - assert count == 2 - if not config["skip_gen_ai_answer_generation"]: + assert count == 3 if skip_gen_ai_answer_generation else 4 + if not skip_gen_ai_answer_generation: mock_llm.stream.assert_called_once() else: mock_llm.stream.assert_not_called() diff --git a/backend/tests/unit/danswer/tools/custom/test_custom_tools.py b/backend/tests/unit/danswer/tools/custom/test_custom_tools.py index 6139f41e62a..f56336809e4 100644 --- a/backend/tests/unit/danswer/tools/custom/test_custom_tools.py +++ b/backend/tests/unit/danswer/tools/custom/test_custom_tools.py @@ -5,14 +5,18 @@ import pytest -from danswer.tools.custom.custom_tool import ( +from danswer.tools.models import DynamicSchemaInfo +from danswer.tools.models import ToolResponse +from danswer.tools.tool_implementations.custom.custom_tool import ( build_custom_tools_from_openapi_schema_and_headers, ) -from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID -from danswer.tools.custom.custom_tool import CustomToolCallSummary -from danswer.tools.custom.custom_tool import validate_openapi_schema -from danswer.tools.models import DynamicSchemaInfo -from danswer.tools.tool import ToolResponse +from danswer.tools.tool_implementations.custom.custom_tool import ( + CUSTOM_TOOL_RESPONSE_ID, +) +from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary +from danswer.tools.tool_implementations.custom.custom_tool import ( + validate_openapi_schema, +) from danswer.utils.headers import HeaderItemDict @@ -78,7 +82,7 @@ def setUp(self) -> None: chat_session_id=uuid.uuid4(), message_id=20 ) - @patch("danswer.tools.custom.custom_tool.requests.request") + @patch("danswer.tools.tool_implementations.custom.custom_tool.requests.request") def test_custom_tool_run_get(self, mock_request: unittest.mock.MagicMock) -> None: """ Test the GET method of a custom tool. @@ -106,7 +110,7 @@ def test_custom_tool_run_get(self, mock_request: unittest.mock.MagicMock) -> Non "Tool name in response does not match expected value", ) - @patch("danswer.tools.custom.custom_tool.requests.request") + @patch("danswer.tools.tool_implementations.custom.custom_tool.requests.request") def test_custom_tool_run_post(self, mock_request: unittest.mock.MagicMock) -> None: """ Test the POST method of a custom tool. @@ -136,7 +140,7 @@ def test_custom_tool_run_post(self, mock_request: unittest.mock.MagicMock) -> No "Tool name in response does not match expected value", ) - @patch("danswer.tools.custom.custom_tool.requests.request") + @patch("danswer.tools.tool_implementations.custom.custom_tool.requests.request") def test_custom_tool_with_headers( self, mock_request: unittest.mock.MagicMock ) -> None: @@ -164,7 +168,7 @@ def test_custom_tool_with_headers( "GET", expected_url, json=None, headers=expected_headers ) - @patch("danswer.tools.custom.custom_tool.requests.request") + @patch("danswer.tools.tool_implementations.custom.custom_tool.requests.request") def test_custom_tool_with_empty_headers( self, mock_request: unittest.mock.MagicMock ) -> None: