From f180a71e99421b9e02671f83176718ac14364075 Mon Sep 17 00:00:00 2001 From: Alex Co Date: Mon, 11 Nov 2024 16:47:28 +0800 Subject: [PATCH] Injecting employee context into system prompt Signed-off-by: Alex Co --- .gitignore | 2 + backend/danswer/chat/personas.yaml | 7 ++-- backend/danswer/chat/process_message.py | 14 +++---- backend/danswer/chat/prompts.yaml | 35 +++++++++------- backend/danswer/configs/app_configs.py | 6 +++ backend/danswer/llm/answering/answer.py | 21 ++++------ .../llm/answering/prompts/citations_prompt.py | 18 +++++--- backend/danswer/prompts/direct_qa_prompts.py | 19 +++++---- backend/danswer/prompts/prompt_utils.py | 42 +++++++++++++++++-- .../server/query_and_chat/chat_backend.py | 19 ++++----- 10 files changed, 115 insertions(+), 68 deletions(-) diff --git a/.gitignore b/.gitignore index ba50495d7ff..9ebde9f5fa0 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ .vscode/ *.sw? /backend/tests/regression/answer_quality/search_test_config.yaml +env.sh +.cursorrules \ No newline at end of file diff --git a/backend/danswer/chat/personas.yaml b/backend/danswer/chat/personas.yaml index 1eb95d70e25..f15b304c178 100644 --- a/backend/danswer/chat/personas.yaml +++ b/backend/danswer/chat/personas.yaml @@ -19,7 +19,7 @@ personas: # Default number of chunks to include as context, set to 0 to disable retrieval # Remove the field to set to the system default number of chunks/tokens to pass to Gen AI # Each chunk is 512 tokens long - num_chunks: 20 + num_chunks: 50 # Enable/Disable usage of the LLM chunk filter feature whereby each chunk is passed to the LLM to determine # if the chunk is useful or not towards the latest user query # This feature can be overriden for all personas via DISABLE_LLM_DOC_RELEVANCE env variable @@ -80,7 +80,6 @@ personas: is_visible: true internet_search: true - - id: 3 name: "Art" description: > @@ -92,8 +91,8 @@ personas: llm_filter_extraction: false recency_bias: "no_decay" document_sets: [] - icon_shape: 234124 + icon_shape: 234124 icon_color: "#9B59B6" - image_generation: true + image_generation: true display_priority: 3 is_visible: false diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 26c59ebf60c..71896b5c84b 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -4,8 +4,6 @@ from functools import partial from typing import cast -from sqlalchemy.orm import Session - from danswer.chat.chat_utils import create_chat_chain from danswer.chat.models import AllCitations from danswer.chat.models import CitationInfo @@ -80,12 +78,8 @@ from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID from danswer.tools.images.image_generation_tool import ImageGenerationResponse from danswer.tools.images.image_generation_tool import ImageGenerationTool -from danswer.tools.internet_search.internet_search_tool import ( - INTERNET_SEARCH_RESPONSE_ID, -) -from danswer.tools.internet_search.internet_search_tool import ( - internet_search_response_to_search_docs, -) +from danswer.tools.internet_search.internet_search_tool import INTERNET_SEARCH_RESPONSE_ID +from danswer.tools.internet_search.internet_search_tool import internet_search_response_to_search_docs from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse from danswer.tools.internet_search.internet_search_tool import InternetSearchTool from danswer.tools.models import DynamicSchemaInfo @@ -101,6 +95,7 @@ from danswer.tools.utils import explicit_tool_calling_supported from danswer.utils.logger import setup_logger from danswer.utils.timing import log_generator_function_time +from sqlalchemy.orm import Session logger = setup_logger() @@ -286,6 +281,8 @@ def stream_chat_message_objects( try: user_id = user.id if user is not None else None + user_email = user.email if user is not None else None + chat_session = get_chat_session_by_id( chat_session_id=new_msg_req.chat_session_id, user_id=user_id, @@ -628,6 +625,7 @@ def stream_chat_message_objects( # LLM prompt building, response capturing, etc. answer = Answer( + user_email=user_email, is_connected=is_connected, question=final_msg.message, latest_query_files=latest_query_files, diff --git a/backend/danswer/chat/prompts.yaml b/backend/danswer/chat/prompts.yaml index a5eb34e4ad3..899f6cfd7b6 100644 --- a/backend/danswer/chat/prompts.yaml +++ b/backend/danswer/chat/prompts.yaml @@ -7,21 +7,26 @@ prompts: description: "Answers user questions using retrieved context!" # System Prompt (as shown in UI) system: > - You are a question answering system that is constantly learning and improving. - The current date is DANSWER_DATETIME_REPLACEMENT. - - You can process and comprehend vast amounts of text and utilize this knowledge to provide - grounded, accurate, and concise answers to diverse queries. - + You are an advanced AI assistant, Mindvalley’s AI Assistant, designed to assist in various tasks and provide information relevant to my role and my department. The current date is DANSWER_DATETIME_REPLACEMENT. + + You can process and comprehend vast amounts of text and utilize this knowledge to provide grounded, accurate, and concise answers to diverse queries, specifically tailored to MindValley’s operations and me. + + Below is my employee information in Mindvalley: + + DANSWER_EMPLOYEE_REPLACEMENT + + When answering questions, always keep it relavent to me and my role and you can address me with my name. + You always clearly communicate ANY UNCERTAINTY in your answer. # Task Prompt (as shown in UI) task: > Answer my query based on the documents provided. + The documents may not all be relevant, ignore any documents that are not directly relevant to the most recent user query. - + I have not read or seen any of the documents and do not want to read them. - + If there are no relevant documents, refer to the chat history and your internal knowledge. # Inject a statement at the end of system prompt to inform the LLM of the current date/time # If the DANSWER_DATETIME_REPLACEMENT is set, the date/time is inserted there instead @@ -30,20 +35,20 @@ prompts: # Prompts the LLM to include citations in the for [1], [2] etc. # which get parsed to match the passed in sources include_citations: true - + - name: "ImageGeneration" description: "Generates images based on user prompts!" system: > You are an advanced image generation system capable of creating diverse and detailed images. - + You can interpret user prompts and generate high-quality, creative images that match their descriptions. - + You always strive to create safe and appropriate content, avoiding any harmful or offensive imagery. task: > Generate an image based on the user's description. - + Provide a detailed description of the generated image, including key elements, colors, and composition. - + If the request is not possible or appropriate, explain why and suggest alternatives. datetime_aware: true include_citations: false @@ -70,7 +75,7 @@ prompts: You are a text summarizing assistant that highlights the most important knowledge from the context provided, prioritizing the information that relates to the user query. The current date is DANSWER_DATETIME_REPLACEMENT. - + You ARE NOT creative and always stick to the provided documents. If there are no documents, refer to the conversation history. @@ -104,7 +109,7 @@ prompts: - name: "InternetSearch" description: "Use this Assistant to search the Internet for you (via Bing) and getting the answer" - system: > + system: > You are an intelligent AI agent designed to assist users by providing accurate and relevant information through internet searches. Your primary objectives are: Information Retrieval: Search the internet to find reliable and up-to-date information based on user queries. Ensure that the sources you reference are credible and trustworthy. Context Understanding: Analyze user questions to understand context and intent. Provide answers that are directly related to the user's needs, offering additional context when necessary. diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 2dbe596b635..32fee1714bc 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -72,6 +72,12 @@ if _VALID_EMAIL_DOMAINS_STR else [] ) + +# Airtable Config To Get MV Employee Info +AIRTABLE_API_TOKEN = os.environ.get("AIRTABLE_API_TOKEN") +AIRTABLE_EMPLOYEE_BASE_ID = os.environ.get("AIRTABLE_EMPLOYEE_BASE_ID") +AIRTABLE_EMPLOYEE_TABLE_NAME_OR_ID = os.environ.get("AIRTABLE_EMPLOYEE_TABLE_NAME_OR_ID") + # OAuth Login Flow # Used for both Google OAuth2 and OIDC flows OAUTH_CLIENT_ID = ( diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index 3c0dc4961f1..af10884be3e 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -4,10 +4,6 @@ from typing import cast from uuid import uuid4 -from langchain.schema.messages import BaseMessage -from langchain_core.messages import AIMessageChunk -from langchain_core.messages import HumanMessage - from danswer.chat.chat_utils import llm_doc_from_inference_section from danswer.chat.models import AnswerQuestionPossibleReturn from danswer.chat.models import CitationInfo @@ -24,9 +20,7 @@ from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.answering.prompts.build import default_build_system_message from danswer.llm.answering.prompts.build import default_build_user_message -from danswer.llm.answering.prompts.citations_prompt import ( - build_citations_system_message, -) +from danswer.llm.answering.prompts.citations_prompt import build_citations_system_message from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message from danswer.llm.answering.stream_processing.citation_processing import ( @@ -59,16 +53,16 @@ from danswer.tools.search.search_tool import SearchTool from danswer.tools.tool import Tool from danswer.tools.tool import ToolResponse -from danswer.tools.tool_runner import ( - check_which_tools_should_run_for_non_tool_calling_llm, -) +from danswer.tools.tool_runner import check_which_tools_should_run_for_non_tool_calling_llm from danswer.tools.tool_runner import ToolCallFinalResult from danswer.tools.tool_runner import ToolCallKickoff from danswer.tools.tool_runner import ToolRunner from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm from danswer.tools.utils import explicit_tool_calling_supported from danswer.utils.logger import setup_logger - +from langchain.schema.messages import BaseMessage +from langchain_core.messages import AIMessageChunk +from langchain_core.messages import HumanMessage logger = setup_logger() @@ -99,6 +93,7 @@ def _get_answer_stream_processor( class Answer: def __init__( self, + user_email: str, question: str, answer_style_config: AnswerStyleConfig, llm: LLM, @@ -125,7 +120,7 @@ def __init__( raise ValueError( "Cannot provide both `message_history` and `single_message_history`" ) - + self.user_email = user_email self.question = question self.is_connected: Callable[[], bool] | None = is_connected @@ -166,7 +161,7 @@ def _update_prompt_builder_for_search_tool( ) -> None: if self.answer_style_config.citation_config: prompt_builder.update_system_prompt( - build_citations_system_message(self.prompt_config) + build_citations_system_message(self.prompt_config, self.user_email) ) prompt_builder.update_user_prompt( build_citations_user_message( diff --git a/backend/danswer/llm/answering/prompts/citations_prompt.py b/backend/danswer/llm/answering/prompts/citations_prompt.py index eddae9badb4..21a1edf9015 100644 --- a/backend/danswer/llm/answering/prompts/citations_prompt.py +++ b/backend/danswer/llm/answering/prompts/citations_prompt.py @@ -1,6 +1,3 @@ -from langchain.schema.messages import HumanMessage -from langchain.schema.messages import SystemMessage - from danswer.chat.models import LlmDoc from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from danswer.db.models import Persona @@ -19,16 +16,20 @@ from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT_FOR_TOOL_CALLING from danswer.prompts.prompt_utils import add_date_time_to_prompt +from danswer.prompts.prompt_utils import add_employee_context_to_prompt from danswer.prompts.prompt_utils import build_complete_context_str from danswer.prompts.prompt_utils import build_task_prompt_reminders from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT -from danswer.prompts.token_counts import ( - CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT, -) +from danswer.prompts.token_counts import CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT from danswer.search.models import InferenceChunk +from danswer.utils.logger import setup_logger +from langchain.schema.messages import HumanMessage +from langchain.schema.messages import SystemMessage + +logger = setup_logger() def get_prompt_tokens(prompt_config: PromptConfig) -> int: @@ -117,12 +118,17 @@ def compute_max_llm_input_tokens(llm_config: LLMConfig) -> int: def build_citations_system_message( prompt_config: PromptConfig, + user_email: str | None = None, ) -> SystemMessage: system_prompt = prompt_config.system_prompt.strip() if prompt_config.include_citations: system_prompt += REQUIRE_CITATION_STATEMENT if prompt_config.datetime_aware: system_prompt = add_date_time_to_prompt(prompt_str=system_prompt) + if user_email: + system_prompt = add_employee_context_to_prompt( + prompt_str=system_prompt, user_email=user_email + ) return SystemMessage(content=system_prompt) diff --git a/backend/danswer/prompts/direct_qa_prompts.py b/backend/danswer/prompts/direct_qa_prompts.py index 16768963931..1b7448c081c 100644 --- a/backend/danswer/prompts/direct_qa_prompts.py +++ b/backend/danswer/prompts/direct_qa_prompts.py @@ -2,12 +2,13 @@ # It is used also for the one shot direct QA flow import json -from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT -from danswer.prompts.constants import FINAL_QUERY_PAT -from danswer.prompts.constants import GENERAL_SEP_PAT -from danswer.prompts.constants import QUESTION_PAT -from danswer.prompts.constants import THOUGHT_PAT - +from danswer.prompts.constants import ( + DEFAULT_IGNORE_STATEMENT, + FINAL_QUERY_PAT, + GENERAL_SEP_PAT, + QUESTION_PAT, + THOUGHT_PAT, +) ONE_SHOT_SYSTEM_PROMPT = """ You are a question answering system that is constantly learning and improving. @@ -90,7 +91,8 @@ # similar to the chat flow, but with the option of including a # "conversation history" block CITATIONS_PROMPT = f""" -Refer to the following context documents when responding to me.{DEFAULT_IGNORE_STATEMENT} +Refer to the following context documents when responding to me. \ +Make sure you take into account my employee information in the system message.{DEFAULT_IGNORE_STATEMENT} CONTEXT: {GENERAL_SEP_PAT} {{context_docs_str}} @@ -106,7 +108,8 @@ # NOTE: need to add the extra line about "getting right to the point" since the # tool calling models from OpenAI tend to be more verbose CITATIONS_PROMPT_FOR_TOOL_CALLING = f""" -Refer to the provided context documents when responding to me.{DEFAULT_IGNORE_STATEMENT} \ +Refer to the provided context documents when responding to me. \ +Make sure you take into account my employee information in the system message.{DEFAULT_IGNORE_STATEMENT} \ You should always get right to the point, and never use extraneous language. {{task_prompt}} diff --git a/backend/danswer/prompts/prompt_utils.py b/backend/danswer/prompts/prompt_utils.py index cd59e97061f..0eb41bf1275 100644 --- a/backend/danswer/prompts/prompt_utils.py +++ b/backend/danswer/prompts/prompt_utils.py @@ -2,9 +2,14 @@ from datetime import datetime from typing import cast -from langchain_core.messages import BaseMessage - +import redis from danswer.chat.models import LlmDoc +from danswer.configs.app_configs import AIRTABLE_API_TOKEN +from danswer.configs.app_configs import AIRTABLE_EMPLOYEE_BASE_ID +from danswer.configs.app_configs import AIRTABLE_EMPLOYEE_TABLE_NAME_OR_ID +from danswer.configs.app_configs import REDIS_DB_NUMBER +from danswer.configs.app_configs import REDIS_HOST +from danswer.configs.app_configs import REDIS_PORT from danswer.configs.chat_configs import LANGUAGE_HINT from danswer.configs.constants import DocumentSource from danswer.db.models import Prompt @@ -13,11 +18,16 @@ from danswer.prompts.chat_prompts import CITATION_REMINDER from danswer.prompts.constants import CODE_BLOCK_PAT from danswer.search.models import InferenceChunk +from danswer.utils.logger import setup_logger +from langchain_core.messages import BaseMessage +from pyairtable import Api as AirtableApi +logger = setup_logger() MOST_BASIC_PROMPT = "You are a helpful AI assistant." DANSWER_DATETIME_REPLACEMENT = "DANSWER_DATETIME_REPLACEMENT" BASIC_TIME_STR = "The current date is {datetime_info}." +DANSWER_EMPLOYEE_REPLACEMENT = "DANSWER_EMPLOYEE_REPLACEMENT" def get_current_llm_day_time( @@ -51,8 +61,32 @@ def add_date_time_to_prompt(prompt_str: str) -> str: + " " + BASIC_TIME_STR.format(datetime_info=get_current_llm_day_time()) ) - - +# Initialize Redis client +redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER) + +def add_employee_context_to_prompt(prompt_str: str, user_email: str) -> str: + # Check Redis for cached employee context + cached_context = redis_client.get(user_email) + if cached_context: + logger.info("Employee context retrieved from Redis.") + return prompt_str.replace(DANSWER_EMPLOYEE_REPLACEMENT, cached_context.decode('utf-8')) + + airtable_client = AirtableApi(AIRTABLE_API_TOKEN) + all_employees = airtable_client.table(AIRTABLE_EMPLOYEE_BASE_ID, AIRTABLE_EMPLOYEE_TABLE_NAME_OR_ID).all() + + for employee in all_employees: + if "fields" in employee and "MV Email" in employee["fields"]: + if employee["fields"]["MV Email"] == user_email: + logger.info(f"Employee found: {employee['fields']['Preferred Name']}") + employee_context = f"My Name: {employee['fields']['Preferred Name']}\nMy Title: {employee['fields']['Job Role']}\nMy City Office: {employee['fields']['City Office']}\nMy Division: {employee['fields']['Import: Division']}\nMy Manager: {employee['fields']['Reports To']}\nMy Department: {employee['fields']['Import: Department']}" + + # Store the employee context in Redis with a TTL of 30 days + redis_client.setex(user_email, 30 * 24 * 60 * 60, employee_context) + break + + if DANSWER_EMPLOYEE_REPLACEMENT in prompt_str: + return prompt_str.replace(DANSWER_EMPLOYEE_REPLACEMENT, employee_context) + def build_task_prompt_reminders( prompt: Prompt | PromptConfig, use_language_hint: bool, diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index c7f5983417d..e6c5fce29d1 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -4,16 +4,6 @@ from collections.abc import Callable from collections.abc import Generator -from fastapi import APIRouter -from fastapi import Depends -from fastapi import HTTPException -from fastapi import Request -from fastapi import Response -from fastapi import UploadFile -from fastapi.responses import StreamingResponse -from pydantic import BaseModel -from sqlalchemy.orm import Session - from danswer.auth.users import current_user from danswer.chat.chat_utils import create_chat_chain from danswer.chat.process_message import stream_chat_message @@ -70,6 +60,15 @@ from danswer.server.query_and_chat.models import UpdateChatSessionThreadRequest from danswer.server.query_and_chat.token_limit import check_token_rate_limits from danswer.utils.logger import setup_logger +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from fastapi import Request +from fastapi import Response +from fastapi import UploadFile +from fastapi.responses import StreamingResponse +from pydantic import BaseModel +from sqlalchemy.orm import Session logger = setup_logger()