Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Weves committed Oct 8, 2024
1 parent 2400d9d commit 965ffe7
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 211 deletions.
128 changes: 16 additions & 112 deletions backend/danswer/llm/answering/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,21 @@
from danswer.chat.models import AnswerQuestionPossibleReturn
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.llm_response_handler import LLMCall
from danswer.llm.answering.llm_response_handler import LLMResponseHandler
from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.models import StreamProcessor
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.answering.prompts.build import default_build_system_message
from danswer.llm.answering.prompts.build import default_build_user_message
from danswer.llm.answering.stream_processing.citation_processing import (
build_citation_processor,
)
from danswer.llm.answering.stream_processing.citation_response_handler import (
CitationResponseHandler,
)
from danswer.llm.answering.stream_processing.quotes_processing import (
build_quotes_processor,
from danswer.llm.answering.stream_processing.citation_response_handler import (
DummyAnswerResponseHandler,
)
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
from danswer.llm.interfaces import LLM
Expand All @@ -40,41 +32,17 @@
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import (
check_which_tools_should_run_for_non_tool_calling_llm,
)
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.logger import setup_logger


logger = setup_logger()


def _get_answer_stream_processor(
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
answer_style_configs: AnswerStyleConfig,
) -> StreamProcessor:
if answer_style_configs.citation_config:
return build_citation_processor(
context_docs=context_docs, doc_id_to_rank_map=doc_id_to_rank_map
)
if answer_style_configs.quotes_config:
return build_quotes_processor(
context_docs=context_docs, is_json_prompt=not (QA_PROMPT_OVERRIDE == "weak")
)

raise RuntimeError("Not implemented yet")


AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]


logger = setup_logger()


class Answer:
def __init__(
self,
Expand Down Expand Up @@ -163,72 +131,6 @@ def _get_tools_list(self) -> list[Tool]:
)
return [tool]

@classmethod
def _get_tool_call_for_non_tool_calling_llm(
cls, llm_call: LLMCall, llm: LLM
) -> tuple[Tool, dict] | None:
if llm_call.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = next(
(
t
for t in llm_call.tools
if t.name == llm_call.force_use_tool.tool_name
),
None,
)
if not tool:
raise RuntimeError(
f"Tool '{llm_call.force_use_tool.tool_name}' not found"
)

tool_args = (
llm_call.force_use_tool.args
if llm_call.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=llm_call.prompt_builder.get_user_message_content(),
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
force_run=True,
)
)

if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")

return (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=llm_call.tools,
query=llm_call.prompt_builder.get_user_message_content(),
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
)

available_tools_and_args = [
(llm_call.tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]

logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)

chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=llm_call.prompt_builder.raw_message_history,
query=llm_call.prompt_builder.get_user_message_content(),
llm=llm,
)
if available_tools_and_args
else None
)

logger.notice(f"Chosen tool: {chosen_tool_and_args}")
return chosen_tool_and_args

def _handle_specified_tool_call(
self, llm_calls: list[LLMCall], tool: Tool, tool_args: dict
) -> AnswerStream:
Expand All @@ -242,12 +144,14 @@ def _handle_specified_tool_call(
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
]

response_handler_manager = LLMResponseHandlerManager([tool_handler])
response_handler_manager = LLMResponseHandlerManager(
tool_handler, DummyAnswerResponseHandler(), self.is_cancelled
)
yield from response_handler_manager.handle_llm_response(
iter([dummy_tool_call_chunk])
)

