diff --git a/backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py b/backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py index 272335ca07c..1e2e7cd3c1b 100644 --- a/backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py +++ b/backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py @@ -9,7 +9,7 @@ import sqlalchemy as sa from danswer.db.models import IndexModelStatus -from danswer.search.models import RecencyBiasSetting +from danswer.search.enums import RecencyBiasSetting from danswer.search.models import SearchType # revision identifiers, used by Alembic. diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index fe97b0b3923..ee2f582c954 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -1,97 +1,29 @@ import re -from collections.abc import Callable -from collections.abc import Iterator from collections.abc import Sequence -from functools import lru_cache -from typing import cast -from langchain.schema.messages import BaseMessage -from langchain.schema.messages import HumanMessage -from langchain.schema.messages import SystemMessage from sqlalchemy.orm import Session -from tiktoken.core import Encoding from danswer.chat.models import CitationInfo -from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc -from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION -from danswer.configs.chat_configs import STOP_STREAM_PAT -from danswer.configs.constants import IGNORE_FOR_QA -from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE -from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from danswer.db.chat import get_chat_messages_by_session -from danswer.db.chat import get_default_prompt from danswer.db.models import ChatMessage -from danswer.db.models import Persona -from danswer.db.models import Prompt from danswer.indexing.models import InferenceChunk -from danswer.llm.utils import check_number_of_tokens -from danswer.llm.utils import get_default_llm_tokenizer -from danswer.llm.utils import get_default_llm_version -from danswer.llm.utils import get_max_input_tokens -from danswer.llm.utils import tokenizer_trim_content -from danswer.prompts.chat_prompts import ADDITIONAL_INFO -from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT -from danswer.prompts.chat_prompts import CHAT_USER_PROMPT -from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT -from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT -from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT -from danswer.prompts.constants import TRIPLE_BACKTICK -from danswer.prompts.prompt_utils import build_complete_context_str -from danswer.prompts.prompt_utils import build_task_prompt_reminders -from danswer.prompts.prompt_utils import get_current_llm_day_time -from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT -from danswer.prompts.token_counts import ( - CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT, -) -from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT -from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT -from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT from danswer.utils.logger import setup_logger logger = setup_logger() -@lru_cache() -def build_chat_system_message( - prompt: Prompt, - context_exists: bool, - llm_tokenizer_encode_func: Callable, - citation_line: str = REQUIRE_CITATION_STATEMENT, - no_citation_line: str = NO_CITATION_STATEMENT, -) -> tuple[SystemMessage | None, int]: - system_prompt = prompt.system_prompt.strip() - if prompt.include_citations: - if context_exists: - system_prompt += citation_line - else: - system_prompt += no_citation_line - if prompt.datetime_aware: - if system_prompt: - system_prompt += ADDITIONAL_INFO.format( - datetime_info=get_current_llm_day_time() - ) - else: - system_prompt = get_current_llm_day_time() - - if not system_prompt: - return None, 0 - - token_count = len(llm_tokenizer_encode_func(system_prompt)) - system_msg = SystemMessage(content=system_prompt) - - return system_msg, token_count - - def llm_doc_from_inference_chunk(inf_chunk: InferenceChunk) -> LlmDoc: return LlmDoc( document_id=inf_chunk.document_id, content=inf_chunk.content, + blurb=inf_chunk.blurb, semantic_identifier=inf_chunk.semantic_identifier, source_type=inf_chunk.source_type, metadata=inf_chunk.metadata, updated_at=inf_chunk.updated_at, link=inf_chunk.source_links[0] if inf_chunk.source_links else None, + source_links=inf_chunk.source_links, ) @@ -108,170 +40,6 @@ def map_document_id_order( return order_mapping -def build_chat_user_message( - chat_message: ChatMessage, - prompt: Prompt, - context_docs: list[LlmDoc], - llm_tokenizer_encode_func: Callable, - all_doc_useful: bool, - user_prompt_template: str = CHAT_USER_PROMPT, - context_free_template: str = CHAT_USER_CONTEXT_FREE_PROMPT, - ignore_str: str = DEFAULT_IGNORE_STATEMENT, -) -> tuple[HumanMessage, int]: - user_query = chat_message.message - - if not context_docs: - # Simpler prompt for cases where there is no context - user_prompt = ( - context_free_template.format( - task_prompt=prompt.task_prompt, user_query=user_query - ) - if prompt.task_prompt - else user_query - ) - user_prompt = user_prompt.strip() - token_count = len(llm_tokenizer_encode_func(user_prompt)) - user_msg = HumanMessage(content=user_prompt) - return user_msg, token_count - - context_docs_str = build_complete_context_str( - cast(list[LlmDoc | InferenceChunk], context_docs) - ) - optional_ignore = "" if all_doc_useful else ignore_str - - task_prompt_with_reminder = build_task_prompt_reminders(prompt) - - user_prompt = user_prompt_template.format( - optional_ignore_statement=optional_ignore, - context_docs_str=context_docs_str, - task_prompt=task_prompt_with_reminder, - user_query=user_query, - ) - - user_prompt = user_prompt.strip() - token_count = len(llm_tokenizer_encode_func(user_prompt)) - user_msg = HumanMessage(content=user_prompt) - - return user_msg, token_count - - -def _get_usable_chunks( - chunks: list[InferenceChunk], token_limit: int -) -> list[InferenceChunk]: - total_token_count = 0 - usable_chunks = [] - for chunk in chunks: - chunk_token_count = check_number_of_tokens(chunk.content) - if total_token_count + chunk_token_count > token_limit: - break - - total_token_count += chunk_token_count - usable_chunks.append(chunk) - - # try and return at least one chunk if possible. This chunk will - # get truncated later on in the pipeline. This would only occur if - # the first chunk is larger than the token limit (usually due to character - # count -> token count mismatches caused by special characters / non-ascii - # languages) - if not usable_chunks and chunks: - usable_chunks = [chunks[0]] - - return usable_chunks - - -def get_usable_chunks( - chunks: list[InferenceChunk], - token_limit: int, - offset: int = 0, -) -> list[InferenceChunk]: - offset_into_chunks = 0 - usable_chunks: list[InferenceChunk] = [] - for _ in range(min(offset + 1, 1)): # go through this process at least once - if offset_into_chunks >= len(chunks) and offset_into_chunks > 0: - raise ValueError( - "Chunks offset too large, should not retry this many times" - ) - - usable_chunks = _get_usable_chunks( - chunks=chunks[offset_into_chunks:], token_limit=token_limit - ) - offset_into_chunks += len(usable_chunks) - - return usable_chunks - - -def get_chunks_for_qa( - chunks: list[InferenceChunk], - llm_chunk_selection: list[bool], - token_limit: int | None, - llm_tokenizer: Encoding | None = None, - batch_offset: int = 0, -) -> list[int]: - """ - Gives back indices of chunks to pass into the LLM for Q&A. - - Only selects chunks viable for Q&A, within the token limit, and prioritize those selected - by the LLM in a separate flow (this can be turned off) - - Note, the batch_offset calculation has to count the batches from the beginning each time as - there's no way to know which chunks were included in the prior batches without recounting atm, - this is somewhat slow as it requires tokenizing all the chunks again - """ - token_leeway = 50 - batch_index = 0 - latest_batch_indices: list[int] = [] - token_count = 0 - - # First iterate the LLM selected chunks, then iterate the rest if tokens remaining - for selection_target in [True, False]: - for ind, chunk in enumerate(chunks): - if llm_chunk_selection[ind] is not selection_target or chunk.metadata.get( - IGNORE_FOR_QA - ): - continue - - # We calculate it live in case the user uses a different LLM + tokenizer - chunk_token = check_number_of_tokens(chunk.content) - if chunk_token > DOC_EMBEDDING_CONTEXT_SIZE + token_leeway: - logger.warning( - "Found more tokens in chunk than expected, " - "likely mismatch between embedding and LLM tokenizers. Trimming content..." - ) - chunk.content = tokenizer_trim_content( - content=chunk.content, - desired_length=DOC_EMBEDDING_CONTEXT_SIZE, - tokenizer=llm_tokenizer or get_default_llm_tokenizer(), - ) - - # 50 for an approximate/slight overestimate for # tokens for metadata for the chunk - token_count += chunk_token + token_leeway - - # Always use at least 1 chunk - if ( - token_limit is None - or token_count <= token_limit - or not latest_batch_indices - ): - latest_batch_indices.append(ind) - current_chunk_unused = False - else: - current_chunk_unused = True - - if token_limit is not None and token_count >= token_limit: - if batch_index < batch_offset: - batch_index += 1 - if current_chunk_unused: - latest_batch_indices = [ind] - token_count = chunk_token - else: - latest_batch_indices = [] - token_count = 0 - else: - return latest_batch_indices - - return latest_batch_indices - - def create_chat_chain( chat_session_id: int, db_session: Session, @@ -341,157 +109,6 @@ def combine_message_chain( return "\n\n".join(message_strs) -_PER_MESSAGE_TOKEN_BUFFER = 7 - - -def find_last_index(lst: list[int], max_prompt_tokens: int) -> int: - """From the back, find the index of the last element to include - before the list exceeds the maximum""" - running_sum = 0 - - last_ind = 0 - for i in range(len(lst) - 1, -1, -1): - running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER - if running_sum > max_prompt_tokens: - last_ind = i + 1 - break - if last_ind >= len(lst): - raise ValueError("Last message alone is too large!") - return last_ind - - -def drop_messages_history_overflow( - system_msg: BaseMessage | None, - system_token_count: int, - history_msgs: list[BaseMessage], - history_token_counts: list[int], - final_msg: BaseMessage, - final_msg_token_count: int, - max_allowed_tokens: int, -) -> list[BaseMessage]: - """As message history grows, messages need to be dropped starting from the furthest in the past. - The System message should be kept if at all possible and the latest user input which is inserted in the - prompt template must be included""" - if len(history_msgs) != len(history_token_counts): - # This should never happen - raise ValueError("Need exactly 1 token count per message for tracking overflow") - - prompt: list[BaseMessage] = [] - - # Start dropping from the history if necessary - all_tokens = history_token_counts + [system_token_count, final_msg_token_count] - ind_prev_msg_start = find_last_index( - all_tokens, max_prompt_tokens=max_allowed_tokens - ) - - if system_msg and ind_prev_msg_start <= len(history_msgs): - prompt.append(system_msg) - - prompt.extend(history_msgs[ind_prev_msg_start:]) - - prompt.append(final_msg) - - return prompt - - -def in_code_block(llm_text: str) -> bool: - count = llm_text.count(TRIPLE_BACKTICK) - return count % 2 != 0 - - -def extract_citations_from_stream( - tokens: Iterator[str], - context_docs: list[LlmDoc], - doc_id_to_rank_map: dict[str, int], - stop_stream: str | None = STOP_STREAM_PAT, -) -> Iterator[DanswerAnswerPiece | CitationInfo]: - llm_out = "" - max_citation_num = len(context_docs) - curr_segment = "" - prepend_bracket = False - cited_inds = set() - hold = "" - for raw_token in tokens: - if stop_stream: - next_hold = hold + raw_token - - if stop_stream in next_hold: - break - - if next_hold == stop_stream[: len(next_hold)]: - hold = next_hold - continue - - token = next_hold - hold = "" - else: - token = raw_token - - # Special case of [1][ where ][ is a single token - # This is where the model attempts to do consecutive citations like [1][2] - if prepend_bracket: - curr_segment += "[" + curr_segment - prepend_bracket = False - - curr_segment += token - llm_out += token - - possible_citation_pattern = r"(\[\d*$)" # [1, [, etc - possible_citation_found = re.search(possible_citation_pattern, curr_segment) - - citation_pattern = r"\[(\d+)\]" # [1], [2] etc - citation_found = re.search(citation_pattern, curr_segment) - - if citation_found and not in_code_block(llm_out): - numerical_value = int(citation_found.group(1)) - if 1 <= numerical_value <= max_citation_num: - context_llm_doc = context_docs[ - numerical_value - 1 - ] # remove 1 index offset - - link = context_llm_doc.link - target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id] - - # Use the citation number for the document's rank in - # the search (or selected docs) results - curr_segment = re.sub( - rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment - ) - - if target_citation_num not in cited_inds: - cited_inds.add(target_citation_num) - yield CitationInfo( - citation_num=target_citation_num, - document_id=context_llm_doc.document_id, - ) - - if link: - curr_segment = re.sub(r"\[", "[[", curr_segment, count=1) - curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1) - - # In case there's another open bracket like [1][, don't want to match this - possible_citation_found = None - - # if we see "[", but haven't seen the right side, hold back - this may be a - # citation that needs to be replaced with a link - if possible_citation_found: - continue - - # Special case with back to back citations [1][2] - if curr_segment and curr_segment[-1] == "[": - curr_segment = curr_segment[:-1] - prepend_bracket = True - - yield DanswerAnswerPiece(answer_piece=curr_segment) - curr_segment = "" - - if curr_segment: - if prepend_bracket: - yield DanswerAnswerPiece(answer_piece="[" + curr_segment) - else: - yield DanswerAnswerPiece(answer_piece=curr_segment) - - def reorganize_citations( answer: str, citations: list[CitationInfo] ) -> tuple[str, list[CitationInfo]]: @@ -547,72 +164,3 @@ def slack_link_format(match: re.Match) -> str: new_citation_info[citation.citation_num] = citation return new_answer, list(new_citation_info.values()) - - -def get_prompt_tokens(prompt: Prompt) -> int: - # Note: currently custom prompts do not allow datetime aware, only default prompts - return ( - check_number_of_tokens(prompt.system_prompt) - + check_number_of_tokens(prompt.task_prompt) - + CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT - + CITATION_STATEMENT_TOKEN_CNT - + CITATION_REMINDER_TOKEN_CNT - + (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0) - + (ADDITIONAL_INFO_TOKEN_CNT if prompt.datetime_aware else 0) - ) - - -# buffer just to be safe so that we don't overflow the token limit due to -# a small miscalculation -_MISC_BUFFER = 40 - - -def compute_max_document_tokens( - persona: Persona, - actual_user_input: str | None = None, - max_llm_token_override: int | None = None, -) -> int: - """Estimates the number of tokens available for context documents. Formula is roughly: - - ( - model_context_window - reserved_output_tokens - prompt_tokens - - (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe) - ) - - The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g. - if we're trying to determine if the user should be able to select another document) then we just set an - arbitrary "upper bound". - """ - llm_name = get_default_llm_version()[0] - if persona.llm_model_version_override: - llm_name = persona.llm_model_version_override - - # if we can't find a number of tokens, just assume some common default - max_input_tokens = ( - max_llm_token_override - if max_llm_token_override - else get_max_input_tokens(model_name=llm_name) - ) - if persona.prompts: - # TODO this may not always be the first prompt - prompt_tokens = get_prompt_tokens(persona.prompts[0]) - else: - prompt_tokens = get_prompt_tokens(get_default_prompt()) - - user_input_tokens = ( - check_number_of_tokens(actual_user_input) - if actual_user_input is not None - else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS - ) - - return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER - - -def compute_max_llm_input_tokens(persona: Persona) -> int: - """Maximum tokens allows in the input to the LLM (of any type).""" - llm_name = get_default_llm_version()[0] - if persona.llm_model_version_override: - llm_name = persona.llm_model_version_override - - input_tokens = get_max_input_tokens(model_name=llm_name) - return input_tokens - _MISC_BUFFER diff --git a/backend/danswer/chat/load_yamls.py b/backend/danswer/chat/load_yamls.py index 0800abb70a1..ccc75443749 100644 --- a/backend/danswer/chat/load_yamls.py +++ b/backend/danswer/chat/load_yamls.py @@ -13,7 +13,7 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import DocumentSet as DocumentSetDBModel from danswer.db.models import Prompt as PromptDBModel -from danswer.search.models import RecencyBiasSetting +from danswer.search.enums import RecencyBiasSetting def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None: diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index de3f7e4f017..d2dd9f31faf 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -5,10 +5,10 @@ from pydantic import BaseModel from danswer.configs.constants import DocumentSource -from danswer.search.models import QueryFlow +from danswer.search.enums import QueryFlow +from danswer.search.enums import SearchType from danswer.search.models import RetrievalDocs from danswer.search.models import SearchResponse -from danswer.search.models import SearchType class LlmDoc(BaseModel): @@ -16,11 +16,13 @@ class LlmDoc(BaseModel): document_id: str content: str + blurb: str semantic_identifier: str source_type: DocumentSource metadata: dict[str, str | list[str]] updated_at: datetime | None link: str | None + source_links: dict[int, str] | None # First chunk of info for streaming QA @@ -100,9 +102,12 @@ class QAResponse(SearchResponse, DanswerAnswer): error_msg: str | None = None -AnswerQuestionStreamReturn = Iterator[ - DanswerAnswerPiece | DanswerQuotes | DanswerContexts | StreamingError -] +AnswerQuestionPossibleReturn = ( + DanswerAnswerPiece | DanswerQuotes | CitationInfo | DanswerContexts | StreamingError +) + + +AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn] class LLMMetricsContainer(BaseModel): diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index aafe5d000f8..270afc67e29 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -5,16 +5,8 @@ from sqlalchemy.orm import Session -from danswer.chat.chat_utils import build_chat_system_message -from danswer.chat.chat_utils import build_chat_user_message -from danswer.chat.chat_utils import compute_max_document_tokens -from danswer.chat.chat_utils import compute_max_llm_input_tokens from danswer.chat.chat_utils import create_chat_chain -from danswer.chat.chat_utils import drop_messages_history_overflow -from danswer.chat.chat_utils import extract_citations_from_stream -from danswer.chat.chat_utils import get_chunks_for_qa from danswer.chat.chat_utils import llm_doc_from_inference_chunk -from danswer.chat.chat_utils import map_document_id_order from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc @@ -23,9 +15,7 @@ from danswer.chat.models import StreamingError from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT -from danswer.configs.constants import DISABLED_GEN_AI_MSG from danswer.configs.constants import MessageType -from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.chat import create_db_search_doc from danswer.db.chat import create_new_chat_message from danswer.db.chat import get_chat_message @@ -37,27 +27,22 @@ from danswer.db.chat import translate_db_search_doc_to_server_search_doc from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_session_context_manager -from danswer.db.models import ChatMessage -from danswer.db.models import Persona from danswer.db.models import SearchDoc as DbSearchDoc from danswer.db.models import User from danswer.document_index.factory import get_default_document_index -from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.answer import Answer +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import CitationConfig +from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.models import PreviousMessage from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm -from danswer.llm.interfaces import LLM from danswer.llm.utils import get_default_llm_tokenizer -from danswer.llm.utils import get_default_llm_version -from danswer.llm.utils import get_max_input_tokens -from danswer.llm.utils import tokenizer_trim_content -from danswer.llm.utils import translate_history_to_basemessages -from danswer.prompts.prompt_utils import build_doc_context_str from danswer.search.models import OptionalSearchSetting -from danswer.search.models import RetrievalDetails -from danswer.search.request_preprocessing import retrieval_preprocessing -from danswer.search.search_runner import chunks_to_search_docs -from danswer.search.search_runner import full_chunk_search_generator -from danswer.search.search_runner import inference_documents_from_ids +from danswer.search.models import SearchRequest +from danswer.search.pipeline import SearchPipeline +from danswer.search.retrieval.search_runner import inference_documents_from_ids +from danswer.search.utils import chunks_to_search_docs from danswer.secondary_llm_flows.choose_search import check_if_need_search from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase from danswer.server.query_and_chat.models import ChatMessageDetail @@ -69,72 +54,6 @@ logger = setup_logger() -def generate_ai_chat_response( - query_message: ChatMessage, - history: list[ChatMessage], - persona: Persona, - context_docs: list[LlmDoc], - doc_id_to_rank_map: dict[str, int], - llm: LLM | None, - llm_tokenizer_encode_func: Callable, - all_doc_useful: bool, -) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]: - if llm is None: - try: - llm = get_default_llm() - except GenAIDisabledException: - # Not an error if it's a user configuration - yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG) - return - - if query_message.prompt is None: - raise RuntimeError("No prompt received for generating Gen AI answer.") - - try: - context_exists = len(context_docs) > 0 - - system_message_or_none, system_tokens = build_chat_system_message( - prompt=query_message.prompt, - context_exists=context_exists, - llm_tokenizer_encode_func=llm_tokenizer_encode_func, - ) - - history_basemessages, history_token_counts = translate_history_to_basemessages( - history - ) - - # Be sure the context_docs passed to build_chat_user_message - # Is the same as passed in later for extracting citations - user_message, user_tokens = build_chat_user_message( - chat_message=query_message, - prompt=query_message.prompt, - context_docs=context_docs, - llm_tokenizer_encode_func=llm_tokenizer_encode_func, - all_doc_useful=all_doc_useful, - ) - - prompt = drop_messages_history_overflow( - system_msg=system_message_or_none, - system_token_count=system_tokens, - history_msgs=history_basemessages, - history_token_counts=history_token_counts, - final_msg=user_message, - final_msg_token_count=user_tokens, - max_allowed_tokens=compute_max_llm_input_tokens(persona), - ) - - # Good Debug/Breakpoint - tokens = llm.stream(prompt) - - yield from extract_citations_from_stream( - tokens, context_docs, doc_id_to_rank_map - ) - - except Exception as e: - logger.exception(f"LLM failed to produce valid chat message, error: {e}") - yield StreamingError(error=str(e)) - - def translate_citations( citations_list: list[CitationInfo], db_docs: list[DbSearchDoc] ) -> dict[int, int]: @@ -155,24 +74,26 @@ def translate_citations( return citation_to_saved_doc_id_map +ChatPacketStream = Iterator[ + StreamingError + | QADocsResponse + | LLMRelevanceFilterResponse + | ChatMessageDetail + | DanswerAnswerPiece + | CitationInfo +] + + def stream_chat_message_objects( new_msg_req: CreateChatMessageRequest, user: User | None, db_session: Session, # Needed to translate persona num_chunks to tokens to the LLM default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, - default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE, # For flow with search, don't include as many chunks as possible since we need to leave space # for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE, -) -> Iterator[ - StreamingError - | QADocsResponse - | LLMRelevanceFilterResponse - | ChatMessageDetail - | DanswerAnswerPiece - | CitationInfo -]: +) -> ChatPacketStream: """Streams in order: 1. [conditional] Retrieved documents if a search needs to be run 2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on @@ -278,10 +199,6 @@ def stream_chat_message_objects( query_message=final_msg, history=history_msgs, llm=llm ) - max_document_tokens = compute_max_document_tokens( - persona=persona, actual_user_input=message_text - ) - rephrased_query = None if reference_doc_ids: identifier_tuples = get_doc_query_identifiers_from_model( @@ -297,64 +214,8 @@ def stream_chat_message_objects( doc_identifiers=identifier_tuples, document_index=document_index, ) - - # truncate the last document if it exceeds the token limit - tokens_per_doc = [ - len( - llm_tokenizer_encode_func( - build_doc_context_str( - semantic_identifier=llm_doc.semantic_identifier, - source_type=llm_doc.source_type, - content=llm_doc.content, - metadata_dict=llm_doc.metadata, - updated_at=llm_doc.updated_at, - ind=ind, - ) - ) - ) - for ind, llm_doc in enumerate(llm_docs) - ] - final_doc_ind = None - total_tokens = 0 - for ind, tokens in enumerate(tokens_per_doc): - total_tokens += tokens - if total_tokens > max_document_tokens: - final_doc_ind = ind - break - if final_doc_ind is not None: - # only allow the final document to get truncated - # if more than that, then the user message is too long - if final_doc_ind != len(tokens_per_doc) - 1: - yield StreamingError( - error="LLM context window exceeded. Please de-select some documents or shorten your query." - ) - return - - final_doc_desired_length = tokens_per_doc[final_doc_ind] - ( - total_tokens - max_document_tokens - ) - # 75 tokens is a reasonable over-estimate of the metadata and title - final_doc_content_length = final_doc_desired_length - 75 - # this could occur if we only have space for the title / metadata - # not ideal, but it's the most reasonable thing to do - # NOTE: the frontend prevents documents from being selected if - # less than 75 tokens are available to try and avoid this situation - # from occuring in the first place - if final_doc_content_length <= 0: - logger.error( - f"Final doc ({llm_docs[final_doc_ind].semantic_identifier}) content " - "length is less than 0. Removing this doc from the final prompt." - ) - llm_docs.pop() - else: - llm_docs[final_doc_ind].content = tokenizer_trim_content( - content=llm_docs[final_doc_ind].content, - desired_length=final_doc_content_length, - tokenizer=llm_tokenizer, - ) - - doc_id_to_rank_map = map_document_id_order( - cast(list[InferenceChunk | LlmDoc], llm_docs) + document_pruning_config = DocumentPruningConfig( + is_manually_selected_docs=True ) # In case the search doc is deleted, just don't include it @@ -377,36 +238,21 @@ def stream_chat_message_objects( else query_override ) - ( - retrieval_request, - predicted_search_type, - predicted_flow, - ) = retrieval_preprocessing( - query=rephrased_query, - retrieval_details=cast(RetrievalDetails, retrieval_options), - persona=persona, + search_pipeline = SearchPipeline( + search_request=SearchRequest( + query=rephrased_query, + human_selected_filters=retrieval_options.filters + if retrieval_options + else None, + persona=persona, + offset=retrieval_options.offset if retrieval_options else None, + limit=retrieval_options.limit if retrieval_options else None, + ), user=user, db_session=db_session, ) - documents_generator = full_chunk_search_generator( - search_query=retrieval_request, - document_index=document_index, - db_session=db_session, - ) - time_cutoff = retrieval_request.filters.time_cutoff - recency_bias_multiplier = retrieval_request.recency_bias_multiplier - run_llm_chunk_filter = not retrieval_request.skip_llm_chunk_filter - - # First fetch and return the top chunks to the UI so the user can - # immediately see some results - top_chunks = cast(list[InferenceChunk], next(documents_generator)) - - # Get ranking of the documents for citation purposes later - doc_id_to_rank_map = map_document_id_order( - cast(list[InferenceChunk | LlmDoc], top_chunks) - ) - + top_chunks = search_pipeline.reranked_docs top_docs = chunks_to_search_docs(top_chunks) reference_db_search_docs = [ @@ -422,62 +268,35 @@ def stream_chat_message_objects( initial_response = QADocsResponse( rephrased_query=rephrased_query, top_documents=response_docs, - predicted_flow=predicted_flow, - predicted_search=predicted_search_type, - applied_source_filters=retrieval_request.filters.source_type, - applied_time_cutoff=time_cutoff, - recency_bias_multiplier=recency_bias_multiplier, + predicted_flow=search_pipeline.predicted_flow, + predicted_search=search_pipeline.predicted_search_type, + applied_source_filters=search_pipeline.search_query.filters.source_type, + applied_time_cutoff=search_pipeline.search_query.filters.time_cutoff, + recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier, ) yield initial_response - # Get the final ordering of chunks for the LLM call - llm_chunk_selection = cast(list[bool], next(documents_generator)) - # Yield the list of LLM selected chunks for showing the LLM selected icons in the UI llm_relevance_filtering_response = LLMRelevanceFilterResponse( - relevant_chunk_indices=[ - index for index, value in enumerate(llm_chunk_selection) if value - ] - if run_llm_chunk_filter - else [] + relevant_chunk_indices=search_pipeline.relevant_chunk_indicies ) yield llm_relevance_filtering_response - # Prep chunks to pass to LLM - num_llm_chunks = ( - persona.num_chunks - if persona.num_chunks is not None - else default_num_chunks + document_pruning_config = DocumentPruningConfig( + max_chunks=int( + persona.num_chunks + if persona.num_chunks is not None + else default_num_chunks + ), + max_window_percentage=max_document_percentage, ) - llm_name = get_default_llm_version()[0] - if persona.llm_model_version_override: - llm_name = persona.llm_model_version_override - - llm_max_input_tokens = get_max_input_tokens(model_name=llm_name) - - llm_token_based_chunk_lim = max_document_percentage * llm_max_input_tokens - - chunk_token_limit = int( - min( - num_llm_chunks * default_chunk_size, - max_document_tokens, - llm_token_based_chunk_lim, - ) - ) - llm_chunks_indices = get_chunks_for_qa( - chunks=top_chunks, - llm_chunk_selection=llm_chunk_selection, - token_limit=chunk_token_limit, - llm_tokenizer=llm_tokenizer, - ) - llm_chunks = [top_chunks[i] for i in llm_chunks_indices] - llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks] + llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in top_chunks] else: llm_docs = [] - doc_id_to_rank_map = {} reference_db_search_docs = None + document_pruning_config = DocumentPruningConfig() # Cannot determine these without the LLM step or breaking out early partial_response = partial( @@ -515,33 +334,24 @@ def stream_chat_message_objects( return # LLM prompt building, response capturing, etc. - response_packets = generate_ai_chat_response( - query_message=final_msg, - history=history_msgs, + answer = Answer( + question=final_msg.message, + docs=llm_docs, + answer_style_config=AnswerStyleConfig( + citation_config=CitationConfig( + all_docs_useful=reference_db_search_docs is not None + ), + document_pruning_config=document_pruning_config, + ), + prompt=final_msg.prompt, persona=persona, - context_docs=llm_docs, - doc_id_to_rank_map=doc_id_to_rank_map, - llm=llm, - llm_tokenizer_encode_func=llm_tokenizer_encode_func, - all_doc_useful=reference_doc_ids is not None, + message_history=[ + PreviousMessage.from_chat_message(msg) for msg in history_msgs + ], ) + # generator will not include quotes, so we can cast + yield from cast(ChatPacketStream, answer.processed_streamed_output) - # Capture outputs and errors - llm_output = "" - error: str | None = None - citations: list[CitationInfo] = [] - for packet in response_packets: - if isinstance(packet, DanswerAnswerPiece): - token = packet.answer_piece - if token: - llm_output += token - elif isinstance(packet, StreamingError): - error = packet.error - elif isinstance(packet, CitationInfo): - citations.append(packet) - continue - - yield packet except Exception as e: logger.exception(e) @@ -555,16 +365,16 @@ def stream_chat_message_objects( db_citations = None if reference_db_search_docs: db_citations = translate_citations( - citations_list=citations, + citations_list=answer.citations, db_docs=reference_db_search_docs, ) # Saving Gen AI answer and responding with message info gen_ai_response_message = partial_response( - message=llm_output, - token_count=len(llm_tokenizer_encode_func(llm_output)), + message=answer.llm_answer, + token_count=len(llm_tokenizer_encode_func(answer.llm_answer)), citations=db_citations, - error=error, + error=None, ) msg_detail_response = translate_db_message_to_chat_message_detail( diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index 1e065dd1dad..b3fdb79c88b 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -12,7 +12,6 @@ from slack_sdk.models.blocks import DividerBlock from sqlalchemy.orm import Session -from danswer.chat.chat_utils import compute_max_document_tokens from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER @@ -39,6 +38,7 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import SlackBotConfig from danswer.db.models import SlackBotResponseType +from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import get_default_llm_version from danswer.llm.utils import get_max_input_tokens diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 343912e275b..6dfa02c2f9c 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -27,7 +27,7 @@ from danswer.db.models import SearchDoc as DBSearchDoc from danswer.db.models import StarterMessage from danswer.db.models import User__UserGroup -from danswer.search.models import RecencyBiasSetting +from danswer.search.enums import RecencyBiasSetting from danswer.search.models import RetrievalDocs from danswer.search.models import SavedSearchDoc from danswer.search.models import SearchDoc as ServerSearchDoc diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index abe189c45ec..faafd2aedf8 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -36,8 +36,8 @@ from danswer.configs.constants import SearchFeedbackType from danswer.connectors.models import InputType from danswer.dynamic_configs.interface import JSON_ro -from danswer.search.models import RecencyBiasSetting -from danswer.search.models import SearchType +from danswer.search.enums import RecencyBiasSetting +from danswer.search.enums import SearchType class IndexingStatus(str, PyEnum): diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py index 3e93a76cf6a..c3b463e35d2 100644 --- a/backend/danswer/db/slack_bot_config.py +++ b/backend/danswer/db/slack_bot_config.py @@ -12,7 +12,7 @@ from danswer.db.models import Persona__DocumentSet from danswer.db.models import SlackBotConfig from danswer.db.models import SlackBotResponseType -from danswer.search.models import RecencyBiasSetting +from danswer.search.enums import RecencyBiasSetting def _build_persona_name(channel_names: list[str]) -> str: diff --git a/backend/danswer/document_index/vespa/index.py b/backend/danswer/document_index/vespa/index.py index 178aadf3eea..9f78f05c20e 100644 --- a/backend/danswer/document_index/vespa/index.py +++ b/backend/danswer/document_index/vespa/index.py @@ -64,8 +64,8 @@ from danswer.indexing.models import DocMetadataAwareIndexChunk from danswer.indexing.models import InferenceChunk from danswer.search.models import IndexFilters -from danswer.search.search_runner import query_processing -from danswer.search.search_runner import remove_stop_words_and_punctuation +from danswer.search.retrieval.search_runner import query_processing +from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation from danswer.utils.batching import batch_generator from danswer.utils.logger import setup_logger from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py new file mode 100644 index 00000000000..76d399d8bd9 --- /dev/null +++ b/backend/danswer/llm/answering/answer.py @@ -0,0 +1,176 @@ +from collections.abc import Iterator +from typing import cast + +from langchain.schema.messages import BaseMessage + +from danswer.chat.models import AnswerQuestionPossibleReturn +from danswer.chat.models import AnswerQuestionStreamReturn +from danswer.chat.models import CitationInfo +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE +from danswer.configs.chat_configs import QA_TIMEOUT +from danswer.db.models import Persona +from danswer.db.models import Prompt +from danswer.llm.answering.doc_pruning import prune_documents +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.answering.models import StreamProcessor +from danswer.llm.answering.prompts.citations_prompt import build_citations_prompt +from danswer.llm.answering.prompts.quotes_prompt import ( + build_quotes_prompt, +) +from danswer.llm.answering.stream_processing.citation_processing import ( + build_citation_processor, +) +from danswer.llm.answering.stream_processing.quotes_processing import ( + build_quotes_processor, +) +from danswer.llm.factory import get_default_llm +from danswer.llm.utils import get_default_llm_tokenizer + + +def _get_stream_processor( + docs: list[LlmDoc], answer_style_configs: AnswerStyleConfig +) -> StreamProcessor: + if answer_style_configs.citation_config: + return build_citation_processor( + context_docs=docs, + ) + if answer_style_configs.quotes_config: + return build_quotes_processor( + context_docs=docs, is_json_prompt=not (QA_PROMPT_OVERRIDE == "weak") + ) + + raise RuntimeError("Not implemented yet") + + +class Answer: + def __init__( + self, + question: str, + docs: list[LlmDoc], + answer_style_config: AnswerStyleConfig, + prompt: Prompt, + persona: Persona, + # must be the same length as `docs`. If None, all docs are considered "relevant" + doc_relevance_list: list[bool] | None = None, + message_history: list[PreviousMessage] | None = None, + single_message_history: str | None = None, + timeout: int = QA_TIMEOUT, + ) -> None: + if single_message_history and message_history: + raise ValueError( + "Cannot provide both `message_history` and `single_message_history`" + ) + + self.question = question + self.docs = docs + self.doc_relevance_list = doc_relevance_list + self.message_history = message_history or [] + # used for QA flow where we only want to send a single message + self.single_message_history = single_message_history + + self.answer_style_config = answer_style_config + + self.llm = get_default_llm( + gen_ai_model_version_override=persona.llm_model_version_override, + timeout=timeout, + ) + self.llm_tokenizer = get_default_llm_tokenizer() + + self.prompt = prompt + self.persona = persona + + self.process_stream_fn = _get_stream_processor(docs, answer_style_config) + + self._final_prompt: list[BaseMessage] | None = None + + self._pruned_docs: list[LlmDoc] | None = None + + self._streamed_output: list[str] | None = None + self._processed_stream: list[AnswerQuestionPossibleReturn] | None = None + + @property + def pruned_docs(self) -> list[LlmDoc]: + if self._pruned_docs is not None: + return self._pruned_docs + + self._pruned_docs = prune_documents( + docs=self.docs, + doc_relevance_list=self.doc_relevance_list, + persona=self.persona, + question=self.question, + document_pruning_config=self.answer_style_config.document_pruning_config, + ) + return self._pruned_docs + + @property + def final_prompt(self) -> list[BaseMessage]: + if self._final_prompt is not None: + return self._final_prompt + + if self.answer_style_config.citation_config: + self._final_prompt = build_citations_prompt( + question=self.question, + message_history=self.message_history, + persona=self.persona, + prompt=self.prompt, + context_docs=self.pruned_docs, + all_doc_useful=self.answer_style_config.citation_config.all_docs_useful, + llm_tokenizer_encode_func=self.llm_tokenizer.encode, + history_message=self.single_message_history or "", + ) + elif self.answer_style_config.quotes_config: + self._final_prompt = build_quotes_prompt( + question=self.question, + context_docs=self.pruned_docs, + history_str=self.single_message_history or "", + prompt=self.prompt, + ) + + return cast(list[BaseMessage], self._final_prompt) + + @property + def raw_streamed_output(self) -> Iterator[str]: + if self._streamed_output is not None: + yield from self._streamed_output + return + + streamed_output = [] + for message in self.llm.stream(self.final_prompt): + streamed_output.append(message) + yield message + + self._streamed_output = streamed_output + + @property + def processed_streamed_output(self) -> AnswerQuestionStreamReturn: + if self._processed_stream is not None: + yield from self._processed_stream + return + + processed_stream = [] + for processed_packet in self.process_stream_fn(self.raw_streamed_output): + processed_stream.append(processed_packet) + yield processed_packet + + self._processed_stream = processed_stream + + @property + def llm_answer(self) -> str: + answer = "" + for packet in self.processed_streamed_output: + if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece: + answer += packet.answer_piece + + return answer + + @property + def citations(self) -> list[CitationInfo]: + citations: list[CitationInfo] = [] + for packet in self.processed_streamed_output: + if isinstance(packet, CitationInfo): + citations.append(packet) + + return citations diff --git a/backend/danswer/llm/answering/doc_pruning.py b/backend/danswer/llm/answering/doc_pruning.py new file mode 100644 index 00000000000..29c913673d5 --- /dev/null +++ b/backend/danswer/llm/answering/doc_pruning.py @@ -0,0 +1,205 @@ +from copy import deepcopy +from typing import TypeVar + +from danswer.chat.models import ( + LlmDoc, +) +from danswer.configs.constants import IGNORE_FOR_QA +from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE +from danswer.db.models import Persona +from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens +from danswer.llm.utils import get_default_llm_tokenizer +from danswer.llm.utils import tokenizer_trim_content +from danswer.prompts.prompt_utils import build_doc_context_str +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + +T = TypeVar("T", bound=LlmDoc | InferenceChunk) + +_METADATA_TOKEN_ESTIMATE = 75 + + +class PruningError(Exception): + pass + + +def _compute_limit( + persona: Persona, + question: str, + max_chunks: int | None, + max_window_percentage: float | None, + max_tokens: int | None, +) -> int: + llm_max_document_tokens = compute_max_document_tokens( + persona=persona, actual_user_input=question + ) + + window_percentage_based_limit = ( + max_window_percentage * llm_max_document_tokens + if max_window_percentage + else None + ) + chunk_count_based_limit = ( + max_chunks * DOC_EMBEDDING_CONTEXT_SIZE if max_chunks else None + ) + + limit_options = [ + lim + for lim in [ + window_percentage_based_limit, + chunk_count_based_limit, + max_tokens, + llm_max_document_tokens, + ] + if lim + ] + return int(min(limit_options)) + + +def reorder_docs( + docs: list[T], + doc_relevance_list: list[bool] | None, +) -> list[T]: + if doc_relevance_list is None: + return docs + + reordered_docs: list[T] = [] + if doc_relevance_list is not None: + for selection_target in [True, False]: + for doc, is_relevant in zip(docs, doc_relevance_list): + if is_relevant == selection_target: + reordered_docs.append(doc) + return reordered_docs + + +def _remove_docs_to_ignore(docs: list[LlmDoc]) -> list[LlmDoc]: + return [doc for doc in docs if not doc.metadata.get(IGNORE_FOR_QA)] + + +def _apply_pruning( + docs: list[LlmDoc], + doc_relevance_list: list[bool] | None, + token_limit: int, + is_manually_selected_docs: bool, +) -> list[LlmDoc]: + llm_tokenizer = get_default_llm_tokenizer() + docs = deepcopy(docs) # don't modify in place + + # re-order docs with all the "relevant" docs at the front + docs = reorder_docs(docs=docs, doc_relevance_list=doc_relevance_list) + # remove docs that are explicitly marked as not for QA + docs = _remove_docs_to_ignore(docs=docs) + + tokens_per_doc: list[int] = [] + final_doc_ind = None + total_tokens = 0 + for ind, llm_doc in enumerate(docs): + doc_tokens = len( + llm_tokenizer.encode( + build_doc_context_str( + semantic_identifier=llm_doc.semantic_identifier, + source_type=llm_doc.source_type, + content=llm_doc.content, + metadata_dict=llm_doc.metadata, + updated_at=llm_doc.updated_at, + ind=ind, + ) + ) + ) + # if chunks, truncate chunks that are way too long + # this can happen if the embedding model tokenizer is different + # than the LLM tokenizer + if ( + not is_manually_selected_docs + and doc_tokens > DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE + ): + logger.warning( + "Found more tokens in chunk than expected, " + "likely mismatch between embedding and LLM tokenizers. Trimming content..." + ) + llm_doc.content = tokenizer_trim_content( + content=llm_doc.content, + desired_length=DOC_EMBEDDING_CONTEXT_SIZE, + tokenizer=llm_tokenizer, + ) + doc_tokens = DOC_EMBEDDING_CONTEXT_SIZE + tokens_per_doc.append(doc_tokens) + total_tokens += doc_tokens + if total_tokens > token_limit: + final_doc_ind = ind + break + + if final_doc_ind is not None: + if is_manually_selected_docs: + # for document selection, only allow the final document to get truncated + # if more than that, then the user message is too long + if final_doc_ind != len(docs) - 1: + raise PruningError( + "LLM context window exceeded. Please de-select some documents or shorten your query." + ) + + final_doc_desired_length = tokens_per_doc[final_doc_ind] - ( + total_tokens - token_limit + ) + final_doc_content_length = ( + final_doc_desired_length - _METADATA_TOKEN_ESTIMATE + ) + # this could occur if we only have space for the title / metadata + # not ideal, but it's the most reasonable thing to do + # NOTE: the frontend prevents documents from being selected if + # less than 75 tokens are available to try and avoid this situation + # from occuring in the first place + if final_doc_content_length <= 0: + logger.error( + f"Final doc ({docs[final_doc_ind].semantic_identifier}) content " + "length is less than 0. Removing this doc from the final prompt." + ) + docs.pop() + else: + docs[final_doc_ind].content = tokenizer_trim_content( + content=docs[final_doc_ind].content, + desired_length=final_doc_content_length, + tokenizer=llm_tokenizer, + ) + else: + # for regular search, don't truncate the final document unless it's the only one + if final_doc_ind != 0: + docs = docs[:final_doc_ind] + else: + docs[0].content = tokenizer_trim_content( + content=docs[0].content, + desired_length=token_limit - _METADATA_TOKEN_ESTIMATE, + tokenizer=llm_tokenizer, + ) + docs = [docs[0]] + + return docs + + +def prune_documents( + docs: list[LlmDoc], + doc_relevance_list: list[bool] | None, + persona: Persona, + question: str, + document_pruning_config: DocumentPruningConfig, +) -> list[LlmDoc]: + if doc_relevance_list is not None: + assert len(docs) == len(doc_relevance_list) + + doc_token_limit = _compute_limit( + persona=persona, + question=question, + max_chunks=document_pruning_config.max_chunks, + max_window_percentage=document_pruning_config.max_window_percentage, + max_tokens=document_pruning_config.max_tokens, + ) + return _apply_pruning( + docs=docs, + doc_relevance_list=doc_relevance_list, + token_limit=doc_token_limit, + is_manually_selected_docs=document_pruning_config.is_manually_selected_docs, + ) diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py new file mode 100644 index 00000000000..360535ac803 --- /dev/null +++ b/backend/danswer/llm/answering/models.py @@ -0,0 +1,77 @@ +from collections.abc import Callable +from collections.abc import Iterator +from typing import Any +from typing import TYPE_CHECKING + +from pydantic import BaseModel +from pydantic import Field +from pydantic import root_validator + +from danswer.chat.models import AnswerQuestionStreamReturn +from danswer.configs.constants import MessageType + +if TYPE_CHECKING: + from danswer.db.models import ChatMessage + + +StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn] + + +class PreviousMessage(BaseModel): + """Simplified version of `ChatMessage`""" + + message: str + token_count: int + message_type: MessageType + + @classmethod + def from_chat_message(cls, chat_message: "ChatMessage") -> "PreviousMessage": + return cls( + message=chat_message.message, + token_count=chat_message.token_count, + message_type=chat_message.message_type, + ) + + +class DocumentPruningConfig(BaseModel): + max_chunks: int | None = None + max_window_percentage: float | None = None + max_tokens: int | None = None + # different pruning behavior is expected when the + # user manually selects documents they want to chat with + # e.g. we don't want to truncate each document to be no more + # than one chunk long + is_manually_selected_docs: bool = False + + +class CitationConfig(BaseModel): + all_docs_useful: bool = False + + +class QuotesConfig(BaseModel): + pass + + +class AnswerStyleConfig(BaseModel): + citation_config: CitationConfig | None = None + quotes_config: QuotesConfig | None = None + document_pruning_config: DocumentPruningConfig = Field( + default_factory=DocumentPruningConfig + ) + + @root_validator + def check_quotes_and_citation(cls, values: dict[str, Any]) -> dict[str, Any]: + citation_config = values.get("citation_config") + quotes_config = values.get("quotes_config") + + if citation_config is None and quotes_config is None: + raise ValueError( + "One of `citation_config` or `quotes_config` must be provided" + ) + + if citation_config is not None and quotes_config is not None: + raise ValueError( + "Only one of `citation_config` or `quotes_config` must be provided" + ) + + return values diff --git a/backend/danswer/llm/answering/prompts/citations_prompt.py b/backend/danswer/llm/answering/prompts/citations_prompt.py new file mode 100644 index 00000000000..61c42c19c78 --- /dev/null +++ b/backend/danswer/llm/answering/prompts/citations_prompt.py @@ -0,0 +1,281 @@ +from collections.abc import Callable +from functools import lru_cache +from typing import cast + +from langchain.schema.messages import BaseMessage +from langchain.schema.messages import HumanMessage +from langchain.schema.messages import SystemMessage + +from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION +from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS +from danswer.db.chat import get_default_prompt +from danswer.db.models import Persona +from danswer.db.models import Prompt +from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.utils import check_number_of_tokens +from danswer.llm.utils import get_default_llm_tokenizer +from danswer.llm.utils import get_default_llm_version +from danswer.llm.utils import get_max_input_tokens +from danswer.llm.utils import translate_history_to_basemessages +from danswer.prompts.chat_prompts import ADDITIONAL_INFO +from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT +from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT +from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT +from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT +from danswer.prompts.direct_qa_prompts import ( + CITATIONS_PROMPT, +) +from danswer.prompts.prompt_utils import build_complete_context_str +from danswer.prompts.prompt_utils import build_task_prompt_reminders +from danswer.prompts.prompt_utils import get_current_llm_day_time +from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT +from danswer.prompts.token_counts import ( + CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT, +) +from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT +from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT +from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT + + +_PER_MESSAGE_TOKEN_BUFFER = 7 + + +def find_last_index(lst: list[int], max_prompt_tokens: int) -> int: + """From the back, find the index of the last element to include + before the list exceeds the maximum""" + running_sum = 0 + + last_ind = 0 + for i in range(len(lst) - 1, -1, -1): + running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER + if running_sum > max_prompt_tokens: + last_ind = i + 1 + break + if last_ind >= len(lst): + raise ValueError("Last message alone is too large!") + return last_ind + + +def drop_messages_history_overflow( + system_msg: BaseMessage | None, + system_token_count: int, + history_msgs: list[BaseMessage], + history_token_counts: list[int], + final_msg: BaseMessage, + final_msg_token_count: int, + max_allowed_tokens: int, +) -> list[BaseMessage]: + """As message history grows, messages need to be dropped starting from the furthest in the past. + The System message should be kept if at all possible and the latest user input which is inserted in the + prompt template must be included""" + if len(history_msgs) != len(history_token_counts): + # This should never happen + raise ValueError("Need exactly 1 token count per message for tracking overflow") + + prompt: list[BaseMessage] = [] + + # Start dropping from the history if necessary + all_tokens = history_token_counts + [system_token_count, final_msg_token_count] + ind_prev_msg_start = find_last_index( + all_tokens, max_prompt_tokens=max_allowed_tokens + ) + + if system_msg and ind_prev_msg_start <= len(history_msgs): + prompt.append(system_msg) + + prompt.extend(history_msgs[ind_prev_msg_start:]) + + prompt.append(final_msg) + + return prompt + + +def get_prompt_tokens(prompt: Prompt) -> int: + # Note: currently custom prompts do not allow datetime aware, only default prompts + return ( + check_number_of_tokens(prompt.system_prompt) + + check_number_of_tokens(prompt.task_prompt) + + CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT + + CITATION_STATEMENT_TOKEN_CNT + + CITATION_REMINDER_TOKEN_CNT + + (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0) + + (ADDITIONAL_INFO_TOKEN_CNT if prompt.datetime_aware else 0) + ) + + +# buffer just to be safe so that we don't overflow the token limit due to +# a small miscalculation +_MISC_BUFFER = 40 + + +def compute_max_document_tokens( + persona: Persona, + actual_user_input: str | None = None, + max_llm_token_override: int | None = None, +) -> int: + """Estimates the number of tokens available for context documents. Formula is roughly: + + ( + model_context_window - reserved_output_tokens - prompt_tokens + - (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe) + ) + + The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g. + if we're trying to determine if the user should be able to select another document) then we just set an + arbitrary "upper bound". + """ + llm_name = get_default_llm_version()[0] + if persona.llm_model_version_override: + llm_name = persona.llm_model_version_override + + # if we can't find a number of tokens, just assume some common default + max_input_tokens = ( + max_llm_token_override + if max_llm_token_override + else get_max_input_tokens(model_name=llm_name) + ) + if persona.prompts: + # TODO this may not always be the first prompt + prompt_tokens = get_prompt_tokens(persona.prompts[0]) + else: + prompt_tokens = get_prompt_tokens(get_default_prompt()) + + user_input_tokens = ( + check_number_of_tokens(actual_user_input) + if actual_user_input is not None + else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS + ) + + return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER + + +def compute_max_llm_input_tokens(persona: Persona) -> int: + """Maximum tokens allows in the input to the LLM (of any type).""" + llm_name = get_default_llm_version()[0] + if persona.llm_model_version_override: + llm_name = persona.llm_model_version_override + + input_tokens = get_max_input_tokens(model_name=llm_name) + return input_tokens - _MISC_BUFFER + + +@lru_cache() +def build_system_message( + prompt: Prompt, + context_exists: bool, + llm_tokenizer_encode_func: Callable, + citation_line: str = REQUIRE_CITATION_STATEMENT, + no_citation_line: str = NO_CITATION_STATEMENT, +) -> tuple[SystemMessage | None, int]: + system_prompt = prompt.system_prompt.strip() + if prompt.include_citations: + if context_exists: + system_prompt += citation_line + else: + system_prompt += no_citation_line + if prompt.datetime_aware: + if system_prompt: + system_prompt += ADDITIONAL_INFO.format( + datetime_info=get_current_llm_day_time() + ) + else: + system_prompt = get_current_llm_day_time() + + if not system_prompt: + return None, 0 + + token_count = len(llm_tokenizer_encode_func(system_prompt)) + system_msg = SystemMessage(content=system_prompt) + + return system_msg, token_count + + +def build_user_message( + question: str, + prompt: Prompt, + context_docs: list[LlmDoc] | list[InferenceChunk], + all_doc_useful: bool, + history_message: str, +) -> tuple[HumanMessage, int]: + llm_tokenizer = get_default_llm_tokenizer() + llm_tokenizer_encode_func = cast(Callable[[str], list[int]], llm_tokenizer.encode) + + if not context_docs: + # Simpler prompt for cases where there is no context + user_prompt = ( + CHAT_USER_CONTEXT_FREE_PROMPT.format( + task_prompt=prompt.task_prompt, user_query=question + ) + if prompt.task_prompt + else question + ) + user_prompt = user_prompt.strip() + token_count = len(llm_tokenizer_encode_func(user_prompt)) + user_msg = HumanMessage(content=user_prompt) + return user_msg, token_count + + context_docs_str = build_complete_context_str(context_docs) + optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT + + task_prompt_with_reminder = build_task_prompt_reminders(prompt) + + user_prompt = CITATIONS_PROMPT.format( + optional_ignore_statement=optional_ignore, + context_docs_str=context_docs_str, + task_prompt=task_prompt_with_reminder, + user_query=question, + history_block=history_message, + ) + + user_prompt = user_prompt.strip() + token_count = len(llm_tokenizer_encode_func(user_prompt)) + user_msg = HumanMessage(content=user_prompt) + + return user_msg, token_count + + +def build_citations_prompt( + question: str, + message_history: list[PreviousMessage], + persona: Persona, + prompt: Prompt, + context_docs: list[LlmDoc] | list[InferenceChunk], + all_doc_useful: bool, + history_message: str, + llm_tokenizer_encode_func: Callable, +) -> list[BaseMessage]: + context_exists = len(context_docs) > 0 + + system_message_or_none, system_tokens = build_system_message( + prompt=prompt, + context_exists=context_exists, + llm_tokenizer_encode_func=llm_tokenizer_encode_func, + ) + + history_basemessages, history_token_counts = translate_history_to_basemessages( + message_history + ) + + # Be sure the context_docs passed to build_chat_user_message + # Is the same as passed in later for extracting citations + user_message, user_tokens = build_user_message( + question=question, + prompt=prompt, + context_docs=context_docs, + all_doc_useful=all_doc_useful, + history_message=history_message, + ) + + final_prompt_msgs = drop_messages_history_overflow( + system_msg=system_message_or_none, + system_token_count=system_tokens, + history_msgs=history_basemessages, + history_token_counts=history_token_counts, + final_msg=user_message, + final_msg_token_count=user_tokens, + max_allowed_tokens=compute_max_llm_input_tokens(persona), + ) + + return final_prompt_msgs diff --git a/backend/danswer/llm/answering/prompts/quotes_prompt.py b/backend/danswer/llm/answering/prompts/quotes_prompt.py new file mode 100644 index 00000000000..c9e145e8100 --- /dev/null +++ b/backend/danswer/llm/answering/prompts/quotes_prompt.py @@ -0,0 +1,88 @@ +from langchain.schema.messages import BaseMessage +from langchain.schema.messages import HumanMessage + +from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION +from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE +from danswer.db.models import Prompt +from danswer.indexing.models import InferenceChunk +from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK +from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK +from danswer.prompts.direct_qa_prompts import JSON_PROMPT +from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT +from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT +from danswer.prompts.prompt_utils import build_complete_context_str + + +def _build_weak_llm_quotes_prompt( + question: str, + context_docs: list[LlmDoc] | list[InferenceChunk], + history_str: str, + prompt: Prompt, + use_language_hint: bool, +) -> list[BaseMessage]: + """Since Danswer supports a variety of LLMs, this less demanding prompt is provided + as an option to use with weaker LLMs such as small version, low float precision, quantized, + or distilled models. It only uses one context document and has very weak requirements of + output format. + """ + context_block = "" + if context_docs: + context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs[0].content) + + prompt_str = WEAK_LLM_PROMPT.format( + system_prompt=prompt.system_prompt, + context_block=context_block, + task_prompt=prompt.task_prompt, + user_query=question, + ) + return [HumanMessage(content=prompt_str)] + + +def _build_strong_llm_quotes_prompt( + question: str, + context_docs: list[LlmDoc] | list[InferenceChunk], + history_str: str, + prompt: Prompt, + use_language_hint: bool, +) -> list[BaseMessage]: + context_block = "" + if context_docs: + context_docs_str = build_complete_context_str(context_docs) + context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str) + + history_block = "" + if history_str: + history_block = HISTORY_BLOCK.format(history_str=history_str) + + full_prompt = JSON_PROMPT.format( + system_prompt=prompt.system_prompt, + context_block=context_block, + history_block=history_block, + task_prompt=prompt.task_prompt, + user_query=question, + language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "", + ).strip() + return [HumanMessage(content=full_prompt)] + + +def build_quotes_prompt( + question: str, + context_docs: list[LlmDoc] | list[InferenceChunk], + history_str: str, + prompt: Prompt, + use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), +) -> list[BaseMessage]: + prompt_builder = ( + _build_weak_llm_quotes_prompt + if QA_PROMPT_OVERRIDE == "weak" + else _build_strong_llm_quotes_prompt + ) + + return prompt_builder( + question=question, + context_docs=context_docs, + history_str=history_str, + prompt=prompt, + use_language_hint=use_language_hint, + ) diff --git a/backend/danswer/llm/answering/prompts/utils.py b/backend/danswer/llm/answering/prompts/utils.py new file mode 100644 index 00000000000..bcc8b891815 --- /dev/null +++ b/backend/danswer/llm/answering/prompts/utils.py @@ -0,0 +1,20 @@ +from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT +from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT + + +def build_dummy_prompt( + system_prompt: str, task_prompt: str, retrieval_disabled: bool +) -> str: + if retrieval_disabled: + return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format( + user_query="", + system_prompt=system_prompt, + task_prompt=task_prompt, + ).strip() + + return PARAMATERIZED_PROMPT.format( + context_docs_str="", + user_query="", + system_prompt=system_prompt, + task_prompt=task_prompt, + ).strip() diff --git a/backend/danswer/llm/answering/stream_processing/citation_processing.py b/backend/danswer/llm/answering/stream_processing/citation_processing.py new file mode 100644 index 00000000000..a26021835cc --- /dev/null +++ b/backend/danswer/llm/answering/stream_processing/citation_processing.py @@ -0,0 +1,126 @@ +import re +from collections.abc import Iterator + +from danswer.chat.models import AnswerQuestionStreamReturn +from danswer.chat.models import CitationInfo +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import STOP_STREAM_PAT +from danswer.llm.answering.models import StreamProcessor +from danswer.llm.answering.stream_processing.utils import map_document_id_order +from danswer.prompts.constants import TRIPLE_BACKTICK +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + + +def in_code_block(llm_text: str) -> bool: + count = llm_text.count(TRIPLE_BACKTICK) + return count % 2 != 0 + + +def extract_citations_from_stream( + tokens: Iterator[str], + context_docs: list[LlmDoc], + doc_id_to_rank_map: dict[str, int], + stop_stream: str | None = STOP_STREAM_PAT, +) -> Iterator[DanswerAnswerPiece | CitationInfo]: + llm_out = "" + max_citation_num = len(context_docs) + curr_segment = "" + prepend_bracket = False + cited_inds = set() + hold = "" + for raw_token in tokens: + if stop_stream: + next_hold = hold + raw_token + + if stop_stream in next_hold: + break + + if next_hold == stop_stream[: len(next_hold)]: + hold = next_hold + continue + + token = next_hold + hold = "" + else: + token = raw_token + + # Special case of [1][ where ][ is a single token + # This is where the model attempts to do consecutive citations like [1][2] + if prepend_bracket: + curr_segment += "[" + curr_segment + prepend_bracket = False + + curr_segment += token + llm_out += token + + possible_citation_pattern = r"(\[\d*$)" # [1, [, etc + possible_citation_found = re.search(possible_citation_pattern, curr_segment) + + citation_pattern = r"\[(\d+)\]" # [1], [2] etc + citation_found = re.search(citation_pattern, curr_segment) + + if citation_found and not in_code_block(llm_out): + numerical_value = int(citation_found.group(1)) + if 1 <= numerical_value <= max_citation_num: + context_llm_doc = context_docs[ + numerical_value - 1 + ] # remove 1 index offset + + link = context_llm_doc.link + target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id] + + # Use the citation number for the document's rank in + # the search (or selected docs) results + curr_segment = re.sub( + rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment + ) + + if target_citation_num not in cited_inds: + cited_inds.add(target_citation_num) + yield CitationInfo( + citation_num=target_citation_num, + document_id=context_llm_doc.document_id, + ) + + if link: + curr_segment = re.sub(r"\[", "[[", curr_segment, count=1) + curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1) + + # In case there's another open bracket like [1][, don't want to match this + possible_citation_found = None + + # if we see "[", but haven't seen the right side, hold back - this may be a + # citation that needs to be replaced with a link + if possible_citation_found: + continue + + # Special case with back to back citations [1][2] + if curr_segment and curr_segment[-1] == "[": + curr_segment = curr_segment[:-1] + prepend_bracket = True + + yield DanswerAnswerPiece(answer_piece=curr_segment) + curr_segment = "" + + if curr_segment: + if prepend_bracket: + yield DanswerAnswerPiece(answer_piece="[" + curr_segment) + else: + yield DanswerAnswerPiece(answer_piece=curr_segment) + + +def build_citation_processor( + context_docs: list[LlmDoc], +) -> StreamProcessor: + def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn: + yield from extract_citations_from_stream( + tokens=tokens, + context_docs=context_docs, + doc_id_to_rank_map=map_document_id_order(context_docs), + ) + + return stream_processor diff --git a/backend/danswer/llm/answering/stream_processing/quotes_processing.py b/backend/danswer/llm/answering/stream_processing/quotes_processing.py new file mode 100644 index 00000000000..daa966e6947 --- /dev/null +++ b/backend/danswer/llm/answering/stream_processing/quotes_processing.py @@ -0,0 +1,282 @@ +import math +import re +from collections.abc import Callable +from collections.abc import Generator +from collections.abc import Iterator +from json import JSONDecodeError +from typing import Optional + +import regex + +from danswer.chat.models import AnswerQuestionStreamReturn +from danswer.chat.models import DanswerAnswer +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import DanswerQuote +from danswer.chat.models import DanswerQuotes +from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT +from danswer.indexing.models import InferenceChunk +from danswer.prompts.constants import ANSWER_PAT +from danswer.prompts.constants import QUOTE_PAT +from danswer.prompts.constants import UNCERTAINTY_PAT +from danswer.utils.logger import setup_logger +from danswer.utils.text_processing import clean_model_quote +from danswer.utils.text_processing import clean_up_code_blocks +from danswer.utils.text_processing import extract_embedded_json +from danswer.utils.text_processing import shared_precompare_cleanup + + +logger = setup_logger() + + +def _extract_answer_quotes_freeform( + answer_raw: str, +) -> tuple[Optional[str], Optional[list[str]]]: + """Splits the model output into an Answer and 0 or more Quote sections. + Splits by the Quote pattern, if not exist then assume it's all answer and no quotes + """ + # If no answer section, don't care about the quote + if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()): + return None, None + + # Sometimes model regenerates the Answer: pattern despite it being provided in the prompt + if answer_raw.lower().startswith(ANSWER_PAT.lower()): + answer_raw = answer_raw[len(ANSWER_PAT) :] + + # Accept quote sections starting with the lower case version + answer_raw = answer_raw.replace( + f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}" + ) # Just in case model unreliable + + sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw) + sections_clean = [ + str(section).strip() for section in sections if str(section).strip() + ] + if not sections_clean: + return None, None + + answer = str(sections_clean[0]) + if len(sections) == 1: + return answer, None + return answer, sections_clean[1:] + + +def _extract_answer_quotes_json( + answer_dict: dict[str, str | list[str]] +) -> tuple[Optional[str], Optional[list[str]]]: + answer_dict = {k.lower(): v for k, v in answer_dict.items()} + answer = str(answer_dict.get("answer")) + quotes = answer_dict.get("quotes") or answer_dict.get("quote") + if isinstance(quotes, str): + quotes = [quotes] + return answer, quotes + + +def _extract_answer_json(raw_model_output: str) -> dict: + try: + answer_json = extract_embedded_json(raw_model_output) + except (ValueError, JSONDecodeError): + # LLMs get confused when handling the list in the json. Sometimes it doesn't attend + # enough to the previous { token so it just ends the list of quotes and stops there + # here, we add logic to try to fix this LLM error. + answer_json = extract_embedded_json(raw_model_output + "}") + + if "answer" not in answer_json: + raise ValueError("Model did not output an answer as expected.") + + return answer_json + + +def match_quotes_to_docs( + quotes: list[str], + docs: list[LlmDoc] | list[InferenceChunk], + max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT, + fuzzy_search: bool = False, + prefix_only_length: int = 100, +) -> DanswerQuotes: + danswer_quotes: list[DanswerQuote] = [] + for quote in quotes: + max_edits = math.ceil(float(len(quote)) * max_error_percent) + + for doc in docs: + if not doc.source_links: + continue + + quote_clean = shared_precompare_cleanup( + clean_model_quote(quote, trim_length=prefix_only_length) + ) + chunk_clean = shared_precompare_cleanup(doc.content) + + # Finding the offset of the quote in the plain text + if fuzzy_search: + re_search_str = ( + r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}" + ) + found = regex.search(re_search_str, chunk_clean) + if not found: + continue + offset = found.span()[0] + else: + if quote_clean not in chunk_clean: + continue + offset = chunk_clean.index(quote_clean) + + # Extracting the link from the offset + curr_link = None + for link_offset, link in doc.source_links.items(): + # Should always find one because offset is at least 0 and there + # must be a 0 link_offset + if int(link_offset) <= offset: + curr_link = link + else: + break + + danswer_quotes.append( + DanswerQuote( + quote=quote, + document_id=doc.document_id, + link=curr_link, + source_type=doc.source_type, + semantic_identifier=doc.semantic_identifier, + blurb=doc.blurb, + ) + ) + break + + return DanswerQuotes(quotes=danswer_quotes) + + +def separate_answer_quotes( + answer_raw: str, is_json_prompt: bool = False +) -> tuple[Optional[str], Optional[list[str]]]: + """Takes in a raw model output and pulls out the answer and the quotes sections.""" + if is_json_prompt: + model_raw_json = _extract_answer_json(answer_raw) + return _extract_answer_quotes_json(model_raw_json) + + return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw)) + + +def process_answer( + answer_raw: str, + docs: list[LlmDoc], + is_json_prompt: bool = True, +) -> tuple[DanswerAnswer, DanswerQuotes]: + """Used (1) in the non-streaming case to process the model output + into an Answer and Quotes AND (2) after the complete streaming response + has been received to process the model output into an Answer and Quotes.""" + answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt) + if answer == UNCERTAINTY_PAT or not answer: + if answer == UNCERTAINTY_PAT: + logger.debug("Answer matched UNCERTAINTY_PAT") + else: + logger.debug("No answer extracted from raw output") + return DanswerAnswer(answer=None), DanswerQuotes(quotes=[]) + + logger.info(f"Answer: {answer}") + if not quote_strings: + logger.debug("No quotes extracted from raw output") + return DanswerAnswer(answer=answer), DanswerQuotes(quotes=[]) + logger.info(f"All quotes (including unmatched): {quote_strings}") + quotes = match_quotes_to_docs(quote_strings, docs) + logger.debug(f"Final quotes: {quotes}") + + return DanswerAnswer(answer=answer), quotes + + +def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool: + next_token = next_token.replace('\\"', "") + # If the previous character is an escape token, don't consider the first character of next_token + # This does not work if it's an escaped escape sign before the " but this is rare, not worth handling + if answer_so_far and answer_so_far[-1] == "\\": + next_token = next_token[1:] + if '"' in next_token: + return True + return False + + +def _extract_quotes_from_completed_token_stream( + model_output: str, context_docs: list[LlmDoc], is_json_prompt: bool = True +) -> DanswerQuotes: + answer, quotes = process_answer(model_output, context_docs, is_json_prompt) + if answer: + logger.info(answer) + elif model_output: + logger.warning("Answer extraction from model output failed.") + + return quotes + + +def process_model_tokens( + tokens: Iterator[str], + context_docs: list[LlmDoc], + is_json_prompt: bool = True, +) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]: + """Used in the streaming case to process the model output + into an Answer and Quotes + + Yields Answer tokens back out in a dict for streaming to frontend + When Answer section ends, yields dict with answer_finished key + Collects all the tokens at the end to form the complete model output""" + quote_pat = f"\n{QUOTE_PAT}" + # Sometimes worse model outputs new line instead of : + quote_loose = f"\n{quote_pat[:-1]}\n" + # Sometime model outputs two newlines before quote section + quote_pat_full = f"\n{quote_pat}" + model_output: str = "" + found_answer_start = False if is_json_prompt else True + found_answer_end = False + hold_quote = "" + for token in tokens: + model_previous = model_output + model_output += token + + if not found_answer_start and '{"answer":"' in re.sub(r"\s", "", model_output): + # Note, if the token that completes the pattern has additional text, for example if the token is "? + # Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the + # event that the model outputs the UNCERTAINTY_PAT + found_answer_start = True + + # Prevent heavy cases of hallucinations where model is not even providing a json until later + if is_json_prompt and len(model_output) > 40: + logger.warning("LLM did not produce json as prompted") + found_answer_end = True + + continue + + if found_answer_start and not found_answer_end: + if is_json_prompt and _stream_json_answer_end(model_previous, token): + found_answer_end = True + yield DanswerAnswerPiece(answer_piece=None) + continue + elif not is_json_prompt: + if quote_pat in hold_quote + token or quote_loose in hold_quote + token: + found_answer_end = True + yield DanswerAnswerPiece(answer_piece=None) + continue + if hold_quote + token in quote_pat_full: + hold_quote += token + continue + yield DanswerAnswerPiece(answer_piece=hold_quote + token) + hold_quote = "" + + logger.debug(f"Raw Model QnA Output: {model_output}") + + yield _extract_quotes_from_completed_token_stream( + model_output=model_output, + context_docs=context_docs, + is_json_prompt=is_json_prompt, + ) + + +def build_quotes_processor( + context_docs: list[LlmDoc], is_json_prompt: bool +) -> Callable[[Iterator[str]], AnswerQuestionStreamReturn]: + def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn: + yield from process_model_tokens( + tokens=tokens, + context_docs=context_docs, + is_json_prompt=is_json_prompt, + ) + + return stream_processor diff --git a/backend/danswer/llm/answering/stream_processing/utils.py b/backend/danswer/llm/answering/stream_processing/utils.py new file mode 100644 index 00000000000..1ddcdf605ef --- /dev/null +++ b/backend/danswer/llm/answering/stream_processing/utils.py @@ -0,0 +1,17 @@ +from collections.abc import Sequence + +from danswer.chat.models import LlmDoc +from danswer.indexing.models import InferenceChunk + + +def map_document_id_order( + chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True +) -> dict[str, int]: + order_mapping = {} + current = 1 if one_indexed else 0 + for chunk in chunks: + if chunk.document_id not in order_mapping: + order_mapping[chunk.document_id] = current + current += 1 + + return order_mapping diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index f36f285461b..c07b708bb51 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -33,6 +33,7 @@ from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM from danswer.utils.logger import setup_logger @@ -114,7 +115,9 @@ def tokenizer_trim_chunks( return new_chunks -def translate_danswer_msg_to_langchain(msg: ChatMessage) -> BaseMessage: +def translate_danswer_msg_to_langchain( + msg: ChatMessage | PreviousMessage, +) -> BaseMessage: if msg.message_type == MessageType.SYSTEM: raise ValueError("System messages are not currently part of history") if msg.message_type == MessageType.ASSISTANT: @@ -126,7 +129,7 @@ def translate_danswer_msg_to_langchain(msg: ChatMessage) -> BaseMessage: def translate_history_to_basemessages( - history: list[ChatMessage], + history: list[ChatMessage] | list[PreviousMessage], ) -> tuple[list[BaseMessage], list[int]]: history_basemessages = [ translate_danswer_msg_to_langchain(msg) diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 529180a7993..e863f4ac098 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -1,63 +1,43 @@ -import itertools from collections.abc import Callable from collections.abc import Iterator -from typing import cast -from langchain.schema.messages import BaseMessage -from langchain.schema.messages import HumanMessage from sqlalchemy.orm import Session -from danswer.chat.chat_utils import build_chat_system_message -from danswer.chat.chat_utils import compute_max_document_tokens -from danswer.chat.chat_utils import extract_citations_from_stream -from danswer.chat.chat_utils import get_chunks_for_qa from danswer.chat.chat_utils import llm_doc_from_inference_chunk -from danswer.chat.chat_utils import map_document_id_order from danswer.chat.chat_utils import reorganize_citations from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece -from danswer.chat.models import DanswerContext from danswer.chat.models import DanswerContexts from danswer.chat.models import DanswerQuotes -from danswer.chat.models import LLMMetricsContainer from danswer.chat.models import LLMRelevanceFilterResponse from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.configs.chat_configs import QA_TIMEOUT from danswer.configs.constants import MessageType -from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.chat import create_chat_session from danswer.db.chat import create_new_chat_message from danswer.db.chat import get_or_create_root_message -from danswer.db.chat import get_persona_by_id from danswer.db.chat import get_prompt_by_id from danswer.db.chat import translate_db_message_to_chat_message_detail -from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_session_context_manager -from danswer.db.models import Prompt from danswer.db.models import User -from danswer.document_index.factory import get_default_document_index -from danswer.indexing.models import InferenceChunk -from danswer.llm.factory import get_default_llm +from danswer.llm.answering.answer import Answer +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import CitationConfig +from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.models import QuotesConfig from danswer.llm.utils import get_default_llm_token_encode -from danswer.llm.utils import get_default_llm_tokenizer -from danswer.one_shot_answer.factory import get_question_answer_model from danswer.one_shot_answer.models import DirectQARequest from danswer.one_shot_answer.models import OneShotQAResponse from danswer.one_shot_answer.models import QueryRephrase -from danswer.one_shot_answer.models import ThreadMessage -from danswer.one_shot_answer.qa_block import no_gen_ai_response from danswer.one_shot_answer.qa_utils import combine_message_thread -from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT -from danswer.prompts.prompt_utils import build_complete_context_str -from danswer.prompts.prompt_utils import build_task_prompt_reminders from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer from danswer.search.models import SavedSearchDoc -from danswer.search.request_preprocessing import retrieval_preprocessing -from danswer.search.search_runner import chunks_to_search_docs -from danswer.search.search_runner import full_chunk_search_generator +from danswer.search.models import SearchRequest +from danswer.search.pipeline import SearchPipeline +from danswer.search.utils import chunks_to_search_docs from danswer.secondary_llm_flows.answer_validation import get_answer_validity from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase from danswer.server.query_and_chat.models import ChatMessageDetail @@ -80,106 +60,6 @@ ] -def quote_based_qa( - prompt: Prompt, - query_message: ThreadMessage, - history_str: str, - context_chunks: list[InferenceChunk], - llm_override: str | None, - timeout: int, - use_chain_of_thought: bool, - return_contexts: bool, - llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, -) -> AnswerObjectIterator: - qa_model = get_question_answer_model( - prompt=prompt, - timeout=timeout, - chain_of_thought=use_chain_of_thought, - llm_version=llm_override, - ) - - full_prompt_str = ( - qa_model.build_prompt( - query=query_message.message, - history_str=history_str, - context_chunks=context_chunks, - ) - if qa_model is not None - else "Gen AI Disabled" - ) - - response_packets = ( - qa_model.answer_question_stream( - prompt=full_prompt_str, - llm_context_docs=context_chunks, - metrics_callback=llm_metrics_callback, - ) - if qa_model is not None - else no_gen_ai_response() - ) - - if qa_model is not None and return_contexts: - contexts = DanswerContexts( - contexts=[ - DanswerContext( - content=context_chunk.content, - document_id=context_chunk.document_id, - semantic_identifier=context_chunk.semantic_identifier, - blurb=context_chunk.semantic_identifier, - ) - for context_chunk in context_chunks - ] - ) - - response_packets = itertools.chain(response_packets, [contexts]) - - yield from response_packets - - -def citation_based_qa( - prompt: Prompt, - query_message: ThreadMessage, - history_str: str, - context_chunks: list[InferenceChunk], - llm_override: str | None, - timeout: int, -) -> AnswerObjectIterator: - llm_tokenizer = get_default_llm_tokenizer() - - system_prompt_or_none, _ = build_chat_system_message( - prompt=prompt, - context_exists=True, - llm_tokenizer_encode_func=llm_tokenizer.encode, - ) - - task_prompt_with_reminder = build_task_prompt_reminders(prompt) - - context_docs_str = build_complete_context_str(context_chunks) - user_message = HumanMessage( - content=CITATIONS_PROMPT.format( - task_prompt=task_prompt_with_reminder, - user_query=query_message.message, - history_block=history_str, - context_docs_str=context_docs_str, - ) - ) - - llm = get_default_llm( - timeout=timeout, - gen_ai_model_version_override=llm_override, - ) - - llm_prompt: list[BaseMessage] = [user_message] - if system_prompt_or_none is not None: - llm_prompt = [system_prompt_or_none] + llm_prompt - - llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in context_chunks] - doc_id_to_rank_map = map_document_id_order(llm_docs) - - tokens = llm.stream(llm_prompt) - yield from extract_citations_from_stream(tokens, llm_docs, doc_id_to_rank_map) - - def stream_answer_objects( query_req: DirectQARequest, user: User | None, @@ -191,14 +71,12 @@ def stream_answer_objects( db_session: Session, # Needed to translate persona num_chunks to tokens to the LLM default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, - default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE, timeout: int = QA_TIMEOUT, bypass_acl: bool = False, use_citations: bool = False, retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, - llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, ) -> AnswerObjectIterator: """Streams in order: 1. [always] Retrieved documents, stops flow if nothing is found @@ -221,12 +99,6 @@ def stream_answer_objects( llm_tokenizer = get_default_llm_token_encode() - embedding_model = get_current_db_embedding_model(db_session) - - document_index = get_default_document_index( - primary_index_name=embedding_model.index_name, secondary_index_name=None - ) - # Create a chat session which will just store the root message, the query, and the AI response root_message = get_or_create_root_message( chat_session_id=chat_session.id, db_session=db_session @@ -244,33 +116,23 @@ def stream_answer_objects( # In chat flow it's given back along with the documents yield QueryRephrase(rephrased_query=rephrased_query) - ( - retrieval_request, - predicted_search_type, - predicted_flow, - ) = retrieval_preprocessing( - query=rephrased_query, - retrieval_details=query_req.retrieval_options, - persona=chat_session.persona, + search_pipeline = SearchPipeline( + search_request=SearchRequest( + query=rephrased_query, + human_selected_filters=query_req.retrieval_options.filters, + persona=chat_session.persona, + offset=query_req.retrieval_options.offset, + limit=query_req.retrieval_options.limit, + ), user=user, db_session=db_session, bypass_acl=bypass_acl, - ) - - documents_generator = full_chunk_search_generator( - search_query=retrieval_request, - document_index=document_index, - db_session=db_session, retrieval_metrics_callback=retrieval_metrics_callback, rerank_metrics_callback=rerank_metrics_callback, ) - applied_time_cutoff = retrieval_request.filters.time_cutoff - recency_bias_multiplier = retrieval_request.recency_bias_multiplier - run_llm_chunk_filter = not retrieval_request.skip_llm_chunk_filter # First fetch and return the top chunks so the user can immediately see some results - top_chunks = cast(list[InferenceChunk], next(documents_generator)) - + top_chunks = search_pipeline.reranked_docs top_docs = chunks_to_search_docs(top_chunks) fake_saved_docs = [SavedSearchDoc.from_search_doc(doc) for doc in top_docs] @@ -278,64 +140,25 @@ def stream_answer_objects( initial_response = QADocsResponse( rephrased_query=rephrased_query, top_documents=fake_saved_docs, - predicted_flow=predicted_flow, - predicted_search=predicted_search_type, - applied_source_filters=retrieval_request.filters.source_type, - applied_time_cutoff=applied_time_cutoff, - recency_bias_multiplier=recency_bias_multiplier, + predicted_flow=search_pipeline.predicted_flow, + predicted_search=search_pipeline.predicted_search_type, + applied_source_filters=search_pipeline.search_query.filters.source_type, + applied_time_cutoff=search_pipeline.search_query.filters.time_cutoff, + recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier, ) yield initial_response - # Get the final ordering of chunks for the LLM call - llm_chunk_selection = cast(list[bool], next(documents_generator)) - # Yield the list of LLM selected chunks for showing the LLM selected icons in the UI llm_relevance_filtering_response = LLMRelevanceFilterResponse( - relevant_chunk_indices=[ - index for index, value in enumerate(llm_chunk_selection) if value - ] - if run_llm_chunk_filter - else [] + relevant_chunk_indices=search_pipeline.relevant_chunk_indicies ) yield llm_relevance_filtering_response - # Prep chunks to pass to LLM - num_llm_chunks = ( - chat_session.persona.num_chunks - if chat_session.persona.num_chunks is not None - else default_num_chunks - ) - - chunk_token_limit = int(num_llm_chunks * default_chunk_size) - if max_document_tokens: - chunk_token_limit = min(chunk_token_limit, max_document_tokens) - else: - max_document_tokens = compute_max_document_tokens( - persona=chat_session.persona, actual_user_input=query_msg.message - ) - chunk_token_limit = min(chunk_token_limit, max_document_tokens) - - llm_chunks_indices = get_chunks_for_qa( - chunks=top_chunks, - llm_chunk_selection=llm_chunk_selection, - token_limit=chunk_token_limit, - ) - llm_chunks = [top_chunks[i] for i in llm_chunks_indices] - - logger.debug( - f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in llm_chunks]}" - ) - prompt = None - llm_override = None if query_req.prompt_id is not None: prompt = get_prompt_by_id( prompt_id=query_req.prompt_id, user_id=user_id, db_session=db_session ) - persona = get_persona_by_id( - persona_id=query_req.persona_id, user_id=user_id, db_session=db_session - ) - llm_override = persona.llm_model_version_override if prompt is None: if not chat_session.persona.prompts: raise RuntimeError( @@ -355,52 +178,39 @@ def stream_answer_objects( commit=True, ) - if use_citations: - qa_stream = citation_based_qa( - prompt=prompt, - query_message=query_msg, - history_str=history_str, - context_chunks=llm_chunks, - llm_override=llm_override, - timeout=timeout, - ) - else: - qa_stream = quote_based_qa( - prompt=prompt, - query_message=query_msg, - history_str=history_str, - context_chunks=llm_chunks, - llm_override=llm_override, - timeout=timeout, - use_chain_of_thought=False, - return_contexts=False, - llm_metrics_callback=llm_metrics_callback, - ) - - # Capture outputs and errors - llm_output = "" - error: str | None = None - for packet in qa_stream: - logger.debug(packet) - - if isinstance(packet, DanswerAnswerPiece): - token = packet.answer_piece - if token: - llm_output += token - elif isinstance(packet, StreamingError): - error = packet.error - - yield packet + answer_config = AnswerStyleConfig( + citation_config=CitationConfig() if use_citations else None, + quotes_config=QuotesConfig() if not use_citations else None, + document_pruning_config=DocumentPruningConfig( + max_chunks=int( + chat_session.persona.num_chunks + if chat_session.persona.num_chunks is not None + else default_num_chunks + ), + max_tokens=max_document_tokens, + ), + ) + answer = Answer( + question=query_msg.message, + docs=[llm_doc_from_inference_chunk(chunk) for chunk in top_chunks], + answer_style_config=answer_config, + prompt=prompt, + persona=chat_session.persona, + doc_relevance_list=search_pipeline.chunk_relevance_list, + single_message_history=history_str, + timeout=timeout, + ) + yield from answer.processed_streamed_output # Saving Gen AI answer and responding with message info gen_ai_response_message = create_new_chat_message( chat_session_id=chat_session.id, parent_message=new_user_message, prompt_id=query_req.prompt_id, - message=llm_output, - token_count=len(llm_tokenizer(llm_output)), + message=answer.llm_answer, + token_count=len(llm_tokenizer(answer.llm_answer)), message_type=MessageType.ASSISTANT, - error=error, + error=None, reference_docs=None, # Don't need to save reference docs for one shot flow db_session=db_session, commit=True, @@ -445,7 +255,6 @@ def get_search_answer( retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, - llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, ) -> OneShotQAResponse: """Collects the streamed one shot answer responses into a single object""" qa_response = OneShotQAResponse() @@ -461,7 +270,6 @@ def get_search_answer( timeout=answer_generation_timeout, retrieval_metrics_callback=retrieval_metrics_callback, rerank_metrics_callback=rerank_metrics_callback, - llm_metrics_callback=llm_metrics_callback, ) answer = "" diff --git a/backend/danswer/one_shot_answer/factory.py b/backend/danswer/one_shot_answer/factory.py deleted file mode 100644 index 122ed6ac06f..00000000000 --- a/backend/danswer/one_shot_answer/factory.py +++ /dev/null @@ -1,48 +0,0 @@ -from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE -from danswer.configs.chat_configs import QA_TIMEOUT -from danswer.db.models import Prompt -from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm -from danswer.one_shot_answer.interfaces import QAModel -from danswer.one_shot_answer.qa_block import QABlock -from danswer.one_shot_answer.qa_block import QAHandler -from danswer.one_shot_answer.qa_block import SingleMessageQAHandler -from danswer.one_shot_answer.qa_block import WeakLLMQAHandler -from danswer.utils.logger import setup_logger - -logger = setup_logger() - - -def get_question_answer_model( - prompt: Prompt | None, - api_key: str | None = None, - timeout: int = QA_TIMEOUT, - chain_of_thought: bool = False, - llm_version: str | None = None, - qa_model_version: str | None = QA_PROMPT_OVERRIDE, -) -> QAModel | None: - if chain_of_thought: - raise NotImplementedError("COT has been disabled") - - system_prompt = prompt.system_prompt if prompt is not None else None - task_prompt = prompt.task_prompt if prompt is not None else None - - try: - llm = get_default_llm( - api_key=api_key, - timeout=timeout, - gen_ai_model_version_override=llm_version, - ) - except GenAIDisabledException: - return None - - if qa_model_version == "weak": - qa_handler: QAHandler = WeakLLMQAHandler( - system_prompt=system_prompt, task_prompt=task_prompt - ) - else: - qa_handler = SingleMessageQAHandler( - system_prompt=system_prompt, task_prompt=task_prompt - ) - - return QABlock(llm=llm, qa_handler=qa_handler) diff --git a/backend/danswer/one_shot_answer/interfaces.py b/backend/danswer/one_shot_answer/interfaces.py deleted file mode 100644 index ca916d699df..00000000000 --- a/backend/danswer/one_shot_answer/interfaces.py +++ /dev/null @@ -1,26 +0,0 @@ -import abc -from collections.abc import Callable - -from danswer.chat.models import AnswerQuestionStreamReturn -from danswer.chat.models import LLMMetricsContainer -from danswer.indexing.models import InferenceChunk - - -class QAModel: - @abc.abstractmethod - def build_prompt( - self, - query: str, - history_str: str, - context_chunks: list[InferenceChunk], - ) -> str: - raise NotImplementedError - - @abc.abstractmethod - def answer_question_stream( - self, - prompt: str, - llm_context_docs: list[InferenceChunk], - metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, - ) -> AnswerQuestionStreamReturn: - raise NotImplementedError diff --git a/backend/danswer/one_shot_answer/qa_block.py b/backend/danswer/one_shot_answer/qa_block.py deleted file mode 100644 index 68cb6e4a821..00000000000 --- a/backend/danswer/one_shot_answer/qa_block.py +++ /dev/null @@ -1,313 +0,0 @@ -import abc -import re -from collections.abc import Callable -from collections.abc import Iterator -from typing import cast - -from danswer.chat.models import AnswerQuestionStreamReturn -from danswer.chat.models import DanswerAnswer -from danswer.chat.models import DanswerAnswerPiece -from danswer.chat.models import DanswerQuotes -from danswer.chat.models import LlmDoc -from danswer.chat.models import LLMMetricsContainer -from danswer.chat.models import StreamingError -from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION -from danswer.configs.constants import DISABLED_GEN_AI_MSG -from danswer.indexing.models import InferenceChunk -from danswer.llm.interfaces import LLM -from danswer.llm.utils import check_number_of_tokens -from danswer.llm.utils import get_default_llm_token_encode -from danswer.one_shot_answer.interfaces import QAModel -from danswer.one_shot_answer.qa_utils import process_answer -from danswer.one_shot_answer.qa_utils import process_model_tokens -from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK -from danswer.prompts.direct_qa_prompts import COT_PROMPT -from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK -from danswer.prompts.direct_qa_prompts import JSON_PROMPT -from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT -from danswer.prompts.direct_qa_prompts import ONE_SHOT_SYSTEM_PROMPT -from danswer.prompts.direct_qa_prompts import ONE_SHOT_TASK_PROMPT -from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT -from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT -from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT -from danswer.prompts.direct_qa_prompts import WEAK_MODEL_SYSTEM_PROMPT -from danswer.prompts.direct_qa_prompts import WEAK_MODEL_TASK_PROMPT -from danswer.prompts.prompt_utils import build_complete_context_str -from danswer.utils.logger import setup_logger -from danswer.utils.text_processing import clean_up_code_blocks -from danswer.utils.text_processing import escape_newlines - -logger = setup_logger() - - -class QAHandler(abc.ABC): - @property - @abc.abstractmethod - def is_json_output(self) -> bool: - """Does the model output a valid json with answer and quotes keys? Most flows with a - capable model should output a json. This hints to the model that the output is used - with a downstream system rather than freeform creative output. Most models should be - finetuned to recognize this.""" - raise NotImplementedError - - @abc.abstractmethod - def build_prompt( - self, - query: str, - history_str: str, - context_chunks: list[InferenceChunk], - ) -> str: - raise NotImplementedError - - def process_llm_token_stream( - self, tokens: Iterator[str], context_chunks: list[InferenceChunk] - ) -> AnswerQuestionStreamReturn: - yield from process_model_tokens( - tokens=tokens, - context_docs=context_chunks, - is_json_prompt=self.is_json_output, - ) - - -class WeakLLMQAHandler(QAHandler): - """Since Danswer supports a variety of LLMs, this less demanding prompt is provided - as an option to use with weaker LLMs such as small version, low float precision, quantized, - or distilled models. It only uses one context document and has very weak requirements of - output format. - """ - - def __init__( - self, - system_prompt: str | None, - task_prompt: str | None, - ) -> None: - if not system_prompt and not task_prompt: - self.system_prompt = WEAK_MODEL_SYSTEM_PROMPT - self.task_prompt = WEAK_MODEL_TASK_PROMPT - else: - self.system_prompt = system_prompt or "" - self.task_prompt = task_prompt or "" - - self.task_prompt = self.task_prompt.rstrip() - if self.task_prompt and self.task_prompt[0] != "\n": - self.task_prompt = "\n" + self.task_prompt - - @property - def is_json_output(self) -> bool: - return False - - def build_prompt( - self, - query: str, - history_str: str, - context_chunks: list[InferenceChunk], - ) -> str: - context_block = "" - if context_chunks: - context_block = CONTEXT_BLOCK.format( - context_docs_str=context_chunks[0].content - ) - - prompt_str = WEAK_LLM_PROMPT.format( - system_prompt=self.system_prompt, - context_block=context_block, - task_prompt=self.task_prompt, - user_query=query, - ) - return prompt_str - - -class SingleMessageQAHandler(QAHandler): - def __init__( - self, - system_prompt: str | None, - task_prompt: str | None, - use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), - ) -> None: - self.use_language_hint = use_language_hint - if not system_prompt and not task_prompt: - self.system_prompt = ONE_SHOT_SYSTEM_PROMPT - self.task_prompt = ONE_SHOT_TASK_PROMPT - else: - self.system_prompt = system_prompt or "" - self.task_prompt = task_prompt or "" - - self.task_prompt = self.task_prompt.rstrip() - if self.task_prompt and self.task_prompt[0] != "\n": - self.task_prompt = "\n" + self.task_prompt - - @property - def is_json_output(self) -> bool: - return True - - def build_prompt( - self, query: str, history_str: str, context_chunks: list[InferenceChunk] - ) -> str: - context_block = "" - if context_chunks: - context_docs_str = build_complete_context_str( - cast(list[LlmDoc | InferenceChunk], context_chunks) - ) - context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str) - - history_block = "" - if history_str: - history_block = HISTORY_BLOCK.format(history_str=history_str) - - full_prompt = JSON_PROMPT.format( - system_prompt=self.system_prompt, - context_block=context_block, - history_block=history_block, - task_prompt=self.task_prompt, - user_query=query, - language_hint_or_none=LANGUAGE_HINT.strip() - if self.use_language_hint - else "", - ).strip() - return full_prompt - - -# This one isn't used, currently only streaming prompts are used -class SingleMessageScratchpadHandler(QAHandler): - def __init__( - self, - system_prompt: str | None, - task_prompt: str | None, - use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), - ) -> None: - self.use_language_hint = use_language_hint - if not system_prompt and not task_prompt: - self.system_prompt = ONE_SHOT_SYSTEM_PROMPT - self.task_prompt = ONE_SHOT_TASK_PROMPT - else: - self.system_prompt = system_prompt or "" - self.task_prompt = task_prompt or "" - - self.task_prompt = self.task_prompt.rstrip() - if self.task_prompt and self.task_prompt[0] != "\n": - self.task_prompt = "\n" + self.task_prompt - - @property - def is_json_output(self) -> bool: - return True - - def build_prompt( - self, query: str, history_str: str, context_chunks: list[InferenceChunk] - ) -> str: - context_docs_str = build_complete_context_str( - cast(list[LlmDoc | InferenceChunk], context_chunks) - ) - - # Outdated - prompt = COT_PROMPT.format( - context_docs_str=context_docs_str, - user_query=query, - language_hint_or_none=LANGUAGE_HINT.strip() - if self.use_language_hint - else "", - ).strip() - - return prompt - - def process_llm_output( - self, model_output: str, context_chunks: list[InferenceChunk] - ) -> tuple[DanswerAnswer, DanswerQuotes]: - logger.debug(model_output) - - model_clean = clean_up_code_blocks(model_output) - - match = re.search(r'{\s*"answer":', model_clean) - if not match: - return DanswerAnswer(answer=None), DanswerQuotes(quotes=[]) - - final_json = escape_newlines(model_clean[match.start() :]) - - return process_answer( - final_json, context_chunks, is_json_prompt=self.is_json_output - ) - - def process_llm_token_stream( - self, tokens: Iterator[str], context_chunks: list[InferenceChunk] - ) -> AnswerQuestionStreamReturn: - # Can be supported but the parsing is more involved, not handling until needed - raise ValueError( - "This Scratchpad approach is not suitable for real time uses like streaming" - ) - - -def build_dummy_prompt( - system_prompt: str, task_prompt: str, retrieval_disabled: bool -) -> str: - if retrieval_disabled: - return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format( - user_query="", - system_prompt=system_prompt, - task_prompt=task_prompt, - ).strip() - - return PARAMATERIZED_PROMPT.format( - context_docs_str="", - user_query="", - system_prompt=system_prompt, - task_prompt=task_prompt, - ).strip() - - -def no_gen_ai_response() -> Iterator[DanswerAnswerPiece]: - yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG) - - -class QABlock(QAModel): - def __init__(self, llm: LLM, qa_handler: QAHandler) -> None: - self._llm = llm - self._qa_handler = qa_handler - - def build_prompt( - self, - query: str, - history_str: str, - context_chunks: list[InferenceChunk], - ) -> str: - prompt = self._qa_handler.build_prompt( - query=query, history_str=history_str, context_chunks=context_chunks - ) - return prompt - - def answer_question_stream( - self, - prompt: str, - llm_context_docs: list[InferenceChunk], - metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, - ) -> AnswerQuestionStreamReturn: - tokens_stream = self._llm.stream(prompt) - - captured_tokens = [] - - try: - for answer_piece in self._qa_handler.process_llm_token_stream( - iter(tokens_stream), llm_context_docs - ): - if ( - isinstance(answer_piece, DanswerAnswerPiece) - and answer_piece.answer_piece - ): - captured_tokens.append(answer_piece.answer_piece) - yield answer_piece - - except Exception as e: - yield StreamingError(error=str(e)) - - if metrics_callback is not None: - prompt_tokens = check_number_of_tokens( - text=str(prompt), encode_fn=get_default_llm_token_encode() - ) - - response_tokens = check_number_of_tokens( - text="".join(captured_tokens), encode_fn=get_default_llm_token_encode() - ) - - metrics_callback( - LLMMetricsContainer( - prompt_tokens=prompt_tokens, response_tokens=response_tokens - ) - ) diff --git a/backend/danswer/one_shot_answer/qa_utils.py b/backend/danswer/one_shot_answer/qa_utils.py index 032d2434594..e912a915e2e 100644 --- a/backend/danswer/one_shot_answer/qa_utils.py +++ b/backend/danswer/one_shot_answer/qa_utils.py @@ -1,275 +1,14 @@ -import math -import re from collections.abc import Callable from collections.abc import Generator -from collections.abc import Iterator -from json.decoder import JSONDecodeError -from typing import Optional -from typing import Tuple -import regex - -from danswer.chat.models import DanswerAnswer -from danswer.chat.models import DanswerAnswerPiece -from danswer.chat.models import DanswerQuote -from danswer.chat.models import DanswerQuotes -from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT from danswer.configs.constants import MessageType -from danswer.indexing.models import InferenceChunk from danswer.llm.utils import get_default_llm_token_encode from danswer.one_shot_answer.models import ThreadMessage -from danswer.prompts.constants import ANSWER_PAT -from danswer.prompts.constants import QUOTE_PAT -from danswer.prompts.constants import UNCERTAINTY_PAT from danswer.utils.logger import setup_logger -from danswer.utils.text_processing import clean_model_quote -from danswer.utils.text_processing import clean_up_code_blocks -from danswer.utils.text_processing import extract_embedded_json -from danswer.utils.text_processing import shared_precompare_cleanup logger = setup_logger() -def _extract_answer_quotes_freeform( - answer_raw: str, -) -> Tuple[Optional[str], Optional[list[str]]]: - """Splits the model output into an Answer and 0 or more Quote sections. - Splits by the Quote pattern, if not exist then assume it's all answer and no quotes - """ - # If no answer section, don't care about the quote - if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()): - return None, None - - # Sometimes model regenerates the Answer: pattern despite it being provided in the prompt - if answer_raw.lower().startswith(ANSWER_PAT.lower()): - answer_raw = answer_raw[len(ANSWER_PAT) :] - - # Accept quote sections starting with the lower case version - answer_raw = answer_raw.replace( - f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}" - ) # Just in case model unreliable - - sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw) - sections_clean = [ - str(section).strip() for section in sections if str(section).strip() - ] - if not sections_clean: - return None, None - - answer = str(sections_clean[0]) - if len(sections) == 1: - return answer, None - return answer, sections_clean[1:] - - -def _extract_answer_quotes_json( - answer_dict: dict[str, str | list[str]] -) -> Tuple[Optional[str], Optional[list[str]]]: - answer_dict = {k.lower(): v for k, v in answer_dict.items()} - answer = str(answer_dict.get("answer")) - quotes = answer_dict.get("quotes") or answer_dict.get("quote") - if isinstance(quotes, str): - quotes = [quotes] - return answer, quotes - - -def _extract_answer_json(raw_model_output: str) -> dict: - try: - answer_json = extract_embedded_json(raw_model_output) - except (ValueError, JSONDecodeError): - # LLMs get confused when handling the list in the json. Sometimes it doesn't attend - # enough to the previous { token so it just ends the list of quotes and stops there - # here, we add logic to try to fix this LLM error. - answer_json = extract_embedded_json(raw_model_output + "}") - - if "answer" not in answer_json: - raise ValueError("Model did not output an answer as expected.") - - return answer_json - - -def separate_answer_quotes( - answer_raw: str, is_json_prompt: bool = False -) -> Tuple[Optional[str], Optional[list[str]]]: - """Takes in a raw model output and pulls out the answer and the quotes sections.""" - if is_json_prompt: - model_raw_json = _extract_answer_json(answer_raw) - return _extract_answer_quotes_json(model_raw_json) - - return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw)) - - -def match_quotes_to_docs( - quotes: list[str], - chunks: list[InferenceChunk], - max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT, - fuzzy_search: bool = False, - prefix_only_length: int = 100, -) -> DanswerQuotes: - danswer_quotes: list[DanswerQuote] = [] - for quote in quotes: - max_edits = math.ceil(float(len(quote)) * max_error_percent) - - for chunk in chunks: - if not chunk.source_links: - continue - - quote_clean = shared_precompare_cleanup( - clean_model_quote(quote, trim_length=prefix_only_length) - ) - chunk_clean = shared_precompare_cleanup(chunk.content) - - # Finding the offset of the quote in the plain text - if fuzzy_search: - re_search_str = ( - r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}" - ) - found = regex.search(re_search_str, chunk_clean) - if not found: - continue - offset = found.span()[0] - else: - if quote_clean not in chunk_clean: - continue - offset = chunk_clean.index(quote_clean) - - # Extracting the link from the offset - curr_link = None - for link_offset, link in chunk.source_links.items(): - # Should always find one because offset is at least 0 and there - # must be a 0 link_offset - if int(link_offset) <= offset: - curr_link = link - else: - break - - danswer_quotes.append( - DanswerQuote( - quote=quote, - document_id=chunk.document_id, - link=curr_link, - source_type=chunk.source_type, - semantic_identifier=chunk.semantic_identifier, - blurb=chunk.blurb, - ) - ) - break - - return DanswerQuotes(quotes=danswer_quotes) - - -def process_answer( - answer_raw: str, - chunks: list[InferenceChunk], - is_json_prompt: bool = True, -) -> tuple[DanswerAnswer, DanswerQuotes]: - """Used (1) in the non-streaming case to process the model output - into an Answer and Quotes AND (2) after the complete streaming response - has been received to process the model output into an Answer and Quotes.""" - answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt) - if answer == UNCERTAINTY_PAT or not answer: - if answer == UNCERTAINTY_PAT: - logger.debug("Answer matched UNCERTAINTY_PAT") - else: - logger.debug("No answer extracted from raw output") - return DanswerAnswer(answer=None), DanswerQuotes(quotes=[]) - - logger.info(f"Answer: {answer}") - if not quote_strings: - logger.debug("No quotes extracted from raw output") - return DanswerAnswer(answer=answer), DanswerQuotes(quotes=[]) - logger.info(f"All quotes (including unmatched): {quote_strings}") - quotes = match_quotes_to_docs(quote_strings, chunks) - logger.debug(f"Final quotes: {quotes}") - - return DanswerAnswer(answer=answer), quotes - - -def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool: - next_token = next_token.replace('\\"', "") - # If the previous character is an escape token, don't consider the first character of next_token - # This does not work if it's an escaped escape sign before the " but this is rare, not worth handling - if answer_so_far and answer_so_far[-1] == "\\": - next_token = next_token[1:] - if '"' in next_token: - return True - return False - - -def _extract_quotes_from_completed_token_stream( - model_output: str, context_chunks: list[InferenceChunk], is_json_prompt: bool = True -) -> DanswerQuotes: - answer, quotes = process_answer(model_output, context_chunks, is_json_prompt) - if answer: - logger.info(answer) - elif model_output: - logger.warning("Answer extraction from model output failed.") - - return quotes - - -def process_model_tokens( - tokens: Iterator[str], - context_docs: list[InferenceChunk], - is_json_prompt: bool = True, -) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]: - """Used in the streaming case to process the model output - into an Answer and Quotes - - Yields Answer tokens back out in a dict for streaming to frontend - When Answer section ends, yields dict with answer_finished key - Collects all the tokens at the end to form the complete model output""" - quote_pat = f"\n{QUOTE_PAT}" - # Sometimes worse model outputs new line instead of : - quote_loose = f"\n{quote_pat[:-1]}\n" - # Sometime model outputs two newlines before quote section - quote_pat_full = f"\n{quote_pat}" - model_output: str = "" - found_answer_start = False if is_json_prompt else True - found_answer_end = False - hold_quote = "" - for token in tokens: - model_previous = model_output - model_output += token - - if not found_answer_start and '{"answer":"' in re.sub(r"\s", "", model_output): - # Note, if the token that completes the pattern has additional text, for example if the token is "? - # Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the - # event that the model outputs the UNCERTAINTY_PAT - found_answer_start = True - - # Prevent heavy cases of hallucinations where model is not even providing a json until later - if is_json_prompt and len(model_output) > 40: - logger.warning("LLM did not produce json as prompted") - found_answer_end = True - - continue - - if found_answer_start and not found_answer_end: - if is_json_prompt and _stream_json_answer_end(model_previous, token): - found_answer_end = True - yield DanswerAnswerPiece(answer_piece=None) - continue - elif not is_json_prompt: - if quote_pat in hold_quote + token or quote_loose in hold_quote + token: - found_answer_end = True - yield DanswerAnswerPiece(answer_piece=None) - continue - if hold_quote + token in quote_pat_full: - hold_quote += token - continue - yield DanswerAnswerPiece(answer_piece=hold_quote + token) - hold_quote = "" - - logger.debug(f"Raw Model QnA Output: {model_output}") - - yield _extract_quotes_from_completed_token_stream( - model_output=model_output, - context_chunks=context_docs, - is_json_prompt=is_json_prompt, - ) - - def simulate_streaming_response(model_out: str) -> Generator[str, None, None]: """Mock streaming by generating the passed in model output, character by character""" for token in model_out: diff --git a/backend/danswer/search/enums.py b/backend/danswer/search/enums.py new file mode 100644 index 00000000000..9ba44ada2cb --- /dev/null +++ b/backend/danswer/search/enums.py @@ -0,0 +1,30 @@ +"""NOTE: this needs to be separate from models.py because of circular imports. +Both search/models.py and db/models.py import enums from this file AND +search/models.py imports from db/models.py.""" +from enum import Enum + + +class OptionalSearchSetting(str, Enum): + ALWAYS = "always" + NEVER = "never" + # Determine whether to run search based on history and latest query + AUTO = "auto" + + +class RecencyBiasSetting(str, Enum): + FAVOR_RECENT = "favor_recent" # 2x decay rate + BASE_DECAY = "base_decay" + NO_DECAY = "no_decay" + # Determine based on query if to use base_decay or favor_recent + AUTO = "auto" + + +class SearchType(str, Enum): + KEYWORD = "keyword" + SEMANTIC = "semantic" + HYBRID = "hybrid" + + +class QueryFlow(str, Enum): + SEARCH = "search" + QUESTION_ANSWER = "question-answer" diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index db3dc31f83b..d2ad74c34e3 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -1,46 +1,24 @@ from datetime import datetime -from enum import Enum from typing import Any from pydantic import BaseModel from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER +from danswer.configs.chat_configs import HYBRID_ALPHA from danswer.configs.chat_configs import NUM_RERANKED_RESULTS from danswer.configs.chat_configs import NUM_RETURNED_HITS from danswer.configs.constants import DocumentSource from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW +from danswer.db.models import Persona +from danswer.search.enums import OptionalSearchSetting +from danswer.search.enums import SearchType + MAX_METRICS_CONTENT = ( 200 # Just need enough characters to identify where in the doc the chunk is ) -class OptionalSearchSetting(str, Enum): - ALWAYS = "always" - NEVER = "never" - # Determine whether to run search based on history and latest query - AUTO = "auto" - - -class RecencyBiasSetting(str, Enum): - FAVOR_RECENT = "favor_recent" # 2x decay rate - BASE_DECAY = "base_decay" - NO_DECAY = "no_decay" - # Determine based on query if to use base_decay or favor_recent - AUTO = "auto" - - -class SearchType(str, Enum): - KEYWORD = "keyword" - SEMANTIC = "semantic" - HYBRID = "hybrid" - - -class QueryFlow(str, Enum): - SEARCH = "search" - QUESTION_ANSWER = "question-answer" - - class Tag(BaseModel): tag_key: str tag_value: str @@ -64,6 +42,28 @@ class ChunkMetric(BaseModel): score: float +class SearchRequest(BaseModel): + """Input to the SearchPipeline.""" + + query: str + search_type: SearchType = SearchType.HYBRID + + human_selected_filters: BaseFilters | None = None + enable_auto_detect_filters: bool | None = None + persona: Persona | None = None + + # if None, no offset / limit + offset: int | None = None + limit: int | None = None + + recency_bias_multiplier: float = 1.0 + hybrid_alpha: float = HYBRID_ALPHA + skip_rerank: bool = True + + class Config: + arbitrary_types_allowed = True + + class SearchQuery(BaseModel): query: str filters: IndexFilters diff --git a/backend/danswer/search/pipeline.py b/backend/danswer/search/pipeline.py new file mode 100644 index 00000000000..5c590939b54 --- /dev/null +++ b/backend/danswer/search/pipeline.py @@ -0,0 +1,160 @@ +from collections.abc import Callable +from collections.abc import Generator +from typing import cast + +from sqlalchemy.orm import Session + +from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION +from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.db.models import User +from danswer.document_index.factory import get_default_document_index +from danswer.indexing.models import InferenceChunk +from danswer.search.enums import QueryFlow +from danswer.search.enums import SearchType +from danswer.search.models import RerankMetricsContainer +from danswer.search.models import RetrievalMetricsContainer +from danswer.search.models import SearchQuery +from danswer.search.models import SearchRequest +from danswer.search.postprocessing.postprocessing import search_postprocessing +from danswer.search.preprocessing.preprocessing import retrieval_preprocessing +from danswer.search.retrieval.search_runner import retrieve_chunks + + +class SearchPipeline: + def __init__( + self, + search_request: SearchRequest, + user: User | None, + db_session: Session, + bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION + retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] + | None = None, + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, + ): + self.search_request = search_request + self.user = user + self.db_session = db_session + self.bypass_acl = bypass_acl + self.retrieval_metrics_callback = retrieval_metrics_callback + self.rerank_metrics_callback = rerank_metrics_callback + + self.embedding_model = get_current_db_embedding_model(db_session) + self.document_index = get_default_document_index( + primary_index_name=self.embedding_model.index_name, + secondary_index_name=None, + ) + + self._search_query: SearchQuery | None = None + self._predicted_search_type: SearchType | None = None + self._predicted_flow: QueryFlow | None = None + + self._retrieved_docs: list[InferenceChunk] | None = None + self._reranked_docs: list[InferenceChunk] | None = None + self._relevant_chunk_indicies: list[int] | None = None + + # generator state + self._postprocessing_generator: Generator[ + list[InferenceChunk] | list[str], None, None + ] | None = None + + """Pre-processing""" + + def _run_preprocessing(self) -> None: + ( + final_search_query, + predicted_search_type, + predicted_flow, + ) = retrieval_preprocessing( + search_request=self.search_request, + user=self.user, + db_session=self.db_session, + bypass_acl=self.bypass_acl, + ) + self._predicted_search_type = predicted_search_type + self._predicted_flow = predicted_flow + self._search_query = final_search_query + + @property + def search_query(self) -> SearchQuery: + if self._search_query is not None: + return self._search_query + + self._run_preprocessing() + return cast(SearchQuery, self._search_query) + + @property + def predicted_search_type(self) -> SearchType: + if self._predicted_search_type is not None: + return self._predicted_search_type + + self._run_preprocessing() + return cast(SearchType, self._predicted_search_type) + + @property + def predicted_flow(self) -> QueryFlow: + if self._predicted_flow is not None: + return self._predicted_flow + + self._run_preprocessing() + return cast(QueryFlow, self._predicted_flow) + + """Retrieval""" + + @property + def retrieved_docs(self) -> list[InferenceChunk]: + if self._retrieved_docs is not None: + return self._retrieved_docs + + self._retrieved_docs = retrieve_chunks( + query=self.search_query, + document_index=self.document_index, + db_session=self.db_session, + hybrid_alpha=self.search_request.hybrid_alpha, + multilingual_expansion_str=MULTILINGUAL_QUERY_EXPANSION, + retrieval_metrics_callback=self.retrieval_metrics_callback, + ) + + # self._retrieved_docs = chunks_to_search_docs(retrieved_chunks) + return cast(list[InferenceChunk], self._retrieved_docs) + + """Post-Processing""" + + @property + def reranked_docs(self) -> list[InferenceChunk]: + if self._reranked_docs is not None: + return self._reranked_docs + + self._postprocessing_generator = search_postprocessing( + search_query=self.search_query, + retrieved_chunks=self.retrieved_docs, + rerank_metrics_callback=self.rerank_metrics_callback, + ) + self._reranked_docs = cast( + list[InferenceChunk], next(self._postprocessing_generator) + ) + return self._reranked_docs + + @property + def relevant_chunk_indicies(self) -> list[int]: + if self._relevant_chunk_indicies is not None: + return self._relevant_chunk_indicies + + # run first step of postprocessing generator if not already done + reranked_docs = self.reranked_docs + + relevant_chunk_ids = next( + cast(Generator[list[str], None, None], self._postprocessing_generator) + ) + self._relevant_chunk_indicies = [ + ind + for ind, chunk in enumerate(reranked_docs) + if chunk.unique_id in relevant_chunk_ids + ] + return self._relevant_chunk_indicies + + @property + def chunk_relevance_list(self) -> list[bool]: + return [ + True if ind in self.relevant_chunk_indicies else False + for ind in range(len(self.reranked_docs)) + ] diff --git a/backend/danswer/search/postprocessing/postprocessing.py b/backend/danswer/search/postprocessing/postprocessing.py new file mode 100644 index 00000000000..e1cee4bd6d5 --- /dev/null +++ b/backend/danswer/search/postprocessing/postprocessing.py @@ -0,0 +1,222 @@ +from collections.abc import Callable +from collections.abc import Generator +from typing import cast + +import numpy + +from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX +from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN +from danswer.document_index.document_index_utils import ( + translate_boost_count_to_multiplier, +) +from danswer.indexing.models import InferenceChunk +from danswer.search.models import ChunkMetric +from danswer.search.models import MAX_METRICS_CONTENT +from danswer.search.models import RerankMetricsContainer +from danswer.search.models import SearchQuery +from danswer.search.models import SearchType +from danswer.search.search_nlp_models import CrossEncoderEnsembleModel +from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks +from danswer.utils.logger import setup_logger +from danswer.utils.threadpool_concurrency import FunctionCall +from danswer.utils.threadpool_concurrency import run_functions_in_parallel +from danswer.utils.timing import log_function_time + + +logger = setup_logger() + + +def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None: + top_links = [ + c.source_links[0] if c.source_links is not None else "No Link" for c in chunks + ] + logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}") + + +def should_rerank(query: SearchQuery) -> bool: + # Don't re-rank for keyword search + return query.search_type != SearchType.KEYWORD and not query.skip_rerank + + +def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool: + return not query.skip_llm_chunk_filter + + +@log_function_time(print_only=True) +def semantic_reranking( + query: str, + chunks: list[InferenceChunk], + model_min: int = CROSS_ENCODER_RANGE_MIN, + model_max: int = CROSS_ENCODER_RANGE_MAX, + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, +) -> tuple[list[InferenceChunk], list[int]]: + """Reranks chunks based on cross-encoder models. Additionally provides the original indices + of the chunks in their new sorted order. + + Note: this updates the chunks in place, it updates the chunk scores which came from retrieval + """ + cross_encoders = CrossEncoderEnsembleModel() + passages = [chunk.content for chunk in chunks] + sim_scores_floats = cross_encoders.predict(query=query, passages=passages) + + sim_scores = [numpy.array(scores) for scores in sim_scores_floats] + + raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores)) + + cross_models_min = numpy.min(sim_scores) + + shifted_sim_scores = sum( + [enc_n_scores - cross_models_min for enc_n_scores in sim_scores] + ) / len(sim_scores) + + boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks] + recency_multiplier = [chunk.recency_bias for chunk in chunks] + boosted_sim_scores = shifted_sim_scores * boosts * recency_multiplier + normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / ( + model_max - model_min + ) + orig_indices = [i for i in range(len(normalized_b_s_scores))] + scored_results = list( + zip(normalized_b_s_scores, raw_sim_scores, chunks, orig_indices) + ) + scored_results.sort(key=lambda x: x[0], reverse=True) + ranked_sim_scores, ranked_raw_scores, ranked_chunks, ranked_indices = zip( + *scored_results + ) + + logger.debug( + f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}" + ) + + # Assign new chunk scores based on reranking + for ind, chunk in enumerate(ranked_chunks): + chunk.score = ranked_sim_scores[ind] + + if rerank_metrics_callback is not None: + chunk_metrics = [ + ChunkMetric( + document_id=chunk.document_id, + chunk_content_start=chunk.content[:MAX_METRICS_CONTENT], + first_link=chunk.source_links[0] if chunk.source_links else None, + score=chunk.score if chunk.score is not None else 0, + ) + for chunk in ranked_chunks + ] + + rerank_metrics_callback( + RerankMetricsContainer( + metrics=chunk_metrics, raw_similarity_scores=ranked_raw_scores # type: ignore + ) + ) + + return list(ranked_chunks), list(ranked_indices) + + +def rerank_chunks( + query: SearchQuery, + chunks_to_rerank: list[InferenceChunk], + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, +) -> list[InferenceChunk]: + ranked_chunks, _ = semantic_reranking( + query=query.query, + chunks=chunks_to_rerank[: query.num_rerank], + rerank_metrics_callback=rerank_metrics_callback, + ) + lower_chunks = chunks_to_rerank[query.num_rerank :] + # Scores from rerank cannot be meaningfully combined with scores without rerank + for lower_chunk in lower_chunks: + lower_chunk.score = None + ranked_chunks.extend(lower_chunks) + return ranked_chunks + + +@log_function_time(print_only=True) +def filter_chunks( + query: SearchQuery, + chunks_to_filter: list[InferenceChunk], +) -> list[str]: + """Filters chunks based on whether the LLM thought they were relevant to the query. + + Returns a list of the unique chunk IDs that were marked as relevant""" + chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks] + llm_chunk_selection = llm_batch_eval_chunks( + query=query.query, + chunk_contents=[chunk.content for chunk in chunks_to_filter], + ) + return [ + chunk.unique_id + for ind, chunk in enumerate(chunks_to_filter) + if llm_chunk_selection[ind] + ] + + +def search_postprocessing( + search_query: SearchQuery, + retrieved_chunks: list[InferenceChunk], + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, +) -> Generator[list[InferenceChunk] | list[str], None, None]: + post_processing_tasks: list[FunctionCall] = [] + + rerank_task_id = None + if should_rerank(search_query): + post_processing_tasks.append( + FunctionCall( + rerank_chunks, + ( + search_query, + retrieved_chunks, + rerank_metrics_callback, + ), + ) + ) + rerank_task_id = post_processing_tasks[-1].result_id + else: + final_chunks = retrieved_chunks + # NOTE: if we don't rerank, we can return the chunks immediately + # since we know this is the final order + _log_top_chunk_links(search_query.search_type.value, final_chunks) + yield final_chunks + chunks_yielded = True + + llm_filter_task_id = None + if should_apply_llm_based_relevance_filter(search_query): + post_processing_tasks.append( + FunctionCall( + filter_chunks, + (search_query, retrieved_chunks[: search_query.max_llm_filter_chunks]), + ) + ) + llm_filter_task_id = post_processing_tasks[-1].result_id + + post_processing_results = ( + run_functions_in_parallel(post_processing_tasks) + if post_processing_tasks + else {} + ) + reranked_chunks = cast( + list[InferenceChunk] | None, + post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None, + ) + if reranked_chunks: + if chunks_yielded: + logger.error( + "Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen." + ) + else: + _log_top_chunk_links(search_query.search_type.value, reranked_chunks) + yield reranked_chunks + + llm_chunk_selection = cast( + list[str] | None, + post_processing_results.get(str(llm_filter_task_id)) + if llm_filter_task_id + else None, + ) + if llm_chunk_selection is not None: + yield [ + chunk.unique_id + for chunk in reranked_chunks or retrieved_chunks + if chunk.unique_id in llm_chunk_selection + ] + else: + yield [] diff --git a/backend/danswer/search/access_filters.py b/backend/danswer/search/preprocessing/access_filters.py similarity index 100% rename from backend/danswer/search/access_filters.py rename to backend/danswer/search/preprocessing/access_filters.py diff --git a/backend/danswer/search/danswer_helper.py b/backend/danswer/search/preprocessing/danswer_helper.py similarity index 96% rename from backend/danswer/search/danswer_helper.py rename to backend/danswer/search/preprocessing/danswer_helper.py index d5dbeb8a3e8..88e465dacb5 100644 --- a/backend/danswer/search/danswer_helper.py +++ b/backend/danswer/search/preprocessing/danswer_helper.py @@ -1,10 +1,10 @@ from typing import TYPE_CHECKING -from danswer.search.models import QueryFlow +from danswer.search.enums import QueryFlow from danswer.search.models import SearchType +from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation from danswer.search.search_nlp_models import get_default_tokenizer from danswer.search.search_nlp_models import IntentModel -from danswer.search.search_runner import remove_stop_words_and_punctuation from danswer.server.query_and_chat.models import HelperResponse from danswer.utils.logger import setup_logger diff --git a/backend/danswer/search/request_preprocessing.py b/backend/danswer/search/preprocessing/preprocessing.py similarity index 76% rename from backend/danswer/search/request_preprocessing.py rename to backend/danswer/search/preprocessing/preprocessing.py index e74618d3950..f35afe43895 100644 --- a/backend/danswer/search/request_preprocessing.py +++ b/backend/danswer/search/preprocessing/preprocessing.py @@ -5,19 +5,16 @@ from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER from danswer.configs.chat_configs import NUM_RETURNED_HITS -from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW -from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW -from danswer.db.models import Persona from danswer.db.models import User -from danswer.search.access_filters import build_access_filters_for_user -from danswer.search.danswer_helper import query_intent +from danswer.search.enums import QueryFlow +from danswer.search.enums import RecencyBiasSetting from danswer.search.models import BaseFilters from danswer.search.models import IndexFilters -from danswer.search.models import QueryFlow -from danswer.search.models import RecencyBiasSetting -from danswer.search.models import RetrievalDetails from danswer.search.models import SearchQuery +from danswer.search.models import SearchRequest from danswer.search.models import SearchType +from danswer.search.preprocessing.access_filters import build_access_filters_for_user +from danswer.search.preprocessing.danswer_helper import query_intent from danswer.secondary_llm_flows.source_filter import extract_source_filter from danswer.secondary_llm_flows.time_filter import extract_time_filter from danswer.utils.logger import setup_logger @@ -31,15 +28,12 @@ @log_function_time(print_only=True) def retrieval_preprocessing( - query: str, - retrieval_details: RetrievalDetails, - persona: Persona, + search_request: SearchRequest, user: User | None, db_session: Session, bypass_acl: bool = False, include_query_intent: bool = True, - skip_rerank_realtime: bool = not ENABLE_RERANKING_REAL_TIME_FLOW, - skip_rerank_non_realtime: bool = not ENABLE_RERANKING_ASYNC_FLOW, + enable_auto_detect_filters: bool = False, disable_llm_filter_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION, disable_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER, base_recency_decay: float = BASE_RECENCY_DECAY, @@ -50,8 +44,12 @@ def retrieval_preprocessing( Then any filters or settings as part of the query are used Then defaults to Persona settings if not specified by the query """ + query = search_request.query + limit = search_request.limit + offset = search_request.offset + persona = search_request.persona - preset_filters = retrieval_details.filters or BaseFilters() + preset_filters = search_request.human_selected_filters or BaseFilters() if persona and persona.document_sets and preset_filters.document_set is None: preset_filters.document_set = [ document_set.name for document_set in persona.document_sets @@ -65,16 +63,20 @@ def retrieval_preprocessing( if disable_llm_filter_extraction: auto_detect_time_filter = False auto_detect_source_filter = False - elif retrieval_details.enable_auto_detect_filters is False: + elif enable_auto_detect_filters is False: logger.debug("Retrieval details disables auto detect filters") auto_detect_time_filter = False auto_detect_source_filter = False - elif persona.llm_filter_extraction is False: + elif persona and persona.llm_filter_extraction is False: logger.debug("Persona disables auto detect filters") auto_detect_time_filter = False auto_detect_source_filter = False - if time_filter is not None and persona.recency_bias != RecencyBiasSetting.AUTO: + if ( + time_filter is not None + and persona + and persona.recency_bias != RecencyBiasSetting.AUTO + ): auto_detect_time_filter = False logger.debug("Not extract time filter - already provided") if source_filter is not None: @@ -138,24 +140,18 @@ def retrieval_preprocessing( access_control_list=user_acl_filters, ) - # Tranformer-based re-ranking to run at same time as LLM chunk relevance filter - # This one is only set globally, not via query or Persona settings - skip_reranking = ( - skip_rerank_realtime - if retrieval_details.real_time - else skip_rerank_non_realtime - ) - - llm_chunk_filter = persona.llm_relevance_filter + llm_chunk_filter = False + if persona: + llm_chunk_filter = persona.llm_relevance_filter if disable_llm_chunk_filter: llm_chunk_filter = False # Decays at 1 / (1 + (multiplier * num years)) - if persona.recency_bias == RecencyBiasSetting.NO_DECAY: + if persona and persona.recency_bias == RecencyBiasSetting.NO_DECAY: recency_bias_multiplier = 0.0 - elif persona.recency_bias == RecencyBiasSetting.BASE_DECAY: + elif persona and persona.recency_bias == RecencyBiasSetting.BASE_DECAY: recency_bias_multiplier = base_recency_decay - elif persona.recency_bias == RecencyBiasSetting.FAVOR_RECENT: + elif persona and persona.recency_bias == RecencyBiasSetting.FAVOR_RECENT: recency_bias_multiplier = base_recency_decay * favor_recent_decay_multiplier else: if predicted_favor_recent: @@ -166,14 +162,12 @@ def retrieval_preprocessing( return ( SearchQuery( query=query, - search_type=persona.search_type, + search_type=persona.search_type if persona else SearchType.HYBRID, filters=final_filters, recency_bias_multiplier=recency_bias_multiplier, - num_hits=retrieval_details.limit - if retrieval_details.limit is not None - else NUM_RETURNED_HITS, - offset=retrieval_details.offset or 0, - skip_rerank=skip_reranking, + num_hits=limit if limit is not None else NUM_RETURNED_HITS, + offset=offset or 0, + skip_rerank=search_request.skip_rerank, skip_llm_chunk_filter=not llm_chunk_filter, ), predicted_search_type, diff --git a/backend/danswer/search/retrieval/search_runner.py b/backend/danswer/search/retrieval/search_runner.py new file mode 100644 index 00000000000..41aa3a3c7e4 --- /dev/null +++ b/backend/danswer/search/retrieval/search_runner.py @@ -0,0 +1,258 @@ +import string +from collections.abc import Callable + +from nltk.corpus import stopwords # type:ignore +from nltk.stem import WordNetLemmatizer # type:ignore +from nltk.tokenize import word_tokenize # type:ignore +from sqlalchemy.orm import Session + +from danswer.chat.models import LlmDoc +from danswer.configs.app_configs import MODEL_SERVER_HOST +from danswer.configs.app_configs import MODEL_SERVER_PORT +from danswer.configs.chat_configs import HYBRID_ALPHA +from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION +from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.document_index.interfaces import DocumentIndex +from danswer.indexing.models import InferenceChunk +from danswer.search.models import ChunkMetric +from danswer.search.models import IndexFilters +from danswer.search.models import MAX_METRICS_CONTENT +from danswer.search.models import RetrievalMetricsContainer +from danswer.search.models import SearchQuery +from danswer.search.models import SearchType +from danswer.search.search_nlp_models import EmbeddingModel +from danswer.search.search_nlp_models import EmbedTextType +from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion +from danswer.utils.logger import setup_logger +from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel +from danswer.utils.timing import log_function_time + + +logger = setup_logger() + + +def lemmatize_text(text: str) -> list[str]: + lemmatizer = WordNetLemmatizer() + word_tokens = word_tokenize(text) + return [lemmatizer.lemmatize(word) for word in word_tokens] + + +def remove_stop_words_and_punctuation(text: str) -> list[str]: + stop_words = set(stopwords.words("english")) + word_tokens = word_tokenize(text) + text_trimmed = [ + word + for word in word_tokens + if (word.casefold() not in stop_words and word not in string.punctuation) + ] + return text_trimmed or word_tokens + + +def query_processing( + query: str, +) -> str: + query = " ".join(remove_stop_words_and_punctuation(query)) + query = " ".join(lemmatize_text(query)) + return query + + +def combine_retrieval_results( + chunk_sets: list[list[InferenceChunk]], +) -> list[InferenceChunk]: + all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set] + + unique_chunks: dict[tuple[str, int], InferenceChunk] = {} + for chunk in all_chunks: + key = (chunk.document_id, chunk.chunk_id) + if key not in unique_chunks: + unique_chunks[key] = chunk + continue + + stored_chunk_score = unique_chunks[key].score or 0 + this_chunk_score = chunk.score or 0 + if stored_chunk_score < this_chunk_score: + unique_chunks[key] = chunk + + sorted_chunks = sorted( + unique_chunks.values(), key=lambda x: x.score or 0, reverse=True + ) + + return sorted_chunks + + +@log_function_time(print_only=True) +def doc_index_retrieval( + query: SearchQuery, + document_index: DocumentIndex, + db_session: Session, + hybrid_alpha: float = HYBRID_ALPHA, +) -> list[InferenceChunk]: + if query.search_type == SearchType.KEYWORD: + top_chunks = document_index.keyword_retrieval( + query=query.query, + filters=query.filters, + time_decay_multiplier=query.recency_bias_multiplier, + num_to_retrieve=query.num_hits, + ) + else: + db_embedding_model = get_current_db_embedding_model(db_session) + + model = EmbeddingModel( + model_name=db_embedding_model.model_name, + query_prefix=db_embedding_model.query_prefix, + passage_prefix=db_embedding_model.passage_prefix, + normalize=db_embedding_model.normalize, + # The below are globally set, this flow always uses the indexing one + server_host=MODEL_SERVER_HOST, + server_port=MODEL_SERVER_PORT, + ) + + query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0] + + if query.search_type == SearchType.SEMANTIC: + top_chunks = document_index.semantic_retrieval( + query=query.query, + query_embedding=query_embedding, + filters=query.filters, + time_decay_multiplier=query.recency_bias_multiplier, + num_to_retrieve=query.num_hits, + ) + + elif query.search_type == SearchType.HYBRID: + top_chunks = document_index.hybrid_retrieval( + query=query.query, + query_embedding=query_embedding, + filters=query.filters, + time_decay_multiplier=query.recency_bias_multiplier, + num_to_retrieve=query.num_hits, + offset=query.offset, + hybrid_alpha=hybrid_alpha, + ) + + else: + raise RuntimeError("Invalid Search Flow") + + return top_chunks + + +def _simplify_text(text: str) -> str: + return "".join( + char for char in text if char not in string.punctuation and not char.isspace() + ).lower() + + +def retrieve_chunks( + query: SearchQuery, + document_index: DocumentIndex, + db_session: Session, + hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search + multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION, + retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] + | None = None, +) -> list[InferenceChunk]: + """Returns a list of the best chunks from an initial keyword/semantic/ hybrid search.""" + # Don't do query expansion on complex queries, rephrasings likely would not work well + if not multilingual_expansion_str or "\n" in query.query or "\r" in query.query: + top_chunks = doc_index_retrieval( + query=query, + document_index=document_index, + db_session=db_session, + hybrid_alpha=hybrid_alpha, + ) + else: + simplified_queries = set() + run_queries: list[tuple[Callable, tuple]] = [] + + # Currently only uses query expansion on multilingual use cases + query_rephrases = multilingual_query_expansion( + query.query, multilingual_expansion_str + ) + # Just to be extra sure, add the original query. + query_rephrases.append(query.query) + for rephrase in set(query_rephrases): + # Sometimes the model rephrases the query in the same language with minor changes + # Avoid doing an extra search with the minor changes as this biases the results + simplified_rephrase = _simplify_text(rephrase) + if simplified_rephrase in simplified_queries: + continue + simplified_queries.add(simplified_rephrase) + + q_copy = query.copy(update={"query": rephrase}, deep=True) + run_queries.append( + ( + doc_index_retrieval, + (q_copy, document_index, db_session, hybrid_alpha), + ) + ) + parallel_search_results = run_functions_tuples_in_parallel(run_queries) + top_chunks = combine_retrieval_results(parallel_search_results) + + if not top_chunks: + logger.info( + f"{query.search_type.value.capitalize()} search returned no results " + f"with filters: {query.filters}" + ) + return [] + + if retrieval_metrics_callback is not None: + chunk_metrics = [ + ChunkMetric( + document_id=chunk.document_id, + chunk_content_start=chunk.content[:MAX_METRICS_CONTENT], + first_link=chunk.source_links[0] if chunk.source_links else None, + score=chunk.score if chunk.score is not None else 0, + ) + for chunk in top_chunks + ] + retrieval_metrics_callback( + RetrievalMetricsContainer( + search_type=query.search_type, metrics=chunk_metrics + ) + ) + + return top_chunks + + +def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc: + if not inf_chunks: + raise ValueError("Cannot combine empty list of chunks") + + # Use the first link of the document + first_chunk = inf_chunks[0] + chunk_texts = [chunk.content for chunk in inf_chunks] + return LlmDoc( + document_id=first_chunk.document_id, + content="\n".join(chunk_texts), + blurb=first_chunk.blurb, + semantic_identifier=first_chunk.semantic_identifier, + source_type=first_chunk.source_type, + metadata=first_chunk.metadata, + updated_at=first_chunk.updated_at, + link=first_chunk.source_links[0] if first_chunk.source_links else None, + source_links=first_chunk.source_links, + ) + + +def inference_documents_from_ids( + doc_identifiers: list[tuple[str, int]], + document_index: DocumentIndex, +) -> list[LlmDoc]: + # Currently only fetches whole docs + doc_ids_set = set(doc_id for doc_id, chunk_id in doc_identifiers) + + # No need for ACL here because the doc ids were validated beforehand + filters = IndexFilters(access_control_list=None) + + functions_with_args: list[tuple[Callable, tuple]] = [ + (document_index.id_based_retrieval, (doc_id, None, filters)) + for doc_id in doc_ids_set + ] + + parallel_results = run_functions_tuples_in_parallel( + functions_with_args, allow_failures=True + ) + + # Any failures to retrieve would give a None, drop the Nones and empty lists + inference_chunks_sets = [res for res in parallel_results if res] + + return [combine_inference_chunks(chunk_set) for chunk_set in inference_chunks_sets] diff --git a/backend/danswer/search/search_runner.py b/backend/danswer/search/search_runner.py deleted file mode 100644 index 18bfa1a3c13..00000000000 --- a/backend/danswer/search/search_runner.py +++ /dev/null @@ -1,645 +0,0 @@ -import string -from collections.abc import Callable -from collections.abc import Iterator -from typing import cast - -import numpy -from nltk.corpus import stopwords # type:ignore -from nltk.stem import WordNetLemmatizer # type:ignore -from nltk.tokenize import word_tokenize # type:ignore -from sqlalchemy.orm import Session - -from danswer.chat.models import LlmDoc -from danswer.configs.app_configs import MODEL_SERVER_HOST -from danswer.configs.app_configs import MODEL_SERVER_PORT -from danswer.configs.chat_configs import HYBRID_ALPHA -from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION -from danswer.configs.chat_configs import NUM_RERANKED_RESULTS -from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX -from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN -from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH -from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW -from danswer.db.embedding_model import get_current_db_embedding_model -from danswer.document_index.document_index_utils import ( - translate_boost_count_to_multiplier, -) -from danswer.document_index.interfaces import DocumentIndex -from danswer.indexing.models import InferenceChunk -from danswer.search.models import ChunkMetric -from danswer.search.models import IndexFilters -from danswer.search.models import MAX_METRICS_CONTENT -from danswer.search.models import RerankMetricsContainer -from danswer.search.models import RetrievalMetricsContainer -from danswer.search.models import SearchDoc -from danswer.search.models import SearchQuery -from danswer.search.models import SearchType -from danswer.search.search_nlp_models import CrossEncoderEnsembleModel -from danswer.search.search_nlp_models import EmbeddingModel -from danswer.search.search_nlp_models import EmbedTextType -from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks -from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion -from danswer.utils.logger import setup_logger -from danswer.utils.threadpool_concurrency import FunctionCall -from danswer.utils.threadpool_concurrency import run_functions_in_parallel -from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel -from danswer.utils.timing import log_function_time - - -logger = setup_logger() - - -def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None: - top_links = [ - c.source_links[0] if c.source_links is not None else "No Link" for c in chunks - ] - logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}") - - -def lemmatize_text(text: str) -> list[str]: - lemmatizer = WordNetLemmatizer() - word_tokens = word_tokenize(text) - return [lemmatizer.lemmatize(word) for word in word_tokens] - - -def remove_stop_words_and_punctuation(text: str) -> list[str]: - stop_words = set(stopwords.words("english")) - word_tokens = word_tokenize(text) - text_trimmed = [ - word - for word in word_tokens - if (word.casefold() not in stop_words and word not in string.punctuation) - ] - return text_trimmed or word_tokens - - -def query_processing( - query: str, -) -> str: - query = " ".join(remove_stop_words_and_punctuation(query)) - query = " ".join(lemmatize_text(query)) - return query - - -def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]: - search_docs = ( - [ - SearchDoc( - document_id=chunk.document_id, - chunk_ind=chunk.chunk_id, - semantic_identifier=chunk.semantic_identifier or "Unknown", - link=chunk.source_links.get(0) if chunk.source_links else None, - blurb=chunk.blurb, - source_type=chunk.source_type, - boost=chunk.boost, - hidden=chunk.hidden, - metadata=chunk.metadata, - score=chunk.score, - match_highlights=chunk.match_highlights, - updated_at=chunk.updated_at, - primary_owners=chunk.primary_owners, - secondary_owners=chunk.secondary_owners, - ) - for chunk in chunks - ] - if chunks - else [] - ) - return search_docs - - -def combine_retrieval_results( - chunk_sets: list[list[InferenceChunk]], -) -> list[InferenceChunk]: - all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set] - - unique_chunks: dict[tuple[str, int], InferenceChunk] = {} - for chunk in all_chunks: - key = (chunk.document_id, chunk.chunk_id) - if key not in unique_chunks: - unique_chunks[key] = chunk - continue - - stored_chunk_score = unique_chunks[key].score or 0 - this_chunk_score = chunk.score or 0 - if stored_chunk_score < this_chunk_score: - unique_chunks[key] = chunk - - sorted_chunks = sorted( - unique_chunks.values(), key=lambda x: x.score or 0, reverse=True - ) - - return sorted_chunks - - -@log_function_time(print_only=True) -def doc_index_retrieval( - query: SearchQuery, - document_index: DocumentIndex, - db_session: Session, - hybrid_alpha: float = HYBRID_ALPHA, -) -> list[InferenceChunk]: - if query.search_type == SearchType.KEYWORD: - top_chunks = document_index.keyword_retrieval( - query=query.query, - filters=query.filters, - time_decay_multiplier=query.recency_bias_multiplier, - num_to_retrieve=query.num_hits, - ) - else: - db_embedding_model = get_current_db_embedding_model(db_session) - - model = EmbeddingModel( - model_name=db_embedding_model.model_name, - query_prefix=db_embedding_model.query_prefix, - passage_prefix=db_embedding_model.passage_prefix, - normalize=db_embedding_model.normalize, - # The below are globally set, this flow always uses the indexing one - server_host=MODEL_SERVER_HOST, - server_port=MODEL_SERVER_PORT, - ) - - query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0] - - if query.search_type == SearchType.SEMANTIC: - top_chunks = document_index.semantic_retrieval( - query=query.query, - query_embedding=query_embedding, - filters=query.filters, - time_decay_multiplier=query.recency_bias_multiplier, - num_to_retrieve=query.num_hits, - ) - - elif query.search_type == SearchType.HYBRID: - top_chunks = document_index.hybrid_retrieval( - query=query.query, - query_embedding=query_embedding, - filters=query.filters, - time_decay_multiplier=query.recency_bias_multiplier, - num_to_retrieve=query.num_hits, - offset=query.offset, - hybrid_alpha=hybrid_alpha, - ) - - else: - raise RuntimeError("Invalid Search Flow") - - return top_chunks - - -@log_function_time(print_only=True) -def semantic_reranking( - query: str, - chunks: list[InferenceChunk], - rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, - model_min: int = CROSS_ENCODER_RANGE_MIN, - model_max: int = CROSS_ENCODER_RANGE_MAX, -) -> tuple[list[InferenceChunk], list[int]]: - """Reranks chunks based on cross-encoder models. Additionally provides the original indices - of the chunks in their new sorted order. - - Note: this updates the chunks in place, it updates the chunk scores which came from retrieval - """ - cross_encoders = CrossEncoderEnsembleModel() - passages = [chunk.content for chunk in chunks] - sim_scores_floats = cross_encoders.predict(query=query, passages=passages) - - sim_scores = [numpy.array(scores) for scores in sim_scores_floats] - - raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores)) - - cross_models_min = numpy.min(sim_scores) - - shifted_sim_scores = sum( - [enc_n_scores - cross_models_min for enc_n_scores in sim_scores] - ) / len(sim_scores) - - boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks] - recency_multiplier = [chunk.recency_bias for chunk in chunks] - boosted_sim_scores = shifted_sim_scores * boosts * recency_multiplier - normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / ( - model_max - model_min - ) - orig_indices = [i for i in range(len(normalized_b_s_scores))] - scored_results = list( - zip(normalized_b_s_scores, raw_sim_scores, chunks, orig_indices) - ) - scored_results.sort(key=lambda x: x[0], reverse=True) - ranked_sim_scores, ranked_raw_scores, ranked_chunks, ranked_indices = zip( - *scored_results - ) - - logger.debug( - f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}" - ) - - # Assign new chunk scores based on reranking - for ind, chunk in enumerate(ranked_chunks): - chunk.score = ranked_sim_scores[ind] - - if rerank_metrics_callback is not None: - chunk_metrics = [ - ChunkMetric( - document_id=chunk.document_id, - chunk_content_start=chunk.content[:MAX_METRICS_CONTENT], - first_link=chunk.source_links[0] if chunk.source_links else None, - score=chunk.score if chunk.score is not None else 0, - ) - for chunk in ranked_chunks - ] - - rerank_metrics_callback( - RerankMetricsContainer( - metrics=chunk_metrics, raw_similarity_scores=ranked_raw_scores # type: ignore - ) - ) - - return list(ranked_chunks), list(ranked_indices) - - -def apply_boost_legacy( - chunks: list[InferenceChunk], - norm_min: float = SIM_SCORE_RANGE_LOW, - norm_max: float = SIM_SCORE_RANGE_HIGH, -) -> list[InferenceChunk]: - scores = [chunk.score or 0 for chunk in chunks] - boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks] - - logger.debug(f"Raw similarity scores: {scores}") - - score_min = min(scores) - score_max = max(scores) - score_range = score_max - score_min - - if score_range != 0: - boosted_scores = [ - ((score - score_min) / score_range) * boost - for score, boost in zip(scores, boosts) - ] - unnormed_boosted_scores = [ - score * score_range + score_min for score in boosted_scores - ] - else: - unnormed_boosted_scores = [ - score * boost for score, boost in zip(scores, boosts) - ] - - norm_min = min(norm_min, min(scores)) - norm_max = max(norm_max, max(scores)) - # This should never be 0 unless user has done some weird/wrong settings - norm_range = norm_max - norm_min - - # For score display purposes - if norm_range != 0: - re_normed_scores = [ - ((score - norm_min) / norm_range) for score in unnormed_boosted_scores - ] - else: - re_normed_scores = unnormed_boosted_scores - - rescored_chunks = list(zip(re_normed_scores, chunks)) - rescored_chunks.sort(key=lambda x: x[0], reverse=True) - sorted_boosted_scores, boost_sorted_chunks = zip(*rescored_chunks) - - final_chunks = list(boost_sorted_chunks) - final_scores = list(sorted_boosted_scores) - for ind, chunk in enumerate(final_chunks): - chunk.score = final_scores[ind] - - logger.debug(f"Boost sorted similary scores: {list(final_scores)}") - - return final_chunks - - -def apply_boost( - chunks: list[InferenceChunk], - # Need the range of values to not be too spread out for applying boost - # therefore norm across only the top few results - norm_cutoff: int = NUM_RERANKED_RESULTS, - norm_min: float = SIM_SCORE_RANGE_LOW, - norm_max: float = SIM_SCORE_RANGE_HIGH, -) -> list[InferenceChunk]: - scores = [chunk.score or 0.0 for chunk in chunks] - logger.debug(f"Raw similarity scores: {scores}") - - boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks] - recency_multiplier = [chunk.recency_bias for chunk in chunks] - - norm_min = min(norm_min, min(scores[:norm_cutoff])) - norm_max = max(norm_max, max(scores[:norm_cutoff])) - # This should never be 0 unless user has done some weird/wrong settings - norm_range = norm_max - norm_min - - boosted_scores = [ - max(0, (score - norm_min) * boost * recency / norm_range) - for score, boost, recency in zip(scores, boosts, recency_multiplier) - ] - - rescored_chunks = list(zip(boosted_scores, chunks)) - rescored_chunks.sort(key=lambda x: x[0], reverse=True) - sorted_boosted_scores, boost_sorted_chunks = zip(*rescored_chunks) - - final_chunks = list(boost_sorted_chunks) - final_scores = list(sorted_boosted_scores) - for ind, chunk in enumerate(final_chunks): - chunk.score = final_scores[ind] - - logger.debug( - f"Boosted + Time Weighted sorted similarity scores: {list(final_scores)}" - ) - - return final_chunks - - -def _simplify_text(text: str) -> str: - return "".join( - char for char in text if char not in string.punctuation and not char.isspace() - ).lower() - - -def retrieve_chunks( - query: SearchQuery, - document_index: DocumentIndex, - db_session: Session, - hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search - multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION, - retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] - | None = None, -) -> list[InferenceChunk]: - """Returns a list of the best chunks from an initial keyword/semantic/ hybrid search.""" - # Don't do query expansion on complex queries, rephrasings likely would not work well - if not multilingual_expansion_str or "\n" in query.query or "\r" in query.query: - top_chunks = doc_index_retrieval( - query=query, - document_index=document_index, - db_session=db_session, - hybrid_alpha=hybrid_alpha, - ) - else: - simplified_queries = set() - run_queries: list[tuple[Callable, tuple]] = [] - - # Currently only uses query expansion on multilingual use cases - query_rephrases = multilingual_query_expansion( - query.query, multilingual_expansion_str - ) - # Just to be extra sure, add the original query. - query_rephrases.append(query.query) - for rephrase in set(query_rephrases): - # Sometimes the model rephrases the query in the same language with minor changes - # Avoid doing an extra search with the minor changes as this biases the results - simplified_rephrase = _simplify_text(rephrase) - if simplified_rephrase in simplified_queries: - continue - simplified_queries.add(simplified_rephrase) - - q_copy = query.copy(update={"query": rephrase}, deep=True) - run_queries.append( - ( - doc_index_retrieval, - (q_copy, document_index, db_session, hybrid_alpha), - ) - ) - parallel_search_results = run_functions_tuples_in_parallel(run_queries) - top_chunks = combine_retrieval_results(parallel_search_results) - - if not top_chunks: - logger.info( - f"{query.search_type.value.capitalize()} search returned no results " - f"with filters: {query.filters}" - ) - return [] - - if retrieval_metrics_callback is not None: - chunk_metrics = [ - ChunkMetric( - document_id=chunk.document_id, - chunk_content_start=chunk.content[:MAX_METRICS_CONTENT], - first_link=chunk.source_links[0] if chunk.source_links else None, - score=chunk.score if chunk.score is not None else 0, - ) - for chunk in top_chunks - ] - retrieval_metrics_callback( - RetrievalMetricsContainer( - search_type=query.search_type, metrics=chunk_metrics - ) - ) - - return top_chunks - - -def should_rerank(query: SearchQuery) -> bool: - # Don't re-rank for keyword search - return query.search_type != SearchType.KEYWORD and not query.skip_rerank - - -def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool: - return not query.skip_llm_chunk_filter - - -def rerank_chunks( - query: SearchQuery, - chunks_to_rerank: list[InferenceChunk], - rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> list[InferenceChunk]: - ranked_chunks, _ = semantic_reranking( - query=query.query, - chunks=chunks_to_rerank[: query.num_rerank], - rerank_metrics_callback=rerank_metrics_callback, - ) - lower_chunks = chunks_to_rerank[query.num_rerank :] - # Scores from rerank cannot be meaningfully combined with scores without rerank - for lower_chunk in lower_chunks: - lower_chunk.score = None - ranked_chunks.extend(lower_chunks) - return ranked_chunks - - -@log_function_time(print_only=True) -def filter_chunks( - query: SearchQuery, - chunks_to_filter: list[InferenceChunk], -) -> list[str]: - """Filters chunks based on whether the LLM thought they were relevant to the query. - - Returns a list of the unique chunk IDs that were marked as relevant""" - chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks] - llm_chunk_selection = llm_batch_eval_chunks( - query=query.query, - chunk_contents=[chunk.content for chunk in chunks_to_filter], - ) - return [ - chunk.unique_id - for ind, chunk in enumerate(chunks_to_filter) - if llm_chunk_selection[ind] - ] - - -def full_chunk_search( - query: SearchQuery, - document_index: DocumentIndex, - db_session: Session, - hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search - multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION, - retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] - | None = None, - rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> tuple[list[InferenceChunk], list[bool]]: - """A utility which provides an easier interface than `full_chunk_search_generator`. - Rather than returning the chunks and llm relevance filter results in two separate - yields, just returns them both at once.""" - search_generator = full_chunk_search_generator( - search_query=query, - document_index=document_index, - db_session=db_session, - hybrid_alpha=hybrid_alpha, - multilingual_expansion_str=multilingual_expansion_str, - retrieval_metrics_callback=retrieval_metrics_callback, - rerank_metrics_callback=rerank_metrics_callback, - ) - top_chunks = cast(list[InferenceChunk], next(search_generator)) - llm_chunk_selection = cast(list[bool], next(search_generator)) - return top_chunks, llm_chunk_selection - - -def empty_search_generator() -> Iterator[list[InferenceChunk] | list[bool]]: - yield cast(list[InferenceChunk], []) - yield cast(list[bool], []) - - -def full_chunk_search_generator( - search_query: SearchQuery, - document_index: DocumentIndex, - db_session: Session, - hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search - multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION, - retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] - | None = None, - rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> Iterator[list[InferenceChunk] | list[bool]]: - """Always yields twice. Once with the selected chunks and once with the LLM relevance filter result. - If LLM filter results are turned off, returns a list of False - """ - chunks_yielded = False - - retrieved_chunks = retrieve_chunks( - query=search_query, - document_index=document_index, - db_session=db_session, - hybrid_alpha=hybrid_alpha, - multilingual_expansion_str=multilingual_expansion_str, - retrieval_metrics_callback=retrieval_metrics_callback, - ) - - if not retrieved_chunks: - yield cast(list[InferenceChunk], []) - yield cast(list[bool], []) - return - - post_processing_tasks: list[FunctionCall] = [] - - rerank_task_id = None - if should_rerank(search_query): - post_processing_tasks.append( - FunctionCall( - rerank_chunks, - ( - search_query, - retrieved_chunks, - rerank_metrics_callback, - ), - ) - ) - rerank_task_id = post_processing_tasks[-1].result_id - else: - final_chunks = retrieved_chunks - # NOTE: if we don't rerank, we can return the chunks immediately - # since we know this is the final order - _log_top_chunk_links(search_query.search_type.value, final_chunks) - yield final_chunks - chunks_yielded = True - - llm_filter_task_id = None - if should_apply_llm_based_relevance_filter(search_query): - post_processing_tasks.append( - FunctionCall( - filter_chunks, - (search_query, retrieved_chunks[: search_query.max_llm_filter_chunks]), - ) - ) - llm_filter_task_id = post_processing_tasks[-1].result_id - - post_processing_results = ( - run_functions_in_parallel(post_processing_tasks) - if post_processing_tasks - else {} - ) - reranked_chunks = cast( - list[InferenceChunk] | None, - post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None, - ) - if reranked_chunks: - if chunks_yielded: - logger.error( - "Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen." - ) - else: - _log_top_chunk_links(search_query.search_type.value, reranked_chunks) - yield reranked_chunks - - llm_chunk_selection = cast( - list[str] | None, - post_processing_results.get(str(llm_filter_task_id)) - if llm_filter_task_id - else None, - ) - if llm_chunk_selection is not None: - yield [ - chunk.unique_id in llm_chunk_selection - for chunk in reranked_chunks or retrieved_chunks - ] - else: - yield [False for _ in reranked_chunks or retrieved_chunks] - - -def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc: - if not inf_chunks: - raise ValueError("Cannot combine empty list of chunks") - - # Use the first link of the document - first_chunk = inf_chunks[0] - chunk_texts = [chunk.content for chunk in inf_chunks] - return LlmDoc( - document_id=first_chunk.document_id, - content="\n".join(chunk_texts), - semantic_identifier=first_chunk.semantic_identifier, - source_type=first_chunk.source_type, - metadata=first_chunk.metadata, - updated_at=first_chunk.updated_at, - link=first_chunk.source_links[0] if first_chunk.source_links else None, - ) - - -def inference_documents_from_ids( - doc_identifiers: list[tuple[str, int]], - document_index: DocumentIndex, -) -> list[LlmDoc]: - # Currently only fetches whole docs - doc_ids_set = set(doc_id for doc_id, chunk_id in doc_identifiers) - - # No need for ACL here because the doc ids were validated beforehand - filters = IndexFilters(access_control_list=None) - - functions_with_args: list[tuple[Callable, tuple]] = [ - (document_index.id_based_retrieval, (doc_id, None, filters)) - for doc_id in doc_ids_set - ] - - parallel_results = run_functions_tuples_in_parallel( - functions_with_args, allow_failures=True - ) - - # Any failures to retrieve would give a None, drop the Nones and empty lists - inference_chunks_sets = [res for res in parallel_results if res] - - return [combine_inference_chunks(chunk_set) for chunk_set in inference_chunks_sets] diff --git a/backend/danswer/search/utils.py b/backend/danswer/search/utils.py new file mode 100644 index 00000000000..4b01f70eb90 --- /dev/null +++ b/backend/danswer/search/utils.py @@ -0,0 +1,29 @@ +from danswer.indexing.models import InferenceChunk +from danswer.search.models import SearchDoc + + +def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]: + search_docs = ( + [ + SearchDoc( + document_id=chunk.document_id, + chunk_ind=chunk.chunk_id, + semantic_identifier=chunk.semantic_identifier or "Unknown", + link=chunk.source_links.get(0) if chunk.source_links else None, + blurb=chunk.blurb, + source_type=chunk.source_type, + boost=chunk.boost, + hidden=chunk.hidden, + metadata=chunk.metadata, + score=chunk.score, + match_highlights=chunk.match_highlights, + updated_at=chunk.updated_at, + primary_owners=chunk.primary_owners, + secondary_owners=chunk.secondary_owners, + ) + for chunk in chunks + ] + if chunks + else [] + ) + return search_docs diff --git a/backend/danswer/server/documents/document.py b/backend/danswer/server/documents/document.py index ea080b0335d..3abab330293 100644 --- a/backend/danswer/server/documents/document.py +++ b/backend/danswer/server/documents/document.py @@ -11,8 +11,8 @@ from danswer.document_index.factory import get_default_document_index from danswer.llm.utils import get_default_llm_token_encode from danswer.prompts.prompt_utils import build_doc_context_str -from danswer.search.access_filters import build_access_filters_for_user from danswer.search.models import IndexFilters +from danswer.search.preprocessing.access_filters import build_access_filters_for_user from danswer.server.documents.models import ChunkInfo from danswer.server.documents.models import DocumentInfo diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index 8762f40b51a..d75ff694809 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -14,8 +14,8 @@ from danswer.db.engine import get_session from danswer.db.models import User from danswer.db.persona import create_update_persona +from danswer.llm.answering.prompts.utils import build_dummy_prompt from danswer.llm.utils import get_default_llm_version -from danswer.one_shot_answer.qa_block import build_dummy_prompt from danswer.server.features.persona.models import CreatePersonaRequest from danswer.server.features.persona.models import PersonaSnapshot from danswer.server.features.persona.models import PromptTemplateResponse diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py index a724ac5f3e2..4cc80eec0ee 100644 --- a/backend/danswer/server/features/persona/models.py +++ b/backend/danswer/server/features/persona/models.py @@ -4,7 +4,7 @@ from danswer.db.models import Persona from danswer.db.models import StarterMessage -from danswer.search.models import RecencyBiasSetting +from danswer.search.enums import RecencyBiasSetting from danswer.server.features.document_set.models import DocumentSet from danswer.server.features.prompt.models import PromptSnapshot diff --git a/backend/danswer/server/gpts/api.py b/backend/danswer/server/gpts/api.py index 9800032520e..bfada9b5593 100644 --- a/backend/danswer/server/gpts/api.py +++ b/backend/danswer/server/gpts/api.py @@ -6,13 +6,9 @@ from pydantic import BaseModel from sqlalchemy.orm import Session -from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_session -from danswer.document_index.factory import get_default_document_index -from danswer.search.access_filters import build_access_filters_for_user -from danswer.search.models import IndexFilters -from danswer.search.models import SearchQuery -from danswer.search.search_runner import full_chunk_search +from danswer.search.models import SearchRequest +from danswer.search.pipeline import SearchPipeline from danswer.server.danswer_api.ingestion import api_key_dep from danswer.utils.logger import setup_logger @@ -70,27 +66,13 @@ def gpt_search( _: str | None = Depends(api_key_dep), db_session: Session = Depends(get_session), ) -> GptSearchResponse: - query = search_request.query - - user_acl_filters = build_access_filters_for_user(None, db_session) - final_filters = IndexFilters(access_control_list=user_acl_filters) - - search_query = SearchQuery( - query=query, - filters=final_filters, - recency_bias_multiplier=1.0, - skip_llm_chunk_filter=True, - ) - - embedding_model = get_current_db_embedding_model(db_session) - - document_index = get_default_document_index( - primary_index_name=embedding_model.index_name, secondary_index_name=None - ) - - top_chunks, __ = full_chunk_search( - query=search_query, document_index=document_index, db_session=db_session - ) + top_chunks = SearchPipeline( + search_request=SearchRequest( + query=search_request.query, + ), + user=None, + db_session=db_session, + ).reranked_docs return GptSearchResponse( matching_document_chunks=[ diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index a8076659c61..4fb98c5a156 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -6,7 +6,6 @@ from sqlalchemy.orm import Session from danswer.auth.users import current_user -from danswer.chat.chat_utils import compute_max_document_tokens from danswer.chat.chat_utils import create_chat_chain from danswer.chat.process_message import stream_chat_message from danswer.db.chat import create_chat_session @@ -25,6 +24,7 @@ from danswer.db.models import User from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index +from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens from danswer.secondary_llm_flows.chat_session_naming import ( get_renamed_conversation_name, ) diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index 6d8529486fd..5150eb9ce10 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -15,11 +15,11 @@ from danswer.document_index.vespa.index import VespaIndex from danswer.one_shot_answer.answer_question import stream_search_answer from danswer.one_shot_answer.models import DirectQARequest -from danswer.search.access_filters import build_access_filters_for_user -from danswer.search.danswer_helper import recommend_search_flow from danswer.search.models import IndexFilters from danswer.search.models import SearchDoc -from danswer.search.search_runner import chunks_to_search_docs +from danswer.search.preprocessing.access_filters import build_access_filters_for_user +from danswer.search.preprocessing.danswer_helper import recommend_search_flow +from danswer.search.utils import chunks_to_search_docs from danswer.secondary_llm_flows.query_validation import get_query_answerability from danswer.secondary_llm_flows.query_validation import stream_query_answerability from danswer.server.query_and_chat.models import AdminSearchRequest diff --git a/backend/tests/regression/answer_quality/eval_direct_qa.py b/backend/tests/regression/answer_quality/eval_direct_qa.py index bd2f70010e2..d32f2754725 100644 --- a/backend/tests/regression/answer_quality/eval_direct_qa.py +++ b/backend/tests/regression/answer_quality/eval_direct_qa.py @@ -77,7 +77,6 @@ def get_answer_for_question( str | None, RetrievalMetricsContainer | None, RerankMetricsContainer | None, - LLMMetricsContainer | None, ]: filters = IndexFilters( source_type=None, @@ -103,7 +102,6 @@ def get_answer_for_question( retrieval_metrics = MetricsHander[RetrievalMetricsContainer]() rerank_metrics = MetricsHander[RerankMetricsContainer]() - llm_metrics = MetricsHander[LLMMetricsContainer]() answer = get_search_answer( query_req=new_message_request, @@ -116,14 +114,12 @@ def get_answer_for_question( bypass_acl=True, retrieval_metrics_callback=retrieval_metrics.record_metric, rerank_metrics_callback=rerank_metrics.record_metric, - llm_metrics_callback=llm_metrics.record_metric, ) return ( answer.answer, retrieval_metrics.metrics, rerank_metrics.metrics, - llm_metrics.metrics, ) @@ -221,7 +217,6 @@ def _print_llm_metrics(metrics_container: LLMMetricsContainer) -> None: answer, retrieval_metrics, rerank_metrics, - llm_metrics, ) = get_answer_for_question(sample["question"], db_session) end_time = datetime.now() @@ -237,12 +232,6 @@ def _print_llm_metrics(metrics_container: LLMMetricsContainer) -> None: else "\tFailed, either crashed or refused to answer." ) if not args.discard_metrics: - print("\nLLM Tokens Usage:") - if llm_metrics is None: - print("No LLM Metrics Available") - else: - _print_llm_metrics(llm_metrics) - print("\nRetrieval Metrics:") if retrieval_metrics is None: print("No Retrieval Metrics Available") diff --git a/backend/tests/regression/search_quality/eval_search.py b/backend/tests/regression/search_quality/eval_search.py index 7cd3e6068c6..5bf9406b412 100644 --- a/backend/tests/regression/search_quality/eval_search.py +++ b/backend/tests/regression/search_quality/eval_search.py @@ -7,16 +7,13 @@ from sqlalchemy.orm import Session -from danswer.chat.chat_utils import get_chunks_for_qa -from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_sqlalchemy_engine -from danswer.document_index.factory import get_default_document_index from danswer.indexing.models import InferenceChunk -from danswer.search.models import IndexFilters +from danswer.llm.answering.doc_pruning import reorder_docs from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer -from danswer.search.models import SearchQuery -from danswer.search.search_runner import full_chunk_search +from danswer.search.models import SearchRequest +from danswer.search.pipeline import SearchPipeline from danswer.utils.callbacks import MetricsHander @@ -81,46 +78,25 @@ def get_search_results( RetrievalMetricsContainer | None, RerankMetricsContainer | None, ]: - filters = IndexFilters( - source_type=None, - document_set=None, - time_cutoff=None, - access_control_list=None, - ) - search_query = SearchQuery( - query=query, - filters=filters, - recency_bias_multiplier=1.0, - ) - retrieval_metrics = MetricsHander[RetrievalMetricsContainer]() rerank_metrics = MetricsHander[RerankMetricsContainer]() with Session(get_sqlalchemy_engine()) as db_session: - embedding_model = get_current_db_embedding_model(db_session) - - document_index = get_default_document_index( - primary_index_name=embedding_model.index_name, secondary_index_name=None - ) - - top_chunks, llm_chunk_selection = full_chunk_search( - query=search_query, - document_index=document_index, - db_session=db_session, - retrieval_metrics_callback=retrieval_metrics.record_metric, - rerank_metrics_callback=rerank_metrics.record_metric, - ) - - llm_chunks_indices = get_chunks_for_qa( - chunks=top_chunks, - llm_chunk_selection=llm_chunk_selection, - token_limit=None, - ) - - llm_chunks = [top_chunks[i] for i in llm_chunks_indices] + search_pipeline = SearchPipeline( + search_request=SearchRequest( + query=query, + ), + user=None, + db_session=db_session, + retrieval_metrics_callback=retrieval_metrics.record_metric, + rerank_metrics_callback=rerank_metrics.record_metric, + ) + + top_chunks = search_pipeline.reranked_docs + llm_chunk_selection = search_pipeline.chunk_relevance_list return ( - llm_chunks, + reorder_docs(top_chunks, llm_chunk_selection), retrieval_metrics.metrics, rerank_metrics.metrics, ) diff --git a/backend/tests/unit/danswer/direct_qa/test_qa_utils.py b/backend/tests/unit/danswer/direct_qa/test_qa_utils.py index b30d08b1697..b7b30b63d2d 100644 --- a/backend/tests/unit/danswer/direct_qa/test_qa_utils.py +++ b/backend/tests/unit/danswer/direct_qa/test_qa_utils.py @@ -3,8 +3,12 @@ from danswer.configs.constants import DocumentSource from danswer.indexing.models import InferenceChunk -from danswer.one_shot_answer.qa_utils import match_quotes_to_docs -from danswer.one_shot_answer.qa_utils import separate_answer_quotes +from danswer.llm.answering.stream_processing.quotes_processing import ( + match_quotes_to_docs, +) +from danswer.llm.answering.stream_processing.quotes_processing import ( + separate_answer_quotes, +) class TestQAPostprocessing(unittest.TestCase):