From dce8c32ebcad0e4f3c710fc261130005316d053c Mon Sep 17 00:00:00 2001 From: Alex Co Date: Tue, 17 Sep 2024 03:56:27 +0000 Subject: [PATCH] matching code with latest upstream Signed-off-by: Alex Co --- backend/Dockerfile | 3 +- backend/Dockerfile.model_server | 18 +- ...60c3401_embedding_model_search_settings.py | 2 +- ...3a51d5_server_default_chosen_assistants.py | 38 +++ ...1c1ac29467_add_tables_for_ui_based_llm_.py | 1 - ...3_match_any_keywords_flag_for_standard_.py | 35 +++ .../703313b75876_add_tokenratelimit_tables.py | 1 - ...298_add_nullable_to_persona_id_in_chat_.py | 31 +++ .../ef7da92f7213_add_files_to_chatmessage.py | 1 - ...76026c_standard_answer_match_regex_flag.py | 32 +++ ...ad14119fb92_delete_tags_with_wrong_enum.py | 1 - backend/danswer/auth/users.py | 15 +- .../danswer/background/celery/celery_app.py | 82 +++++- .../danswer/background/celery/celery_redis.py | 2 +- backend/danswer/chat/process_message.py | 37 +-- backend/danswer/configs/chat_configs.py | 2 +- .../connectors/google_site/connector.py | 8 +- backend/danswer/danswerbot/slack/blocks.py | 17 -- .../slack/handlers/handle_regular_answer.py | 34 ++- .../slack/handlers/handle_standard_answers.py | 215 +++------------- backend/danswer/danswerbot/slack/listener.py | 3 + backend/danswer/db/chat.py | 2 +- backend/danswer/db/engine.py | 3 - backend/danswer/db/llm.py | 16 ++ backend/danswer/db/models.py | 106 ++++---- backend/danswer/db/persona.py | 25 +- backend/danswer/db/search_settings.py | 2 +- backend/danswer/db/slack_bot_config.py | 52 +++- backend/danswer/document_index/vespa/index.py | 2 +- backend/danswer/llm/gpt_4_all.py | 77 ------ backend/danswer/main.py | 10 +- .../one_shot_answer/answer_question.py | 45 ++-- backend/danswer/one_shot_answer/models.py | 53 +++- backend/danswer/redis/redis_pool.py | 54 ++-- backend/danswer/search/models.py | 1 + backend/danswer/server/auth_check.py | 2 + .../danswer/server/features/persona/api.py | 25 ++ backend/danswer/server/manage/models.py | 58 +---- backend/danswer/server/manage/slack_bot.py | 1 + .../server/query_and_chat/chat_backend.py | 2 +- .../danswer/server/query_and_chat/models.py | 6 +- backend/danswer/tools/built_in_tools.py | 1 + backend/danswer/tools/custom/custom_tool.py | 31 ++- backend/danswer/tools/models.py | 9 + backend/danswer/utils/errors.py | 3 + backend/danswer/utils/logger.py | 19 +- .../danswerbot/slack/handlers/__init__.py | 0 .../slack/handlers/handle_standard_answers.py | 238 ++++++++++++++++++ .../{ => ee}/danswer/db/standard_answer.py | 128 ++++++---- backend/ee/danswer/main.py | 2 + .../danswer/server/enterprise_settings/api.py | 87 +++---- backend/ee/danswer/server/manage/models.py | 98 ++++++++ .../danswer/server/manage/standard_answer.py | 30 ++- .../server/query_and_chat/chat_backend.py | 13 +- .../danswer/server/query_and_chat/models.py | 7 +- .../server/query_and_chat/query_backend.py | 30 ++- .../ee/danswer/server/query_and_chat/utils.py | 83 ++++++ .../ee/danswer/server/query_history/api.py | 14 +- backend/ee/danswer/server/seeding.py | 7 + backend/requirements/default.txt | 2 - backend/shared_configs/configs.py | 16 +- backend/supervisord.conf | 5 +- .../integration/common_utils/managers/chat.py | 160 ++++++++++++ .../integration/common_utils/test_models.py | 25 ++ .../connector/test_connector_deletion.py | 2 - .../tests/dev_apis/test_simple_chat_api.py | 1 + .../tests/document_set/test_syncing.py | 91 +++++++ .../streaming_endpoints/test_answer_stream.py | 25 ++ .../streaming_endpoints/test_chat_stream.py | 19 ++ .../tests/usergroup/test_usergroup_syncing.py | 102 ++++++++ backend/tests/unit/danswer/redis_ca.pem | 91 +++++++ backend/tests/unit/danswer/test_redis.py | 39 +++ 72 files changed, 1843 insertions(+), 655 deletions(-) create mode 100644 backend/alembic/versions/35e6853a51d5_server_default_chosen_assistants.py create mode 100644 backend/alembic/versions/5c7fdadae813_match_any_keywords_flag_for_standard_.py create mode 100644 backend/alembic/versions/c99d76fcd298_add_nullable_to_persona_id_in_chat_.py create mode 100644 backend/alembic/versions/efb35676026c_standard_answer_match_regex_flag.py delete mode 100644 backend/danswer/llm/gpt_4_all.py create mode 100644 backend/danswer/utils/errors.py create mode 100644 backend/ee/danswer/danswerbot/slack/handlers/__init__.py create mode 100644 backend/ee/danswer/danswerbot/slack/handlers/handle_standard_answers.py rename backend/{ => ee}/danswer/db/standard_answer.py (72%) create mode 100644 backend/ee/danswer/server/manage/models.py rename backend/{ => ee}/danswer/server/manage/standard_answer.py (79%) create mode 100644 backend/ee/danswer/server/query_and_chat/utils.py create mode 100644 backend/tests/integration/common_utils/managers/chat.py create mode 100644 backend/tests/integration/tests/streaming_endpoints/test_answer_stream.py create mode 100644 backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py create mode 100644 backend/tests/integration/tests/usergroup/test_usergroup_syncing.py create mode 100644 backend/tests/unit/danswer/redis_ca.pem create mode 100644 backend/tests/unit/danswer/test_redis.py diff --git a/backend/Dockerfile b/backend/Dockerfile index fc7bcc586d7..a9dd411c691 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -9,7 +9,8 @@ founders@danswer.ai for more information. Please visit https://github.com/danswe # Default DANSWER_VERSION, typically overriden during builds by GitHub Actions. ARG DANSWER_VERSION=0.3-dev -ENV DANSWER_VERSION=${DANSWER_VERSION} +ENV DANSWER_VERSION=${DANSWER_VERSION} \ + DANSWER_RUNNING_IN_DOCKER="true" RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}" # Install system dependencies diff --git a/backend/Dockerfile.model_server b/backend/Dockerfile.model_server index c915dc688c4..90ded483122 100644 --- a/backend/Dockerfile.model_server +++ b/backend/Dockerfile.model_server @@ -8,7 +8,10 @@ visit https://github.com/danswer-ai/danswer." # Default DANSWER_VERSION, typically overriden during builds by GitHub Actions. ARG DANSWER_VERSION=0.3-dev -ENV DANSWER_VERSION=${DANSWER_VERSION} +ENV DANSWER_VERSION=${DANSWER_VERSION} \ + DANSWER_RUNNING_IN_DOCKER="true" + + RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}" COPY ./requirements/model_server.txt /tmp/requirements.txt @@ -21,16 +24,15 @@ RUN apt-get remove -y --allow-remove-essential perl-base && \ # Download tokenizers, distilbert for the Danswer model # Download model weights # Run Nomic to pull in the custom architecture and have it cached locally -# RUN python -c "from transformers import AutoTokenizer; \ -# from huggingface_hub import snapshot_download; \ -# snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \ -# snapshot_download('nomic-ai/nomic-embed-text-v1'); \ -# from sentence_transformers import SentenceTransformer; \ -# SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);" +RUN python -c "from transformers import AutoTokenizer; \ +AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \ +from huggingface_hub import snapshot_download; \ +snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \ +snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1');" # In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while # running Danswer, don't overwrite it with the built in cache folder -# RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface +RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface WORKDIR /app diff --git a/backend/alembic/versions/1f60f60c3401_embedding_model_search_settings.py b/backend/alembic/versions/1f60f60c3401_embedding_model_search_settings.py index 42f4c22ed78..f5b21c81d8e 100644 --- a/backend/alembic/versions/1f60f60c3401_embedding_model_search_settings.py +++ b/backend/alembic/versions/1f60f60c3401_embedding_model_search_settings.py @@ -30,7 +30,7 @@ def upgrade() -> None: op.add_column( "search_settings", sa.Column( - "multipass_indexing", sa.Boolean(), nullable=False, server_default="true" + "multipass_indexing", sa.Boolean(), nullable=False, server_default="false" ), ) op.add_column( diff --git a/backend/alembic/versions/35e6853a51d5_server_default_chosen_assistants.py b/backend/alembic/versions/35e6853a51d5_server_default_chosen_assistants.py new file mode 100644 index 00000000000..44c29c59a72 --- /dev/null +++ b/backend/alembic/versions/35e6853a51d5_server_default_chosen_assistants.py @@ -0,0 +1,38 @@ +"""server default chosen assistants + +Revision ID: 35e6853a51d5 +Revises: c99d76fcd298 +Create Date: 2024-09-13 13:20:32.885317 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "35e6853a51d5" +down_revision = "c99d76fcd298" +branch_labels = None +depends_on = None + +DEFAULT_ASSISTANTS = [-2, -1, 0] + + +def upgrade() -> None: + op.alter_column( + "user", + "chosen_assistants", + type_=postgresql.JSONB(astext_type=sa.Text()), + nullable=False, + server_default=sa.text(f"'{DEFAULT_ASSISTANTS}'::jsonb"), + ) + + +def downgrade() -> None: + op.alter_column( + "user", + "chosen_assistants", + type_=postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + server_default=None, + ) diff --git a/backend/alembic/versions/401c1ac29467_add_tables_for_ui_based_llm_.py b/backend/alembic/versions/401c1ac29467_add_tables_for_ui_based_llm_.py index 34469b13c17..dcc766fe287 100644 --- a/backend/alembic/versions/401c1ac29467_add_tables_for_ui_based_llm_.py +++ b/backend/alembic/versions/401c1ac29467_add_tables_for_ui_based_llm_.py @@ -5,7 +5,6 @@ Create Date: 2024-04-13 18:07:29.153817 """ - from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql diff --git a/backend/alembic/versions/5c7fdadae813_match_any_keywords_flag_for_standard_.py b/backend/alembic/versions/5c7fdadae813_match_any_keywords_flag_for_standard_.py new file mode 100644 index 00000000000..0e49b603cec --- /dev/null +++ b/backend/alembic/versions/5c7fdadae813_match_any_keywords_flag_for_standard_.py @@ -0,0 +1,35 @@ +"""match_any_keywords flag for standard answers + +Revision ID: 5c7fdadae813 +Revises: efb35676026c +Create Date: 2024-09-13 18:52:59.256478 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "5c7fdadae813" +down_revision = "efb35676026c" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "standard_answer", + sa.Column( + "match_any_keywords", + sa.Boolean(), + nullable=False, + server_default=sa.false(), + ), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("standard_answer", "match_any_keywords") + # ### end Alembic commands ### diff --git a/backend/alembic/versions/703313b75876_add_tokenratelimit_tables.py b/backend/alembic/versions/703313b75876_add_tokenratelimit_tables.py index 7dbc63ce345..ed1993efed3 100644 --- a/backend/alembic/versions/703313b75876_add_tokenratelimit_tables.py +++ b/backend/alembic/versions/703313b75876_add_tokenratelimit_tables.py @@ -5,7 +5,6 @@ Create Date: 2024-04-15 01:36:02.952809 """ - import json from typing import cast from alembic import op diff --git a/backend/alembic/versions/c99d76fcd298_add_nullable_to_persona_id_in_chat_.py b/backend/alembic/versions/c99d76fcd298_add_nullable_to_persona_id_in_chat_.py new file mode 100644 index 00000000000..58fcf482c85 --- /dev/null +++ b/backend/alembic/versions/c99d76fcd298_add_nullable_to_persona_id_in_chat_.py @@ -0,0 +1,31 @@ +"""add nullable to persona id in Chat Session + +Revision ID: c99d76fcd298 +Revises: 5c7fdadae813 +Create Date: 2024-07-09 19:27:01.579697 + +""" + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "c99d76fcd298" +down_revision = "5c7fdadae813" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.alter_column( + "chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True + ) + + +def downgrade() -> None: + op.alter_column( + "chat_session", + "persona_id", + existing_type=sa.INTEGER(), + nullable=False, + ) diff --git a/backend/alembic/versions/ef7da92f7213_add_files_to_chatmessage.py b/backend/alembic/versions/ef7da92f7213_add_files_to_chatmessage.py index 10ec5d93a49..eb04a1b8208 100644 --- a/backend/alembic/versions/ef7da92f7213_add_files_to_chatmessage.py +++ b/backend/alembic/versions/ef7da92f7213_add_files_to_chatmessage.py @@ -5,7 +5,6 @@ Create Date: 2024-04-28 16:59:33.199153 """ - from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql diff --git a/backend/alembic/versions/efb35676026c_standard_answer_match_regex_flag.py b/backend/alembic/versions/efb35676026c_standard_answer_match_regex_flag.py new file mode 100644 index 00000000000..c85bb68a3b9 --- /dev/null +++ b/backend/alembic/versions/efb35676026c_standard_answer_match_regex_flag.py @@ -0,0 +1,32 @@ +"""standard answer match_regex flag + +Revision ID: efb35676026c +Revises: 52a219fb5233 +Create Date: 2024-09-11 13:55:46.101149 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "efb35676026c" +down_revision = "0ebb1d516877" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "standard_answer", + sa.Column( + "match_regex", sa.Boolean(), nullable=False, server_default=sa.false() + ), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("standard_answer", "match_regex") + # ### end Alembic commands ### diff --git a/backend/alembic/versions/fad14119fb92_delete_tags_with_wrong_enum.py b/backend/alembic/versions/fad14119fb92_delete_tags_with_wrong_enum.py index ccd9df6c01b..b9c428640eb 100644 --- a/backend/alembic/versions/fad14119fb92_delete_tags_with_wrong_enum.py +++ b/backend/alembic/versions/fad14119fb92_delete_tags_with_wrong_enum.py @@ -5,7 +5,6 @@ Create Date: 2024-04-25 17:05:09.695703 """ - from alembic import op # revision identifiers, used by Alembic. diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 56d9a99eb33..1776217d39c 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -231,6 +231,7 @@ async def oauth_callback( associate_by_email: bool = False, is_verified_by_default: bool = False, ) -> models.UOAP: + verify_email_in_whitelist(account_email) verify_email_domain(account_email) user = await super().oauth_callback( # type: ignore @@ -267,6 +268,7 @@ async def oauth_callback( ) user.is_verified = is_verified_by_default user.has_web_login = True + return user async def on_after_register( @@ -413,6 +415,7 @@ async def optional_user( async def double_check_user( user: User | None, optional: bool = DISABLE_AUTH, + include_expired: bool = False, ) -> User | None: if optional: return None @@ -429,7 +432,11 @@ async def double_check_user( detail="Access denied. User is not verified.", ) - if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc): + if ( + user.oidc_expiry + and user.oidc_expiry < datetime.now(timezone.utc) + and not include_expired + ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Access denied. User's OIDC token has expired.", @@ -438,6 +445,12 @@ async def double_check_user( return user +async def current_user_with_expired_token( + user: User | None = Depends(optional_user), +) -> User | None: + return await double_check_user(user, include_expired=True) + + async def current_user( user: User | None = Depends(optional_user), ) -> User | None: diff --git a/backend/danswer/background/celery/celery_app.py b/backend/danswer/background/celery/celery_app.py index a48d8aa4a15..9334992423e 100644 --- a/backend/danswer/background/celery/celery_app.py +++ b/backend/danswer/background/celery/celery_app.py @@ -1,4 +1,5 @@ import json +import logging import traceback from datetime import timedelta from typing import Any @@ -6,6 +7,7 @@ import redis from celery import Celery +from celery import current_task from celery import signals from celery import Task from celery.contrib.abortable import AbortableTask # type: ignore @@ -64,6 +66,8 @@ from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import UpdateRequest from danswer.redis.redis_pool import RedisPool +from danswer.utils.logger import ColoredFormatter +from danswer.utils.logger import PlainFormatter from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import fetch_versioned_implementation from danswer.utils.variable_functionality import ( @@ -136,8 +140,7 @@ def cleanup_connector_credential_pair_task( add_deletion_failure_message(db_session, cc_pair.id, error_message) task_logger.exception( f"Failed to run connector_deletion. " - f"connector_id={connector_id} credential_id={credential_id}\n" - f"Stack Trace:\n{stack_trace}" + f"connector_id={connector_id} credential_id={credential_id}" ) raise e @@ -271,6 +274,8 @@ def try_generate_document_set_sync_tasks( return None # don't generate sync tasks if we're up to date + # race condition with the monitor/cleanup function if we use a cached result! + db_session.refresh(document_set) if document_set.is_up_to_date: return None @@ -313,6 +318,8 @@ def try_generate_user_group_sync_tasks( if r.exists(rug.fence_key): return None + # race condition with the monitor/cleanup function if we use a cached result! + db_session.refresh(usergroup) if usergroup.is_up_to_date: return None @@ -883,6 +890,77 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None: r.delete(key) +class CeleryTaskPlainFormatter(PlainFormatter): + def format(self, record: logging.LogRecord) -> str: + task = current_task + if task and task.request: + record.__dict__.update(task_id=task.request.id, task_name=task.name) + record.msg = f"[{task.name}({task.request.id})] {record.msg}" + + return super().format(record) + + +class CeleryTaskColoredFormatter(ColoredFormatter): + def format(self, record: logging.LogRecord) -> str: + task = current_task + if task and task.request: + record.__dict__.update(task_id=task.request.id, task_name=task.name) + record.msg = f"[{task.name}({task.request.id})] {record.msg}" + + return super().format(record) + + +@signals.setup_logging.connect +def on_setup_logging( + loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any +) -> None: + # TODO: could unhardcode format and colorize and accept these as options from + # celery's config + + # reformats celery's worker logger + root_logger = logging.getLogger() + + root_handler = logging.StreamHandler() # Set up a handler for the root logger + root_formatter = ColoredFormatter( + "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + root_handler.setFormatter(root_formatter) + root_logger.addHandler(root_handler) # Apply the handler to the root logger + + if logfile: + root_file_handler = logging.FileHandler(logfile) + root_file_formatter = PlainFormatter( + "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + root_file_handler.setFormatter(root_file_formatter) + root_logger.addHandler(root_file_handler) + + root_logger.setLevel(loglevel) + + # reformats celery's task logger + task_formatter = CeleryTaskColoredFormatter( + "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + task_handler = logging.StreamHandler() # Set up a handler for the task logger + task_handler.setFormatter(task_formatter) + task_logger.addHandler(task_handler) # Apply the handler to the task logger + + if logfile: + task_file_handler = logging.FileHandler(logfile) + task_file_formatter = CeleryTaskPlainFormatter( + "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + task_file_handler.setFormatter(task_file_formatter) + task_logger.addHandler(task_file_handler) + + task_logger.setLevel(loglevel) + task_logger.propagate = False + + ##### # Celery Beat (Periodic Tasks) Settings ##### diff --git a/backend/danswer/background/celery/celery_redis.py b/backend/danswer/background/celery/celery_redis.py index bf82f0a7274..b3132c59a89 100644 --- a/backend/danswer/background/celery/celery_redis.py +++ b/backend/danswer/background/celery/celery_redis.py @@ -124,7 +124,7 @@ def generate_tasks( last_lock_time = time.monotonic() async_results = [] - stmt = construct_document_select_by_docset(self._id) + stmt = construct_document_select_by_docset(self._id, current_only=False) for doc in db_session.scalars(stmt).yield_per(1): current_time = time.monotonic() if current_time - last_lock_time >= ( diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index fa13f245ccb..223f3b5ce47 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -88,6 +88,7 @@ ) from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse from danswer.tools.internet_search.internet_search_tool import InternetSearchTool +from danswer.tools.models import DynamicSchemaInfo from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID from danswer.tools.search.search_tool import SearchResponseSummary @@ -605,7 +606,11 @@ def stream_chat_message_objects( tool_dict[db_tool_model.id] = cast( list[Tool], build_custom_tools_from_openapi_schema( - db_tool_model.openapi_schema + db_tool_model.openapi_schema, + dynamic_schema_info=DynamicSchemaInfo( + chat_session_id=chat_session_id, + message_id=user_message.id if user_message else None, + ), ), ) @@ -670,9 +675,11 @@ def stream_chat_message_objects( db_session=db_session, selected_search_docs=selected_db_search_docs, # Deduping happens at the last step to avoid harming quality by dropping content early on - dedupe_docs=retrieval_options.dedupe_docs - if retrieval_options - else False, + dedupe_docs=( + retrieval_options.dedupe_docs + if retrieval_options + else False + ), ) yield qa_docs_response elif packet.id == SECTION_RELEVANCE_LIST_ID: @@ -781,16 +788,18 @@ def stream_chat_message_objects( if message_specific_citations else None, error=None, - tool_calls=[ - ToolCall( - tool_id=tool_name_to_tool_id[tool_result.tool_name], - tool_name=tool_result.tool_name, - tool_arguments=tool_result.tool_args, - tool_result=tool_result.tool_result, - ) - ] - if tool_result - else [], + tool_calls=( + [ + ToolCall( + tool_id=tool_name_to_tool_id[tool_result.tool_name], + tool_name=tool_result.tool_name, + tool_arguments=tool_result.tool_args, + tool_result=tool_result.tool_result, + ) + ] + if tool_result + else [] + ), ) logger.debug("Committing messages") diff --git a/backend/danswer/configs/chat_configs.py b/backend/danswer/configs/chat_configs.py index b7c10ea36fb..e67e4258fec 100644 --- a/backend/danswer/configs/chat_configs.py +++ b/backend/danswer/configs/chat_configs.py @@ -5,7 +5,7 @@ PERSONAS_YAML = "./danswer/chat/personas.yaml" INPUT_PROMPT_YAML = "./danswer/chat/input_prompts.yaml" -NUM_RETURNED_HITS = os.environ.get("TOOL_SEARCH_NUM_RETURNED_HITS") or 50 +NUM_RETURNED_HITS = 50 # Used for LLM filtering and reranking # We want this to be approximately the number of results we want to show on the first page # It cannot be too large due to cost and latency implications diff --git a/backend/danswer/connectors/google_site/connector.py b/backend/danswer/connectors/google_site/connector.py index 3720ff5f433..9cfcf224e3f 100644 --- a/backend/danswer/connectors/google_site/connector.py +++ b/backend/danswer/connectors/google_site/connector.py @@ -119,11 +119,9 @@ def load_from_state(self) -> GenerateDocumentsOutput: semantic_identifier=title, sections=[ Section( - link=( - (self.base_url.rstrip("/") + "/" + path.lstrip("/")) - if path - else "" - ), + link=(self.base_url.rstrip("/") + "/" + path.lstrip("/")) + if path + else "", text=parsed_html.cleaned_text, ) ], diff --git a/backend/danswer/danswerbot/slack/blocks.py b/backend/danswer/danswerbot/slack/blocks.py index da4a867e233..4107a381554 100644 --- a/backend/danswer/danswerbot/slack/blocks.py +++ b/backend/danswer/danswerbot/slack/blocks.py @@ -25,7 +25,6 @@ from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID -from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID from danswer.danswerbot.slack.icons import source_to_github_img_link @@ -360,22 +359,6 @@ def build_quotes_block( return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))] -def build_standard_answer_blocks( - answer_message: str, -) -> list[Block]: - generate_button_block = ButtonElement( - action_id=GENERATE_ANSWER_BUTTON_ACTION_ID, - text="Generate Full Answer", - ) - answer_block = SectionBlock(text=answer_message) - return [ - answer_block, - ActionsBlock( - elements=[generate_button_block], - ), - ] - - def build_qa_response_blocks( message_id: int | None, answer: str | None, diff --git a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py index 7057d7c2e4b..f1c9bd077cf 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py @@ -5,6 +5,7 @@ from typing import Optional from typing import TypeVar +from fastapi import HTTPException from retry import retry from slack_sdk import WebClient from slack_sdk.models.blocks import DividerBlock @@ -135,7 +136,8 @@ def handle_regular_answer( else slack_bot_config.response_type == SlackBotResponseType.CITATIONS ) - if not message_ts_to_respond_to: + if not message_ts_to_respond_to and not is_bot_msg: + # if the message is not "/danswer" command, then it should have a message ts to respond to raise RuntimeError( "No message timestamp to respond to in `handle_message`. This should never happen." ) @@ -152,15 +154,23 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non with Session(get_sqlalchemy_engine()) as db_session: if len(new_message_request.messages) > 1: - persona = cast( - Persona, - fetch_persona_by_id( - db_session, - new_message_request.persona_id, - user=None, - get_editable=False, - ), - ) + if new_message_request.persona_config: + raise HTTPException( + status_code=403, + detail="Slack bot does not support persona config", + ) + + elif new_message_request.persona_id: + persona = cast( + Persona, + fetch_persona_by_id( + db_session, + new_message_request.persona_id, + user=None, + get_editable=False, + ), + ) + llm, _ = get_llms_for_persona(persona) # In cases of threads, split the available tokens between docs and thread context @@ -470,7 +480,9 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non # For DM (ephemeral message), we need to create a thread via a normal message so the user can see # the ephemeral message. This also will give the user a notification which ephemeral message does not. - if receiver_ids: + # if there is no message_ts_to_respond_to, and we have made it this far, then this is a /danswer message + # so we shouldn't send_team_member_message + if receiver_ids and message_ts_to_respond_to is not None: send_team_member_message( client=client, channel=channel, diff --git a/backend/danswer/danswerbot/slack/handlers/handle_standard_answers.py b/backend/danswer/danswerbot/slack/handlers/handle_standard_answers.py index 8e1663c1a4c..58a2101588d 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_standard_answers.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_standard_answers.py @@ -1,61 +1,43 @@ from slack_sdk import WebClient from sqlalchemy.orm import Session -from danswer.configs.constants import MessageType -from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI -from danswer.danswerbot.slack.blocks import build_standard_answer_blocks -from danswer.danswerbot.slack.blocks import get_restate_blocks -from danswer.danswerbot.slack.handlers.utils import send_team_member_message from danswer.danswerbot.slack.models import SlackMessageInfo -from danswer.danswerbot.slack.utils import respond_in_thread -from danswer.danswerbot.slack.utils import update_emote_react -from danswer.db.chat import create_chat_session -from danswer.db.chat import create_new_chat_message -from danswer.db.chat import get_chat_messages_by_sessions -from danswer.db.chat import get_chat_sessions_by_slack_thread_id -from danswer.db.chat import get_or_create_root_message from danswer.db.models import Prompt from danswer.db.models import SlackBotConfig -from danswer.db.standard_answer import fetch_standard_answer_categories_by_names -from danswer.db.standard_answer import find_matching_standard_answers -from danswer.server.manage.models import StandardAnswer from danswer.utils.logger import DanswerLoggingAdapter from danswer.utils.logger import setup_logger +from danswer.utils.variable_functionality import fetch_versioned_implementation logger = setup_logger() -def oneoff_standard_answers( - message: str, - slack_bot_categories: list[str], +def handle_standard_answers( + message_info: SlackMessageInfo, + receiver_ids: list[str] | None, + slack_bot_config: SlackBotConfig | None, + prompt: Prompt | None, + logger: DanswerLoggingAdapter, + client: WebClient, db_session: Session, -) -> list[StandardAnswer]: - """ - Respond to the user message if it matches any configured standard answers. - - Returns a list of matching StandardAnswers if found, otherwise None. - """ - configured_standard_answers = { - standard_answer - for category in fetch_standard_answer_categories_by_names( - slack_bot_categories, db_session=db_session - ) - for standard_answer in category.standard_answers - } - - matching_standard_answers = find_matching_standard_answers( - query=message, - id_in=[answer.id for answer in configured_standard_answers], +) -> bool: + """Returns whether one or more Standard Answer message blocks were + emitted by the Slack bot""" + versioned_handle_standard_answers = fetch_versioned_implementation( + "danswer.danswerbot.slack.handlers.handle_standard_answers", + "_handle_standard_answers", + ) + return versioned_handle_standard_answers( + message_info=message_info, + receiver_ids=receiver_ids, + slack_bot_config=slack_bot_config, + prompt=prompt, + logger=logger, + client=client, db_session=db_session, ) - server_standard_answers = [ - StandardAnswer.from_model(db_answer) for db_answer in matching_standard_answers - ] - return server_standard_answers - -def handle_standard_answers( +def _handle_standard_answers( message_info: SlackMessageInfo, receiver_ids: list[str] | None, slack_bot_config: SlackBotConfig | None, @@ -65,151 +47,10 @@ def handle_standard_answers( db_session: Session, ) -> bool: """ - Potentially respond to the user message depending on whether the user's message matches - any of the configured standard answers and also whether those answers have already been - provided in the current thread. + Standard Answers are a paid Enterprise Edition feature. This is the fallback + function handling the case where EE features are not enabled. - Returns True if standard answers are found to match the user's message and therefore, - we still need to respond to the users. + Always returns false i.e. since EE features are not enabled, we NEVER create any + Slack message blocks. """ - # if no channel config, then no standard answers are configured - if not slack_bot_config: - return False - - slack_thread_id = message_info.thread_to_respond - configured_standard_answer_categories = ( - slack_bot_config.standard_answer_categories if slack_bot_config else [] - ) - configured_standard_answers = set( - [ - standard_answer - for standard_answer_category in configured_standard_answer_categories - for standard_answer in standard_answer_category.standard_answers - ] - ) - query_msg = message_info.thread_messages[-1] - - if slack_thread_id is None: - used_standard_answer_ids = set([]) - else: - chat_sessions = get_chat_sessions_by_slack_thread_id( - slack_thread_id=slack_thread_id, - user_id=None, - db_session=db_session, - ) - chat_messages = get_chat_messages_by_sessions( - chat_session_ids=[chat_session.id for chat_session in chat_sessions], - user_id=None, - db_session=db_session, - skip_permission_check=True, - ) - used_standard_answer_ids = set( - [ - standard_answer.id - for chat_message in chat_messages - for standard_answer in chat_message.standard_answers - ] - ) - - usable_standard_answers = configured_standard_answers.difference( - used_standard_answer_ids - ) - if usable_standard_answers: - matching_standard_answers = find_matching_standard_answers( - query=query_msg.message, - id_in=[standard_answer.id for standard_answer in usable_standard_answers], - db_session=db_session, - ) - else: - matching_standard_answers = [] - if matching_standard_answers: - chat_session = create_chat_session( - db_session=db_session, - description="", - user_id=None, - persona_id=slack_bot_config.persona.id if slack_bot_config.persona else 0, - danswerbot_flow=True, - slack_thread_id=slack_thread_id, - one_shot=True, - ) - - root_message = get_or_create_root_message( - chat_session_id=chat_session.id, db_session=db_session - ) - - new_user_message = create_new_chat_message( - chat_session_id=chat_session.id, - parent_message=root_message, - prompt_id=prompt.id if prompt else None, - message=query_msg.message, - token_count=0, - message_type=MessageType.USER, - db_session=db_session, - commit=True, - ) - - formatted_answers = [] - for standard_answer in matching_standard_answers: - block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ") - formatted_answer = ( - f'Since you mentioned _"{standard_answer.keyword}"_, ' - f"I thought this might be useful: \n\n{block_quotified_answer}" - ) - formatted_answers.append(formatted_answer) - answer_message = "\n\n".join(formatted_answers) - - _ = create_new_chat_message( - chat_session_id=chat_session.id, - parent_message=new_user_message, - prompt_id=prompt.id if prompt else None, - message=answer_message, - token_count=0, - message_type=MessageType.ASSISTANT, - error=None, - db_session=db_session, - commit=True, - ) - - update_emote_react( - emoji=DANSWER_REACT_EMOJI, - channel=message_info.channel_to_respond, - message_ts=message_info.msg_to_respond, - remove=True, - client=client, - ) - - restate_question_blocks = get_restate_blocks( - msg=query_msg.message, - is_bot_msg=message_info.is_bot_msg, - ) - - answer_blocks = build_standard_answer_blocks( - answer_message=answer_message, - ) - - all_blocks = restate_question_blocks + answer_blocks - - try: - respond_in_thread( - client=client, - channel=message_info.channel_to_respond, - receiver_ids=receiver_ids, - text="Hello! Danswer has some results for you!", - blocks=all_blocks, - thread_ts=message_info.msg_to_respond, - unfurl=False, - ) - - if receiver_ids and slack_thread_id: - send_team_member_message( - client=client, - channel=message_info.channel_to_respond, - thread_ts=slack_thread_id, - ) - - return True - except Exception as e: - logger.exception(f"Unable to send standard answer message: {e}") - return False - else: - return False + return False diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index 63f8bcfcd9c..c430f1b31b7 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -56,6 +56,7 @@ from danswer.search.retrieval.search_runner import download_nltk_data from danswer.server.manage.models import SlackBotTokens from danswer.utils.logger import setup_logger +from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable from shared_configs.configs import MODEL_SERVER_HOST from shared_configs.configs import MODEL_SERVER_PORT from shared_configs.configs import SLACK_CHANNEL_ID @@ -481,6 +482,8 @@ def _initialize_socket_client(socket_client: SocketModeClient) -> None: slack_bot_tokens: SlackBotTokens | None = None socket_client: SocketModeClient | None = None + set_is_ee_based_on_env_variable() + logger.notice("Verifying query preprocessing (NLTK) data is downloaded") download_nltk_data() diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 8485bb4f0ae..8599714ce8b 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -226,7 +226,7 @@ def create_chat_session( db_session: Session, description: str, user_id: UUID | None, - persona_id: int, + persona_id: int | None, # Can be none if temporary persona is used llm_override: LLMOverride | None = None, prompt_override: PromptOverride | None = None, one_shot: bool = False, diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 208dbdba3c1..94b5d0123cc 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -14,7 +14,6 @@ from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import NullPool from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS from danswer.configs.app_configs import LOG_POSTGRES_LATENCY @@ -131,7 +130,6 @@ def init_sqlalchemy_engine(app_name: str) -> None: def get_sqlalchemy_engine() -> Engine: - connect_args = {"sslmode": "disable"} global _SYNC_ENGINE if _SYNC_ENGINE is None: connection_string = build_connection_string( @@ -148,7 +146,6 @@ def get_sqlalchemy_engine() -> Engine: def get_sqlalchemy_async_engine() -> AsyncEngine: - connect_args = {"ssl": "disable"} global _ASYNC_ENGINE if _ASYNC_ENGINE is None: # underlying asyncpg cannot accept application_name directly in the connection string diff --git a/backend/danswer/db/llm.py b/backend/danswer/db/llm.py index a68beadc084..36d05948be5 100644 --- a/backend/danswer/db/llm.py +++ b/backend/danswer/db/llm.py @@ -4,9 +4,11 @@ from sqlalchemy.orm import Session from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel +from danswer.db.models import DocumentSet from danswer.db.models import LLMProvider as LLMProviderModel from danswer.db.models import LLMProvider__UserGroup from danswer.db.models import SearchSettings +from danswer.db.models import Tool as ToolModel from danswer.db.models import User from danswer.db.models import User__UserGroup from danswer.server.manage.embedding.models import CloudEmbeddingProvider @@ -103,6 +105,20 @@ def fetch_existing_embedding_providers( return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all()) +def fetch_existing_doc_sets( + db_session: Session, doc_ids: list[int] +) -> list[DocumentSet]: + return list( + db_session.scalars(select(DocumentSet).where(DocumentSet.id.in_(doc_ids))).all() + ) + + +def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolModel]: + return list( + db_session.scalars(select(ToolModel).where(ToolModel.id.in_(tool_ids))).all() + ) + + def fetch_existing_llm_providers( db_session: Session, user: User | None = None, diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index c0d24770704..79d8206586b 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -122,7 +122,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base): # if specified, controls the assistants that are shown to the user + their order # if not specified, all assistants are shown chosen_assistants: Mapped[list[int]] = mapped_column( - postgresql.JSONB(), nullable=True + postgresql.JSONB(), nullable=False, default=[-2, -1, 0] ) oidc_expiry: Mapped[datetime.datetime] = mapped_column( @@ -866,7 +866,9 @@ class ChatSession(Base): id: Mapped[int] = mapped_column(primary_key=True) user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) - persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id")) + persona_id: Mapped[int | None] = mapped_column( + ForeignKey("persona.id"), nullable=True + ) description: Mapped[str] = mapped_column(Text) # One-shot direct answering, currently the two types of chats are not mixed one_shot: Mapped[bool] = mapped_column(Boolean, default=False) @@ -900,7 +902,6 @@ class ChatSession(Base): prompt_override: Mapped[PromptOverride | None] = mapped_column( PydanticType(PromptOverride), nullable=True ) - time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), @@ -909,7 +910,6 @@ class ChatSession(Base): time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) - user: Mapped[User] = relationship("User", back_populates="chat_sessions") folder: Mapped["ChatFolder"] = relationship( "ChatFolder", back_populates="chat_sessions" @@ -1347,53 +1347,6 @@ class ChannelConfig(TypedDict): follow_up_tags: NotRequired[list[str]] -class StandardAnswerCategory(Base): - __tablename__ = "standard_answer_category" - - id: Mapped[int] = mapped_column(primary_key=True) - name: Mapped[str] = mapped_column(String, unique=True) - standard_answers: Mapped[list["StandardAnswer"]] = relationship( - "StandardAnswer", - secondary=StandardAnswer__StandardAnswerCategory.__table__, - back_populates="categories", - ) - slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship( - "SlackBotConfig", - secondary=SlackBotConfig__StandardAnswerCategory.__table__, - back_populates="standard_answer_categories", - ) - - -class StandardAnswer(Base): - __tablename__ = "standard_answer" - - id: Mapped[int] = mapped_column(primary_key=True) - keyword: Mapped[str] = mapped_column(String) - answer: Mapped[str] = mapped_column(String) - active: Mapped[bool] = mapped_column(Boolean) - - __table_args__ = ( - Index( - "unique_keyword_active", - keyword, - active, - unique=True, - postgresql_where=(active == True), # noqa: E712 - ), - ) - - categories: Mapped[list[StandardAnswerCategory]] = relationship( - "StandardAnswerCategory", - secondary=StandardAnswer__StandardAnswerCategory.__table__, - back_populates="standard_answers", - ) - chat_messages: Mapped[list[ChatMessage]] = relationship( - "ChatMessage", - secondary=ChatMessage__StandardAnswer.__table__, - back_populates="standard_answers", - ) - - class SlackBotResponseType(str, PyEnum): QUOTES = "quotes" CITATIONS = "citations" @@ -1419,7 +1372,7 @@ class SlackBotConfig(Base): ) persona: Mapped[Persona | None] = relationship("Persona") - standard_answer_categories: Mapped[list[StandardAnswerCategory]] = relationship( + standard_answer_categories: Mapped[list["StandardAnswerCategory"]] = relationship( "StandardAnswerCategory", secondary=SlackBotConfig__StandardAnswerCategory.__table__, back_populates="slack_bot_configs", @@ -1649,6 +1602,55 @@ class TokenRateLimit__UserGroup(Base): ) +class StandardAnswerCategory(Base): + __tablename__ = "standard_answer_category" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String, unique=True) + standard_answers: Mapped[list["StandardAnswer"]] = relationship( + "StandardAnswer", + secondary=StandardAnswer__StandardAnswerCategory.__table__, + back_populates="categories", + ) + slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship( + "SlackBotConfig", + secondary=SlackBotConfig__StandardAnswerCategory.__table__, + back_populates="standard_answer_categories", + ) + + +class StandardAnswer(Base): + __tablename__ = "standard_answer" + + id: Mapped[int] = mapped_column(primary_key=True) + keyword: Mapped[str] = mapped_column(String) + answer: Mapped[str] = mapped_column(String) + active: Mapped[bool] = mapped_column(Boolean) + match_regex: Mapped[bool] = mapped_column(Boolean) + match_any_keywords: Mapped[bool] = mapped_column(Boolean) + + __table_args__ = ( + Index( + "unique_keyword_active", + keyword, + active, + unique=True, + postgresql_where=(active == True), # noqa: E712 + ), + ) + + categories: Mapped[list[StandardAnswerCategory]] = relationship( + "StandardAnswerCategory", + secondary=StandardAnswer__StandardAnswerCategory.__table__, + back_populates="standard_answers", + ) + chat_messages: Mapped[list[ChatMessage]] = relationship( + "ChatMessage", + secondary=ChatMessage__StandardAnswer.__table__, + back_populates="standard_answers", + ) + + """Tables related to Permission Sync""" diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index bbf45a1d9ad..3ed3c1230bf 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -210,6 +210,22 @@ def update_persona_shared_users( ) +def update_persona_public_status( + persona_id: int, + is_public: bool, + db_session: Session, + user: User | None, +) -> None: + persona = fetch_persona_by_id( + db_session=db_session, persona_id=persona_id, user=user, get_editable=True + ) + if user and user.role != UserRole.ADMIN and persona.user_id != user.id: + raise ValueError("You don't have permission to modify this persona") + + persona.is_public = is_public + db_session.commit() + + def get_prompts( user_id: UUID | None, db_session: Session, @@ -551,6 +567,7 @@ def update_persona_visibility( persona = fetch_persona_by_id( db_session=db_session, persona_id=persona_id, user=user, get_editable=True ) + persona.is_visible = is_visible db_session.commit() @@ -563,13 +580,15 @@ def validate_persona_tools(tools: list[Tool]) -> None: ) -def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> Sequence[Prompt]: +def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]: """Unsafe, can fetch prompts from all users""" if not prompt_ids: return [] - prompts = db_session.scalars(select(Prompt).where(Prompt.id.in_(prompt_ids))).all() + prompts = db_session.scalars( + select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False)) + ).all() - return prompts + return list(prompts) def get_prompt_by_id( diff --git a/backend/danswer/db/search_settings.py b/backend/danswer/db/search_settings.py index bb869c471dc..e3f35e31007 100644 --- a/backend/danswer/db/search_settings.py +++ b/backend/danswer/db/search_settings.py @@ -183,7 +183,7 @@ def update_current_search_settings( # Whenever we update the current search settings, we should ensure that the local reranking model is warmed up. if ( - current_settings.provider_type is None + search_settings.rerank_provider_type is None and search_settings.rerank_model_name is not None and current_settings.rerank_model_name != search_settings.rerank_model_name ): diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py index 322dc4c4ed9..a37bd18c0ec 100644 --- a/backend/danswer/db/slack_bot_config.py +++ b/backend/danswer/db/slack_bot_config.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session @@ -14,8 +15,11 @@ from danswer.db.persona import get_default_prompt from danswer.db.persona import mark_persona_as_deleted from danswer.db.persona import upsert_persona -from danswer.db.standard_answer import fetch_standard_answer_categories_by_ids from danswer.search.enums import RecencyBiasSetting +from danswer.utils.errors import EERequiredError +from danswer.utils.variable_functionality import ( + fetch_versioned_implementation_with_fallback, +) def _build_persona_name(channel_names: list[str]) -> str: @@ -70,6 +74,10 @@ def create_slack_bot_persona( return persona +def _no_ee_standard_answer_categories(*args: Any, **kwargs: Any) -> list: + return [] + + def insert_slack_bot_config( persona_id: int | None, channel_config: ChannelConfig, @@ -78,14 +86,29 @@ def insert_slack_bot_config( enable_auto_filters: bool, db_session: Session, ) -> SlackBotConfig: - existing_standard_answer_categories = fetch_standard_answer_categories_by_ids( - standard_answer_category_ids=standard_answer_category_ids, - db_session=db_session, + versioned_fetch_standard_answer_categories_by_ids = ( + fetch_versioned_implementation_with_fallback( + "danswer.db.standard_answer", + "fetch_standard_answer_categories_by_ids", + _no_ee_standard_answer_categories, + ) ) - if len(existing_standard_answer_categories) != len(standard_answer_category_ids): - raise ValueError( - f"Some or all categories with ids {standard_answer_category_ids} do not exist" + existing_standard_answer_categories = ( + versioned_fetch_standard_answer_categories_by_ids( + standard_answer_category_ids=standard_answer_category_ids, + db_session=db_session, ) + ) + + if len(existing_standard_answer_categories) != len(standard_answer_category_ids): + if len(existing_standard_answer_categories) == 0: + raise EERequiredError( + "Standard answers are a paid Enterprise Edition feature - enable EE or remove standard answer categories" + ) + else: + raise ValueError( + f"Some or all categories with ids {standard_answer_category_ids} do not exist" + ) slack_bot_config = SlackBotConfig( persona_id=persona_id, @@ -117,9 +140,18 @@ def update_slack_bot_config( f"Unable to find slack bot config with ID {slack_bot_config_id}" ) - existing_standard_answer_categories = fetch_standard_answer_categories_by_ids( - standard_answer_category_ids=standard_answer_category_ids, - db_session=db_session, + versioned_fetch_standard_answer_categories_by_ids = ( + fetch_versioned_implementation_with_fallback( + "danswer.db.standard_answer", + "fetch_standard_answer_categories_by_ids", + _no_ee_standard_answer_categories, + ) + ) + existing_standard_answer_categories = ( + versioned_fetch_standard_answer_categories_by_ids( + standard_answer_category_ids=standard_answer_category_ids, + db_session=db_session, + ) ) if len(existing_standard_answer_categories) != len(standard_answer_category_ids): raise ValueError( diff --git a/backend/danswer/document_index/vespa/index.py b/backend/danswer/document_index/vespa/index.py index 0153f372fd4..c63f4b626be 100644 --- a/backend/danswer/document_index/vespa/index.py +++ b/backend/danswer/document_index/vespa/index.py @@ -120,7 +120,7 @@ def ensure_indices_exist( secondary_index_embedding_dim: int | None, ) -> None: deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate" - logger.debug(f"Sending Vespa zip to {deploy_url}") + logger.info(f"Deploying Vespa application package to {deploy_url}") vespa_schema_path = os.path.join( os.getcwd(), "danswer", "document_index", "vespa", "app_config" diff --git a/backend/danswer/llm/gpt_4_all.py b/backend/danswer/llm/gpt_4_all.py deleted file mode 100644 index c7cf6a61557..00000000000 --- a/backend/danswer/llm/gpt_4_all.py +++ /dev/null @@ -1,77 +0,0 @@ -from collections.abc import Iterator -from typing import Any - -from langchain.schema.language_model import LanguageModelInput - -from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS -from danswer.configs.model_configs import GEN_AI_TEMPERATURE -from danswer.llm.interfaces import LLM -from danswer.llm.utils import convert_lm_input_to_basic_string -from danswer.utils.logger import setup_logger - - -logger = setup_logger() - - -class DummyGPT4All: - """In the case of import failure due to architectural incompatibilities, - this module does not raise exceptions during server startup, - as long as the module isn't actually used""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - raise RuntimeError("GPT4All library not installed.") - - -try: - from gpt4all import GPT4All # type:ignore -except ImportError: - # Setting a low log level because users get scared when they see this - logger.debug( - "GPT4All library not installed. " - "If you wish to run GPT4ALL (in memory) to power Danswer's " - "Generative AI features, please install gpt4all==2.0.2." - ) - GPT4All = DummyGPT4All - - -class DanswerGPT4All(LLM): - """Option to run an LLM locally, however this is significantly slower and - answers tend to be much worse - - NOTE: currently unused, but kept for future reference / if we want to add this back. - """ - - @property - def requires_warm_up(self) -> bool: - """GPT4All models are lazy loaded, load them on server start so that the - first inference isn't extremely delayed""" - return True - - @property - def requires_api_key(self) -> bool: - return False - - def __init__( - self, - timeout: int, - model_version: str, - max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, - temperature: float = GEN_AI_TEMPERATURE, - ): - self.timeout = timeout - self.max_output_tokens = max_output_tokens - self.temperature = temperature - self.gpt4all_model = GPT4All(model_version) - - def log_model_configs(self) -> None: - logger.debug( - f"GPT4All Model: {self.gpt4all_model}, Temperature: {self.temperature}" - ) - - def invoke(self, prompt: LanguageModelInput) -> str: - prompt_basic = convert_lm_input_to_basic_string(prompt) - return self.gpt4all_model.generate(prompt_basic) - - def stream(self, prompt: LanguageModelInput) -> Iterator[str]: - prompt_basic = convert_lm_input_to_basic_string(prompt) - return self.gpt4all_model.generate(prompt_basic, streaming=True) diff --git a/backend/danswer/main.py b/backend/danswer/main.py index a00826f11c8..9a681c39a13 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -62,7 +62,6 @@ from danswer.db.search_settings import get_secondary_search_settings from danswer.db.search_settings import update_current_search_settings from danswer.db.search_settings import update_secondary_search_settings -from danswer.db.standard_answer import create_initial_default_standard_answer_category from danswer.db.swap_index import check_index_swap from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import DocumentIndex @@ -102,7 +101,6 @@ from danswer.server.manage.llm.models import LLMProviderUpsertRequest from danswer.server.manage.search_settings import router as search_settings_router from danswer.server.manage.slack_bot import router as slack_bot_management_router -from danswer.server.manage.standard_answer import router as standard_answer_router from danswer.server.manage.users import router as user_router from danswer.server.middleware.latency_logging import add_latency_logging_middleware from danswer.server.query_and_chat.chat_backend import router as chat_router @@ -126,10 +124,10 @@ from danswer.utils.variable_functionality import fetch_versioned_implementation from danswer.utils.variable_functionality import global_version from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable +from shared_configs.configs import CORS_ALLOWED_ORIGIN from shared_configs.configs import MODEL_SERVER_HOST from shared_configs.configs import MODEL_SERVER_PORT - logger = setup_logger() @@ -187,9 +185,6 @@ def setup_postgres(db_session: Session) -> None: create_initial_default_connector(db_session) associate_default_cc_pair(db_session) - logger.notice("Verifying default standard answer category exists.") - create_initial_default_standard_answer_category(db_session) - logger.notice("Loading default Prompts and Personas") delete_old_default_personas(db_session) load_chat_yamls() @@ -503,7 +498,6 @@ def get_application() -> FastAPI: include_router_with_global_prefix_prepended( application, slack_bot_management_router ) - include_router_with_global_prefix_prepended(application, standard_answer_router) include_router_with_global_prefix_prepended(application, persona_router) include_router_with_global_prefix_prepended(application, admin_persona_router) include_router_with_global_prefix_prepended(application, input_prompt_router) @@ -593,7 +587,7 @@ def get_application() -> FastAPI: application.add_middleware( CORSMiddleware, - allow_origins=["*"], # Change this to the list of allowed origins if needed + allow_origins=CORS_ALLOWED_ORIGIN, # Configurable via environment variable allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 3f83ad19551..f051da82f14 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -26,6 +26,7 @@ from danswer.db.chat import translate_db_search_doc_to_server_search_doc from danswer.db.chat import update_search_docs_table_with_relevance from danswer.db.engine import get_session_context_manager +from danswer.db.models import Persona from danswer.db.models import User from danswer.db.persona import get_prompt_by_id from danswer.llm.answering.answer import Answer @@ -60,7 +61,7 @@ from danswer.tools.tool_runner import ToolCallKickoff from danswer.utils.logger import setup_logger from danswer.utils.timing import log_generator_function_time - +from ee.danswer.server.query_and_chat.utils import create_temporary_persona logger = setup_logger() @@ -118,7 +119,17 @@ def stream_answer_objects( one_shot=True, danswerbot_flow=danswerbot_flow, ) - llm, fast_llm = get_llms_for_persona(persona=chat_session.persona) + + temporary_persona: Persona | None = None + if query_req.persona_config is not None: + new_persona = create_temporary_persona( + db_session=db_session, persona_config=query_req.persona_config, user=user + ) + temporary_persona = new_persona + + persona = temporary_persona if temporary_persona else chat_session.persona + + llm, fast_llm = get_llms_for_persona(persona=persona) llm_tokenizer = get_tokenizer( model_name=llm.config.model_name, @@ -153,11 +164,11 @@ def stream_answer_objects( prompt_id=query_req.prompt_id, user=None, db_session=db_session ) if prompt is None: - if not chat_session.persona.prompts: + if not persona.prompts: raise RuntimeError( "Persona does not have any prompts - this should never happen" ) - prompt = chat_session.persona.prompts[0] + prompt = persona.prompts[0] # Create the first User query message new_user_message = create_new_chat_message( @@ -174,9 +185,7 @@ def stream_answer_objects( prompt_config = PromptConfig.from_model(prompt) document_pruning_config = DocumentPruningConfig( max_chunks=int( - chat_session.persona.num_chunks - if chat_session.persona.num_chunks is not None - else default_num_chunks + persona.num_chunks if persona.num_chunks is not None else default_num_chunks ), max_tokens=max_document_tokens, ) @@ -187,16 +196,16 @@ def stream_answer_objects( evaluation_type=LLMEvaluationType.SKIP if DISABLE_LLM_DOC_RELEVANCE else query_req.evaluation_type, - persona=chat_session.persona, + persona=persona, retrieval_options=query_req.retrieval_options, prompt_config=prompt_config, llm=llm, fast_llm=fast_llm, pruning_config=document_pruning_config, + bypass_acl=bypass_acl, chunks_above=query_req.chunks_above, chunks_below=query_req.chunks_below, full_doc=query_req.full_doc, - bypass_acl=bypass_acl, ) answer_config = AnswerStyleConfig( @@ -209,13 +218,15 @@ def stream_answer_objects( question=query_msg.message, answer_style_config=answer_config, prompt_config=PromptConfig.from_model(prompt), - llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona)), + llm=get_main_llm_from_tuple(get_llms_for_persona(persona=persona)), single_message_history=history_str, - tools=[search_tool], - force_use_tool=ForceUseTool( - force_use=True, - tool_name=search_tool.name, - args={"query": rephrased_query}, + tools=[search_tool] if search_tool else [], + force_use_tool=( + ForceUseTool( + tool_name=search_tool.name, + args={"query": rephrased_query}, + force_use=True, + ) ), # for now, don't use tool calling for this flow, as we haven't # tested quotes with tool calling too much yet @@ -223,9 +234,7 @@ def stream_answer_objects( return_contexts=query_req.return_contexts, skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation, ) - # won't be any ImageGenerationDisplay responses since that tool is never passed in - for packet in cast(AnswerObjectIterator, answer.processed_streamed_output): # for one-shot flow, don't currently do anything with these if isinstance(packet, ToolResponse): @@ -261,6 +270,7 @@ def stream_answer_objects( applied_time_cutoff=search_response_summary.final_filters.time_cutoff, recency_bias_multiplier=search_response_summary.recency_bias_multiplier, ) + yield initial_response elif packet.id == SEARCH_DOC_CONTENT_ID: @@ -287,6 +297,7 @@ def stream_answer_objects( relevance_summary=evaluation_response, ) yield evaluation_response + else: yield packet diff --git a/backend/danswer/one_shot_answer/models.py b/backend/danswer/one_shot_answer/models.py index fceb78de7aa..735fc12bbb9 100644 --- a/backend/danswer/one_shot_answer/models.py +++ b/backend/danswer/one_shot_answer/models.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel from pydantic import Field from pydantic import model_validator @@ -8,6 +10,8 @@ from danswer.chat.models import QADocsResponse from danswer.configs.constants import MessageType from danswer.search.enums import LLMEvaluationType +from danswer.search.enums import RecencyBiasSetting +from danswer.search.enums import SearchType from danswer.search.models import ChunkContext from danswer.search.models import RerankingDetails from danswer.search.models import RetrievalDetails @@ -23,10 +27,49 @@ class ThreadMessage(BaseModel): role: MessageType = MessageType.USER +class PromptConfig(BaseModel): + name: str + description: str = "" + system_prompt: str + task_prompt: str = "" + include_citations: bool = True + datetime_aware: bool = True + + +class DocumentSetConfig(BaseModel): + id: int + + +class ToolConfig(BaseModel): + id: int + + +class PersonaConfig(BaseModel): + name: str + description: str + search_type: SearchType = SearchType.SEMANTIC + num_chunks: float | None = None + llm_relevance_filter: bool = False + llm_filter_extraction: bool = False + recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO + llm_model_provider_override: str | None = None + llm_model_version_override: str | None = None + + prompts: list[PromptConfig] = Field(default_factory=list) + prompt_ids: list[int] = Field(default_factory=list) + + document_set_ids: list[int] = Field(default_factory=list) + tools: list[ToolConfig] = Field(default_factory=list) + tool_ids: list[int] = Field(default_factory=list) + custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list) + + class DirectQARequest(ChunkContext): + persona_config: PersonaConfig | None = None + persona_id: int | None = None + messages: list[ThreadMessage] - prompt_id: int | None - persona_id: int + prompt_id: int | None = None multilingual_query_expansion: list[str] | None = None retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails) rerank_settings: RerankingDetails | None = None @@ -43,6 +86,12 @@ class DirectQARequest(ChunkContext): # If True, skips generative an AI response to the search query skip_gen_ai_answer_generation: bool = False + @model_validator(mode="after") + def check_persona_fields(self) -> "DirectQARequest": + if (self.persona_config is None) == (self.persona_id is None): + raise ValueError("Exactly one of persona_config or persona_id must be set") + return self + @model_validator(mode="after") def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest": if self.chain_of_thought and self.prompt_id is not None: diff --git a/backend/danswer/redis/redis_pool.py b/backend/danswer/redis/redis_pool.py index 1ca2e07ecd3..25b932dbcd3 100644 --- a/backend/danswer/redis/redis_pool.py +++ b/backend/danswer/redis/redis_pool.py @@ -30,30 +30,44 @@ def __new__(cls) -> "RedisPool": return cls._instance def _init_pool(self) -> None: - if REDIS_SSL: - # Examples: https://github.com/redis/redis-py/issues/780 - self._pool = redis.ConnectionPool( - host=REDIS_HOST, - port=REDIS_PORT, - db=REDIS_DB_NUMBER, - password=REDIS_PASSWORD, - max_connections=REDIS_POOL_MAX_CONNECTIONS, - connection_class=redis.SSLConnection, - ssl_ca_certs=REDIS_SSL_CA_CERTS, - ssl_cert_reqs=REDIS_SSL_CERT_REQS, - ) - else: - self._pool = redis.ConnectionPool( - host=REDIS_HOST, - port=REDIS_PORT, - db=REDIS_DB_NUMBER, - password=REDIS_PASSWORD, - max_connections=REDIS_POOL_MAX_CONNECTIONS, - ) + self._pool = RedisPool.create_pool(ssl=REDIS_SSL) def get_client(self) -> Redis: return redis.Redis(connection_pool=self._pool) + @staticmethod + def create_pool( + host: str = REDIS_HOST, + port: int = REDIS_PORT, + db: int = REDIS_DB_NUMBER, + password: str = REDIS_PASSWORD, + max_connections: int = REDIS_POOL_MAX_CONNECTIONS, + ssl_ca_certs: str = REDIS_SSL_CA_CERTS, + ssl_cert_reqs: str = REDIS_SSL_CERT_REQS, + ssl: bool = False, + ) -> redis.ConnectionPool: + # Using ConnectionPool is not well documented. + # Useful examples: https://github.com/redis/redis-py/issues/780 + if ssl: + return redis.ConnectionPool( + host=host, + port=port, + db=db, + password=password, + max_connections=max_connections, + connection_class=redis.SSLConnection, + ssl_ca_certs=ssl_ca_certs, + ssl_cert_reqs=ssl_cert_reqs, + ) + + return redis.ConnectionPool( + host=host, + port=port, + db=db, + password=password, + max_connections=max_connections, + ) + # # Usage example # redis_pool = RedisPool() diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index 678877812a2..503b07653ef 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -84,6 +84,7 @@ def from_db_model(cls, search_settings: SearchSettings) -> "SavedSearchSettings" # Multilingual Expansion multilingual_expansion=search_settings.multilingual_expansion, rerank_api_url=search_settings.rerank_api_url, + disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming, ) diff --git a/backend/danswer/server/auth_check.py b/backend/danswer/server/auth_check.py index 12258eba29b..8a35a560a24 100644 --- a/backend/danswer/server/auth_check.py +++ b/backend/danswer/server/auth_check.py @@ -7,6 +7,7 @@ from danswer.auth.users import current_admin_user from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user +from danswer.auth.users import current_user_with_expired_token from danswer.configs.app_configs import APP_API_PREFIX from danswer.server.danswer_api.ingestion import api_key_dep @@ -96,6 +97,7 @@ def check_router_auth( or depends_fn == current_admin_user or depends_fn == current_curator_or_admin_user or depends_fn == api_key_dep + or depends_fn == current_user_with_expired_token ): found_auth = True break diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index 72b16d719ff..bcc4800b860 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -3,6 +3,7 @@ from fastapi import APIRouter from fastapi import Depends +from fastapi import HTTPException from fastapi import Query from fastapi import UploadFile from pydantic import BaseModel @@ -20,6 +21,7 @@ from danswer.db.persona import mark_persona_as_deleted from danswer.db.persona import mark_persona_as_not_deleted from danswer.db.persona import update_all_personas_display_priority +from danswer.db.persona import update_persona_public_status from danswer.db.persona import update_persona_shared_users from danswer.db.persona import update_persona_visibility from danswer.file_store.file_store import get_default_file_store @@ -43,6 +45,10 @@ class IsVisibleRequest(BaseModel): is_visible: bool +class IsPublicRequest(BaseModel): + is_public: bool + + @admin_router.patch("/{persona_id}/visible") def patch_persona_visibility( persona_id: int, @@ -58,6 +64,25 @@ def patch_persona_visibility( ) +@basic_router.patch("/{persona_id}/public") +def patch_user_presona_public_status( + persona_id: int, + is_public_request: IsPublicRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + try: + update_persona_public_status( + persona_id=persona_id, + is_public=is_public_request.is_public, + db_session=db_session, + user=user, + ) + except ValueError as e: + logger.exception("Failed to update persona public status") + raise HTTPException(status_code=403, detail=str(e)) + + @admin_router.put("/display-priority") def patch_persona_display_priority( display_priority_request: DisplayPriorityRequest, diff --git a/backend/danswer/server/manage/models.py b/backend/danswer/server/manage/models.py index 160c90bdb78..7b0a3813a82 100644 --- a/backend/danswer/server/manage/models.py +++ b/backend/danswer/server/manage/models.py @@ -15,13 +15,12 @@ from danswer.db.models import ChannelConfig from danswer.db.models import SlackBotConfig as SlackBotConfigModel from danswer.db.models import SlackBotResponseType -from danswer.db.models import StandardAnswer as StandardAnswerModel -from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel from danswer.db.models import User from danswer.search.models import SavedSearchSettings from danswer.server.features.persona.models import PersonaSnapshot from danswer.server.models import FullUserSnapshot from danswer.server.models import InvitedUserSnapshot +from ee.danswer.server.manage.models import StandardAnswerCategory if TYPE_CHECKING: @@ -117,58 +116,6 @@ class HiddenUpdateRequest(BaseModel): hidden: bool -class StandardAnswerCategoryCreationRequest(BaseModel): - name: str - - -class StandardAnswerCategory(BaseModel): - id: int - name: str - - @classmethod - def from_model( - cls, standard_answer_category: StandardAnswerCategoryModel - ) -> "StandardAnswerCategory": - return cls( - id=standard_answer_category.id, - name=standard_answer_category.name, - ) - - -class StandardAnswer(BaseModel): - id: int - keyword: str - answer: str - categories: list[StandardAnswerCategory] - - @classmethod - def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer": - return cls( - id=standard_answer_model.id, - keyword=standard_answer_model.keyword, - answer=standard_answer_model.answer, - categories=[ - StandardAnswerCategory.from_model(standard_answer_category_model) - for standard_answer_category_model in standard_answer_model.categories - ], - ) - - -class StandardAnswerCreationRequest(BaseModel): - keyword: str - answer: str - categories: list[int] - - @field_validator("categories", mode="before") - @classmethod - def validate_categories(cls, value: list[int]) -> list[int]: - if len(value) < 1: - raise ValueError( - "At least one category must be attached to a standard answer" - ) - return value - - class SlackBotTokens(BaseModel): bot_token: str app_token: str @@ -194,6 +141,7 @@ class SlackBotConfigCreationRequest(BaseModel): # list of user emails follow_up_tags: list[str] | None = None response_type: SlackBotResponseType + # XXX this is going away soon standard_answer_categories: list[int] = Field(default_factory=list) @field_validator("answer_filters", mode="before") @@ -218,6 +166,7 @@ class SlackBotConfig(BaseModel): persona: PersonaSnapshot | None channel_config: ChannelConfig response_type: SlackBotResponseType + # XXX this is going away soon standard_answer_categories: list[StandardAnswerCategory] enable_auto_filters: bool @@ -236,6 +185,7 @@ def from_model( ), channel_config=slack_bot_config_model.channel_config, response_type=slack_bot_config_model.response_type, + # XXX this is going away soon standard_answer_categories=[ StandardAnswerCategory.from_model(standard_answer_category_model) for standard_answer_category_model in slack_bot_config_model.standard_answer_categories diff --git a/backend/danswer/server/manage/slack_bot.py b/backend/danswer/server/manage/slack_bot.py index 0fb1459072b..9a06b225cce 100644 --- a/backend/danswer/server/manage/slack_bot.py +++ b/backend/danswer/server/manage/slack_bot.py @@ -108,6 +108,7 @@ def create_slack_bot_config( persona_id=persona_id, channel_config=channel_config, response_type=slack_bot_config_creation_request.response_type, + # XXX this is going away soon standard_answer_category_ids=slack_bot_config_creation_request.standard_answer_categories, db_session=db_session, enable_auto_filters=slack_bot_config_creation_request.enable_auto_filters, diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 20ae7124fa1..c7f5983417d 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -164,7 +164,7 @@ def get_chat_session( chat_session_id=session_id, description=chat_session.description, persona_id=chat_session.persona_id, - persona_name=chat_session.persona.name, + persona_name=chat_session.persona.name if chat_session.persona else None, current_alternate_model=chat_session.current_alternate_model, messages=[ translate_db_message_to_chat_message_detail( diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 55d1094ea86..c9109b141c3 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -136,7 +136,7 @@ class RenameChatSessionResponse(BaseModel): class ChatSessionDetails(BaseModel): id: int name: str - persona_id: int + persona_id: int | None = None time_created: str shared_status: ChatSessionSharedStatus folder_id: int | None = None @@ -196,8 +196,8 @@ class SearchSessionDetailResponse(BaseModel): class ChatSessionDetailResponse(BaseModel): chat_session_id: int description: str - persona_id: int - persona_name: str + persona_id: int | None = None + persona_name: str | None messages: list[ChatMessageDetail] time_created: datetime shared_status: ChatSessionSharedStatus diff --git a/backend/danswer/tools/built_in_tools.py b/backend/danswer/tools/built_in_tools.py index 1bfecef7ce4..99b2ae3bbb6 100644 --- a/backend/danswer/tools/built_in_tools.py +++ b/backend/danswer/tools/built_in_tools.py @@ -146,6 +146,7 @@ def auto_add_search_tool_to_personas(db_session: Session) -> None: db_session.commit() logger.notice("Completed adding SearchTool to relevant Personas.") + _built_in_tools_cache: dict[int, Type[Tool]] | None = None diff --git a/backend/danswer/tools/custom/custom_tool.py b/backend/danswer/tools/custom/custom_tool.py index f7cbf236f2b..0272b4ad607 100644 --- a/backend/danswer/tools/custom/custom_tool.py +++ b/backend/danswer/tools/custom/custom_tool.py @@ -24,6 +24,9 @@ from danswer.tools.custom.openapi_parsing import openapi_to_url from danswer.tools.custom.openapi_parsing import REQUEST_BODY from danswer.tools.custom.openapi_parsing import validate_openapi_schema +from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER +from danswer.tools.models import DynamicSchemaInfo +from danswer.tools.models import MESSAGE_ID_PLACEHOLDER from danswer.tools.tool import Tool from danswer.tools.tool import ToolResponse from danswer.utils.logger import setup_logger @@ -39,7 +42,11 @@ class CustomToolCallSummary(BaseModel): class CustomTool(Tool): - def __init__(self, method_spec: MethodSpec, base_url: str) -> None: + def __init__( + self, + method_spec: MethodSpec, + base_url: str, + ) -> None: self._base_url = base_url self._method_spec = method_spec self._tool_definition = self._method_spec.to_tool_definition() @@ -141,6 +148,7 @@ def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: request_body = kwargs.get(REQUEST_BODY) path_params = {} + for path_param_schema in self._method_spec.get_path_param_schemas(): path_params[path_param_schema["name"]] = kwargs[path_param_schema["name"]] @@ -168,8 +176,23 @@ def final_result(self, *args: ToolResponse) -> JSON_ro: def build_custom_tools_from_openapi_schema( - openapi_schema: dict[str, Any] + openapi_schema: dict[str, Any], + dynamic_schema_info: DynamicSchemaInfo | None = None, ) -> list[CustomTool]: + if dynamic_schema_info: + # Process dynamic schema information + schema_str = json.dumps(openapi_schema) + placeholders = { + CHAT_SESSION_ID_PLACEHOLDER: dynamic_schema_info.chat_session_id, + MESSAGE_ID_PLACEHOLDER: dynamic_schema_info.message_id, + } + + for placeholder, value in placeholders.items(): + if value: + schema_str = schema_str.replace(placeholder, str(value)) + + openapi_schema = json.loads(schema_str) + url = openapi_to_url(openapi_schema) method_specs = openapi_to_method_specs(openapi_schema) return [CustomTool(method_spec, url) for method_spec in method_specs] @@ -223,7 +246,9 @@ def build_custom_tools_from_openapi_schema( } validate_openapi_schema(openapi_schema) - tools = build_custom_tools_from_openapi_schema(openapi_schema) + tools = build_custom_tools_from_openapi_schema( + openapi_schema, dynamic_schema_info=None + ) openai_client = openai.OpenAI() response = openai_client.chat.completions.create( diff --git a/backend/danswer/tools/models.py b/backend/danswer/tools/models.py index 052e4293a53..6317a95e2d3 100644 --- a/backend/danswer/tools/models.py +++ b/backend/danswer/tools/models.py @@ -37,3 +37,12 @@ class ToolCallFinalResult(ToolCallKickoff): tool_result: Any = ( None # we would like to use JSON_ro, but can't due to its recursive nature ) + + +class DynamicSchemaInfo(BaseModel): + chat_session_id: int | None + message_id: int | None + + +CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID" +MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID" diff --git a/backend/danswer/utils/errors.py b/backend/danswer/utils/errors.py new file mode 100644 index 00000000000..86b9d4252f3 --- /dev/null +++ b/backend/danswer/utils/errors.py @@ -0,0 +1,3 @@ +class EERequiredError(Exception): + """This error is thrown if an Enterprise Edition feature or API is + requested but the Enterprise Edition flag is not set.""" diff --git a/backend/danswer/utils/logger.py b/backend/danswer/utils/logger.py index 9489a6244ff..96d4ae2a25e 100644 --- a/backend/danswer/utils/logger.py +++ b/backend/danswer/utils/logger.py @@ -80,6 +80,16 @@ def notice(self, msg: Any, *args: Any, **kwargs: Any) -> None: ) +class PlainFormatter(logging.Formatter): + """Adds log levels.""" + + def format(self, record: logging.LogRecord) -> str: + levelname = record.levelname + level_display = f"{levelname}:" + formatted_message = super().format(record) + return f"{level_display.ljust(9)} {formatted_message}" + + class ColoredFormatter(logging.Formatter): """Custom formatter to add colors to log levels.""" @@ -114,6 +124,13 @@ def get_standard_formatter() -> ColoredFormatter: ) +DANSWER_DOCKER_ENV_STR = "DANSWER_RUNNING_IN_DOCKER" + + +def is_running_in_container() -> bool: + return os.getenv(DANSWER_DOCKER_ENV_STR) == "true" + + def setup_logger( name: str = __name__, log_level: int = get_log_level_from_str(), @@ -141,7 +158,7 @@ def setup_logger( uvicorn_logger.addHandler(handler) uvicorn_logger.setLevel(log_level) - is_containerized = os.path.exists("/.dockerenv") + is_containerized = is_running_in_container() if LOG_FILE_NAME and (is_containerized or DEV_LOGGING_ENABLED): log_levels = ["debug", "info", "notice"] for level in log_levels: diff --git a/backend/ee/danswer/danswerbot/slack/handlers/__init__.py b/backend/ee/danswer/danswerbot/slack/handlers/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/ee/danswer/danswerbot/slack/handlers/handle_standard_answers.py b/backend/ee/danswer/danswerbot/slack/handlers/handle_standard_answers.py new file mode 100644 index 00000000000..6807e77135a --- /dev/null +++ b/backend/ee/danswer/danswerbot/slack/handlers/handle_standard_answers.py @@ -0,0 +1,238 @@ +from slack_sdk import WebClient +from slack_sdk.models.blocks import ActionsBlock +from slack_sdk.models.blocks import Block +from slack_sdk.models.blocks import ButtonElement +from slack_sdk.models.blocks import SectionBlock +from sqlalchemy.orm import Session + +from danswer.configs.constants import MessageType +from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI +from danswer.danswerbot.slack.blocks import get_restate_blocks +from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID +from danswer.danswerbot.slack.handlers.utils import send_team_member_message +from danswer.danswerbot.slack.models import SlackMessageInfo +from danswer.danswerbot.slack.utils import respond_in_thread +from danswer.danswerbot.slack.utils import update_emote_react +from danswer.db.chat import create_chat_session +from danswer.db.chat import create_new_chat_message +from danswer.db.chat import get_chat_messages_by_sessions +from danswer.db.chat import get_chat_sessions_by_slack_thread_id +from danswer.db.chat import get_or_create_root_message +from danswer.db.models import Prompt +from danswer.db.models import SlackBotConfig +from danswer.db.models import StandardAnswer as StandardAnswerModel +from danswer.utils.logger import DanswerLoggingAdapter +from danswer.utils.logger import setup_logger +from ee.danswer.db.standard_answer import fetch_standard_answer_categories_by_names +from ee.danswer.db.standard_answer import find_matching_standard_answers +from ee.danswer.server.manage.models import StandardAnswer as PydanticStandardAnswer + +logger = setup_logger() + + +def build_standard_answer_blocks( + answer_message: str, +) -> list[Block]: + generate_button_block = ButtonElement( + action_id=GENERATE_ANSWER_BUTTON_ACTION_ID, + text="Generate Full Answer", + ) + answer_block = SectionBlock(text=answer_message) + return [ + answer_block, + ActionsBlock( + elements=[generate_button_block], + ), + ] + + +def oneoff_standard_answers( + message: str, + slack_bot_categories: list[str], + db_session: Session, +) -> list[PydanticStandardAnswer]: + """ + Respond to the user message if it matches any configured standard answers. + + Returns a list of matching StandardAnswers if found, otherwise None. + """ + configured_standard_answers = { + standard_answer + for category in fetch_standard_answer_categories_by_names( + slack_bot_categories, db_session=db_session + ) + for standard_answer in category.standard_answers + } + + matching_standard_answers = find_matching_standard_answers( + query=message, + id_in=[answer.id for answer in configured_standard_answers], + db_session=db_session, + ) + + server_standard_answers = [ + PydanticStandardAnswer.from_model(standard_answer_model) + for (standard_answer_model, _) in matching_standard_answers + ] + return server_standard_answers + + +def _handle_standard_answers( + message_info: SlackMessageInfo, + receiver_ids: list[str] | None, + slack_bot_config: SlackBotConfig | None, + prompt: Prompt | None, + logger: DanswerLoggingAdapter, + client: WebClient, + db_session: Session, +) -> bool: + """ + Potentially respond to the user message depending on whether the user's message matches + any of the configured standard answers and also whether those answers have already been + provided in the current thread. + + Returns True if standard answers are found to match the user's message and therefore, + we still need to respond to the users. + """ + # if no channel config, then no standard answers are configured + if not slack_bot_config: + return False + + slack_thread_id = message_info.thread_to_respond + configured_standard_answer_categories = ( + slack_bot_config.standard_answer_categories if slack_bot_config else [] + ) + configured_standard_answers = set( + [ + standard_answer + for standard_answer_category in configured_standard_answer_categories + for standard_answer in standard_answer_category.standard_answers + ] + ) + query_msg = message_info.thread_messages[-1] + + if slack_thread_id is None: + used_standard_answer_ids = set([]) + else: + chat_sessions = get_chat_sessions_by_slack_thread_id( + slack_thread_id=slack_thread_id, + user_id=None, + db_session=db_session, + ) + chat_messages = get_chat_messages_by_sessions( + chat_session_ids=[chat_session.id for chat_session in chat_sessions], + user_id=None, + db_session=db_session, + skip_permission_check=True, + ) + used_standard_answer_ids = set( + [ + standard_answer.id + for chat_message in chat_messages + for standard_answer in chat_message.standard_answers + ] + ) + + usable_standard_answers = configured_standard_answers.difference( + used_standard_answer_ids + ) + + matching_standard_answers: list[tuple[StandardAnswerModel, str]] = [] + if usable_standard_answers: + matching_standard_answers = find_matching_standard_answers( + query=query_msg.message, + id_in=[standard_answer.id for standard_answer in usable_standard_answers], + db_session=db_session, + ) + + if matching_standard_answers: + chat_session = create_chat_session( + db_session=db_session, + description="", + user_id=None, + persona_id=slack_bot_config.persona.id if slack_bot_config.persona else 0, + danswerbot_flow=True, + slack_thread_id=slack_thread_id, + one_shot=True, + ) + + root_message = get_or_create_root_message( + chat_session_id=chat_session.id, db_session=db_session + ) + + new_user_message = create_new_chat_message( + chat_session_id=chat_session.id, + parent_message=root_message, + prompt_id=prompt.id if prompt else None, + message=query_msg.message, + token_count=0, + message_type=MessageType.USER, + db_session=db_session, + commit=True, + ) + + formatted_answers = [] + for standard_answer, match_str in matching_standard_answers: + since_you_mentioned_pretext = ( + f'Since your question contains "_{match_str}_"' + ) + block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ") + formatted_answer = f"{since_you_mentioned_pretext}, I thought this might be useful: \n\n{block_quotified_answer}" + formatted_answers.append(formatted_answer) + answer_message = "\n\n".join(formatted_answers) + + _ = create_new_chat_message( + chat_session_id=chat_session.id, + parent_message=new_user_message, + prompt_id=prompt.id if prompt else None, + message=answer_message, + token_count=0, + message_type=MessageType.ASSISTANT, + error=None, + db_session=db_session, + commit=True, + ) + + update_emote_react( + emoji=DANSWER_REACT_EMOJI, + channel=message_info.channel_to_respond, + message_ts=message_info.msg_to_respond, + remove=True, + client=client, + ) + + restate_question_blocks = get_restate_blocks( + msg=query_msg.message, + is_bot_msg=message_info.is_bot_msg, + ) + + answer_blocks = build_standard_answer_blocks( + answer_message=answer_message, + ) + + all_blocks = restate_question_blocks + answer_blocks + + try: + respond_in_thread( + client=client, + channel=message_info.channel_to_respond, + receiver_ids=receiver_ids, + text="Hello! Danswer has some results for you!", + blocks=all_blocks, + thread_ts=message_info.msg_to_respond, + unfurl=False, + ) + + if receiver_ids and slack_thread_id: + send_team_member_message( + client=client, + channel=message_info.channel_to_respond, + thread_ts=slack_thread_id, + ) + + return True + except Exception as e: + logger.exception(f"Unable to send standard answer message: {e}") + return False + else: + return False diff --git a/backend/danswer/db/standard_answer.py b/backend/ee/danswer/db/standard_answer.py similarity index 72% rename from backend/danswer/db/standard_answer.py rename to backend/ee/danswer/db/standard_answer.py index 064a5fa59ef..0fa074e36a7 100644 --- a/backend/danswer/db/standard_answer.py +++ b/backend/ee/danswer/db/standard_answer.py @@ -1,3 +1,4 @@ +import re import string from collections.abc import Sequence @@ -41,6 +42,8 @@ def insert_standard_answer( keyword: str, answer: str, category_ids: list[int], + match_regex: bool, + match_any_keywords: bool, db_session: Session, ) -> StandardAnswer: existing_categories = fetch_standard_answer_categories_by_ids( @@ -55,6 +58,8 @@ def insert_standard_answer( answer=answer, categories=existing_categories, active=True, + match_regex=match_regex, + match_any_keywords=match_any_keywords, ) db_session.add(standard_answer) db_session.commit() @@ -66,6 +71,8 @@ def update_standard_answer( keyword: str, answer: str, category_ids: list[int], + match_regex: bool, + match_any_keywords: bool, db_session: Session, ) -> StandardAnswer: standard_answer = db_session.scalar( @@ -84,6 +91,8 @@ def update_standard_answer( standard_answer.keyword = keyword standard_answer.answer = answer standard_answer.categories = list(existing_categories) + standard_answer.match_regex = match_regex + standard_answer.match_any_keywords = match_any_keywords db_session.commit() @@ -140,17 +149,6 @@ def fetch_standard_answer_category( ) -def fetch_standard_answer_categories_by_names( - standard_answer_category_names: list[str], - db_session: Session, -) -> Sequence[StandardAnswerCategory]: - return db_session.scalars( - select(StandardAnswerCategory).where( - StandardAnswerCategory.name.in_(standard_answer_category_names) - ) - ).all() - - def fetch_standard_answer_categories_by_ids( standard_answer_category_ids: list[int], db_session: Session, @@ -177,39 +175,6 @@ def fetch_standard_answer( ) -def find_matching_standard_answers( - id_in: list[int], - query: str, - db_session: Session, -) -> list[StandardAnswer]: - stmt = ( - select(StandardAnswer) - .where(StandardAnswer.active.is_(True)) - .where(StandardAnswer.id.in_(id_in)) - ) - possible_standard_answers = db_session.scalars(stmt).all() - - matching_standard_answers: list[StandardAnswer] = [] - for standard_answer in possible_standard_answers: - # Remove punctuation and split the keyword into individual words - keyword_words = "".join( - char - for char in standard_answer.keyword.lower() - if char not in string.punctuation - ).split() - - # Remove punctuation and split the query into individual words - query_words = "".join( - char for char in query.lower() if char not in string.punctuation - ).split() - - # Check if all of the keyword words are in the query words - if all(word in query_words for word in keyword_words): - matching_standard_answers.append(standard_answer) - - return matching_standard_answers - - def fetch_standard_answers(db_session: Session) -> Sequence[StandardAnswer]: return db_session.scalars( select(StandardAnswer).where(StandardAnswer.active.is_(True)) @@ -237,3 +202,78 @@ def create_initial_default_standard_answer_category(db_session: Session) -> None ) db_session.add(standard_answer_category) db_session.commit() + + +def fetch_standard_answer_categories_by_names( + standard_answer_category_names: list[str], + db_session: Session, +) -> Sequence[StandardAnswerCategory]: + return db_session.scalars( + select(StandardAnswerCategory).where( + StandardAnswerCategory.name.in_(standard_answer_category_names) + ) + ).all() + + +def find_matching_standard_answers( + id_in: list[int], + query: str, + db_session: Session, +) -> list[tuple[StandardAnswer, str]]: + """ + Returns a list of tuples, where each tuple is a StandardAnswer definition matching + the query and a string representing the match (either the regex match group or the + set of keywords). + + If `answer_instance.match_regex` is true, the definition is considered "matched" + if the query matches the `answer_instance.keyword` using `re.search`. + + Otherwise, the definition is considered "matched" if the space-delimited tokens + in `keyword` exists in `query`, depending on the state of `match_any_keywords` + """ + stmt = ( + select(StandardAnswer) + .where(StandardAnswer.active.is_(True)) + .where(StandardAnswer.id.in_(id_in)) + ) + possible_standard_answers: Sequence[StandardAnswer] = db_session.scalars(stmt).all() + + matching_standard_answers: list[tuple[StandardAnswer, str]] = [] + for standard_answer in possible_standard_answers: + if standard_answer.match_regex: + maybe_matches = re.search(standard_answer.keyword, query, re.IGNORECASE) + if maybe_matches is not None: + match_group = maybe_matches.group(0) + matching_standard_answers.append((standard_answer, match_group)) + + else: + # Remove punctuation and split the keyword into individual words + keyword_words = set( + "".join( + char + for char in standard_answer.keyword.lower() + if char not in string.punctuation + ).split() + ) + + # Remove punctuation and split the query into individual words + query_words = "".join( + char for char in query.lower() if char not in string.punctuation + ).split() + + # Check if all of the keyword words are in the query words + if standard_answer.match_any_keywords: + for word in query_words: + if word in keyword_words: + matching_standard_answers.append((standard_answer, word)) + break + else: + if all(word in query_words for word in keyword_words): + matching_standard_answers.append( + ( + standard_answer, + re.sub(r"\s+?", ", ", standard_answer.keyword), + ) + ) + + return matching_standard_answers diff --git a/backend/ee/danswer/main.py b/backend/ee/danswer/main.py index d7d1d6406a3..7d150107c75 100644 --- a/backend/ee/danswer/main.py +++ b/backend/ee/danswer/main.py @@ -23,6 +23,7 @@ from ee.danswer.server.enterprise_settings.api import ( basic_router as enterprise_settings_router, ) +from ee.danswer.server.manage.standard_answer import router as standard_answer_router from ee.danswer.server.query_and_chat.chat_backend import ( router as chat_router, ) @@ -86,6 +87,7 @@ def get_application() -> FastAPI: # EE only backend APIs include_router_with_global_prefix_prepended(application, query_router) include_router_with_global_prefix_prepended(application, chat_router) + include_router_with_global_prefix_prepended(application, standard_answer_router) # Enterprise-only global settings include_router_with_global_prefix_prepended( application, enterprise_settings_admin_router diff --git a/backend/ee/danswer/server/enterprise_settings/api.py b/backend/ee/danswer/server/enterprise_settings/api.py index 8590fd6c5e7..385adcf689e 100644 --- a/backend/ee/danswer/server/enterprise_settings/api.py +++ b/backend/ee/danswer/server/enterprise_settings/api.py @@ -1,6 +1,6 @@ from datetime import datetime -from datetime import timedelta from datetime import timezone +from typing import Any import httpx from fastapi import APIRouter @@ -9,10 +9,12 @@ from fastapi import Response from fastapi import status from fastapi import UploadFile +from pydantic import BaseModel +from pydantic import Field from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user -from danswer.auth.users import current_user +from danswer.auth.users import current_user_with_expired_token from danswer.auth.users import get_user_manager from danswer.auth.users import UserManager from danswer.db.engine import get_session @@ -28,7 +30,6 @@ from ee.danswer.server.enterprise_settings.store import store_analytics_script from ee.danswer.server.enterprise_settings.store import store_settings from ee.danswer.server.enterprise_settings.store import upload_logo -from shared_configs.configs import CUSTOM_REFRESH_URL admin_router = APIRouter(prefix="/admin/enterprise-settings") basic_router = APIRouter(prefix="/enterprise-settings") @@ -36,69 +37,37 @@ logger = setup_logger() -def mocked_refresh_token() -> dict: - """ - This function mocks the response from a token refresh endpoint. - It generates a mock access token, refresh token, and user information - with an expiration time set to 1 hour from now. - This is useful for testing or development when the actual refresh endpoint is not available. - """ - mock_exp = int((datetime.now() + timedelta(hours=1)).timestamp() * 1000) - data = { - "access_token": "asdf Mock access token", - "refresh_token": "asdf Mock refresh token", - "session": {"exp": mock_exp}, - "userinfo": { - "sub": "Mock email", - "familyName": "Mock name", - "givenName": "Mock name", - "fullName": "Mock name", - "userId": "Mock User ID", - "email": "test_email@danswer.ai", - }, - } - return data - - -@basic_router.get("/refresh-token") -async def refresh_access_token( - user: User = Depends(current_user), - user_manager: UserManager = Depends(get_user_manager), -) -> None: - # return - if CUSTOM_REFRESH_URL is None: - logger.error( - "Custom refresh URL is not set and client is attempting to custom refresh" - ) - raise HTTPException( - status_code=500, - detail="Custom refresh URL is not set", - ) +class RefreshTokenData(BaseModel): + access_token: str + refresh_token: str + session: dict = Field(..., description="Contains session information") + userinfo: dict = Field(..., description="Contains user information") - try: - async with httpx.AsyncClient() as client: - logger.debug(f"Sending request to custom refresh URL for user {user.id}") - access_token = user.oauth_accounts[0].access_token - - response = await client.get( - CUSTOM_REFRESH_URL, - params={"info": "json", "access_token_refresh_interval": 3600}, - headers={"Authorization": f"Bearer {access_token}"}, + def __init__(self, **data: Any) -> None: + super().__init__(**data) + if "exp" not in self.session: + raise ValueError("'exp' must be set in the session dictionary") + if "userId" not in self.userinfo or "email" not in self.userinfo: + raise ValueError( + "'userId' and 'email' must be set in the userinfo dictionary" ) - response.raise_for_status() - data = response.json() - # NOTE: Here is where we can mock the response - # data = mocked_refresh_token() +@basic_router.post("/refresh-token") +async def refresh_access_token( + refresh_token: RefreshTokenData, + user: User = Depends(current_user_with_expired_token), + user_manager: UserManager = Depends(get_user_manager), +) -> None: + try: logger.debug(f"Received response from Meechum auth URL for user {user.id}") # Extract new tokens - new_access_token = data["access_token"] - new_refresh_token = data["refresh_token"] + new_access_token = refresh_token.access_token + new_refresh_token = refresh_token.refresh_token new_expiry = datetime.fromtimestamp( - data["session"]["exp"] / 1000, tz=timezone.utc + refresh_token.session["exp"] / 1000, tz=timezone.utc ) expires_at_timestamp = int(new_expiry.timestamp()) @@ -107,8 +76,8 @@ async def refresh_access_token( await user_manager.oauth_callback( oauth_name="custom", access_token=new_access_token, - account_id=data["userinfo"]["userId"], - account_email=data["userinfo"]["email"], + account_id=refresh_token.userinfo["userId"], + account_email=refresh_token.userinfo["email"], expires_at=expires_at_timestamp, refresh_token=new_refresh_token, associate_by_email=True, diff --git a/backend/ee/danswer/server/manage/models.py b/backend/ee/danswer/server/manage/models.py new file mode 100644 index 00000000000..ae2c401a2fa --- /dev/null +++ b/backend/ee/danswer/server/manage/models.py @@ -0,0 +1,98 @@ +import re +from typing import Any + +from pydantic import BaseModel +from pydantic import field_validator +from pydantic import model_validator + +from danswer.db.models import StandardAnswer as StandardAnswerModel +from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel + + +class StandardAnswerCategoryCreationRequest(BaseModel): + name: str + + +class StandardAnswerCategory(BaseModel): + id: int + name: str + + @classmethod + def from_model( + cls, standard_answer_category: StandardAnswerCategoryModel + ) -> "StandardAnswerCategory": + return cls( + id=standard_answer_category.id, + name=standard_answer_category.name, + ) + + +class StandardAnswer(BaseModel): + id: int + keyword: str + answer: str + categories: list[StandardAnswerCategory] + match_regex: bool + match_any_keywords: bool + + @classmethod + def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer": + return cls( + id=standard_answer_model.id, + keyword=standard_answer_model.keyword, + answer=standard_answer_model.answer, + match_regex=standard_answer_model.match_regex, + match_any_keywords=standard_answer_model.match_any_keywords, + categories=[ + StandardAnswerCategory.from_model(standard_answer_category_model) + for standard_answer_category_model in standard_answer_model.categories + ], + ) + + +class StandardAnswerCreationRequest(BaseModel): + keyword: str + answer: str + categories: list[int] + match_regex: bool + match_any_keywords: bool + + @field_validator("categories", mode="before") + @classmethod + def validate_categories(cls, value: list[int]) -> list[int]: + if len(value) < 1: + raise ValueError( + "At least one category must be attached to a standard answer" + ) + return value + + @model_validator(mode="after") + def validate_only_match_any_if_not_regex(self) -> Any: + if self.match_regex and self.match_any_keywords: + raise ValueError( + "Can only match any keywords in keyword mode, not regex mode" + ) + + return self + + @model_validator(mode="after") + def validate_keyword_if_regex(self) -> Any: + if not self.match_regex: + # no validation for keywords + return self + + try: + re.compile(self.keyword) + return self + except re.error as err: + if isinstance(err.pattern, bytes): + raise ValueError( + f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}' + ) + else: + pattern = f'r"{err.pattern}"' if err.pattern is not None else "" + raise ValueError( + " ".join( + ["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"] + ) + ) diff --git a/backend/danswer/server/manage/standard_answer.py b/backend/ee/danswer/server/manage/standard_answer.py similarity index 79% rename from backend/danswer/server/manage/standard_answer.py rename to backend/ee/danswer/server/manage/standard_answer.py index 69f9e8146df..e832fa19078 100644 --- a/backend/danswer/server/manage/standard_answer.py +++ b/backend/ee/danswer/server/manage/standard_answer.py @@ -6,19 +6,19 @@ from danswer.auth.users import current_admin_user from danswer.db.engine import get_session from danswer.db.models import User -from danswer.db.standard_answer import fetch_standard_answer -from danswer.db.standard_answer import fetch_standard_answer_categories -from danswer.db.standard_answer import fetch_standard_answer_category -from danswer.db.standard_answer import fetch_standard_answers -from danswer.db.standard_answer import insert_standard_answer -from danswer.db.standard_answer import insert_standard_answer_category -from danswer.db.standard_answer import remove_standard_answer -from danswer.db.standard_answer import update_standard_answer -from danswer.db.standard_answer import update_standard_answer_category -from danswer.server.manage.models import StandardAnswer -from danswer.server.manage.models import StandardAnswerCategory -from danswer.server.manage.models import StandardAnswerCategoryCreationRequest -from danswer.server.manage.models import StandardAnswerCreationRequest +from ee.danswer.db.standard_answer import fetch_standard_answer +from ee.danswer.db.standard_answer import fetch_standard_answer_categories +from ee.danswer.db.standard_answer import fetch_standard_answer_category +from ee.danswer.db.standard_answer import fetch_standard_answers +from ee.danswer.db.standard_answer import insert_standard_answer +from ee.danswer.db.standard_answer import insert_standard_answer_category +from ee.danswer.db.standard_answer import remove_standard_answer +from ee.danswer.db.standard_answer import update_standard_answer +from ee.danswer.db.standard_answer import update_standard_answer_category +from ee.danswer.server.manage.models import StandardAnswer +from ee.danswer.server.manage.models import StandardAnswerCategory +from ee.danswer.server.manage.models import StandardAnswerCategoryCreationRequest +from ee.danswer.server.manage.models import StandardAnswerCreationRequest router = APIRouter(prefix="/manage") @@ -33,6 +33,8 @@ def create_standard_answer( keyword=standard_answer_creation_request.keyword, answer=standard_answer_creation_request.answer, category_ids=standard_answer_creation_request.categories, + match_regex=standard_answer_creation_request.match_regex, + match_any_keywords=standard_answer_creation_request.match_any_keywords, db_session=db_session, ) return StandardAnswer.from_model(standard_answer_model) @@ -70,6 +72,8 @@ def patch_standard_answer( keyword=standard_answer_creation_request.keyword, answer=standard_answer_creation_request.answer, category_ids=standard_answer_creation_request.categories, + match_regex=standard_answer_creation_request.match_regex, + match_any_keywords=standard_answer_creation_request.match_any_keywords, db_session=db_session, ) return StandardAnswer.from_model(standard_answer_model) diff --git a/backend/ee/danswer/server/query_and_chat/chat_backend.py b/backend/ee/danswer/server/query_and_chat/chat_backend.py index 55561982325..1e163942533 100644 --- a/backend/ee/danswer/server/query_and_chat/chat_backend.py +++ b/backend/ee/danswer/server/query_and_chat/chat_backend.py @@ -28,6 +28,7 @@ from danswer.one_shot_answer.qa_utils import combine_message_thread from danswer.search.models import OptionalSearchSetting from danswer.search.models import RetrievalDetails +from danswer.search.models import SavedSearchDoc from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.query_and_chat.models import CreateChatMessageRequest @@ -65,18 +66,18 @@ def _translate_doc_response_to_simple_doc( def _get_final_context_doc_indices( final_context_docs: list[LlmDoc] | None, - simple_search_docs: list[SimpleDoc] | None, + top_docs: list[SavedSearchDoc] | None, ) -> list[int] | None: """ this function returns a list of indices of the simple search docs that were actually fed to the LLM. """ - if final_context_docs is None or simple_search_docs is None: + if final_context_docs is None or top_docs is None: return None final_context_doc_ids = {doc.document_id for doc in final_context_docs} return [ - i for i, doc in enumerate(simple_search_docs) if doc.id in final_context_doc_ids + i for i, doc in enumerate(top_docs) if doc.document_id in final_context_doc_ids ] @@ -148,6 +149,7 @@ def handle_simplified_chat_message( answer += packet.answer_piece elif isinstance(packet, QADocsResponse): response.simple_search_docs = _translate_doc_response_to_simple_doc(packet) + response.top_documents = packet.top_documents elif isinstance(packet, StreamingError): response.error_msg = packet.error elif isinstance(packet, ChatMessageDetail): @@ -161,7 +163,7 @@ def handle_simplified_chat_message( } response.final_context_doc_indices = _get_final_context_doc_indices( - final_context_docs, response.simple_search_docs + final_context_docs, response.top_documents ) response.answer = answer @@ -296,6 +298,7 @@ def handle_send_message_simple_with_history( answer += packet.answer_piece elif isinstance(packet, QADocsResponse): response.simple_search_docs = _translate_doc_response_to_simple_doc(packet) + response.top_documents = packet.top_documents elif isinstance(packet, StreamingError): response.error_msg = packet.error elif isinstance(packet, ChatMessageDetail): @@ -311,7 +314,7 @@ def handle_send_message_simple_with_history( } response.final_context_doc_indices = _get_final_context_doc_indices( - final_context_docs, response.simple_search_docs + final_context_docs, response.top_documents ) response.answer = answer diff --git a/backend/ee/danswer/server/query_and_chat/models.py b/backend/ee/danswer/server/query_and_chat/models.py index b1ea648c8f0..be1cd3c6ef6 100644 --- a/backend/ee/danswer/server/query_and_chat/models.py +++ b/backend/ee/danswer/server/query_and_chat/models.py @@ -8,7 +8,8 @@ from danswer.search.models import ChunkContext from danswer.search.models import RerankingDetails from danswer.search.models import RetrievalDetails -from danswer.server.manage.models import StandardAnswer +from danswer.search.models import SavedSearchDoc +from ee.danswer.server.manage.models import StandardAnswer class StandardAnswerRequest(BaseModel): @@ -73,7 +74,11 @@ class ChatBasicResponse(BaseModel): # This is built piece by piece, any of these can be None as the flow could break answer: str | None = None answer_citationless: str | None = None + + # TODO: deprecate `simple_search_docs` simple_search_docs: list[SimpleDoc] | None = None + top_documents: list[SavedSearchDoc] | None = None + error_msg: str | None = None message_id: int | None = None llm_selected_doc_indices: list[int] | None = None diff --git a/backend/ee/danswer/server/query_and_chat/query_backend.py b/backend/ee/danswer/server/query_and_chat/query_backend.py index aef3648220e..59e61ba12df 100644 --- a/backend/ee/danswer/server/query_and_chat/query_backend.py +++ b/backend/ee/danswer/server/query_and_chat/query_backend.py @@ -6,9 +6,6 @@ from danswer.auth.users import current_user from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE -from danswer.danswerbot.slack.handlers.handle_standard_answers import ( - oneoff_standard_answers, -) from danswer.db.engine import get_session from danswer.db.models import User from danswer.db.persona import get_persona_by_id @@ -29,9 +26,13 @@ from danswer.search.utils import drop_llm_indices from danswer.search.utils import relevant_sections_to_indices from danswer.utils.logger import setup_logger +from ee.danswer.danswerbot.slack.handlers.handle_standard_answers import ( + oneoff_standard_answers, +) from ee.danswer.server.query_and_chat.models import DocumentSearchRequest from ee.danswer.server.query_and_chat.models import StandardAnswerRequest from ee.danswer.server.query_and_chat.models import StandardAnswerResponse +from ee.danswer.server.query_and_chat.utils import create_temporary_persona logger = setup_logger() @@ -133,12 +134,23 @@ def get_answer_with_quote( query = query_request.messages[0].message logger.notice(f"Received query for one shot answer API with quotes: {query}") - persona = get_persona_by_id( - persona_id=query_request.persona_id, - user=user, - db_session=db_session, - is_for_edit=False, - ) + if query_request.persona_config is not None: + new_persona = create_temporary_persona( + db_session=db_session, + persona_config=query_request.persona_config, + user=user, + ) + persona = new_persona + + elif query_request.persona_id is not None: + persona = get_persona_by_id( + persona_id=query_request.persona_id, + user=user, + db_session=db_session, + is_for_edit=False, + ) + else: + raise KeyError("Must provide persona ID or Persona Config") llm = get_main_llm_from_tuple( get_default_llms() if not persona else get_llms_for_persona(persona) diff --git a/backend/ee/danswer/server/query_and_chat/utils.py b/backend/ee/danswer/server/query_and_chat/utils.py new file mode 100644 index 00000000000..beb970fd1b8 --- /dev/null +++ b/backend/ee/danswer/server/query_and_chat/utils.py @@ -0,0 +1,83 @@ +from typing import cast + +from fastapi import HTTPException +from sqlalchemy.orm import Session + +from danswer.auth.users import is_user_admin +from danswer.db.llm import fetch_existing_doc_sets +from danswer.db.llm import fetch_existing_tools +from danswer.db.models import Persona +from danswer.db.models import Prompt +from danswer.db.models import Tool +from danswer.db.models import User +from danswer.db.persona import get_prompts_by_ids +from danswer.one_shot_answer.models import PersonaConfig +from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema + + +def create_temporary_persona( + persona_config: PersonaConfig, db_session: Session, user: User | None = None +) -> Persona: + if not is_user_admin(user): + raise HTTPException( + status_code=403, + detail="User is not authorized to create a persona in one shot queries", + ) + + """Create a temporary Persona object from the provided configuration.""" + persona = Persona( + name=persona_config.name, + description=persona_config.description, + num_chunks=persona_config.num_chunks, + llm_relevance_filter=persona_config.llm_relevance_filter, + llm_filter_extraction=persona_config.llm_filter_extraction, + recency_bias=persona_config.recency_bias, + llm_model_provider_override=persona_config.llm_model_provider_override, + llm_model_version_override=persona_config.llm_model_version_override, + ) + + if persona_config.prompts: + persona.prompts = [ + Prompt( + name=p.name, + description=p.description, + system_prompt=p.system_prompt, + task_prompt=p.task_prompt, + include_citations=p.include_citations, + datetime_aware=p.datetime_aware, + ) + for p in persona_config.prompts + ] + elif persona_config.prompt_ids: + persona.prompts = get_prompts_by_ids( + db_session=db_session, prompt_ids=persona_config.prompt_ids + ) + + persona.tools = [] + if persona_config.custom_tools_openapi: + for schema in persona_config.custom_tools_openapi: + tools = cast( + list[Tool], + build_custom_tools_from_openapi_schema(schema), + ) + persona.tools.extend(tools) + + if persona_config.tools: + tool_ids = [tool.id for tool in persona_config.tools] + persona.tools.extend( + fetch_existing_tools(db_session=db_session, tool_ids=tool_ids) + ) + + if persona_config.tool_ids: + persona.tools.extend( + fetch_existing_tools( + db_session=db_session, tool_ids=persona_config.tool_ids + ) + ) + + fetched_docs = fetch_existing_doc_sets( + db_session=db_session, doc_ids=persona_config.document_set_ids + ) + persona.document_sets = fetched_docs + + return persona diff --git a/backend/ee/danswer/server/query_history/api.py b/backend/ee/danswer/server/query_history/api.py index ed532a85603..dbdf3d8bc40 100644 --- a/backend/ee/danswer/server/query_history/api.py +++ b/backend/ee/danswer/server/query_history/api.py @@ -87,7 +87,7 @@ class ChatSessionMinimal(BaseModel): name: str | None first_user_message: str first_ai_message: str - persona_name: str + persona_name: str | None time_created: datetime feedback_type: QAFeedbackType | Literal["mixed"] | None @@ -97,7 +97,7 @@ class ChatSessionSnapshot(BaseModel): user_email: str name: str | None messages: list[MessageSnapshot] - persona_name: str + persona_name: str | None time_created: datetime @@ -111,7 +111,7 @@ class QuestionAnswerPairSnapshot(BaseModel): retrieved_documents: list[AbridgedSearchDoc] feedback_type: QAFeedbackType | None feedback_text: str | None - persona_name: str + persona_name: str | None user_email: str time_created: datetime @@ -145,7 +145,7 @@ def from_chat_session_snapshot( for ind, (user_message, ai_message) in enumerate(message_pairs) ] - def to_json(self) -> dict[str, str]: + def to_json(self) -> dict[str, str | None]: return { "chat_session_id": str(self.chat_session_id), "message_pair_num": str(self.message_pair_num), @@ -235,7 +235,9 @@ def fetch_and_process_chat_session_history_minimal( name=chat_session.description, first_user_message=first_user_message, first_ai_message=first_ai_message, - persona_name=chat_session.persona.name, + persona_name=chat_session.persona.name + if chat_session.persona + else None, time_created=chat_session.time_created, feedback_type=feedback_type, ) @@ -300,7 +302,7 @@ def snapshot_from_chat_session( for message in messages if message.message_type != MessageType.SYSTEM ], - persona_name=chat_session.persona.name, + persona_name=chat_session.persona.name if chat_session.persona else None, time_created=chat_session.time_created, ) diff --git a/backend/ee/danswer/server/seeding.py b/backend/ee/danswer/server/seeding.py index 10dc1afb972..ab6c4b017f9 100644 --- a/backend/ee/danswer/server/seeding.py +++ b/backend/ee/danswer/server/seeding.py @@ -13,6 +13,9 @@ from danswer.server.settings.models import Settings from danswer.server.settings.store import store_settings as store_base_settings from danswer.utils.logger import setup_logger +from ee.danswer.db.standard_answer import ( + create_initial_default_standard_answer_category, +) from ee.danswer.server.enterprise_settings.models import AnalyticsScriptUpload from ee.danswer.server.enterprise_settings.models import EnterpriseSettings from ee.danswer.server.enterprise_settings.store import store_analytics_script @@ -21,6 +24,7 @@ ) from ee.danswer.server.enterprise_settings.store import upload_logo + logger = setup_logger() _SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION" @@ -146,3 +150,6 @@ def seed_db() -> None: _seed_logo(db_session, seed_config.seeded_logo_path) _seed_enterprise_settings(seed_config) _seed_analytics_script(seed_config) + + logger.notice("Verifying default standard answer category exists.") + create_initial_default_standard_answer_category(db_session) diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 82a1ee320c9..5b9d57b9d35 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -5,13 +5,11 @@ atlassian-python-api==3.37.0 beautifulsoup4==4.12.2 boto3==1.34.84 celery==5.3.4 -boto3==1.34.84 chardet==5.2.0 dask==2023.8.1 ddtrace==2.6.5 distributed==2023.8.1 fastapi==0.109.2 -fastapi-health==0.4.0 fastapi-users==12.1.3 fastapi-users-db-sqlalchemy==5.0.0 filelock==3.15.4 diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index fe933227009..ea37b031c7a 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -1,4 +1,5 @@ import os +from urllib.parse import urlparse # Used for logging SLACK_CHANNEL_ID = "channel_id" @@ -74,4 +75,17 @@ "query_prefix", ] -CUSTOM_REFRESH_URL = os.environ.get("CUSTOM_REFRESH_URL") or "/settings/refresh-token" + +# CORS +def validate_cors_origin(origin: str) -> None: + parsed = urlparse(origin) + if parsed.scheme not in ["http", "https"] or not parsed.netloc: + raise ValueError(f"Invalid CORS origin: '{origin}'") + + +CORS_ALLOWED_ORIGIN = os.environ.get("CORS_ALLOWED_ORIGIN", "*").split(",") or ["*"] + +# Validate non-wildcard origins +for origin in CORS_ALLOWED_ORIGIN: + if origin != "*" and (stripped_origin := origin.strip()): + validate_cors_origin(stripped_origin) diff --git a/backend/supervisord.conf b/backend/supervisord.conf index 697866b6c0a..3fa466c045c 100644 --- a/backend/supervisord.conf +++ b/backend/supervisord.conf @@ -12,6 +12,7 @@ command=python danswer/background/update.py redirect_stderr=true autorestart=true + # Background jobs that must be run async due to long time to completion # NOTE: due to an issue with Celery + SQLAlchemy # (https://github.com/celery/celery/issues/7007#issuecomment-1740139367) @@ -37,11 +38,9 @@ autorestart=true # Job scheduler for periodic tasks [program:celery_beat] command=celery -A danswer.background.celery.celery_run:celery_app beat - --loglevel=INFO --logfile=/var/log/celery_beat_supervisor.log environment=LOG_FILE_NAME=celery_beat redirect_stderr=true -autorestart=true # Listens for Slack messages and responds with answers # for all channels that the DanswerBot has been added to. @@ -68,4 +67,4 @@ command=tail -qF stdout_logfile=/dev/stdout stdout_logfile_maxbytes=0 redirect_stderr=true -autorestart=true +autorestart=true \ No newline at end of file diff --git a/backend/tests/integration/common_utils/managers/chat.py b/backend/tests/integration/common_utils/managers/chat.py new file mode 100644 index 00000000000..3d62817641d --- /dev/null +++ b/backend/tests/integration/common_utils/managers/chat.py @@ -0,0 +1,160 @@ +import json + +import requests +from requests.models import Response + +from danswer.file_store.models import FileDescriptor +from danswer.llm.override_models import LLMOverride +from danswer.llm.override_models import PromptOverride +from danswer.one_shot_answer.models import DirectQARequest +from danswer.one_shot_answer.models import ThreadMessage +from danswer.search.models import RetrievalDetails +from danswer.server.query_and_chat.models import ChatSessionCreationRequest +from danswer.server.query_and_chat.models import CreateChatMessageRequest +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import StreamedResponse +from tests.integration.common_utils.test_models import TestChatMessage +from tests.integration.common_utils.test_models import TestChatSession +from tests.integration.common_utils.test_models import TestUser + + +class ChatSessionManager: + @staticmethod + def create( + persona_id: int = -1, + description: str = "Test chat session", + user_performing_action: TestUser | None = None, + ) -> TestChatSession: + chat_session_creation_req = ChatSessionCreationRequest( + persona_id=persona_id, description=description + ) + response = requests.post( + f"{API_SERVER_URL}/chat/create-chat-session", + json=chat_session_creation_req.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + chat_session_id = response.json()["chat_session_id"] + return TestChatSession( + id=chat_session_id, persona_id=persona_id, description=description + ) + + @staticmethod + def send_message( + chat_session_id: int, + message: str, + parent_message_id: int | None = None, + user_performing_action: TestUser | None = None, + file_descriptors: list[FileDescriptor] = [], + prompt_id: int | None = None, + search_doc_ids: list[int] | None = None, + retrieval_options: RetrievalDetails | None = None, + query_override: str | None = None, + regenerate: bool | None = None, + llm_override: LLMOverride | None = None, + prompt_override: PromptOverride | None = None, + alternate_assistant_id: int | None = None, + use_existing_user_message: bool = False, + ) -> StreamedResponse: + chat_message_req = CreateChatMessageRequest( + chat_session_id=chat_session_id, + parent_message_id=parent_message_id, + message=message, + file_descriptors=file_descriptors or [], + prompt_id=prompt_id, + search_doc_ids=search_doc_ids or [], + retrieval_options=retrieval_options, + query_override=query_override, + regenerate=regenerate, + llm_override=llm_override, + prompt_override=prompt_override, + alternate_assistant_id=alternate_assistant_id, + use_existing_user_message=use_existing_user_message, + ) + + response = requests.post( + f"{API_SERVER_URL}/chat/send-message", + json=chat_message_req.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + stream=True, + ) + + return ChatSessionManager.analyze_response(response) + + @staticmethod + def get_answer_with_quote( + persona_id: int, + message: str, + user_performing_action: TestUser | None = None, + ) -> StreamedResponse: + direct_qa_request = DirectQARequest( + messages=[ThreadMessage(message=message)], + prompt_id=None, + persona_id=persona_id, + ) + + response = requests.post( + f"{API_SERVER_URL}/query/stream-answer-with-quote", + json=direct_qa_request.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + stream=True, + ) + response.raise_for_status() + + return ChatSessionManager.analyze_response(response) + + @staticmethod + def analyze_response(response: Response) -> StreamedResponse: + response_data = [ + json.loads(line.decode("utf-8")) for line in response.iter_lines() if line + ] + + analyzed = StreamedResponse() + + for data in response_data: + if "rephrased_query" in data: + analyzed.rephrased_query = data["rephrased_query"] + elif "tool_name" in data: + analyzed.tool_name = data["tool_name"] + analyzed.tool_result = ( + data.get("tool_result") + if analyzed.tool_name == "run_search" + else None + ) + elif "relevance_summaries" in data: + analyzed.relevance_summaries = data["relevance_summaries"] + elif "answer_piece" in data and data["answer_piece"]: + analyzed.full_message += data["answer_piece"] + + return analyzed + + @staticmethod + def get_chat_history( + chat_session: TestChatSession, + user_performing_action: TestUser | None = None, + ) -> list[TestChatMessage]: + response = requests.get( + f"{API_SERVER_URL}/chat/history/{chat_session.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + return [ + TestChatMessage( + id=msg["id"], + chat_session_id=chat_session.id, + parent_message_id=msg.get("parent_message_id"), + message=msg["message"], + response=msg.get("response", ""), + ) + for msg in response.json() + ] diff --git a/backend/tests/integration/common_utils/test_models.py b/backend/tests/integration/common_utils/test_models.py index 04db0851e3d..2d8744327df 100644 --- a/backend/tests/integration/common_utils/test_models.py +++ b/backend/tests/integration/common_utils/test_models.py @@ -118,3 +118,28 @@ class TestPersona(BaseModel): llm_model_version_override: str | None users: list[str] groups: list[int] + + +# +class TestChatSession(BaseModel): + id: int + persona_id: int + description: str + + +class TestChatMessage(BaseModel): + id: str | None = None + chat_session_id: int + parent_message_id: str | None + message: str + response: str + + +class StreamedResponse(BaseModel): + full_message: str = "" + rephrased_query: str | None = None + tool_name: str | None = None + top_documents: list[dict[str, Any]] | None = None + relevance_summaries: list[dict[str, Any]] | None = None + tool_result: Any | None = None + user: str | None = None diff --git a/backend/tests/integration/tests/connector/test_connector_deletion.py b/backend/tests/integration/tests/connector/test_connector_deletion.py index e6f1b474170..f0a83034b32 100644 --- a/backend/tests/integration/tests/connector/test_connector_deletion.py +++ b/backend/tests/integration/tests/connector/test_connector_deletion.py @@ -292,8 +292,6 @@ def test_connector_deletion_for_overlapping_connectors( doc_creating_user=admin_user, ) - # EVERYTHING BELOW HERE IS CURRENTLY BROKEN AND NEEDS TO BE FIXED SERVER SIDE - # delete connector 1 CCPairManager.pause_cc_pair( cc_pair=cc_pair_1, diff --git a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py index 981a9cbd026..d4edcc583aa 100644 --- a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py +++ b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py @@ -51,6 +51,7 @@ def test_send_message_simple_with_history(reset: None) -> None: # Check that the top document is the correct document assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id + assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id # assert that the metadata is correct for doc in cc_pair_1.documents: diff --git a/backend/tests/integration/tests/document_set/test_syncing.py b/backend/tests/integration/tests/document_set/test_syncing.py index ab31b751471..217d106af4d 100644 --- a/backend/tests/integration/tests/document_set/test_syncing.py +++ b/backend/tests/integration/tests/document_set/test_syncing.py @@ -64,3 +64,94 @@ def test_multiple_document_sets_syncing_same_connnector( doc_set_names=[doc_set_1.name, doc_set_2.name], doc_creating_user=admin_user, ) + + +def test_removing_connector(reset: None, vespa_client: TestVespaClient) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + + # add api key to user + api_key: TestAPIKey = APIKeyManager.create( + user_performing_action=admin_user, + ) + + # create connectors + cc_pair_1 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + cc_pair_2 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + + # seed documents + cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_1, + num_docs=NUM_DOCS, + api_key=api_key, + ) + + cc_pair_2 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_2, + num_docs=NUM_DOCS, + api_key=api_key, + ) + + # Create document sets + doc_set_1 = DocumentSetManager.create( + cc_pair_ids=[cc_pair_1.id, cc_pair_2.id], + user_performing_action=admin_user, + ) + + DocumentSetManager.wait_for_sync( + user_performing_action=admin_user, + ) + + DocumentSetManager.verify( + document_set=doc_set_1, + user_performing_action=admin_user, + ) + + # make sure cc_pair_1 docs are doc_set_1 only + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + doc_set_names=[doc_set_1.name], + doc_creating_user=admin_user, + ) + + # make sure cc_pair_2 docs are doc_set_1 only + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + doc_set_names=[doc_set_1.name], + doc_creating_user=admin_user, + ) + + # remove cc_pair_2 from document set + doc_set_1.cc_pair_ids = [cc_pair_1.id] + DocumentSetManager.edit( + doc_set_1, + user_performing_action=admin_user, + ) + + DocumentSetManager.wait_for_sync( + user_performing_action=admin_user, + ) + + # make sure cc_pair_1 docs are doc_set_1 only + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + doc_set_names=[doc_set_1.name], + doc_creating_user=admin_user, + ) + + # make sure cc_pair_2 docs have no doc set + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + doc_set_names=[], + doc_creating_user=admin_user, + ) diff --git a/backend/tests/integration/tests/streaming_endpoints/test_answer_stream.py b/backend/tests/integration/tests/streaming_endpoints/test_answer_stream.py new file mode 100644 index 00000000000..1b8a4c7906d --- /dev/null +++ b/backend/tests/integration/tests/streaming_endpoints/test_answer_stream.py @@ -0,0 +1,25 @@ +from tests.integration.common_utils.llm import LLMProviderManager +from tests.integration.common_utils.managers.chat import ChatSessionManager +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.test_models import TestUser + + +def test_send_message_simple_with_history(reset: None) -> None: + admin_user: TestUser = UserManager.create(name="admin_user") + LLMProviderManager.create(user_performing_action=admin_user) + + test_chat_session = ChatSessionManager.create(user_performing_action=admin_user) + + response = ChatSessionManager.get_answer_with_quote( + persona_id=test_chat_session.persona_id, + message="Hello, this is a test.", + user_performing_action=admin_user, + ) + + assert ( + response.tool_name is not None + ), "Tool name should be specified (always search)" + assert ( + response.relevance_summaries is not None + ), "Relevance summaries should be present for all search streams" + assert len(response.full_message) > 0, "Response message should not be empty" diff --git a/backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py b/backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py new file mode 100644 index 00000000000..4346e18483f --- /dev/null +++ b/backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py @@ -0,0 +1,19 @@ +from tests.integration.common_utils.llm import LLMProviderManager +from tests.integration.common_utils.managers.chat import ChatSessionManager +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.test_models import TestUser + + +def test_send_message_simple_with_history(reset: None) -> None: + admin_user: TestUser = UserManager.create(name="admin_user") + LLMProviderManager.create(user_performing_action=admin_user) + + test_chat_session = ChatSessionManager.create(user_performing_action=admin_user) + + response = ChatSessionManager.send_message( + chat_session_id=test_chat_session.id, + message="this is a test message", + user_performing_action=admin_user, + ) + + assert len(response.full_message) > 0 diff --git a/backend/tests/integration/tests/usergroup/test_usergroup_syncing.py b/backend/tests/integration/tests/usergroup/test_usergroup_syncing.py new file mode 100644 index 00000000000..fbb976f9f0f --- /dev/null +++ b/backend/tests/integration/tests/usergroup/test_usergroup_syncing.py @@ -0,0 +1,102 @@ +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.constants import NUM_DOCS +from tests.integration.common_utils.managers.api_key import APIKeyManager +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.document import DocumentManager +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager +from tests.integration.common_utils.test_models import TestAPIKey +from tests.integration.common_utils.test_models import TestUser +from tests.integration.common_utils.test_models import TestUserGroup +from tests.integration.common_utils.vespa import TestVespaClient + + +def test_removing_connector(reset: None, vespa_client: TestVespaClient) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + + # add api key to user + api_key: TestAPIKey = APIKeyManager.create( + user_performing_action=admin_user, + ) + + # create connectors + cc_pair_1 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + cc_pair_2 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + + # seed documents + cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_1, + num_docs=NUM_DOCS, + api_key=api_key, + ) + + cc_pair_2 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_2, + num_docs=NUM_DOCS, + api_key=api_key, + ) + + # Create user group + user_group_1: TestUserGroup = UserGroupManager.create( + cc_pair_ids=[cc_pair_1.id, cc_pair_2.id], + user_performing_action=admin_user, + ) + + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + UserGroupManager.verify( + user_group=user_group_1, + user_performing_action=admin_user, + ) + + # make sure cc_pair_1 docs are user_group_1 only + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + group_names=[user_group_1.name], + doc_creating_user=admin_user, + ) + + # make sure cc_pair_2 docs are user_group_1 only + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + group_names=[user_group_1.name], + doc_creating_user=admin_user, + ) + + # remove cc_pair_2 from document set + user_group_1.cc_pair_ids = [cc_pair_1.id] + UserGroupManager.edit( + user_group_1, + user_performing_action=admin_user, + ) + + UserGroupManager.wait_for_sync( + user_performing_action=admin_user, + ) + + # make sure cc_pair_1 docs are user_group_1 only + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + group_names=[user_group_1.name], + doc_creating_user=admin_user, + ) + + # make sure cc_pair_2 docs have no user group + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + group_names=[], + doc_creating_user=admin_user, + ) diff --git a/backend/tests/unit/danswer/redis_ca.pem b/backend/tests/unit/danswer/redis_ca.pem new file mode 100644 index 00000000000..8a44124d2fe --- /dev/null +++ b/backend/tests/unit/danswer/redis_ca.pem @@ -0,0 +1,91 @@ +-----BEGIN CERTIFICATE----- +MIIDXzCCAkegAwIBAgILBAAAAAABIVhTCKIwDQYJKoZIhvcNAQELBQAwTDEgMB4G +A1UECxMXR2xvYmFsU2lnbiBSb290IENBIC0gUjMxEzARBgNVBAoTCkdsb2JhbFNp +Z24xEzARBgNVBAMTCkdsb2JhbFNpZ24wHhcNMDkwMzE4MTAwMDAwWhcNMjkwMzE4 +MTAwMDAwWjBMMSAwHgYDVQQLExdHbG9iYWxTaWduIFJvb3QgQ0EgLSBSMzETMBEG +A1UEChMKR2xvYmFsU2lnbjETMBEGA1UEAxMKR2xvYmFsU2lnbjCCASIwDQYJKoZI +hvcNAQEBBQADggEPADCCAQoCggEBAMwldpB5BngiFvXAg7aEyiie/QV2EcWtiHL8 +RgJDx7KKnQRfJMsuS+FggkbhUqsMgUdwbN1k0ev1LKMPgj0MK66X17YUhhB5uzsT +gHeMCOFJ0mpiLx9e+pZo34knlTifBtc+ycsmWQ1z3rDI6SYOgxXG71uL0gRgykmm +KPZpO/bLyCiR5Z2KYVc3rHQU3HTgOu5yLy6c+9C7v/U9AOEGM+iCK65TpjoWc4zd +QQ4gOsC0p6Hpsk+QLjJg6VfLuQSSaGjlOCZgdbKfd/+RFO+uIEn8rUAVSNECMWEZ +XriX7613t2Saer9fwRPvm2L7DWzgVGkWqQPabumDk3F2xmmFghcCAwEAAaNCMEAw +DgYDVR0PAQH/BAQDAgEGMA8GA1UdEwEB/wQFMAMBAf8wHQYDVR0OBBYEFI/wS3+o +LkUkrk1Q+mOai97i3Ru8MA0GCSqGSIb3DQEBCwUAA4IBAQBLQNvAUKr+yAzv95ZU +RUm7lgAJQayzE4aGKAczymvmdLm6AC2upArT9fHxD4q/c2dKg8dEe3jgr25sbwMp +jjM5RcOO5LlXbKr8EpbsU8Yt5CRsuZRj+9xTaGdWPoO4zzUhw8lo/s7awlOqzJCK +6fBdRoyV3XpYKBovHd7NADdBj+1EbddTKJd+82cEHhXXipa0095MJ6RMG3NzdvQX +mcIfeg7jLQitChws/zyrVQ4PkX4268NXSb7hLi18YIvDQVETI53O9zJrlAGomecs +Mx86OyXShkDOOyyGeMlhLxS67ttVb9+E7gUJTb0o2HLO02JQZR7rkpeDMdmztcpH +WD9f +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIGMTCCBBmgAwIBAgICEAAwDQYJKoZIhvcNAQELBQAwajELMAkGA1UEBhMCVVMx +CzAJBgNVBAgMAkNBMQswCQYDVQQHDAJDQTESMBAGA1UECgwJUmVkaXNMYWJzMS0w +KwYDVQQDDCRSZWRpc0xhYnMgUm9vdCBDZXJ0aWZpY2F0ZSBBdXRob3JpdHkwHhcN +MTgwMjI1MTUzNzM3WhcNMjgwMjIzMTUzNzM3WjBfMQswCQYDVQQGEwJVUzELMAkG +A1UECAwCQ0ExEjAQBgNVBAoMCVJlZGlzTGFiczEvMC0GA1UEAwwmUkNQIEludGVy +bWVkaWF0ZSBDZXJ0aWZpY2F0ZSBBdXRob3JpdHkwggIiMA0GCSqGSIb3DQEBAQUA +A4ICDwAwggIKAoICAQDf9dqbxc8Bq7Ctq9rWcxrGNKKHivqLAFpPq02yLPx6fsOv +Tq7GsDChAYBBc4v7Y2Ap9RD5Vs3dIhEANcnolf27QwrG9RMnnvzk8pCvp1o6zSU4 +VuOE1W66/O1/7e2rVxyrnTcP7UgK43zNIXu7+tiAqWsO92uSnuMoGPGpeaUm1jym +hjWKtkAwDFSqvHY+XL5qDVBEjeUe+WHkYUg40cAXjusAqgm2hZt29c2wnVrxW25W +P0meNlzHGFdA2AC5z54iRiqj57dTfBTkHoBczQxcyw6hhzxZQ4e5I5zOKjXXEhZN +r0tA3YC14CTabKRus/JmZieyZzRgEy2oti64tmLYTqSlAD78pRL40VNoaSYetXLw +hhNsXCHgWaY6d5bLOc/aIQMAV5oLvZQKvuXAF1IDmhPA+bZbpWipp0zagf1P1H3s +UzsMdn2KM0ejzgotbtNlj5TcrVwpmvE3ktvUAuA+hi3FkVx1US+2Gsp5x4YOzJ7u +P1WPk6ShF0JgnJH2ILdj6kttTWwFzH17keSFICWDfH/+kM+k7Y1v3EXMQXE7y0T9 +MjvJskz6d/nv+sQhY04xt64xFMGTnZjlJMzfQNi7zWFLTZnDD0lPowq7l3YiPoTT +t5Xky83lu0KZsZBo0WlWaDG00gLVdtRgVbcuSWxpi5BdLb1kRab66JptWjxwXQID +AQABo4HrMIHoMDoGA1UdHwQzMDEwL6AtoCuGKWh0dHBzOi8vcmwtY2Etc2VydmVy +LnJlZGlzbGFicy5jb20vdjEvY3JsMEYGCCsGAQUFBwEBBDowODA2BggrBgEFBQcw +AYYqaHR0cHM6Ly9ybC1jYS1zZXJ2ZXIucmVkaXNsYWJzLmNvbS92MS9vY3NwMB0G +A1UdDgQWBBQHar5OKvQUpP2qWt6mckzToeCOHDAfBgNVHSMEGDAWgBQi42wH6hM4 +L2sujEvLM0/u8lRXTzASBgNVHRMBAf8ECDAGAQH/AgEAMA4GA1UdDwEB/wQEAwIB +hjANBgkqhkiG9w0BAQsFAAOCAgEAirEn/iTsAKyhd+pu2W3Z5NjCko4NPU0EYUbr +AP7+POK2rzjIrJO3nFYQ/LLuC7KCXG+2qwan2SAOGmqWst13Y+WHp44Kae0kaChW +vcYLXXSoGQGC8QuFSNUdaeg3RbMDYFT04dOkqufeWVccoHVxyTSg9eD8LZuHn5jw +7QDLiEECBmIJHk5Eeo2TAZrx4Yx6ufSUX5HeVjlAzqwtAqdt99uCJ/EL8bgpWbe+ +XoSpvUv0SEC1I1dCAhCKAvRlIOA6VBcmzg5Am12KzkqTul12/VEFIgzqu0Zy2Jbc +AUPrYVu/+tOGXQaijy7YgwH8P8n3s7ZeUa1VABJHcxrxYduDDJBLZi+MjheUDaZ1 +jQRHYevI2tlqeSBqdPKG4zBY5lS0GiAlmuze5oENt0P3XboHoZPHiqcK3VECgTVh +/BkJcuudETSJcZDmQ8YfoKfBzRQNg2sv/hwvUv73Ss51Sco8GEt2lD8uEdib1Q6z +zDT5lXJowSzOD5ZA9OGDjnSRL+2riNtKWKEqvtEG3VBJoBzu9GoxbAc7wIZLxmli +iF5a/Zf5X+UXD3s4TMmy6C4QZJpAA2egsSQCnraWO2ULhh7iXMysSkF/nzVfZn43 +iqpaB8++9a37hWq14ZmOv0TJIDz//b2+KC4VFXWQ5W5QC6whsjT+OlG4p5ZYG0jo +616pxqo= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIFujCCA6KgAwIBAgIJAJ1aTT1lu2ScMA0GCSqGSIb3DQEBCwUAMGoxCzAJBgNV +BAYTAlVTMQswCQYDVQQIDAJDQTELMAkGA1UEBwwCQ0ExEjAQBgNVBAoMCVJlZGlz +TGFiczEtMCsGA1UEAwwkUmVkaXNMYWJzIFJvb3QgQ2VydGlmaWNhdGUgQXV0aG9y +aXR5MB4XDTE4MDIyNTE1MjA0MloXDTM4MDIyMDE1MjA0MlowajELMAkGA1UEBhMC +VVMxCzAJBgNVBAgMAkNBMQswCQYDVQQHDAJDQTESMBAGA1UECgwJUmVkaXNMYWJz +MS0wKwYDVQQDDCRSZWRpc0xhYnMgUm9vdCBDZXJ0aWZpY2F0ZSBBdXRob3JpdHkw +ggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDLEjXy7YrbN5Waau5cd6g1 +G5C2tMmeTpZ0duFAPxNU4oE3RHS5gGiok346fUXuUxbZ6QkuzeN2/2Z+RmRcJhQY +Dm0ZgdG4x59An1TJfnzKKoWj8ISmoHS/TGNBdFzXV7FYNLBuqZouqePI6ReC6Qhl +pp45huV32Q3a6IDrrvx7Wo5ZczEQeFNbCeCOQYNDdTmCyEkHqc2AGo8eoIlSTutT +ULOC7R5gzJVTS0e1hesQ7jmqHjbO+VQS1NAL4/5K6cuTEqUl+XhVhPdLWBXJQ5ag +54qhX4v+ojLzeU1R/Vc6NjMvVtptWY6JihpgplprN0Yh2556ewcXMeturcKgXfGJ +xeYzsjzXerEjrVocX5V8BNrg64NlifzTMKNOOv4fVZszq1SIHR8F9ROrqiOdh8iC +JpUbLpXH9hWCSEO6VRMB2xJoKu3cgl63kF30s77x7wLFMEHiwsQRKxooE1UhgS9K +2sO4TlQ1eWUvFvHSTVDQDlGQ6zu4qjbOpb3Q8bQwoK+ai2alkXVR4Ltxe9QlgYK3 +StsnPhruzZGA0wbXdpw0bnM+YdlEm5ffSTpNIfgHeaa7Dtb801FtA71ZlH7A6TaI +SIQuUST9EKmv7xrJyx0W1pGoPOLw5T029aTjnICSLdtV9bLwysrLhIYG5bnPq78B +cS+jZHFGzD7PUVGQD01nOQIDAQABo2MwYTAdBgNVHQ4EFgQUIuNsB+oTOC9rLoxL +yzNP7vJUV08wHwYDVR0jBBgwFoAUIuNsB+oTOC9rLoxLyzNP7vJUV08wDwYDVR0T +AQH/BAUwAwEB/zAOBgNVHQ8BAf8EBAMCAYYwDQYJKoZIhvcNAQELBQADggIBAHfg +z5pMNUAKdMzK1aS1EDdK9yKz4qicILz5czSLj1mC7HKDRy8cVADUxEICis++CsCu +rYOvyCVergHQLREcxPq4rc5Nq1uj6J6649NEeh4WazOOjL4ZfQ1jVznMbGy+fJm3 +3Hoelv6jWRG9iqeJZja7/1s6YC6bWymI/OY1e4wUKeNHAo+Vger7MlHV+RuabaX+ +hSJ8bJAM59NCM7AgMTQpJCncrcdLeceYniGy5Q/qt2b5mJkQVkIdy4TPGGB+AXDJ +D0q3I/JDRkDUFNFdeW0js7fHdsvCR7O3tJy5zIgEV/o/BCkmJVtuwPYOrw/yOlKj +TY/U7ATAx9VFF6/vYEOMYSmrZlFX+98L6nJtwDqfLB5VTltqZ4H/KBxGE3IRSt9l +FXy40U+LnXzhhW+7VBAvyYX8GEXhHkKU8Gqk1xitrqfBXY74xKgyUSTolFSfFVgj +mcM/X4K45bka+qpkj7Kfv/8D4j6aZekwhN2ly6hhC1SmQ8qjMjpG/mrWOSSHZFmf +ybu9iD2AYHeIOkshIl6xYIa++Q/00/vs46IzAbQyriOi0XxlSMMVtPx0Q3isp+ji +n8Mq9eOuxYOEQ4of8twUkUDd528iwGtEdwf0Q01UyT84S62N8AySl1ZBKXJz6W4F +UhWfa/HQYOAPDdEjNgnVwLI23b8t0TozyCWw7q8h +-----END CERTIFICATE----- + diff --git a/backend/tests/unit/danswer/test_redis.py b/backend/tests/unit/danswer/test_redis.py new file mode 100644 index 00000000000..a55c980f618 --- /dev/null +++ b/backend/tests/unit/danswer/test_redis.py @@ -0,0 +1,39 @@ +import os + +import pytest +import redis + +from danswer.redis.redis_pool import RedisPool +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +@pytest.mark.skipif( + os.getenv("REDIS_CLOUD_PYTEST_PASSWORD") is None, + reason="Environment variable REDIS_CLOUD_PYTEST_PASSWORD is not set", +) +def test_redis_ssl() -> None: + REDIS_PASSWORD = os.environ.get("REDIS_CLOUD_PYTEST_PASSWORD") + REDIS_HOST = "redis-15414.c267.us-east-1-4.ec2.redns.redis-cloud.com" + REDIS_PORT = 15414 + REDIS_SSL_CERT_REQS = "required" + + assert REDIS_PASSWORD + + # Construct the path to the CA certificate for the redis ssl test instance + # it contains no secret data, so it's OK to have checked in! + current_dir = os.path.dirname(__file__) + REDIS_SSL_CA_CERTS = os.path.join(current_dir, "redis_ca.pem") + + pool = RedisPool.create_pool( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + ssl=True, + ssl_cert_reqs=REDIS_SSL_CERT_REQS, + ssl_ca_certs=REDIS_SSL_CA_CERTS, + ) + + r = redis.Redis(connection_pool=pool) + assert r.ping()