new_llm_call = response_handler_manager.finish(current_llm_call)
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
if new_llm_call:
yield from self._get_response(llm_calls + [new_llm_call])
else:
Expand All @@ -262,8 +166,8 @@ def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
and current_llm_call.force_use_tool.args is not None
):
tool_name, tool_args = (
self.force_use_tool.tool_name,
self.force_use_tool.args,
current_llm_call.force_use_tool.tool_name,
current_llm_call.force_use_tool.args,
)
tool = next(
(t for t in current_llm_call.tools if t.name == tool_name), None
Expand All @@ -276,8 +180,10 @@ def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:

# special pre-logic for non-tool calling LLM case
if not self.using_tool_calling_llm and current_llm_call.tools:
chosen_tool_and_args = self._get_tool_call_for_non_tool_calling_llm(
current_llm_call, self.llm
chosen_tool_and_args = (
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
current_llm_call, self.llm
)
)
if chosen_tool_and_args:
tool, tool_args = chosen_tool_and_args
Expand All @@ -287,18 +193,17 @@ def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
# set up "handlers" to listen to the LLM response stream and
# feed back the processed results + handle tool call requests
# + figure out what the next LLM call should be
handlers: list[LLMResponseHandler] = []
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
handlers.append(tool_call_handler)

search_result = SearchTool.get_search_result(current_llm_call) or []
citation_response_handler = CitationResponseHandler(
context_docs=search_result,
doc_id_to_rank_map=map_document_id_order(search_result),
)
handlers.append(citation_response_handler)

response_handler_manager = LLMResponseHandlerManager(handlers)
response_handler_manager = LLMResponseHandlerManager(
tool_call_handler, citation_response_handler, self.is_cancelled
)

# DEBUG: good breakpoint
stream = self.llm.stream(
Expand All @@ -312,7 +217,7 @@ def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
)
yield from response_handler_manager.handle_llm_response(stream)

new_llm_call = response_handler_manager.finish(current_llm_call)
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
if new_llm_call:
yield from self._get_response(llm_calls + [new_llm_call])

Expand Down Expand Up @@ -369,7 +274,6 @@ def citations(self) -> list[CitationInfo]:

return citations

@property
def is_cancelled(self) -> bool:
if self._is_cancelled:
return True
Expand Down
77 changes: 37 additions & 40 deletions backend/danswer/llm/answering/llm_response_handler.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import abc
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from typing import TYPE_CHECKING

from langchain_core.messages import BaseMessage
from pydantic.v1 import BaseModel as BaseModel__v1

from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.tools.force import ForceUseTool
Expand All @@ -16,6 +18,14 @@
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool


if TYPE_CHECKING:
from danswer.llm.answering.stream_processing.citation_response_handler import (
AnswerResponseHandler,
)
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler


ResponsePart = (
DanswerAnswerPiece
| CitationInfo
Expand All @@ -38,49 +48,36 @@ class Config:
arbitrary_types_allowed = True


class LLMResponseHandler(abc.ABC):
@abc.abstractmethod
def handle_response_part(
self, response_item: BaseMessage, previous_response_items: list[BaseMessage]
) -> Generator[ResponsePart, None, None]:
raise NotImplementedError

@abc.abstractmethod
def finish(self, current_llm_call: LLMCall) -> LLMCall | None:
raise NotImplementedError


class LLMResponseHandlerManager:
def __init__(self, handlers: list[LLMResponseHandler]):
self.handlers = handlers
def __init__(
self,
tool_handler: "ToolResponseHandler",
answer_handler: "AnswerResponseHandler",
is_cancelled: Callable[[], bool],
):
self.tool_handler = tool_handler
self.answer_handler = answer_handler
self.is_cancelled = is_cancelled

def handle_llm_response(
self,
stream: Iterator[BaseMessage],
) -> Generator[ResponsePart, None, None]:
messages: list[BaseMessage] = []
all_messages: list[BaseMessage] = []
for message in stream:
for handler in self.handlers:
responses = handler.handle_response_part(message, messages)
for response in responses:
yield response

messages.append(message)

for handler in self.handlers:
yield from handler.handle_response_part(None, messages)

def finish(self, llm_call: LLMCall) -> LLMCall | None:
new_llm_call = None
for handler in self.handlers:
new_llm_call_temp = handler.finish(llm_call)

if new_llm_call and new_llm_call_temp:
raise RuntimeError(
"Multiple handlers are trying to add a new LLM call, this is not allowed."
)

if new_llm_call_temp:
new_llm_call = new_llm_call_temp

return new_llm_call
# tool handler doesn't do anything until the full message is received
# NOTE: still need to run list() to get this to run
list(self.tool_handler.handle_response_part(message, all_messages))
yield from self.answer_handler.handle_response_part(message, all_messages)
all_messages.append(message)

if self.is_cancelled():
yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
return

# potentially give back all info on the selected tool call + its result
yield from self.tool_handler.handle_response_part(None, all_messages)
yield from self.answer_handler.handle_response_part(None, all_messages)

def next_llm_call(self, llm_call: LLMCall) -> LLMCall | None:
return self.tool_handler.next_llm_call(llm_call)
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import re
from collections.abc import Generator
from collections.abc import Iterator

from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import STOP_STREAM_PAT
from danswer.llm.answering.models import StreamProcessor
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.prompts.constants import TRIPLE_BACKTICK
from danswer.utils.logger import setup_logger
Expand Down Expand Up @@ -180,20 +178,3 @@ def process_token(

if result:
yield DanswerAnswerPiece(answer_piece=result)


def build_citation_processor(
context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
) -> StreamProcessor:
def stream_processor(
tokens: Iterator[str],
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
processor = CitationProcessor(context_docs, doc_id_to_rank_map)
for token in tokens:
result = processor.process_token(token)
if result:
yield result
if processor.curr_segment:
yield DanswerAnswerPiece(answer_piece=processor.curr_segment)

return stream_processor
Loading

0 comments on commit 965ffe7

Please sign in to comment.