diff --git a/backend/danswer/danswerbot/slack/handlers/handle_standard_answers.py b/backend/danswer/danswerbot/slack/handlers/handle_standard_answers.py index 58a2101588d..e008e26e1b9 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_standard_answers.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_standard_answers.py @@ -1,43 +1,83 @@ from slack_sdk import WebClient +from slack_sdk.models.blocks import ActionsBlock +from slack_sdk.models.blocks import Block +from slack_sdk.models.blocks import ButtonElement +from slack_sdk.models.blocks import SectionBlock from sqlalchemy.orm import Session +from danswer.configs.constants import MessageType +from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI +from danswer.danswerbot.slack.blocks import get_restate_blocks +from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID +from danswer.danswerbot.slack.handlers.utils import send_team_member_message from danswer.danswerbot.slack.models import SlackMessageInfo +from danswer.danswerbot.slack.utils import respond_in_thread +from danswer.danswerbot.slack.utils import update_emote_react +from danswer.db.chat import create_chat_session +from danswer.db.chat import create_new_chat_message +from danswer.db.chat import get_chat_messages_by_sessions +from danswer.db.chat import get_chat_sessions_by_slack_thread_id +from danswer.db.chat import get_or_create_root_message from danswer.db.models import Prompt from danswer.db.models import SlackBotConfig +from danswer.db.models import StandardAnswer as StandardAnswerModel from danswer.utils.logger import DanswerLoggingAdapter from danswer.utils.logger import setup_logger -from danswer.utils.variable_functionality import fetch_versioned_implementation +from danswer.db.standard_answer import fetch_standard_answer_categories_by_names +from danswer.db.standard_answer import find_matching_standard_answers +from ee.danswer.server.manage.models import StandardAnswer as PydanticStandardAnswer logger = setup_logger() -def handle_standard_answers( - message_info: SlackMessageInfo, - receiver_ids: list[str] | None, - slack_bot_config: SlackBotConfig | None, - prompt: Prompt | None, - logger: DanswerLoggingAdapter, - client: WebClient, - db_session: Session, -) -> bool: - """Returns whether one or more Standard Answer message blocks were - emitted by the Slack bot""" - versioned_handle_standard_answers = fetch_versioned_implementation( - "danswer.danswerbot.slack.handlers.handle_standard_answers", - "_handle_standard_answers", +def build_standard_answer_blocks( + answer_message: str, +) -> list[Block]: + generate_button_block = ButtonElement( + action_id=GENERATE_ANSWER_BUTTON_ACTION_ID, + text="Generate Full Answer", ) - return versioned_handle_standard_answers( - message_info=message_info, - receiver_ids=receiver_ids, - slack_bot_config=slack_bot_config, - prompt=prompt, - logger=logger, - client=client, + answer_block = SectionBlock(text=answer_message) + return [ + answer_block, + ActionsBlock( + elements=[generate_button_block], + ), + ] + + +def oneoff_standard_answers( + message: str, + slack_bot_categories: list[str], + db_session: Session, +) -> list[PydanticStandardAnswer]: + """ + Respond to the user message if it matches any configured standard answers. + + Returns a list of matching StandardAnswers if found, otherwise None. + """ + configured_standard_answers = { + standard_answer + for category in fetch_standard_answer_categories_by_names( + slack_bot_categories, db_session=db_session + ) + for standard_answer in category.standard_answers + } + + matching_standard_answers = find_matching_standard_answers( + query=message, + id_in=[answer.id for answer in configured_standard_answers], db_session=db_session, ) + server_standard_answers = [ + PydanticStandardAnswer.from_model(standard_answer_model) + for (standard_answer_model, _) in matching_standard_answers + ] + return server_standard_answers + -def _handle_standard_answers( +def handle_standard_answers( message_info: SlackMessageInfo, receiver_ids: list[str] | None, slack_bot_config: SlackBotConfig | None, @@ -47,10 +87,152 @@ def _handle_standard_answers( db_session: Session, ) -> bool: """ - Standard Answers are a paid Enterprise Edition feature. This is the fallback - function handling the case where EE features are not enabled. + Potentially respond to the user message depending on whether the user's message matches + any of the configured standard answers and also whether those answers have already been + provided in the current thread. - Always returns false i.e. since EE features are not enabled, we NEVER create any - Slack message blocks. + Returns True if standard answers are found to match the user's message and therefore, + we still need to respond to the users. """ - return False + # if no channel config, then no standard answers are configured + if not slack_bot_config: + return False + + slack_thread_id = message_info.thread_to_respond + configured_standard_answer_categories = ( + slack_bot_config.standard_answer_categories if slack_bot_config else [] + ) + configured_standard_answers = set( + [ + standard_answer + for standard_answer_category in configured_standard_answer_categories + for standard_answer in standard_answer_category.standard_answers + ] + ) + query_msg = message_info.thread_messages[-1] + + if slack_thread_id is None: + used_standard_answer_ids = set([]) + else: + chat_sessions = get_chat_sessions_by_slack_thread_id( + slack_thread_id=slack_thread_id, + user_id=None, + db_session=db_session, + ) + chat_messages = get_chat_messages_by_sessions( + chat_session_ids=[chat_session.id for chat_session in chat_sessions], + user_id=None, + db_session=db_session, + skip_permission_check=True, + ) + used_standard_answer_ids = set( + [ + standard_answer.id + for chat_message in chat_messages + for standard_answer in chat_message.standard_answers + ] + ) + + usable_standard_answers = configured_standard_answers.difference( + used_standard_answer_ids + ) + + matching_standard_answers: list[tuple[StandardAnswerModel, str]] = [] + if usable_standard_answers: + matching_standard_answers = find_matching_standard_answers( + query=query_msg.message, + id_in=[standard_answer.id for standard_answer in usable_standard_answers], + db_session=db_session, + ) + + if matching_standard_answers: + chat_session = create_chat_session( + db_session=db_session, + description="", + user_id=None, + persona_id=slack_bot_config.persona.id if slack_bot_config.persona else 0, + danswerbot_flow=True, + slack_thread_id=slack_thread_id, + one_shot=True, + ) + + root_message = get_or_create_root_message( + chat_session_id=chat_session.id, db_session=db_session + ) + + new_user_message = create_new_chat_message( + chat_session_id=chat_session.id, + parent_message=root_message, + prompt_id=prompt.id if prompt else None, + message=query_msg.message, + token_count=0, + message_type=MessageType.USER, + db_session=db_session, + commit=True, + ) + + formatted_answers = [] + for standard_answer, match_str in matching_standard_answers: + since_you_mentioned_pretext = ( + f'Since your question contains "_{match_str}_"' + ) + block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ") + formatted_answer = f"{since_you_mentioned_pretext}, I thought this might be useful: \n\n{block_quotified_answer}" + formatted_answers.append(formatted_answer) + answer_message = "\n\n".join(formatted_answers) + + _ = create_new_chat_message( + chat_session_id=chat_session.id, + parent_message=new_user_message, + prompt_id=prompt.id if prompt else None, + message=answer_message, + token_count=0, + message_type=MessageType.ASSISTANT, + error=None, + db_session=db_session, + commit=True, + ) + + update_emote_react( + emoji=DANSWER_REACT_EMOJI, + channel=message_info.channel_to_respond, + message_ts=message_info.msg_to_respond, + remove=True, + client=client, + ) + + restate_question_blocks = get_restate_blocks( + msg=query_msg.message, + is_bot_msg=message_info.is_bot_msg, + ) + + answer_blocks = build_standard_answer_blocks( + answer_message=answer_message, + ) + + all_blocks = restate_question_blocks + answer_blocks + + try: + respond_in_thread( + client=client, + channel=message_info.channel_to_respond, + receiver_ids=receiver_ids, + text="Hello! Danswer has some results for you!", + blocks=all_blocks, + thread_ts=message_info.msg_to_respond, + unfurl=False, + ) + + if receiver_ids and slack_thread_id: + send_team_member_message( + client=client, + channel=message_info.channel_to_respond, + thread_ts=slack_thread_id, + ) + + return True + except Exception as e: + logger.exception(f"Unable to send standard answer message: {e}") + return False + else: + return False diff --git a/backend/danswer/db/standard_answer.py b/backend/danswer/db/standard_answer.py new file mode 100644 index 00000000000..0fa074e36a7 --- /dev/null +++ b/backend/danswer/db/standard_answer.py @@ -0,0 +1,279 @@ +import re +import string +from collections.abc import Sequence + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.db.models import StandardAnswer +from danswer.db.models import StandardAnswerCategory +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def check_category_validity(category_name: str) -> bool: + """If a category name is too long, it should not be used (it will cause an error in Postgres + as the unique constraint can only apply to entries that are less than 2704 bytes). + + Additionally, extremely long categories are not really usable / useful.""" + if len(category_name) > 255: + logger.error( + f"Category with name '{category_name}' is too long, cannot be used" + ) + return False + + return True + + +def insert_standard_answer_category( + category_name: str, db_session: Session +) -> StandardAnswerCategory: + if not check_category_validity(category_name): + raise ValueError(f"Invalid category name: {category_name}") + standard_answer_category = StandardAnswerCategory(name=category_name) + db_session.add(standard_answer_category) + db_session.commit() + + return standard_answer_category + + +def insert_standard_answer( + keyword: str, + answer: str, + category_ids: list[int], + match_regex: bool, + match_any_keywords: bool, + db_session: Session, +) -> StandardAnswer: + existing_categories = fetch_standard_answer_categories_by_ids( + standard_answer_category_ids=category_ids, + db_session=db_session, + ) + if len(existing_categories) != len(category_ids): + raise ValueError(f"Some or all categories with ids {category_ids} do not exist") + + standard_answer = StandardAnswer( + keyword=keyword, + answer=answer, + categories=existing_categories, + active=True, + match_regex=match_regex, + match_any_keywords=match_any_keywords, + ) + db_session.add(standard_answer) + db_session.commit() + return standard_answer + + +def update_standard_answer( + standard_answer_id: int, + keyword: str, + answer: str, + category_ids: list[int], + match_regex: bool, + match_any_keywords: bool, + db_session: Session, +) -> StandardAnswer: + standard_answer = db_session.scalar( + select(StandardAnswer).where(StandardAnswer.id == standard_answer_id) + ) + if standard_answer is None: + raise ValueError(f"No standard answer with id {standard_answer_id}") + + existing_categories = fetch_standard_answer_categories_by_ids( + standard_answer_category_ids=category_ids, + db_session=db_session, + ) + if len(existing_categories) != len(category_ids): + raise ValueError(f"Some or all categories with ids {category_ids} do not exist") + + standard_answer.keyword = keyword + standard_answer.answer = answer + standard_answer.categories = list(existing_categories) + standard_answer.match_regex = match_regex + standard_answer.match_any_keywords = match_any_keywords + + db_session.commit() + + return standard_answer + + +def remove_standard_answer( + standard_answer_id: int, + db_session: Session, +) -> None: + standard_answer = db_session.scalar( + select(StandardAnswer).where(StandardAnswer.id == standard_answer_id) + ) + if standard_answer is None: + raise ValueError(f"No standard answer with id {standard_answer_id}") + + standard_answer.active = False + db_session.commit() + + +def update_standard_answer_category( + standard_answer_category_id: int, + category_name: str, + db_session: Session, +) -> StandardAnswerCategory: + standard_answer_category = db_session.scalar( + select(StandardAnswerCategory).where( + StandardAnswerCategory.id == standard_answer_category_id + ) + ) + if standard_answer_category is None: + raise ValueError( + f"No standard answer category with id {standard_answer_category_id}" + ) + + if not check_category_validity(category_name): + raise ValueError(f"Invalid category name: {category_name}") + + standard_answer_category.name = category_name + + db_session.commit() + + return standard_answer_category + + +def fetch_standard_answer_category( + standard_answer_category_id: int, + db_session: Session, +) -> StandardAnswerCategory | None: + return db_session.scalar( + select(StandardAnswerCategory).where( + StandardAnswerCategory.id == standard_answer_category_id + ) + ) + + +def fetch_standard_answer_categories_by_ids( + standard_answer_category_ids: list[int], + db_session: Session, +) -> Sequence[StandardAnswerCategory]: + return db_session.scalars( + select(StandardAnswerCategory).where( + StandardAnswerCategory.id.in_(standard_answer_category_ids) + ) + ).all() + + +def fetch_standard_answer_categories( + db_session: Session, +) -> Sequence[StandardAnswerCategory]: + return db_session.scalars(select(StandardAnswerCategory)).all() + + +def fetch_standard_answer( + standard_answer_id: int, + db_session: Session, +) -> StandardAnswer | None: + return db_session.scalar( + select(StandardAnswer).where(StandardAnswer.id == standard_answer_id) + ) + + +def fetch_standard_answers(db_session: Session) -> Sequence[StandardAnswer]: + return db_session.scalars( + select(StandardAnswer).where(StandardAnswer.active.is_(True)) + ).all() + + +def create_initial_default_standard_answer_category(db_session: Session) -> None: + default_category_id = 0 + default_category_name = "General" + default_category = fetch_standard_answer_category( + standard_answer_category_id=default_category_id, + db_session=db_session, + ) + if default_category is not None: + if default_category.name != default_category_name: + raise ValueError( + "DB is not in a valid initial state. " + "Default standard answer category does not have expected name." + ) + return + + standard_answer_category = StandardAnswerCategory( + id=default_category_id, + name=default_category_name, + ) + db_session.add(standard_answer_category) + db_session.commit() + + +def fetch_standard_answer_categories_by_names( + standard_answer_category_names: list[str], + db_session: Session, +) -> Sequence[StandardAnswerCategory]: + return db_session.scalars( + select(StandardAnswerCategory).where( + StandardAnswerCategory.name.in_(standard_answer_category_names) + ) + ).all() + + +def find_matching_standard_answers( + id_in: list[int], + query: str, + db_session: Session, +) -> list[tuple[StandardAnswer, str]]: + """ + Returns a list of tuples, where each tuple is a StandardAnswer definition matching + the query and a string representing the match (either the regex match group or the + set of keywords). + + If `answer_instance.match_regex` is true, the definition is considered "matched" + if the query matches the `answer_instance.keyword` using `re.search`. + + Otherwise, the definition is considered "matched" if the space-delimited tokens + in `keyword` exists in `query`, depending on the state of `match_any_keywords` + """ + stmt = ( + select(StandardAnswer) + .where(StandardAnswer.active.is_(True)) + .where(StandardAnswer.id.in_(id_in)) + ) + possible_standard_answers: Sequence[StandardAnswer] = db_session.scalars(stmt).all() + + matching_standard_answers: list[tuple[StandardAnswer, str]] = [] + for standard_answer in possible_standard_answers: + if standard_answer.match_regex: + maybe_matches = re.search(standard_answer.keyword, query, re.IGNORECASE) + if maybe_matches is not None: + match_group = maybe_matches.group(0) + matching_standard_answers.append((standard_answer, match_group)) + + else: + # Remove punctuation and split the keyword into individual words + keyword_words = set( + "".join( + char + for char in standard_answer.keyword.lower() + if char not in string.punctuation + ).split() + ) + + # Remove punctuation and split the query into individual words + query_words = "".join( + char for char in query.lower() if char not in string.punctuation + ).split() + + # Check if all of the keyword words are in the query words + if standard_answer.match_any_keywords: + for word in query_words: + if word in keyword_words: + matching_standard_answers.append((standard_answer, word)) + break + else: + if all(word in query_words for word in keyword_words): + matching_standard_answers.append( + ( + standard_answer, + re.sub(r"\s+?", ", ", standard_answer.keyword), + ) + ) + + return matching_standard_answers diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 9a681c39a13..d0c7c4dc710 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -113,6 +113,7 @@ from danswer.server.token_rate_limits.api import ( router as token_rate_limit_settings_router, ) +from danswer.server.manage.standard_answer import router as standard_answer_router from danswer.tools.built_in_tools import auto_add_search_tool_to_personas from danswer.tools.built_in_tools import load_builtin_tools from danswer.tools.built_in_tools import refresh_built_in_tools_cache @@ -484,6 +485,7 @@ def get_application() -> FastAPI: ) include_router_with_global_prefix_prepended(application, chat_router) + include_router_with_global_prefix_prepended(application, standard_answer_router) include_router_with_global_prefix_prepended(application, query_router) include_router_with_global_prefix_prepended(application, document_router) include_router_with_global_prefix_prepended(application, admin_query_router) diff --git a/backend/danswer/server/manage/standard_answer.py b/backend/danswer/server/manage/standard_answer.py new file mode 100644 index 00000000000..a4f00b2270a --- /dev/null +++ b/backend/danswer/server/manage/standard_answer.py @@ -0,0 +1,143 @@ +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from sqlalchemy.orm import Session + +from danswer.auth.users import current_admin_user +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.db.standard_answer import fetch_standard_answer +from danswer.db.standard_answer import fetch_standard_answer_categories +from danswer.db.standard_answer import fetch_standard_answer_category +from danswer.db.standard_answer import fetch_standard_answers +from danswer.db.standard_answer import insert_standard_answer +from danswer.db.standard_answer import insert_standard_answer_category +from danswer.db.standard_answer import remove_standard_answer +from danswer.db.standard_answer import update_standard_answer +from danswer.db.standard_answer import update_standard_answer_category +from ee.danswer.server.manage.models import StandardAnswer +from ee.danswer.server.manage.models import StandardAnswerCategory +from ee.danswer.server.manage.models import StandardAnswerCategoryCreationRequest +from ee.danswer.server.manage.models import StandardAnswerCreationRequest + +router = APIRouter(prefix="/manage") + + +@router.post("/admin/standard-answer") +def create_standard_answer( + standard_answer_creation_request: StandardAnswerCreationRequest, + db_session: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> StandardAnswer: + standard_answer_model = insert_standard_answer( + keyword=standard_answer_creation_request.keyword, + answer=standard_answer_creation_request.answer, + category_ids=standard_answer_creation_request.categories, + match_regex=standard_answer_creation_request.match_regex, + match_any_keywords=standard_answer_creation_request.match_any_keywords, + db_session=db_session, + ) + return StandardAnswer.from_model(standard_answer_model) + + +@router.get("/admin/standard-answer") +def list_standard_answers( + db_session: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> list[StandardAnswer]: + standard_answer_models = fetch_standard_answers(db_session=db_session) + return [ + StandardAnswer.from_model(standard_answer_model) + for standard_answer_model in standard_answer_models + ] + + +@router.patch("/admin/standard-answer/{standard_answer_id}") +def patch_standard_answer( + standard_answer_id: int, + standard_answer_creation_request: StandardAnswerCreationRequest, + db_session: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> StandardAnswer: + existing_standard_answer = fetch_standard_answer( + standard_answer_id=standard_answer_id, + db_session=db_session, + ) + + if existing_standard_answer is None: + raise HTTPException(status_code=404, detail="Standard answer not found") + + standard_answer_model = update_standard_answer( + standard_answer_id=standard_answer_id, + keyword=standard_answer_creation_request.keyword, + answer=standard_answer_creation_request.answer, + category_ids=standard_answer_creation_request.categories, + match_regex=standard_answer_creation_request.match_regex, + match_any_keywords=standard_answer_creation_request.match_any_keywords, + db_session=db_session, + ) + return StandardAnswer.from_model(standard_answer_model) + + +@router.delete("/admin/standard-answer/{standard_answer_id}") +def delete_standard_answer( + standard_answer_id: int, + db_session: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> None: + return remove_standard_answer( + standard_answer_id=standard_answer_id, + db_session=db_session, + ) + + +@router.post("/admin/standard-answer/category") +def create_standard_answer_category( + standard_answer_category_creation_request: StandardAnswerCategoryCreationRequest, + db_session: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> StandardAnswerCategory: + standard_answer_category_model = insert_standard_answer_category( + category_name=standard_answer_category_creation_request.name, + db_session=db_session, + ) + return StandardAnswerCategory.from_model(standard_answer_category_model) + + +@router.get("/admin/standard-answer/category") +def list_standard_answer_categories( + db_session: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> list[StandardAnswerCategory]: + standard_answer_category_models = fetch_standard_answer_categories( + db_session=db_session + ) + return [ + StandardAnswerCategory.from_model(standard_answer_category_model) + for standard_answer_category_model in standard_answer_category_models + ] + + +@router.patch("/admin/standard-answer/category/{standard_answer_category_id}") +def patch_standard_answer_category( + standard_answer_category_id: int, + standard_answer_category_creation_request: StandardAnswerCategoryCreationRequest, + db_session: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> StandardAnswerCategory: + existing_standard_answer_category = fetch_standard_answer_category( + standard_answer_category_id=standard_answer_category_id, + db_session=db_session, + ) + + if existing_standard_answer_category is None: + raise HTTPException( + status_code=404, detail="Standard answer category not found" + ) + + standard_answer_category_model = update_standard_answer_category( + standard_answer_category_id=standard_answer_category_id, + category_name=standard_answer_category_creation_request.name, + db_session=db_session, + ) + return StandardAnswerCategory.from_model(standard_answer_category_model) diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index e20de5a3027..13a4cc7415c 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -40,6 +40,11 @@ from danswer.server.query_and_chat.models import TagResponse from danswer.server.query_and_chat.token_limit import check_token_rate_limits from danswer.utils.logger import setup_logger +from danswer.danswerbot.slack.handlers.handle_standard_answers import ( + oneoff_standard_answers, +) +from ee.danswer.server.query_and_chat.models import StandardAnswerRequest +from ee.danswer.server.query_and_chat.models import StandardAnswerResponse logger = setup_logger() @@ -259,3 +264,20 @@ def get_answer_with_quote( max_history_tokens=0, ) return StreamingResponse(packets, media_type="application/json") + +@basic_router.get("/standard-answer") +def get_standard_answer( + request: StandardAnswerRequest, + db_session: Session = Depends(get_session), + _: User | None = Depends(current_user), +) -> StandardAnswerResponse: + try: + standard_answers = oneoff_standard_answers( + message=request.message, + slack_bot_categories=request.slack_bot_categories, + db_session=db_session, + ) + return StandardAnswerResponse(standard_answers=standard_answers) + except Exception as e: + logger.error(f"Error in get_standard_answer: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="An internal server error occurred")