From f09afb30d9f94a0ba0e0d4fbc396f750770461be Mon Sep 17 00:00:00 2001 From: Alex Co Date: Fri, 15 Nov 2024 10:38:37 +0800 Subject: [PATCH] Refactor imports and enhance employee context handling in prompts Signed-off-by: Alex Co --- .../slack/handlers/handle_message.py | 7 ++- .../slack/handlers/handle_regular_answer.py | 43 +++++++++++-------- backend/danswer/llm/answering/answer.py | 9 +++- .../llm/answering/prompts/citations_prompt.py | 40 ++++++++--------- .../llm/answering/prompts/quotes_prompt.py | 33 +++++++++++++- backend/danswer/prompts/prompt_utils.py | 22 +++++++--- 6 files changed, 99 insertions(+), 55 deletions(-) diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index cce45331ee7..a70a2423671 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -1,9 +1,5 @@ import datetime -from slack_sdk import WebClient -from slack_sdk.errors import SlackApiError -from sqlalchemy.orm import Session - from danswer.configs.danswerbot_configs import DANSWER_BOT_FEEDBACK_REMINDER from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI from danswer.danswerbot.slack.blocks import get_feedback_reminder_blocks @@ -24,6 +20,9 @@ from danswer.db.users import add_non_web_user_if_not_exists from danswer.utils.logger import setup_logger from shared_configs.configs import SLACK_CHANNEL_ID +from slack_sdk import WebClient +from slack_sdk.errors import SlackApiError +from sqlalchemy.orm import Session logger_base = setup_logger() diff --git a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py index f1c9bd077cf..77eebeac58a 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py @@ -5,13 +5,6 @@ from typing import Optional from typing import TypeVar -from fastapi import HTTPException -from retry import retry -from slack_sdk import WebClient -from slack_sdk.models.blocks import DividerBlock -from slack_sdk.models.blocks import SectionBlock -from sqlalchemy.orm import Session - from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT @@ -54,6 +47,15 @@ from danswer.search.models import RerankingDetails from danswer.search.models import RetrievalDetails from danswer.utils.logger import DanswerLoggingAdapter +from danswer.utils.logger import setup_logger +from fastapi import HTTPException +from retry import retry +from slack_sdk import WebClient +from slack_sdk.models.blocks import DividerBlock +from slack_sdk.models.blocks import SectionBlock +from sqlalchemy.orm import Session + +logger = setup_logger() srl = SlackRateLimiter() @@ -101,12 +103,11 @@ def handle_regular_answer( messages = message_info.thread_messages message_ts_to_respond_to = message_info.msg_to_respond is_bot_msg = message_info.is_bot_msg - user = None - if message_info.is_bot_dm: - if message_info.email: - engine = get_sqlalchemy_engine() - with Session(engine) as db_session: - user = get_user_by_email(message_info.email, db_session) + + if message_info.email: + engine = get_sqlalchemy_engine() + with Session(engine) as db_session: + user = get_user_by_email(message_info.email, db_session) document_set_names: list[str] | None = None persona = slack_bot_config.persona if slack_bot_config else None @@ -253,16 +254,20 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non answer = _get_answer( DirectQARequest( messages=messages, - multilingual_query_expansion=saved_search_settings.multilingual_expansion - if saved_search_settings - else None, + multilingual_query_expansion=( + saved_search_settings.multilingual_expansion + if saved_search_settings + else None + ), prompt_id=prompt.id if prompt else None, persona_id=persona.id if persona is not None else 0, retrieval_options=retrieval_details, chain_of_thought=not disable_cot, - rerank_settings=RerankingDetails.from_db_model(saved_search_settings) - if saved_search_settings - else None, + rerank_settings=( + RerankingDetails.from_db_model(saved_search_settings) + if saved_search_settings + else None + ), ) ) except Exception as e: diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index af10884be3e..3964accd6a8 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -20,7 +20,9 @@ from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.answering.prompts.build import default_build_system_message from danswer.llm.answering.prompts.build import default_build_user_message -from danswer.llm.answering.prompts.citations_prompt import build_citations_system_message +from danswer.llm.answering.prompts.citations_prompt import ( + build_citations_system_message, +) from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message from danswer.llm.answering.stream_processing.citation_processing import ( @@ -53,7 +55,9 @@ from danswer.tools.search.search_tool import SearchTool from danswer.tools.tool import Tool from danswer.tools.tool import ToolResponse -from danswer.tools.tool_runner import check_which_tools_should_run_for_non_tool_calling_llm +from danswer.tools.tool_runner import ( + check_which_tools_should_run_for_non_tool_calling_llm, +) from danswer.tools.tool_runner import ToolCallFinalResult from danswer.tools.tool_runner import ToolCallKickoff from danswer.tools.tool_runner import ToolRunner @@ -183,6 +187,7 @@ def _update_prompt_builder_for_search_tool( context_docs=final_context_documents, history_str=self.single_message_history or "", prompt=self.prompt_config, + user_email=self.user_email, ) ) diff --git a/backend/danswer/llm/answering/prompts/citations_prompt.py b/backend/danswer/llm/answering/prompts/citations_prompt.py index 7dba38ffa45..705ea1407d7 100644 --- a/backend/danswer/llm/answering/prompts/citations_prompt.py +++ b/backend/danswer/llm/answering/prompts/citations_prompt.py @@ -5,35 +5,31 @@ from danswer.db.search_settings import get_multilingual_expansion from danswer.file_store.utils import InMemoryChatFile from danswer.llm.answering.models import PromptConfig -from danswer.llm.factory import get_llms_for_persona, get_main_llm_from_tuple +from danswer.llm.factory import get_llms_for_persona +from danswer.llm.factory import get_main_llm_from_tuple from danswer.llm.interfaces import LLMConfig -from danswer.llm.utils import ( - build_content_with_imgs, - check_number_of_tokens, - get_max_input_tokens, -) +from danswer.llm.utils import build_content_with_imgs +from danswer.llm.utils import check_number_of_tokens +from danswer.llm.utils import get_max_input_tokens from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT -from danswer.prompts.direct_qa_prompts import ( - CITATIONS_PROMPT, - CITATIONS_PROMPT_FOR_TOOL_CALLING, -) -from danswer.prompts.prompt_utils import ( - add_date_time_to_prompt, - add_employee_context_to_prompt, - build_complete_context_str, - build_task_prompt_reminders, -) +from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT +from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT_FOR_TOOL_CALLING +from danswer.prompts.prompt_utils import add_date_time_to_prompt +from danswer.prompts.prompt_utils import add_employee_context_to_prompt +from danswer.prompts.prompt_utils import build_complete_context_str +from danswer.prompts.prompt_utils import build_task_prompt_reminders +from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT from danswer.prompts.token_counts import ( - ADDITIONAL_INFO_TOKEN_CNT, CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT, - CITATION_REMINDER_TOKEN_CNT, - CITATION_STATEMENT_TOKEN_CNT, - LANGUAGE_HINT_TOKEN_CNT, ) +from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT +from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT +from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT from danswer.search.models import InferenceChunk from danswer.utils.logger import setup_logger -from langchain.schema.messages import HumanMessage, SystemMessage +from langchain.schema.messages import HumanMessage +from langchain.schema.messages import SystemMessage logger = setup_logger() @@ -136,6 +132,8 @@ def build_citations_system_message( prompt_str=system_prompt, user_email=user_email ) + logger.debug(f"Built system message: {system_prompt}") + return SystemMessage(content=system_prompt) diff --git a/backend/danswer/llm/answering/prompts/quotes_prompt.py b/backend/danswer/llm/answering/prompts/quotes_prompt.py index 07abc4356b6..35a8059810c 100644 --- a/backend/danswer/llm/answering/prompts/quotes_prompt.py +++ b/backend/danswer/llm/answering/prompts/quotes_prompt.py @@ -1,5 +1,3 @@ -from langchain.schema.messages import HumanMessage - from danswer.chat.models import LlmDoc from danswer.configs.chat_configs import LANGUAGE_HINT from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE @@ -10,8 +8,13 @@ from danswer.prompts.direct_qa_prompts import JSON_PROMPT from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT from danswer.prompts.prompt_utils import add_date_time_to_prompt +from danswer.prompts.prompt_utils import add_employee_context_to_prompt from danswer.prompts.prompt_utils import build_complete_context_str from danswer.search.models import InferenceChunk +from danswer.utils.logger import setup_logger +from langchain.schema.messages import HumanMessage + +logger = setup_logger() def _build_weak_llm_quotes_prompt( @@ -19,6 +22,7 @@ def _build_weak_llm_quotes_prompt( context_docs: list[LlmDoc] | list[InferenceChunk], history_str: str, prompt: PromptConfig, + user_email: str | None = None, ) -> HumanMessage: """Since Danswer supports a variety of LLMs, this less demanding prompt is provided as an option to use with weaker LLMs such as small version, low float precision, quantized, @@ -39,6 +43,10 @@ def _build_weak_llm_quotes_prompt( if prompt.datetime_aware: prompt_str = add_date_time_to_prompt(prompt_str=prompt_str) + if user_email: + prompt_str = add_employee_context_to_prompt( + prompt_str=prompt_str, user_email=user_email + ) return HumanMessage(content=prompt_str) @@ -47,7 +55,21 @@ def _build_strong_llm_quotes_prompt( context_docs: list[LlmDoc] | list[InferenceChunk], history_str: str, prompt: PromptConfig, + user_email: str | None = None, ) -> HumanMessage: + """ + Constructs a prompt for the language model based on the provided inputs. + + Args: + question (str): The user's query. + context_docs (list[LlmDoc] | list[InferenceChunk]): List of context documents or inference chunks. + history_str (str): The conversation history. + prompt (PromptConfig): The prompt configuration. + user_email (str, optional): The user's email. Defaults to None. + + Returns: + HumanMessage: The constructed prompt. + """ use_language_hint = bool(get_multilingual_expansion()) context_block = "" @@ -71,6 +93,11 @@ def _build_strong_llm_quotes_prompt( if prompt.datetime_aware: full_prompt = add_date_time_to_prompt(prompt_str=full_prompt) + if user_email: + full_prompt = add_employee_context_to_prompt( + prompt_str=full_prompt, user_email=user_email + ) + return HumanMessage(content=full_prompt) @@ -79,6 +106,7 @@ def build_quotes_user_message( context_docs: list[LlmDoc] | list[InferenceChunk], history_str: str, prompt: PromptConfig, + user_email: str, ) -> HumanMessage: prompt_builder = ( _build_weak_llm_quotes_prompt @@ -91,6 +119,7 @@ def build_quotes_user_message( context_docs=context_docs, history_str=history_str, prompt=prompt, + user_email=user_email, ) diff --git a/backend/danswer/prompts/prompt_utils.py b/backend/danswer/prompts/prompt_utils.py index 0eb41bf1275..e105124a6d1 100644 --- a/backend/danswer/prompts/prompt_utils.py +++ b/backend/danswer/prompts/prompt_utils.py @@ -61,32 +61,40 @@ def add_date_time_to_prompt(prompt_str: str) -> str: + " " + BASIC_TIME_STR.format(datetime_info=get_current_llm_day_time()) ) + + # Initialize Redis client redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER) + def add_employee_context_to_prompt(prompt_str: str, user_email: str) -> str: # Check Redis for cached employee context cached_context = redis_client.get(user_email) if cached_context: logger.info("Employee context retrieved from Redis.") - return prompt_str.replace(DANSWER_EMPLOYEE_REPLACEMENT, cached_context.decode('utf-8')) + return prompt_str.replace( + DANSWER_EMPLOYEE_REPLACEMENT, cached_context.decode("utf-8") + ) airtable_client = AirtableApi(AIRTABLE_API_TOKEN) - all_employees = airtable_client.table(AIRTABLE_EMPLOYEE_BASE_ID, AIRTABLE_EMPLOYEE_TABLE_NAME_OR_ID).all() + all_employees = airtable_client.table( + AIRTABLE_EMPLOYEE_BASE_ID, AIRTABLE_EMPLOYEE_TABLE_NAME_OR_ID + ).all() for employee in all_employees: if "fields" in employee and "MV Email" in employee["fields"]: if employee["fields"]["MV Email"] == user_email: logger.info(f"Employee found: {employee['fields']['Preferred Name']}") - employee_context = f"My Name: {employee['fields']['Preferred Name']}\nMy Title: {employee['fields']['Job Role']}\nMy City Office: {employee['fields']['City Office']}\nMy Division: {employee['fields']['Import: Division']}\nMy Manager: {employee['fields']['Reports To']}\nMy Department: {employee['fields']['Import: Department']}" - + employee_context = f"My Name: {employee['fields']['Preferred Name']}\nMy Title: {employee['fields']['Job Role']}\nMy City Office: {employee['fields']['City Office']}\nMy Division: {employee['fields']['Import: Division']}\nMy Manager: {employee['fields']['Reports To']}\nMy Department: {employee['fields']['Import: Department']}\nMy Employment Status: {employee['fields']['Employment Status']}" + # Store the employee context in Redis with a TTL of 30 days - redis_client.setex(user_email, 30 * 24 * 60 * 60, employee_context) + redis_client.setex(user_email, 7 * 24 * 60 * 60, employee_context) break - + if DANSWER_EMPLOYEE_REPLACEMENT in prompt_str: return prompt_str.replace(DANSWER_EMPLOYEE_REPLACEMENT, employee_context) - + + def build_task_prompt_reminders( prompt: Prompt | PromptConfig, use_language_hint: bool,