diff --git a/.github/workflows/docker-build-backend-container-on-merge-group.yml b/.github/workflows/docker-build-backend-container-on-merge-group.yml new file mode 100644 index 00000000000..57ab29b00bf --- /dev/null +++ b/.github/workflows/docker-build-backend-container-on-merge-group.yml @@ -0,0 +1,35 @@ +name: Build Backend Image on Merge Group + +on: + pull_request: + branches: [ "main" ] + merge_group: + types: [checks_requested] + +env: + REGISTRY_IMAGE: danswer/danswer-backend + +jobs: + build: + # TODO: make this a matrix build like the web containers + runs-on: + group: amd64-image-builders + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Backend Image Docker Build + uses: docker/build-push-action@v5 + with: + context: ./backend + file: ./backend/Dockerfile + platforms: linux/amd64,linux/arm64 + push: false + tags: | + ${{ env.REGISTRY_IMAGE }}:latest + build-args: | + DANSWER_VERSION=v0.0.1 diff --git a/.github/workflows/docker-build-push-backend-container-on-tag.yml b/.github/workflows/docker-build-push-backend-container-on-tag.yml index c8d5112e2ae..8b77af29df7 100644 --- a/.github/workflows/docker-build-push-backend-container-on-tag.yml +++ b/.github/workflows/docker-build-push-backend-container-on-tag.yml @@ -5,9 +5,14 @@ on: tags: - '*' +env: + REGISTRY_IMAGE: danswer/danswer-backend + jobs: build-and-push: - runs-on: ubuntu-latest + # TODO: make this a matrix build like the web containers + runs-on: + group: amd64-image-builders steps: - name: Checkout code @@ -30,8 +35,8 @@ jobs: platforms: linux/amd64,linux/arm64 push: true tags: | - danswer/danswer-backend:${{ github.ref_name }} - danswer/danswer-backend:latest + ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} + ${{ env.REGISTRY_IMAGE }}:latest build-args: | DANSWER_VERSION=${{ github.ref_name }} @@ -39,6 +44,6 @@ jobs: uses: aquasecurity/trivy-action@master with: # To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend - image-ref: docker.io/danswer/danswer-backend:${{ github.ref_name }} + image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} severity: 'CRITICAL,HIGH' trivyignores: ./backend/.trivyignore diff --git a/.github/workflows/docker-build-push-web-container-on-tag.yml b/.github/workflows/docker-build-push-web-container-on-tag.yml index 4e3d6a99137..0a97a01f7c8 100644 --- a/.github/workflows/docker-build-push-web-container-on-tag.yml +++ b/.github/workflows/docker-build-push-web-container-on-tag.yml @@ -10,7 +10,7 @@ env: jobs: build: - runs-on: + runs-on: group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }} strategy: fail-fast: false @@ -34,8 +34,8 @@ jobs: with: images: ${{ env.REGISTRY_IMAGE }} tags: | - type=raw,value=danswer/danswer-web-server:${{ github.ref_name }} - type=raw,value=danswer/danswer-web-server:latest + type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} + type=raw,value=${{ env.REGISTRY_IMAGE }}:latest - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 diff --git a/.github/workflows/docker-build-web-container-on-merge-group.yml b/.github/workflows/docker-build-web-container-on-merge-group.yml new file mode 100644 index 00000000000..7b975f476d9 --- /dev/null +++ b/.github/workflows/docker-build-web-container-on-merge-group.yml @@ -0,0 +1,55 @@ +name: Build Web Image on Merge Group + +on: + pull_request: + branches: [ "main" ] + merge_group: + types: [checks_requested] + +env: + REGISTRY_IMAGE: danswer/danswer-web-server + +jobs: + build: + runs-on: + group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }} + strategy: + fail-fast: false + matrix: + platform: + - linux/amd64 + - linux/arm64 + + steps: + - name: Prepare + run: | + platform=${{ matrix.platform }} + echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV + + - name: Checkout + uses: actions/checkout@v4 + + - name: Docker meta + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY_IMAGE }} + tags: | + type=raw,value=${{ env.REGISTRY_IMAGE }}:latest + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Build by digest + id: build + uses: docker/build-push-action@v5 + with: + context: ./web + file: ./web/Dockerfile + platforms: ${{ matrix.platform }} + push: false + build-args: | + DANSWER_VERSION=v0.0.1 + # needed due to weird interactions with the builds for different platforms + no-cache: true + labels: ${{ steps.meta.outputs.labels }} diff --git a/.gitignore b/.gitignore index e1125f5f6c8..68f5348e427 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ .venv .mypy_cache .idea +.python-version /deployment/data/nginx/app.conf .vscode/launch.json *.sw? diff --git a/.vscode/env_template.txt b/.vscode/env_template.txt index 3e30e6c77bc..015672a226c 100644 --- a/.vscode/env_template.txt +++ b/.vscode/env_template.txt @@ -4,17 +4,20 @@ # For local dev, often user Authentication is not needed AUTH_TYPE=disabled -# This passes top N results to LLM an additional time for reranking prior to answer generation, quite token heavy so we disable it for dev generally -DISABLE_LLM_CHUNK_FILTER=True # Always keep these on for Dev # Logs all model prompts to stdout -LOG_ALL_MODEL_INTERACTIONS=True +LOG_DANSWER_MODEL_INTERACTIONS=True # More verbose logging LOG_LEVEL=debug +# This passes top N results to LLM an additional time for reranking prior to answer generation +# This step is quite heavy on token usage so we disable it for dev generally +DISABLE_LLM_CHUNK_FILTER=True + + # Useful if you want to toggle auth on/off (google_oauth/OIDC specifically) OAUTH_CLIENT_ID= OAUTH_CLIENT_SECRET= @@ -22,10 +25,6 @@ OAUTH_CLIENT_SECRET= REQUIRE_EMAIL_VERIFICATION=False -# Toggles on/off the EE Features -NEXT_PUBLIC_ENABLE_PAID_EE_FEATURES=False - - # Set these so if you wipe the DB, you don't end up having to go through the UI every time GEN_AI_API_KEY= # If answer quality isn't important for dev, use 3.5 turbo due to it being cheaper @@ -41,3 +40,13 @@ FAST_GEN_AI_MODEL_VERSION=gpt-3.5-turbo # Python stuff PYTHONPATH=./backend PYTHONUNBUFFERED=1 + + +# Internet Search +BING_API_KEY= + + +# Enable the full set of Danswer Enterprise Edition features +# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development) +ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False + diff --git a/.vscode/launch.template.jsonc b/.vscode/launch.template.jsonc index c5780a65a4f..a4be80fc1c9 100644 --- a/.vscode/launch.template.jsonc +++ b/.vscode/launch.template.jsonc @@ -17,6 +17,7 @@ "request": "launch", "cwd": "${workspaceRoot}/web", "runtimeExecutable": "npm", + "envFile": "${workspaceFolder}/.env", "runtimeArgs": [ "run", "dev" ], @@ -28,6 +29,7 @@ "request": "launch", "module": "uvicorn", "cwd": "${workspaceFolder}/backend", + "envFile": "${workspaceFolder}/.env", "env": { "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1" @@ -45,8 +47,9 @@ "request": "launch", "module": "uvicorn", "cwd": "${workspaceFolder}/backend", + "envFile": "${workspaceFolder}/.env", "env": { - "LOG_ALL_MODEL_INTERACTIONS": "True", + "LOG_DANSWER_MODEL_INTERACTIONS": "True", "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1" }, @@ -63,6 +66,7 @@ "request": "launch", "program": "danswer/background/update.py", "cwd": "${workspaceFolder}/backend", + "envFile": "${workspaceFolder}/.env", "env": { "ENABLE_MINI_CHUNK": "false", "LOG_LEVEL": "DEBUG", @@ -77,7 +81,9 @@ "request": "launch", "program": "scripts/dev_run_background_jobs.py", "cwd": "${workspaceFolder}/backend", + "envFile": "${workspaceFolder}/.env", "env": { + "LOG_DANSWER_MODEL_INTERACTIONS": "True", "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." @@ -100,6 +106,24 @@ "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." } + }, + { + "name": "Pytest", + "type": "python", + "request": "launch", + "module": "pytest", + "cwd": "${workspaceFolder}/backend", + "envFile": "${workspaceFolder}/.env", + "env": { + "LOG_LEVEL": "DEBUG", + "PYTHONUNBUFFERED": "1", + "PYTHONPATH": "." + }, + "args": [ + "-v" + // Specify a sepcific module/test to run or provide nothing to run all tests + //"tests/unit/danswer/llm/answering/test_prune_and_merge.py" + ] } ] -} \ No newline at end of file +} diff --git a/LICENSE b/LICENSE index 314f162560d..d93c97f3141 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,10 @@ -MIT License +Copyright (c) 2023-present DanswerAI, Inc. -Copyright (c) 2023 Yuhong Sun, Chris Weaver +Portions of this software are licensed as follows: + +* All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE". +* All third party components incorporated into the Danswer Software are licensed under the original license provided by the owner of the applicable component. +* Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 658ee682382..8d2b362011d 100644 --- a/README.md +++ b/README.md @@ -105,5 +105,25 @@ Efficiently pulls the latest changes from: * Websites * And more ... +## 📚 Editions + +There are two editions of Danswer: + + * Danswer Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Danswer you will get if you follow the Deployment guide above. + * Danswer Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes: + * Single Sign-On (SSO), with support for both SAML and OIDC + * Role-based access control + * Document permission inheritance from connected sources + * Usage analytics and query history accessible to admins + * Whitelabeling + * API key authentication + * Encryption of secrets + * Any many more! Checkout [our website](https://www.danswer.ai/) for the latest. + +To try the Danswer Enterprise Edition: + + 1. Checkout our [Cloud product](https://app.danswer.ai/signup). + 2. For self-hosting, contact us at [founders@danswer.ai](mailto:founders@danswer.ai) or book a call with us on our [Cal](https://cal.com/team/danswer/founders). + ## 💡 Contributing Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details. diff --git a/backend/.gitignore b/backend/.gitignore index 6b3219cc30e..b1c4f4db71d 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -5,7 +5,7 @@ site_crawls/ .ipynb_checkpoints/ api_keys.py *ipynb -.env +.env* vespa-app.zip dynamic_config_storage/ celerybeat-schedule* diff --git a/backend/Dockerfile b/backend/Dockerfile index 577859ea795..7f9daad94a3 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -1,15 +1,17 @@ FROM python:3.11.7-slim-bookworm LABEL com.danswer.maintainer="founders@danswer.ai" -LABEL com.danswer.description="This image is for the backend of Danswer. It is MIT Licensed and \ -free for all to use. You can find it at https://hub.docker.com/r/danswer/danswer-backend. For \ -more details, visit https://github.com/danswer-ai/danswer." +LABEL com.danswer.description="This image is the web/frontend container of Danswer which \ +contains code for both the Community and Enterprise editions of Danswer. If you do not \ +have a contract or agreement with DanswerAI, you are not permitted to use the Enterprise \ +Edition features outside of personal development or testing purposes. Please reach out to \ +founders@danswer.ai for more information. Please visit https://github.com/danswer-ai/danswer" # Default DANSWER_VERSION, typically overriden during builds by GitHub Actions. ARG DANSWER_VERSION=0.3-dev ENV DANSWER_VERSION=${DANSWER_VERSION} -RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}" +RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}" # Install system dependencies # cmake needed for psycopg (postgres) # libpq-dev needed for psycopg (postgres) @@ -17,18 +19,32 @@ RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}" # zip for Vespa step futher down # ca-certificates for HTTPS RUN apt-get update && \ - apt-get install -y cmake curl zip ca-certificates libgnutls30 \ - libblkid1 libmount1 libsmartcols1 \ - libuuid1 && \ + apt-get install -y \ + cmake \ + curl \ + zip \ + ca-certificates \ + libgnutls30=3.7.9-2+deb12u3 \ + libblkid1=2.38.1-5+deb12u1 \ + libmount1=2.38.1-5+deb12u1 \ + libsmartcols1=2.38.1-5+deb12u1 \ + libuuid1=2.38.1-5+deb12u1 \ + libxmlsec1-dev \ + pkg-config \ + gcc && \ rm -rf /var/lib/apt/lists/* && \ apt-get clean # Install Python dependencies # Remove py which is pulled in by retry, py is not needed and is a CVE COPY ./requirements/default.txt /tmp/requirements.txt -RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \ +COPY ./requirements/ee.txt /tmp/ee-requirements.txt +RUN pip install --no-cache-dir --upgrade \ + -r /tmp/requirements.txt \ + -r /tmp/ee-requirements.txt && \ pip uninstall -y py && \ - playwright install chromium && playwright install-deps chromium && \ + playwright install chromium && \ + playwright install-deps chromium && \ ln -s /usr/local/bin/supervisord /usr/bin/supervisord # Cleanup for CVEs and size reduction @@ -36,11 +52,20 @@ RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \ # xserver-common and xvfb included by playwright installation but not needed after # perl-base is part of the base Python Debian image but not needed for Danswer functionality # perl-base could only be removed with --allow-remove-essential -RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake \ - libldap-2.5-0 libldap-2.5-0 && \ +RUN apt-get update && \ + apt-get remove -y --allow-remove-essential \ + perl-base \ + xserver-common \ + xvfb \ + cmake \ + libldap-2.5-0 \ + libxmlsec1-dev \ + pkg-config \ + gcc && \ + apt-get install -y libxmlsec1-openssl && \ apt-get autoremove -y && \ rm -rf /var/lib/apt/lists/* && \ - rm /usr/local/lib/python3.11/site-packages/tornado/test/test.key + rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key # Pre-downloading models for setups with limited egress RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('intfloat/e5-base-v2')" @@ -53,12 +78,24 @@ nltk.download('punkt', quiet=True);" # Set up application files WORKDIR /app + +# Enterprise Version Files +COPY ./ee /app/ee +COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf + +# Set up application files COPY ./danswer /app/danswer COPY ./shared_configs /app/shared_configs COPY ./alembic /app/alembic COPY ./alembic.ini /app/alembic.ini COPY supervisord.conf /usr/etc/supervisord.conf +# Escape hatch +COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py + +# Put logo in assets +COPY ./assets /app/assets + ENV PYTHONPATH /app # Default command which does nothing diff --git a/backend/alembic/versions/0568ccf46a6b_add_thread_specific_model_selection.py b/backend/alembic/versions/0568ccf46a6b_add_thread_specific_model_selection.py new file mode 100644 index 00000000000..2d8912d9bcf --- /dev/null +++ b/backend/alembic/versions/0568ccf46a6b_add_thread_specific_model_selection.py @@ -0,0 +1,31 @@ +"""Add thread specific model selection + +Revision ID: 0568ccf46a6b +Revises: e209dc5a8156 +Create Date: 2024-06-19 14:25:36.376046 + +""" + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "0568ccf46a6b" +down_revision = "e209dc5a8156" +branch_labels: None = None +depends_on: None = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "chat_session", + sa.Column("current_alternate_model", sa.String(), nullable=True), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("chat_session", "current_alternate_model") + # ### end Alembic commands ### diff --git a/backend/alembic/versions/23957775e5f5_remove_feedback_foreignkey_constraint.py b/backend/alembic/versions/23957775e5f5_remove_feedback_foreignkey_constraint.py new file mode 100644 index 00000000000..07b5601013b --- /dev/null +++ b/backend/alembic/versions/23957775e5f5_remove_feedback_foreignkey_constraint.py @@ -0,0 +1,86 @@ +"""remove-feedback-foreignkey-constraint + +Revision ID: 23957775e5f5 +Revises: bc9771dccadf +Create Date: 2024-06-27 16:04:51.480437 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "23957775e5f5" +down_revision = "bc9771dccadf" +branch_labels = None # type: ignore +depends_on = None # type: ignore + + +def upgrade() -> None: + op.drop_constraint( + "chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey" + ) + op.create_foreign_key( + "chat_feedback__chat_message_fk", + "chat_feedback", + "chat_message", + ["chat_message_id"], + ["id"], + ondelete="SET NULL", + ) + op.alter_column( + "chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=True + ) + op.drop_constraint( + "document_retrieval_feedback__chat_message_fk", + "document_retrieval_feedback", + type_="foreignkey", + ) + op.create_foreign_key( + "document_retrieval_feedback__chat_message_fk", + "document_retrieval_feedback", + "chat_message", + ["chat_message_id"], + ["id"], + ondelete="SET NULL", + ) + op.alter_column( + "document_retrieval_feedback", + "chat_message_id", + existing_type=sa.Integer(), + nullable=True, + ) + + +def downgrade() -> None: + op.alter_column( + "chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=False + ) + op.drop_constraint( + "chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey" + ) + op.create_foreign_key( + "chat_feedback__chat_message_fk", + "chat_feedback", + "chat_message", + ["chat_message_id"], + ["id"], + ) + + op.alter_column( + "document_retrieval_feedback", + "chat_message_id", + existing_type=sa.Integer(), + nullable=False, + ) + op.drop_constraint( + "document_retrieval_feedback__chat_message_fk", + "document_retrieval_feedback", + type_="foreignkey", + ) + op.create_foreign_key( + "document_retrieval_feedback__chat_message_fk", + "document_retrieval", + "chat_message", + ["chat_message_id"], + ["id"], + ) diff --git a/backend/alembic/versions/3a7802814195_add_alternate_assistant_to_chat_message.py b/backend/alembic/versions/3a7802814195_add_alternate_assistant_to_chat_message.py new file mode 100644 index 00000000000..5e50c02c8bf --- /dev/null +++ b/backend/alembic/versions/3a7802814195_add_alternate_assistant_to_chat_message.py @@ -0,0 +1,38 @@ +"""add alternate assistant to chat message + +Revision ID: 3a7802814195 +Revises: 23957775e5f5 +Create Date: 2024-06-05 11:18:49.966333 + +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "3a7802814195" +down_revision = "23957775e5f5" +branch_labels: None = None +depends_on: None = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True) + ) + op.create_foreign_key( + "fk_chat_message_persona", + "chat_message", + "persona", + ["alternate_assistant_id"], + ["id"], + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("fk_chat_message_persona", "chat_message", type_="foreignkey") + op.drop_column("chat_message", "alternate_assistant_id") diff --git a/backend/alembic/versions/4505fd7302e1_added_is_internet_to_dbdoc.py b/backend/alembic/versions/4505fd7302e1_added_is_internet_to_dbdoc.py new file mode 100644 index 00000000000..418f8360372 --- /dev/null +++ b/backend/alembic/versions/4505fd7302e1_added_is_internet_to_dbdoc.py @@ -0,0 +1,23 @@ +"""added is_internet to DBDoc + +Revision ID: 4505fd7302e1 +Revises: c18cdf4b497e +Create Date: 2024-06-18 20:46:09.095034 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "4505fd7302e1" +down_revision = "c18cdf4b497e" + + +def upgrade() -> None: + op.add_column("search_doc", sa.Column("is_internet", sa.Boolean(), nullable=True)) + op.add_column("tool", sa.Column("display_name", sa.String(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("tool", "display_name") + op.drop_column("search_doc", "is_internet") diff --git a/backend/alembic/versions/48d14957fe80_add_support_for_custom_tools.py b/backend/alembic/versions/48d14957fe80_add_support_for_custom_tools.py index 514389ac216..3a77170d856 100644 --- a/backend/alembic/versions/48d14957fe80_add_support_for_custom_tools.py +++ b/backend/alembic/versions/48d14957fe80_add_support_for_custom_tools.py @@ -13,8 +13,8 @@ # revision identifiers, used by Alembic. revision = "48d14957fe80" down_revision = "b85f02ec1308" -branch_labels = None -depends_on = None +branch_labels: None = None +depends_on: None = None def upgrade() -> None: diff --git a/backend/alembic/versions/7aea705850d5_added_slack_auto_filter.py b/backend/alembic/versions/7aea705850d5_added_slack_auto_filter.py new file mode 100644 index 00000000000..a07b94bd925 --- /dev/null +++ b/backend/alembic/versions/7aea705850d5_added_slack_auto_filter.py @@ -0,0 +1,35 @@ +"""added slack_auto_filter + +Revision ID: 7aea705850d5 +Revises: 4505fd7302e1 +Create Date: 2024-07-10 11:01:23.581015 + +""" +from alembic import op +import sqlalchemy as sa + +revision = "7aea705850d5" +down_revision = "4505fd7302e1" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "slack_bot_config", + sa.Column("enable_auto_filters", sa.Boolean(), nullable=True), + ) + op.execute( + "UPDATE slack_bot_config SET enable_auto_filters = FALSE WHERE enable_auto_filters IS NULL" + ) + op.alter_column( + "slack_bot_config", + "enable_auto_filters", + existing_type=sa.Boolean(), + nullable=False, + server_default=sa.false(), + ) + + +def downgrade() -> None: + op.drop_column("slack_bot_config", "enable_auto_filters") diff --git a/backend/alembic/versions/bc9771dccadf_create_usage_reports_table.py b/backend/alembic/versions/bc9771dccadf_create_usage_reports_table.py new file mode 100644 index 00000000000..eab3253a0e0 --- /dev/null +++ b/backend/alembic/versions/bc9771dccadf_create_usage_reports_table.py @@ -0,0 +1,51 @@ +"""create usage reports table + +Revision ID: bc9771dccadf +Revises: 0568ccf46a6b +Create Date: 2024-06-18 10:04:26.800282 + +""" +from alembic import op +import sqlalchemy as sa +import fastapi_users_db_sqlalchemy + +# revision identifiers, used by Alembic. +revision = "bc9771dccadf" +down_revision = "0568ccf46a6b" + +branch_labels: None = None +depends_on: None = None + + +def upgrade() -> None: + op.create_table( + "usage_reports", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("report_name", sa.String(), nullable=False), + sa.Column( + "requestor_user_id", + fastapi_users_db_sqlalchemy.generics.GUID(), + nullable=True, + ), + sa.Column( + "time_created", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column("period_from", sa.DateTime(timezone=True), nullable=True), + sa.Column("period_to", sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["report_name"], + ["file_store.file_name"], + ), + sa.ForeignKeyConstraint( + ["requestor_user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + + +def downgrade() -> None: + op.drop_table("usage_reports") diff --git a/backend/alembic/versions/c18cdf4b497e_add_standard_answer_tables.py b/backend/alembic/versions/c18cdf4b497e_add_standard_answer_tables.py new file mode 100644 index 00000000000..8d2de3bf1e1 --- /dev/null +++ b/backend/alembic/versions/c18cdf4b497e_add_standard_answer_tables.py @@ -0,0 +1,75 @@ +"""Add standard_answer tables + +Revision ID: c18cdf4b497e +Revises: 3a7802814195 +Create Date: 2024-06-06 15:15:02.000648 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "c18cdf4b497e" +down_revision = "3a7802814195" +branch_labels: None = None +depends_on: None = None + + +def upgrade() -> None: + op.create_table( + "standard_answer", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("keyword", sa.String(), nullable=False), + sa.Column("answer", sa.String(), nullable=False), + sa.Column("active", sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("keyword"), + ) + op.create_table( + "standard_answer_category", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name"), + ) + op.create_table( + "standard_answer__standard_answer_category", + sa.Column("standard_answer_id", sa.Integer(), nullable=False), + sa.Column("standard_answer_category_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["standard_answer_category_id"], + ["standard_answer_category.id"], + ), + sa.ForeignKeyConstraint( + ["standard_answer_id"], + ["standard_answer.id"], + ), + sa.PrimaryKeyConstraint("standard_answer_id", "standard_answer_category_id"), + ) + op.create_table( + "slack_bot_config__standard_answer_category", + sa.Column("slack_bot_config_id", sa.Integer(), nullable=False), + sa.Column("standard_answer_category_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["slack_bot_config_id"], + ["slack_bot_config.id"], + ), + sa.ForeignKeyConstraint( + ["standard_answer_category_id"], + ["standard_answer_category.id"], + ), + sa.PrimaryKeyConstraint("slack_bot_config_id", "standard_answer_category_id"), + ) + + op.add_column( + "chat_session", sa.Column("slack_thread_id", sa.String(), nullable=True) + ) + + +def downgrade() -> None: + op.drop_column("chat_session", "slack_thread_id") + + op.drop_table("slack_bot_config__standard_answer_category") + op.drop_table("standard_answer__standard_answer_category") + op.drop_table("standard_answer_category") + op.drop_table("standard_answer") diff --git a/backend/assets/.gitignore b/backend/assets/.gitignore new file mode 100644 index 00000000000..d6b7ef32c84 --- /dev/null +++ b/backend/assets/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/backend/danswer/access/access.py b/backend/danswer/access/access.py index 8b46c711ead..51f5a300c49 100644 --- a/backend/danswer/access/access.py +++ b/backend/danswer/access/access.py @@ -1,6 +1,7 @@ from sqlalchemy.orm import Session from danswer.access.models import DocumentAccess +from danswer.access.utils import prefix_user from danswer.configs.constants import PUBLIC_DOC_PAT from danswer.db.document import get_acccess_info_for_documents from danswer.db.models import User @@ -19,7 +20,7 @@ def _get_access_for_documents( cc_pair_to_delete=cc_pair_to_delete, ) return { - document_id: DocumentAccess.build(user_ids, is_public) + document_id: DocumentAccess.build(user_ids, [], is_public) for document_id, user_ids, is_public in document_access_info } @@ -38,12 +39,6 @@ def get_access_for_documents( ) # type: ignore -def prefix_user(user_id: str) -> str: - """Prefixes a user ID to eliminate collision with group names. - This assumes that groups are prefixed with a different prefix.""" - return f"user_id:{user_id}" - - def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]: """Returns a list of ACL entries that the user has access to. This is meant to be used downstream to filter out documents that the user does not have access to. The diff --git a/backend/danswer/access/models.py b/backend/danswer/access/models.py index 94a18528cb9..a87e2d94f25 100644 --- a/backend/danswer/access/models.py +++ b/backend/danswer/access/models.py @@ -1,20 +1,30 @@ from dataclasses import dataclass from uuid import UUID +from danswer.access.utils import prefix_user +from danswer.access.utils import prefix_user_group from danswer.configs.constants import PUBLIC_DOC_PAT @dataclass(frozen=True) class DocumentAccess: user_ids: set[str] # stringified UUIDs + user_groups: set[str] # names of user groups associated with this document is_public: bool def to_acl(self) -> list[str]: - return list(self.user_ids) + ([PUBLIC_DOC_PAT] if self.is_public else []) + return ( + [prefix_user(user_id) for user_id in self.user_ids] + + [prefix_user_group(group_name) for group_name in self.user_groups] + + ([PUBLIC_DOC_PAT] if self.is_public else []) + ) @classmethod - def build(cls, user_ids: list[UUID | None], is_public: bool) -> "DocumentAccess": + def build( + cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool + ) -> "DocumentAccess": return cls( user_ids={str(user_id) for user_id in user_ids if user_id}, + user_groups=set(user_groups), is_public=is_public, ) diff --git a/backend/danswer/access/utils.py b/backend/danswer/access/utils.py new file mode 100644 index 00000000000..060560eaedc --- /dev/null +++ b/backend/danswer/access/utils.py @@ -0,0 +1,10 @@ +def prefix_user(user_id: str) -> str: + """Prefixes a user ID to eliminate collision with group names. + This assumes that groups are prefixed with a different prefix.""" + return f"user_id:{user_id}" + + +def prefix_user_group(user_group_name: str) -> str: + """Prefixes a user group name to eliminate collision with user IDs. + This assumes that user ids are prefixed with a different prefix.""" + return f"group:{user_group_name}" diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index c0f4de9dc28..479cf07df3d 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -46,6 +46,7 @@ from danswer.configs.constants import DANSWER_API_KEY_PREFIX from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER from danswer.db.auth import get_access_token_db +from danswer.db.auth import get_default_admin_user_emails from danswer.db.auth import get_user_count from danswer.db.auth import get_user_db from danswer.db.engine import get_session @@ -54,7 +55,9 @@ from danswer.utils.logger import setup_logger from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType -from danswer.utils.variable_functionality import fetch_versioned_implementation +from danswer.utils.variable_functionality import ( + fetch_versioned_implementation, +) logger = setup_logger() @@ -148,7 +151,7 @@ async def create( verify_email_domain(user_create.email) if hasattr(user_create, "role"): user_count = await get_user_count() - if user_count == 0: + if user_count == 0 or user_create.email in get_default_admin_user_emails(): user_create.role = UserRole.ADMIN else: user_create.role = UserRole.BASIC diff --git a/backend/danswer/background/celery/celery.py b/backend/danswer/background/celery/celery_app.py similarity index 100% rename from backend/danswer/background/celery/celery.py rename to backend/danswer/background/celery/celery_app.py diff --git a/backend/danswer/background/celery/celery_run.py b/backend/danswer/background/celery/celery_run.py new file mode 100644 index 00000000000..0fdb2f044a8 --- /dev/null +++ b/backend/danswer/background/celery/celery_run.py @@ -0,0 +1,9 @@ +"""Entry point for running celery worker / celery beat.""" +from danswer.utils.variable_functionality import fetch_versioned_implementation +from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable + + +set_is_ee_based_on_env_variable() +celery_app = fetch_versioned_implementation( + "danswer.background.celery.celery_app", "celery_app" +) diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py index 48f0295cd09..6b9b5a89603 100644 --- a/backend/danswer/background/celery/celery_utils.py +++ b/backend/danswer/background/celery/celery_utils.py @@ -6,16 +6,23 @@ from danswer.background.task_utils import name_cc_cleanup_task from danswer.background.task_utils import name_cc_prune_task from danswer.background.task_utils import name_document_set_sync_task +from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE +from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING +from danswer.connectors.cross_connector_utils.rate_limit_wrapper import ( + rate_limit_builder, +) from danswer.connectors.interfaces import BaseConnector from danswer.connectors.interfaces import IdConnector from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import PollConnector +from danswer.connectors.models import Document from danswer.db.engine import get_db_current_time from danswer.db.models import Connector from danswer.db.models import Credential from danswer.db.models import DocumentSet -from danswer.db.tasks import check_live_task_not_timed_out +from danswer.db.tasks import check_task_is_live_and_not_timed_out from danswer.db.tasks import get_latest_task +from danswer.db.tasks import get_latest_task_by_type from danswer.server.documents.models import DeletionAttemptSnapshot from danswer.utils.logger import setup_logger @@ -47,7 +54,7 @@ def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool: task_name = name_document_set_sync_task(document_set.id) latest_sync = get_latest_task(task_name, db_session) - if latest_sync and check_live_task_not_timed_out(latest_sync, db_session): + if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session): logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.") return False @@ -73,7 +80,19 @@ def should_prune_cc_pair( return True return False - if check_live_task_not_timed_out(last_pruning_task, db_session): + if PREVENT_SIMULTANEOUS_PRUNING: + pruning_type_task_name = name_cc_prune_task() + last_pruning_type_task = get_latest_task_by_type( + pruning_type_task_name, db_session + ) + + if last_pruning_type_task and check_task_is_live_and_not_timed_out( + last_pruning_type_task, db_session + ): + logger.info("Another Connector is already pruning. Skipping.") + return False + + if check_task_is_live_and_not_timed_out(last_pruning_task, db_session): logger.info(f"Connector '{connector.name}' is already pruning. Skipping.") return False @@ -84,25 +103,36 @@ def should_prune_cc_pair( return time_since_last_pruning.total_seconds() >= connector.prune_freq +def document_batch_to_ids(doc_batch: list[Document]) -> set[str]: + return {doc.id for doc in doc_batch} + + def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]: """ If the PruneConnector hasnt been implemented for the given connector, just pull all docs using the load_from_state and grab out the IDs """ all_connector_doc_ids: set[str] = set() + + doc_batch_generator = None if isinstance(runnable_connector, IdConnector): all_connector_doc_ids = runnable_connector.retrieve_all_source_ids() elif isinstance(runnable_connector, LoadConnector): doc_batch_generator = runnable_connector.load_from_state() - for doc_batch in doc_batch_generator: - all_connector_doc_ids.update(doc.id for doc in doc_batch) elif isinstance(runnable_connector, PollConnector): start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp() end = datetime.now(timezone.utc).timestamp() doc_batch_generator = runnable_connector.poll_source(start=start, end=end) - for doc_batch in doc_batch_generator: - all_connector_doc_ids.update(doc.id for doc in doc_batch) else: raise RuntimeError("Pruning job could not find a valid runnable_connector.") + if doc_batch_generator: + doc_batch_processing_func = document_batch_to_ids + if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE: + doc_batch_processing_func = rate_limit_builder( + max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60 + )(document_batch_to_ids) + for doc_batch in doc_batch_generator: + all_connector_doc_ids.update(doc_batch_processing_func(doc_batch)) + return all_connector_doc_ids diff --git a/backend/danswer/background/indexing/job_client.py b/backend/danswer/background/indexing/job_client.py index e9ddad58e0f..72919d690a4 100644 --- a/backend/danswer/background/indexing/job_client.py +++ b/backend/danswer/background/indexing/job_client.py @@ -105,7 +105,9 @@ def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | N """NOTE: `pure` arg is needed so this can be a drop in replacement for Dask""" self._cleanup_completed_jobs() if len(self.jobs) >= self.n_workers: - logger.debug("No available workers to run job") + logger.debug( + f"No available workers to run job. Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'." + ) return None job_id = self.job_id_counter diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index ffbfd8c931b..fa684f020b6 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -31,6 +31,7 @@ from danswer.indexing.indexing_pipeline import build_indexing_pipeline from danswer.utils.logger import IndexAttemptSingleton from danswer.utils.logger import setup_logger +from danswer.utils.variable_functionality import global_version logger = setup_logger() @@ -303,11 +304,14 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA return attempt -def run_indexing_entrypoint(index_attempt_id: int) -> None: +def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None: """Entrypoint for indexing run when using dask distributed. Wraps the actual logic in a `try` block so that we can catch any exceptions and mark the attempt as failed.""" try: + if is_ee: + global_version.set_ee() + # set the indexing attempt ID so that all log messages from this process # will have it added as a prefix IndexAttemptSingleton.set_index_attempt_id(index_attempt_id) diff --git a/backend/danswer/background/task_utils.py b/backend/danswer/background/task_utils.py index 902abdfec86..6e122678813 100644 --- a/backend/danswer/background/task_utils.py +++ b/backend/danswer/background/task_utils.py @@ -22,8 +22,13 @@ def name_document_set_sync_task(document_set_id: int) -> str: return f"sync_doc_set_{document_set_id}" -def name_cc_prune_task(connector_id: int, credential_id: int) -> str: - return f"prune_connector_credential_pair_{connector_id}_{credential_id}" +def name_cc_prune_task( + connector_id: int | None = None, credential_id: int | None = None +) -> str: + task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}" + if not connector_id or not credential_id: + task_name = "prune_connector_credential_pair" + return task_name T = TypeVar("T", bound=Callable) diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 8b115e44821..9ca65f8b33a 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -35,6 +35,8 @@ from danswer.db.swap_index import check_index_swap from danswer.search.search_nlp_models import warm_up_encoders from danswer.utils.logger import setup_logger +from danswer.utils.variable_functionality import global_version +from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable from shared_configs.configs import INDEXING_MODEL_SERVER_HOST from shared_configs.configs import LOG_LEVEL from shared_configs.configs import MODEL_SERVER_PORT @@ -307,10 +309,18 @@ def kickoff_indexing_jobs( if use_secondary_index: run = secondary_client.submit( - run_indexing_entrypoint, attempt.id, pure=False + run_indexing_entrypoint, + attempt.id, + global_version.get_is_ee_version(), + pure=False, ) else: - run = client.submit(run_indexing_entrypoint, attempt.id, pure=False) + run = client.submit( + run_indexing_entrypoint, + attempt.id, + global_version.get_is_ee_version(), + pure=False, + ) if run: secondary_str = "(secondary index) " if use_secondary_index else "" @@ -398,6 +408,8 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non def update__main() -> None: + set_is_ee_based_on_env_variable() + logger.info("Starting Indexing Loop") update_loop() diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index f4b0b2e02c5..7e64a118e0d 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -1,5 +1,4 @@ import re -from collections.abc import Sequence from typing import cast from sqlalchemy.orm import Session @@ -9,42 +8,30 @@ from danswer.db.chat import get_chat_messages_by_session from danswer.db.models import ChatMessage from danswer.llm.answering.models import PreviousMessage -from danswer.search.models import InferenceChunk from danswer.search.models import InferenceSection from danswer.utils.logger import setup_logger logger = setup_logger() -def llm_doc_from_inference_section(inf_chunk: InferenceSection) -> LlmDoc: +def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc: return LlmDoc( - document_id=inf_chunk.document_id, + document_id=inference_section.center_chunk.document_id, # This one is using the combined content of all the chunks of the section # In default settings, this is the same as just the content of base chunk - content=inf_chunk.combined_content, - blurb=inf_chunk.blurb, - semantic_identifier=inf_chunk.semantic_identifier, - source_type=inf_chunk.source_type, - metadata=inf_chunk.metadata, - updated_at=inf_chunk.updated_at, - link=inf_chunk.source_links[0] if inf_chunk.source_links else None, - source_links=inf_chunk.source_links, + content=inference_section.combined_content, + blurb=inference_section.center_chunk.blurb, + semantic_identifier=inference_section.center_chunk.semantic_identifier, + source_type=inference_section.center_chunk.source_type, + metadata=inference_section.center_chunk.metadata, + updated_at=inference_section.center_chunk.updated_at, + link=inference_section.center_chunk.source_links[0] + if inference_section.center_chunk.source_links + else None, + source_links=inference_section.center_chunk.source_links, ) -def map_document_id_order( - chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True -) -> dict[str, int]: - order_mapping = {} - current = 1 if one_indexed else 0 - for chunk in chunks: - if chunk.document_id not in order_mapping: - order_mapping[chunk.document_id] = current - current += 1 - - return order_mapping - - def create_chat_chain( chat_session_id: int, db_session: Session, diff --git a/backend/danswer/chat/load_yamls.py b/backend/danswer/chat/load_yamls.py index 4036730a9ab..342802e6c5c 100644 --- a/backend/danswer/chat/load_yamls.py +++ b/backend/danswer/chat/load_yamls.py @@ -1,5 +1,3 @@ -from typing import cast - import yaml from sqlalchemy.orm import Session @@ -50,7 +48,7 @@ def load_personas_from_yaml( with Session(get_sqlalchemy_engine()) as db_session: for persona in all_personas: doc_set_names = persona["document_sets"] - doc_sets: list[DocumentSetDBModel] | None = [ + doc_sets: list[DocumentSetDBModel] = [ get_or_create_document_set_by_name(db_session, name) for name in doc_set_names ] @@ -58,22 +56,24 @@ def load_personas_from_yaml( # Assume if user hasn't set any document sets for the persona, the user may want # to later attach document sets to the persona manually, therefore, don't overwrite/reset # the document sets for the persona - if not doc_sets: - doc_sets = None + doc_set_ids: list[int] | None = None + if doc_sets: + doc_set_ids = [doc_set.id for doc_set in doc_sets] + else: + doc_set_ids = None + prompt_ids: list[int] | None = None prompt_set_names = persona["prompts"] - if not prompt_set_names: - prompts: list[PromptDBModel | None] | None = None - else: - prompts = [ + if prompt_set_names: + prompts: list[PromptDBModel | None] = [ get_prompt_by_name(prompt_name, user=None, db_session=db_session) for prompt_name in prompt_set_names ] if any([prompt is None for prompt in prompts]): raise ValueError("Invalid Persona configs, not all prompts exist") - if not prompts: - prompts = None + if prompts: + prompt_ids = [prompt.id for prompt in prompts if prompt is not None] p_id = persona.get("id") upsert_persona( @@ -91,8 +91,8 @@ def load_personas_from_yaml( llm_model_provider_override=None, llm_model_version_override=None, recency_bias=RecencyBiasSetting(persona["recency_bias"]), - prompts=cast(list[PromptDBModel] | None, prompts), - document_sets=doc_sets, + prompt_ids=prompt_ids, + document_set_ids=doc_set_ids, default_persona=True, is_public=True, db_session=db_session, diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 7733bf523d6..1781c2232f1 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -10,14 +10,15 @@ from danswer.chat.models import CustomToolResponse from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import ImageGenerationDisplay -from danswer.chat.models import LlmDoc from danswer.chat.models import LLMRelevanceFilterResponse from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError +from danswer.configs.chat_configs import BING_API_KEY from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.configs.constants import MessageType +from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.db.chat import attach_files_to_chat_message from danswer.db.chat import create_db_search_doc from danswer.db.chat import create_new_chat_message @@ -34,6 +35,7 @@ from danswer.db.models import SearchDoc as DbSearchDoc from danswer.db.models import ToolCall from danswer.db.models import User +from danswer.db.persona import get_persona_by_id from danswer.document_index.factory import get_default_document_index from danswer.file_store.models import ChatFileType from danswer.file_store.models import FileDescriptor @@ -46,10 +48,15 @@ from danswer.llm.answering.models import PreviousMessage from danswer.llm.answering.models import PromptConfig from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_llm_for_persona +from danswer.llm.factory import get_llms_for_persona +from danswer.llm.factory import get_main_llm_from_tuple +from danswer.llm.interfaces import LLMConfig from danswer.llm.utils import get_default_llm_tokenizer from danswer.search.enums import OptionalSearchSetting -from danswer.search.retrieval.search_runner import inference_documents_from_ids +from danswer.search.enums import QueryFlow +from danswer.search.enums import SearchType +from danswer.search.models import InferenceSection +from danswer.search.retrieval.search_runner import inference_sections_from_ids from danswer.search.utils import chunks_or_sections_to_search_docs from danswer.search.utils import dedupe_documents from danswer.search.utils import drop_llm_indices @@ -64,6 +71,14 @@ from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID from danswer.tools.images.image_generation_tool import ImageGenerationResponse from danswer.tools.images.image_generation_tool import ImageGenerationTool +from danswer.tools.internet_search.internet_search_tool import ( + INTERNET_SEARCH_RESPONSE_ID, +) +from danswer.tools.internet_search.internet_search_tool import ( + internet_search_response_to_search_docs, +) +from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse +from danswer.tools.internet_search.internet_search_tool import InternetSearchTool from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID from danswer.tools.search.search_tool import SearchResponseSummary from danswer.tools.search.search_tool import SearchTool @@ -141,6 +156,37 @@ def _handle_search_tool_response_summary( ) +def _handle_internet_search_tool_response_summary( + packet: ToolResponse, + db_session: Session, +) -> tuple[QADocsResponse, list[DbSearchDoc]]: + internet_search_response = cast(InternetSearchResponse, packet.response) + server_search_docs = internet_search_response_to_search_docs( + internet_search_response + ) + + reference_db_search_docs = [ + create_db_search_doc(server_search_doc=doc, db_session=db_session) + for doc in server_search_docs + ] + response_docs = [ + translate_db_search_doc_to_server_search_doc(db_search_doc) + for db_search_doc in reference_db_search_docs + ] + return ( + QADocsResponse( + rephrased_query=internet_search_response.revised_query, + top_documents=response_docs, + predicted_flow=QueryFlow.QUESTION_ANSWER, + predicted_search=SearchType.HYBRID, + applied_source_filters=[], + applied_time_cutoff=None, + recency_bias_multiplier=1.0, + ), + reference_db_search_docs, + ) + + def _check_should_force_search( new_msg_req: CreateChatMessageRequest, ) -> ForceUseTool | None: @@ -168,7 +214,7 @@ def _check_should_force_search( args = {"query": new_msg_req.message} return ForceUseTool( - tool_name=SearchTool.NAME, + tool_name=SearchTool._NAME, args=args, ) return None @@ -223,7 +269,15 @@ def stream_chat_message_objects( parent_id = new_msg_req.parent_message_id reference_doc_ids = new_msg_req.search_doc_ids retrieval_options = new_msg_req.retrieval_options - persona = chat_session.persona + alternate_assistant_id = new_msg_req.alternate_assistant_id + + # use alternate persona if alternative assistant id is passed in + if alternate_assistant_id is not None: + persona = get_persona_by_id( + alternate_assistant_id, user=user, db_session=db_session + ) + else: + persona = chat_session.persona prompt_id = new_msg_req.prompt_id if prompt_id is None and persona.prompts: @@ -235,7 +289,7 @@ def stream_chat_message_objects( ) try: - llm = get_llm_for_persona( + llm, fast_llm = get_llms_for_persona( persona=persona, llm_override=new_msg_req.llm_override or chat_session.llm_override, additional_headers=litellm_additional_headers, @@ -328,7 +382,7 @@ def stream_chat_message_objects( ) selected_db_search_docs = None - selected_llm_docs: list[LlmDoc] | None = None + selected_sections: list[InferenceSection] | None = None if reference_doc_ids: identifier_tuples = get_doc_query_identifiers_from_model( search_doc_ids=reference_doc_ids, @@ -338,8 +392,8 @@ def stream_chat_message_objects( ) # Generates full documents currently - # May extend to include chunk ranges - selected_llm_docs = inference_documents_from_ids( + # May extend to use sections instead in the future + selected_sections = inference_sections_from_ids( doc_identifiers=identifier_tuples, document_index=document_index, ) @@ -380,6 +434,7 @@ def stream_chat_message_objects( # rephrased_query=, # token_count=, message_type=MessageType.ASSISTANT, + alternate_assistant_id=new_msg_req.alternate_assistant_id, # error=, # reference_docs=, db_session=db_session, @@ -389,11 +444,15 @@ def stream_chat_message_objects( if not final_msg.prompt: raise RuntimeError("No Prompt found") - prompt_config = PromptConfig.from_model( - final_msg.prompt, - prompt_override=( - new_msg_req.prompt_override or chat_session.prompt_override - ), + prompt_config = ( + PromptConfig.from_model( + final_msg.prompt, + prompt_override=( + new_msg_req.prompt_override or chat_session.prompt_override + ), + ) + if not persona + else PromptConfig.from_model(persona.prompts[0]) ) # find out what tools to use @@ -411,21 +470,22 @@ def stream_chat_message_objects( retrieval_options=retrieval_options, prompt_config=prompt_config, llm=llm, + fast_llm=fast_llm, pruning_config=document_pruning_config, - selected_docs=selected_llm_docs, + selected_sections=selected_sections, chunks_above=new_msg_req.chunks_above, chunks_below=new_msg_req.chunks_below, full_doc=new_msg_req.full_doc, ) tool_dict[db_tool_model.id] = [search_tool] elif tool_cls.__name__ == ImageGenerationTool.__name__: - dalle_key = None + img_generation_llm_config: LLMConfig | None = None if ( llm and llm.config.api_key and llm.config.model_provider == "openai" ): - dalle_key = llm.config.api_key + img_generation_llm_config = llm.config else: llm_providers = fetch_existing_llm_providers(db_session) openai_provider = next( @@ -442,9 +502,30 @@ def stream_chat_message_objects( raise ValueError( "Image generation tool requires an OpenAI API key" ) - dalle_key = openai_provider.api_key + img_generation_llm_config = LLMConfig( + model_provider=openai_provider.provider, + model_name=openai_provider.default_model_name, + temperature=GEN_AI_TEMPERATURE, + api_key=openai_provider.api_key, + api_base=openai_provider.api_base, + api_version=openai_provider.api_version, + ) + tool_dict[db_tool_model.id] = [ + ImageGenerationTool( + api_key=cast(str, img_generation_llm_config.api_key), + api_base=img_generation_llm_config.api_base, + api_version=img_generation_llm_config.api_version, + additional_headers=litellm_additional_headers, + ) + ] + elif tool_cls.__name__ == InternetSearchTool.__name__: + bing_api_key = BING_API_KEY + if not bing_api_key: + raise ValueError( + "Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!" + ) tool_dict[db_tool_model.id] = [ - ImageGenerationTool(api_key=dalle_key) + InternetSearchTool(api_key=bing_api_key) ] continue @@ -481,10 +562,14 @@ def stream_chat_message_objects( prompt_config=prompt_config, llm=( llm - or get_llm_for_persona( - persona=persona, - llm_override=new_msg_req.llm_override or chat_session.llm_override, - additional_headers=litellm_additional_headers, + or get_main_llm_from_tuple( + get_llms_for_persona( + persona=persona, + llm_override=( + new_msg_req.llm_override or chat_session.llm_override + ), + additional_headers=litellm_additional_headers, + ) ) ), message_history=[ @@ -548,6 +633,15 @@ def stream_chat_message_objects( yield ImageGenerationDisplay( file_ids=[str(file_id) for file_id in file_ids] ) + elif packet.id == INTERNET_SEARCH_RESPONSE_ID: + ( + qa_docs_response, + reference_db_search_docs, + ) = _handle_internet_search_tool_response_summary( + packet=packet, + db_session=db_session, + ) + yield qa_docs_response elif packet.id == CUSTOM_TOOL_RESPONSE_ID: custom_tool_response = cast(CustomToolCallSummary, packet.response) yield CustomToolResponse( @@ -589,7 +683,7 @@ def stream_chat_message_objects( tool_name_to_tool_id: dict[str, int] = {} for tool_id, tool_list in tool_dict.items(): for tool in tool_list: - tool_name_to_tool_id[tool.name()] = tool_id + tool_name_to_tool_id[tool.name] = tool_id gen_ai_response_message = partial_response( message=answer.llm_answer, diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 996700aef48..47cfee37fec 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -4,6 +4,7 @@ from danswer.configs.constants import AuthType from danswer.configs.constants import DocumentIndexType +from danswer.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy ##### # App Configs @@ -45,13 +46,14 @@ # Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive # information. This provides an extra layer of security on top of Postgres access controls # and is available in Danswer EE -ENCRYPTION_KEY_SECRET = os.environ.get("ENCRYPTION_KEY_SECRET") +ENCRYPTION_KEY_SECRET = os.environ.get("ENCRYPTION_KEY_SECRET") or "" # Turn off mask if admin users should see full credentials for data connectors. MASK_CREDENTIAL_PREFIX = ( os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false" ) + SESSION_EXPIRE_TIME_SECONDS = int( os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7 ) # 7 days @@ -160,6 +162,11 @@ WEB_CONNECTOR_OAUTH_TOKEN_URL = os.environ.get("WEB_CONNECTOR_OAUTH_TOKEN_URL") WEB_CONNECTOR_VALIDATE_URLS = os.environ.get("WEB_CONNECTOR_VALIDATE_URLS") +HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY = os.environ.get( + "HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY", + HtmlBasedConnectorTransformLinksStrategy.STRIP, +) + NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP = ( os.environ.get("NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP", "").lower() == "true" @@ -178,6 +185,12 @@ os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES", "").lower() == "true" ) +# Save pages labels as Danswer metadata tags +# The reason to skip this would be to reduce the number of calls to Confluence due to rate limit concerns +CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING = ( + os.environ.get("CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING", "").lower() == "true" +) + JIRA_CONNECTOR_LABELS_TO_SKIP = [ ignored_tag for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",") @@ -201,6 +214,15 @@ DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day +PREVENT_SIMULTANEOUS_PRUNING = ( + os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true" +) + +# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation. +MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int( + os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0) +) + ##### # Indexing Configs @@ -244,10 +266,14 @@ CURRENT_PROCESS_IS_AN_INDEXING_JOB = ( os.environ.get("CURRENT_PROCESS_IS_AN_INDEXING_JOB", "").lower() == "true" ) -# Logs every model prompt and output, mostly used for development or exploration purposes +# Sets LiteLLM to verbose logging LOG_ALL_MODEL_INTERACTIONS = ( os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true" ) +# Logs Danswer only model interactions like prompts, responses, messages etc. +LOG_DANSWER_MODEL_INTERACTIONS = ( + os.environ.get("LOG_DANSWER_MODEL_INTERACTIONS", "").lower() == "true" +) # If set to `true` will enable additional logs about Vespa query performance # (time spent on finding the right docs + time spent fetching summaries from disk) LOG_VESPA_TIMING_INFORMATION = ( @@ -266,3 +292,15 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads( os.environ.get("CUSTOM_ANSWER_VALIDITY_CONDITIONS", "[]") ) + + +##### +# Enterprise Edition Configs +##### +# NOTE: this should only be enabled if you have purchased an enterprise license. +# if you're interested in an enterprise license, please reach out to us at +# founders@danswer.ai OR message Chris Weaver or Yuhong Sun in the Danswer +# Slack community (https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ) +ENTERPRISE_EDITION_ENABLED = ( + os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true" +) diff --git a/backend/danswer/configs/chat_configs.py b/backend/danswer/configs/chat_configs.py index 66cb8adceb2..198793a0043 100644 --- a/backend/danswer/configs/chat_configs.py +++ b/backend/danswer/configs/chat_configs.py @@ -5,7 +5,10 @@ PERSONAS_YAML = "./danswer/chat/personas.yaml" NUM_RETURNED_HITS = os.environ.get("TOOL_SEARCH_NUM_RETURNED_HITS") or 50 -NUM_RERANKED_RESULTS = 15 +# Used for LLM filtering and reranking +# We want this to be approximately the number of results we want to show on the first page +# It cannot be too large due to cost and latency implications +NUM_RERANKED_RESULTS = 20 # May be less depending on model MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0) @@ -25,9 +28,10 @@ FAVOR_RECENT_DECAY_MULTIPLIER = 2.0 # Currently this next one is not configurable via env DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak" -DISABLE_LLM_FILTER_EXTRACTION = ( - os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true" -) +# For the highest matching base size chunk, how many chunks above and below do we pull in by default +# Note this is not in any of the deployment configs yet +CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0) +CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0) # Whether the LLM should evaluate all of the document chunks passed in for usefulness # in relation to the user query DISABLE_LLM_CHUNK_FILTER = ( @@ -43,8 +47,6 @@ # 1 edit per 20 characters, currently unused due to fuzzy match being too slow QUOTE_ALLOWED_ERROR_PERCENT = 0.05 QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds -# Include additional document/chunk metadata in prompt to GenerativeAI -INCLUDE_METADATA = False # Keyword Search Drop Stopwords # If user has changed the default model, would most likely be to use a multilingual # model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords @@ -64,9 +66,20 @@ # A list of languages passed to the LLM to rephase the query # For example "English,French,Spanish", be sure to use the "," separator MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None +LANGUAGE_HINT = "\n" + ( + os.environ.get("LANGUAGE_HINT") + or "IMPORTANT: Respond in the same language as my query!" +) +LANGUAGE_CHAT_NAMING_HINT = ( + os.environ.get("LANGUAGE_CHAT_NAMING_HINT") + or "The name of the conversation must be in the same language as the user query." +) # Stops streaming answers back to the UI if this pattern is seen: STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None # The backend logic for this being True isn't fully supported yet HARD_DELETE_CHATS = False + +# Internet Search +BING_API_KEY = os.environ.get("BING_API_KEY") or None diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 9e0d318c5ec..1042d688532 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -41,10 +41,6 @@ SESSION_KEY = "session" QUERY_EVENT_ID = "query_event_id" LLM_CHUNKS = "llm_chunks" -TOKEN_BUDGET = "token_budget" -TOKEN_BUDGET_TIME_PERIOD = "token_budget_time_period" -ENABLE_TOKEN_BUDGET = "enable_token_budget" -TOKEN_BUDGET_SETTINGS = "token_budget_settings" # For chunking/processing chunks TITLE_SEPARATOR = "\n\r\n" @@ -104,6 +100,21 @@ class DocumentSource(str, Enum): CLICKUP = "clickup" MEDIAWIKI = "mediawiki" WIKIPEDIA = "wikipedia" + S3 = "s3" + R2 = "r2" + GOOGLE_CLOUD_STORAGE = "google_cloud_storage" + OCI_STORAGE = "oci_storage" + NOT_APPLICABLE = "not_applicable" + + +class BlobType(str, Enum): + R2 = "r2" + S3 = "s3" + GOOGLE_CLOUD_STORAGE = "google_cloud_storage" + OCI_STORAGE = "oci_storage" + + # Special case, for internet search + NOT_APPLICABLE = "not_applicable" class DocumentIndexType(str, Enum): @@ -119,6 +130,11 @@ class AuthType(str, Enum): SAML = "saml" +class QAFeedbackType(str, Enum): + LIKE = "like" # User likes the answer, used for metrics + DISLIKE = "dislike" # User dislikes the answer, used for metrics + + class SearchFeedbackType(str, Enum): ENDORSE = "endorse" # boost this document for all future queries REJECT = "reject" # down-boost this document for all future queries @@ -144,4 +160,5 @@ class FileOrigin(str, Enum): CHAT_UPLOAD = "chat_upload" CHAT_IMAGE_GEN = "chat_image_gen" CONNECTOR = "connector" + GENERATED_REPORT = "generated_report" OTHER = "other" diff --git a/backend/danswer/configs/danswerbot_configs.py b/backend/danswer/configs/danswerbot_configs.py index 5500d28c3ef..b75d69d1f13 100644 --- a/backend/danswer/configs/danswerbot_configs.py +++ b/backend/danswer/configs/danswerbot_configs.py @@ -47,10 +47,6 @@ DANSWER_BOT_RESPOND_EVERY_CHANNEL = ( os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true" ) -# Auto detect query options like time cutoff or heavily favor recently updated docs -DISABLE_DANSWER_BOT_FILTER_DETECT = ( - os.environ.get("DISABLE_DANSWER_BOT_FILTER_DETECT", "").lower() == "true" -) # Add a second LLM call post Answer to verify if the Answer is valid # Throws out answers that don't directly or fully answer the user query # This is the default for all DanswerBot channels unless the channel is configured individually diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index f5be897792b..b4c0e8cf29e 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -39,8 +39,8 @@ # Purely an optimization, memory limitation consideration BATCH_SIZE_ENCODE_CHUNKS = 8 # For score display purposes, only way is to know the expected ranges -CROSS_ENCODER_RANGE_MAX = 12 -CROSS_ENCODER_RANGE_MIN = -12 +CROSS_ENCODER_RANGE_MAX = 1 +CROSS_ENCODER_RANGE_MIN = 0 # Unused currently, can't be used with the current default encoder model due to its output range SEARCH_DISTANCE_CUTOFF = 0 diff --git a/backend/danswer/connectors/blob/__init__.py b/backend/danswer/connectors/blob/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/danswer/connectors/blob/connector.py b/backend/danswer/connectors/blob/connector.py new file mode 100644 index 00000000000..2446bfd1666 --- /dev/null +++ b/backend/danswer/connectors/blob/connector.py @@ -0,0 +1,277 @@ +import os +from datetime import datetime +from datetime import timezone +from io import BytesIO +from typing import Any +from typing import Optional + +import boto3 +from botocore.client import Config +from mypy_boto3_s3 import S3Client + +from danswer.configs.app_configs import INDEX_BATCH_SIZE +from danswer.configs.constants import BlobType +from danswer.configs.constants import DocumentSource +from danswer.connectors.interfaces import GenerateDocumentsOutput +from danswer.connectors.interfaces import LoadConnector +from danswer.connectors.interfaces import PollConnector +from danswer.connectors.interfaces import SecondsSinceUnixEpoch +from danswer.connectors.models import ConnectorMissingCredentialError +from danswer.connectors.models import Document +from danswer.connectors.models import Section +from danswer.file_processing.extract_file_text import extract_file_text +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +class BlobStorageConnector(LoadConnector, PollConnector): + def __init__( + self, + bucket_type: str, + bucket_name: str, + prefix: str = "", + batch_size: int = INDEX_BATCH_SIZE, + ) -> None: + self.bucket_type: BlobType = BlobType(bucket_type) + self.bucket_name = bucket_name + self.prefix = prefix if not prefix or prefix.endswith("/") else prefix + "/" + self.batch_size = batch_size + self.s3_client: Optional[S3Client] = None + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + """Checks for boto3 credentials based on the bucket type. + (1) R2: Access Key ID, Secret Access Key, Account ID + (2) S3: AWS Access Key ID, AWS Secret Access Key + (3) GOOGLE_CLOUD_STORAGE: Access Key ID, Secret Access Key, Project ID + (4) OCI_STORAGE: Namespace, Region, Access Key ID, Secret Access Key + + For each bucket type, the method initializes the appropriate S3 client: + - R2: Uses Cloudflare R2 endpoint with S3v4 signature + - S3: Creates a standard boto3 S3 client + - GOOGLE_CLOUD_STORAGE: Uses Google Cloud Storage endpoint + - OCI_STORAGE: Uses Oracle Cloud Infrastructure Object Storage endpoint + + Raises ConnectorMissingCredentialError if required credentials are missing. + Raises ValueError for unsupported bucket types. + """ + + logger.info( + f"Loading credentials for {self.bucket_name} or type {self.bucket_type}" + ) + + if self.bucket_type == BlobType.R2: + if not all( + credentials.get(key) + for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"] + ): + raise ConnectorMissingCredentialError("Cloudflare R2") + self.s3_client = boto3.client( + "s3", + endpoint_url=f"https://{credentials['account_id']}.r2.cloudflarestorage.com", + aws_access_key_id=credentials["r2_access_key_id"], + aws_secret_access_key=credentials["r2_secret_access_key"], + region_name="auto", + config=Config(signature_version="s3v4"), + ) + + elif self.bucket_type == BlobType.S3: + if not all( + credentials.get(key) + for key in ["aws_access_key_id", "aws_secret_access_key"] + ): + raise ConnectorMissingCredentialError("Google Cloud Storage") + + session = boto3.Session( + aws_access_key_id=credentials["aws_access_key_id"], + aws_secret_access_key=credentials["aws_secret_access_key"], + ) + self.s3_client = session.client("s3") + + elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE: + if not all( + credentials.get(key) for key in ["access_key_id", "secret_access_key"] + ): + raise ConnectorMissingCredentialError("Google Cloud Storage") + + self.s3_client = boto3.client( + "s3", + endpoint_url="https://storage.googleapis.com", + aws_access_key_id=credentials["access_key_id"], + aws_secret_access_key=credentials["secret_access_key"], + region_name="auto", + ) + + elif self.bucket_type == BlobType.OCI_STORAGE: + if not all( + credentials.get(key) + for key in ["namespace", "region", "access_key_id", "secret_access_key"] + ): + raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure") + + self.s3_client = boto3.client( + "s3", + endpoint_url=f"https://{credentials['namespace']}.compat.objectstorage.{credentials['region']}.oraclecloud.com", + aws_access_key_id=credentials["access_key_id"], + aws_secret_access_key=credentials["secret_access_key"], + region_name=credentials["region"], + ) + + else: + raise ValueError(f"Unsupported bucket type: {self.bucket_type}") + + return None + + def _download_object(self, key: str) -> bytes: + if self.s3_client is None: + raise ConnectorMissingCredentialError("Blob storage") + object = self.s3_client.get_object(Bucket=self.bucket_name, Key=key) + return object["Body"].read() + + # NOTE: Left in as may be useful for one-off access to documents and sharing across orgs. + # def _get_presigned_url(self, key: str) -> str: + # if self.s3_client is None: + # raise ConnectorMissingCredentialError("Blog storage") + + # url = self.s3_client.generate_presigned_url( + # "get_object", + # Params={"Bucket": self.bucket_name, "Key": key}, + # ExpiresIn=self.presign_length, + # ) + # return url + + def _get_blob_link(self, key: str) -> str: + if self.s3_client is None: + raise ConnectorMissingCredentialError("Blob storage") + + if self.bucket_type == BlobType.R2: + account_id = self.s3_client.meta.endpoint_url.split("//")[1].split(".")[0] + return f"https://{account_id}.r2.cloudflarestorage.com/{self.bucket_name}/{key}" + + elif self.bucket_type == BlobType.S3: + region = self.s3_client.meta.region_name + return f"https://{self.bucket_name}.s3.{region}.amazonaws.com/{key}" + + elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE: + return f"https://storage.cloud.google.com/{self.bucket_name}/{key}" + + elif self.bucket_type == BlobType.OCI_STORAGE: + namespace = self.s3_client.meta.endpoint_url.split("//")[1].split(".")[0] + region = self.s3_client.meta.region_name + return f"https://objectstorage.{region}.oraclecloud.com/n/{namespace}/b/{self.bucket_name}/o/{key}" + + else: + raise ValueError(f"Unsupported bucket type: {self.bucket_type}") + + def _yield_blob_objects( + self, + start: datetime, + end: datetime, + ) -> GenerateDocumentsOutput: + if self.s3_client is None: + raise ConnectorMissingCredentialError("Blog storage") + + paginator = self.s3_client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix) + + batch: list[Document] = [] + for page in pages: + if "Contents" not in page: + continue + + for obj in page["Contents"]: + if obj["Key"].endswith("/"): + continue + + last_modified = obj["LastModified"].replace(tzinfo=timezone.utc) + + if not start <= last_modified <= end: + continue + + downloaded_file = self._download_object(obj["Key"]) + link = self._get_blob_link(obj["Key"]) + name = os.path.basename(obj["Key"]) + + try: + text = extract_file_text( + name, + BytesIO(downloaded_file), + break_on_unprocessable=False, + ) + batch.append( + Document( + id=f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}", + sections=[Section(link=link, text=text)], + source=DocumentSource(self.bucket_type.value), + semantic_identifier=name, + doc_updated_at=last_modified, + metadata={}, + ) + ) + if len(batch) == self.batch_size: + yield batch + batch = [] + + except Exception as e: + logger.exception( + f"Error decoding object {obj['Key']} as UTF-8: {e}" + ) + if batch: + yield batch + + def load_from_state(self) -> GenerateDocumentsOutput: + logger.info("Loading blob objects") + return self._yield_blob_objects( + start=datetime(1970, 1, 1, tzinfo=timezone.utc), + end=datetime.now(timezone.utc), + ) + + def poll_source( + self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch + ) -> GenerateDocumentsOutput: + if self.s3_client is None: + raise ConnectorMissingCredentialError("Blog storage") + + start_datetime = datetime.fromtimestamp(start, tz=timezone.utc) + end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) + + for batch in self._yield_blob_objects(start_datetime, end_datetime): + yield batch + + return None + + +if __name__ == "__main__": + credentials_dict = { + "aws_access_key_id": os.environ.get("AWS_ACCESS_KEY_ID"), + "aws_secret_access_key": os.environ.get("AWS_SECRET_ACCESS_KEY"), + } + + # Initialize the connector + connector = BlobStorageConnector( + bucket_type=os.environ.get("BUCKET_TYPE") or "s3", + bucket_name=os.environ.get("BUCKET_NAME") or "test", + prefix="", + ) + + try: + connector.load_credentials(credentials_dict) + document_batch_generator = connector.load_from_state() + for document_batch in document_batch_generator: + print("First batch of documents:") + for doc in document_batch: + print(f"Document ID: {doc.id}") + print(f"Semantic Identifier: {doc.semantic_identifier}") + print(f"Source: {doc.source}") + print(f"Updated At: {doc.doc_updated_at}") + print("Sections:") + for section in doc.sections: + print(f" - Link: {section.link}") + print(f" - Text: {section.text[:100]}...") + print("---") + break + + except ConnectorMissingCredentialError as e: + print(f"Error: {e}") + except Exception as e: + print(f"An unexpected error occurred: {e}") diff --git a/backend/danswer/connectors/confluence/connector.py b/backend/danswer/connectors/confluence/connector.py index 3c682e3a753..1c348df6a3b 100644 --- a/backend/danswer/connectors/confluence/connector.py +++ b/backend/danswer/connectors/confluence/connector.py @@ -15,6 +15,7 @@ from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP +from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource @@ -36,16 +37,18 @@ logger = setup_logger() # Potential Improvements -# 1. If wiki page instead of space, do a search of all the children of the page instead of index all in the space -# 2. Include attachments, etc -# 3. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost +# 1. Include attachments, etc +# 2. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost -def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str]: +def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]: """Sample - https://danswer.atlassian.net/wiki/spaces/1234abcd/overview + URL w/ page: https://danswer.atlassian.net/wiki/spaces/1234abcd/pages/5678efgh/overview + URL w/o page: https://danswer.atlassian.net/wiki/spaces/ASAM/overview + wiki_base is https://danswer.atlassian.net/wiki space is 1234abcd + page_id is 5678efgh """ parsed_url = urlparse(wiki_url) wiki_base = ( @@ -54,18 +57,25 @@ def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str]: + parsed_url.netloc + parsed_url.path.split("/spaces")[0] ) - space = parsed_url.path.split("/")[3] - return wiki_base, space + path_parts = parsed_url.path.split("/") + space = path_parts[3] + + page_id = path_parts[5] if len(path_parts) > 5 else "" + return wiki_base, space, page_id -def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str]: + +def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str, str]: """Sample - https://danswer.ai/confluence/display/1234abcd/overview + URL w/ page https://danswer.ai/confluence/display/1234abcd/pages/5678efgh/overview + URL w/o page https://danswer.ai/confluence/display/1234abcd/overview wiki_base is https://danswer.ai/confluence space is 1234abcd + page_id is 5678efgh """ - # /display/ is always right before the space and at the end of the base url + # /display/ is always right before the space and at the end of the base print() DISPLAY = "/display/" + PAGE = "/pages/" parsed_url = urlparse(wiki_url) wiki_base = ( @@ -75,10 +85,13 @@ def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, st + parsed_url.path.split(DISPLAY)[0] ) space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0] - return wiki_base, space + page_id = "" + if (content := parsed_url.path.split(PAGE)) and len(content) > 1: + page_id = content[1] + return wiki_base, space, page_id -def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]: +def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]: is_confluence_cloud = ( ".atlassian.net/wiki/spaces/" in wiki_url or ".jira.com/wiki/spaces/" in wiki_url @@ -86,15 +99,19 @@ def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]: try: if is_confluence_cloud: - wiki_base, space = _extract_confluence_keys_from_cloud_url(wiki_url) + wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url( + wiki_url + ) else: - wiki_base, space = _extract_confluence_keys_from_datacenter_url(wiki_url) + wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url( + wiki_url + ) except Exception as e: - error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base and space names. Exception: {e}" + error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base, space, and page id. Exception: {e}" logger.error(error_msg) raise ValueError(error_msg) - return wiki_base, space, is_confluence_cloud + return wiki_base, space, page_id, is_confluence_cloud @lru_cache() @@ -195,10 +212,135 @@ def _comment_dfs( return comments_str +class RecursiveIndexer: + def __init__( + self, + batch_size: int, + confluence_client: Confluence, + index_origin: bool, + origin_page_id: str, + ) -> None: + self.batch_size = 1 + # batch_size + self.confluence_client = confluence_client + self.index_origin = index_origin + self.origin_page_id = origin_page_id + self.pages = self.recurse_children_pages(0, self.origin_page_id) + + def get_pages(self, ind: int, size: int) -> list[dict]: + if ind * size > len(self.pages): + return [] + return self.pages[ind * size : (ind + 1) * size] + + def _fetch_origin_page( + self, + ) -> dict[str, Any]: + get_page_by_id = make_confluence_call_handle_rate_limit( + self.confluence_client.get_page_by_id + ) + try: + origin_page = get_page_by_id( + self.origin_page_id, expand="body.storage.value,version" + ) + return origin_page + except Exception as e: + logger.warning( + f"Appending orgin page with id {self.origin_page_id} failed: {e}" + ) + return {} + + def recurse_children_pages( + self, + start_ind: int, + page_id: str, + ) -> list[dict[str, Any]]: + pages: list[dict[str, Any]] = [] + current_level_pages: list[dict[str, Any]] = [] + next_level_pages: list[dict[str, Any]] = [] + + # Initial fetch of first level children + index = start_ind + while batch := self._fetch_single_depth_child_pages( + index, self.batch_size, page_id + ): + current_level_pages.extend(batch) + index += len(batch) + + pages.extend(current_level_pages) + + # Recursively index children and children's children, etc. + while current_level_pages: + for child in current_level_pages: + child_index = 0 + while child_batch := self._fetch_single_depth_child_pages( + child_index, self.batch_size, child["id"] + ): + next_level_pages.extend(child_batch) + child_index += len(child_batch) + + pages.extend(next_level_pages) + current_level_pages = next_level_pages + next_level_pages = [] + + if self.index_origin: + try: + origin_page = self._fetch_origin_page() + pages.append(origin_page) + except Exception as e: + logger.warning(f"Appending origin page with id {page_id} failed: {e}") + + return pages + + def _fetch_single_depth_child_pages( + self, start_ind: int, batch_size: int, page_id: str + ) -> list[dict[str, Any]]: + child_pages: list[dict[str, Any]] = [] + + get_page_child_by_type = make_confluence_call_handle_rate_limit( + self.confluence_client.get_page_child_by_type + ) + + try: + child_page = get_page_child_by_type( + page_id, + type="page", + start=start_ind, + limit=batch_size, + expand="body.storage.value,version", + ) + + child_pages.extend(child_page) + return child_pages + + except Exception: + logger.warning( + f"Batch failed with page {page_id} at offset {start_ind} " + f"with size {batch_size}, processing pages individually..." + ) + + for i in range(batch_size): + ind = start_ind + i + try: + child_page = get_page_child_by_type( + page_id, + type="page", + start=ind, + limit=1, + expand="body.storage.value,version", + ) + child_pages.extend(child_page) + except Exception as e: + logger.warning(f"Page {page_id} at offset {ind} failed: {e}") + raise e + + return child_pages + + class ConfluenceConnector(LoadConnector, PollConnector): def __init__( self, wiki_page_url: str, + index_origin: bool = True, batch_size: int = INDEX_BATCH_SIZE, continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE, # if a page has one of the labels specified in this list, we will just @@ -209,11 +351,27 @@ def __init__( self.batch_size = batch_size self.continue_on_failure = continue_on_failure self.labels_to_skip = set(labels_to_skip) - self.wiki_base, self.space, self.is_cloud = extract_confluence_keys_from_url( - wiki_page_url - ) + self.recursive_indexer: RecursiveIndexer | None = None + self.index_origin = index_origin + ( + self.wiki_base, + self.space, + self.page_id, + self.is_cloud, + ) = extract_confluence_keys_from_url(wiki_page_url) + + self.space_level_scan = False + self.confluence_client: Confluence | None = None + if self.page_id is None or self.page_id == "": + self.space_level_scan = True + + logger.info( + f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id}," + + f" space_level_scan: {self.space_level_scan}, origin: {self.index_origin}" + ) + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: username = credentials["confluence_username"] access_token = credentials["confluence_access_token"] @@ -231,8 +389,8 @@ def _fetch_pages( self, confluence_client: Confluence, start_ind: int, - ) -> Collection[dict[str, Any]]: - def _fetch(start_ind: int, batch_size: int) -> Collection[dict[str, Any]]: + ) -> list[dict[str, Any]]: + def _fetch_space(start_ind: int, batch_size: int) -> list[dict[str, Any]]: get_all_pages_from_space = make_confluence_call_handle_rate_limit( confluence_client.get_all_pages_from_space ) @@ -241,9 +399,11 @@ def _fetch(start_ind: int, batch_size: int) -> Collection[dict[str, Any]]: self.space, start=start_ind, limit=batch_size, - status="current" - if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES - else None, + status=( + "current" + if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES + else None + ), expand="body.storage.value,version", ) except Exception: @@ -262,9 +422,11 @@ def _fetch(start_ind: int, batch_size: int) -> Collection[dict[str, Any]]: self.space, start=start_ind + i, limit=1, - status="current" - if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES - else None, + status=( + "current" + if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES + else None + ), expand="body.storage.value,version", ) ) @@ -285,17 +447,41 @@ def _fetch(start_ind: int, batch_size: int) -> Collection[dict[str, Any]]: return view_pages + def _fetch_page(start_ind: int, batch_size: int) -> list[dict[str, Any]]: + if self.recursive_indexer is None: + self.recursive_indexer = RecursiveIndexer( + origin_page_id=self.page_id, + batch_size=self.batch_size, + confluence_client=self.confluence_client, + index_origin=self.index_origin, + ) + + return self.recursive_indexer.get_pages(start_ind, batch_size) + + pages: list[dict[str, Any]] = [] + try: - return _fetch(start_ind, self.batch_size) + pages = ( + _fetch_space(start_ind, self.batch_size) + if self.space_level_scan + else _fetch_page(start_ind, self.batch_size) + ) + return pages + except Exception as e: if not self.continue_on_failure: raise e # error checking phase, only reachable if `self.continue_on_failure=True` - pages: list[dict[str, Any]] = [] for i in range(self.batch_size): try: - pages.extend(_fetch(start_ind + i, 1)) + pages = ( + _fetch_space(start_ind, self.batch_size) + if self.space_level_scan + else _fetch_page(start_ind, self.batch_size) + ) + return pages + except Exception: logger.exception( "Ran into exception when fetching pages from Confluence" @@ -307,6 +493,7 @@ def _fetch_comments(self, confluence_client: Confluence, page_id: str) -> str: get_page_child_by_type = make_confluence_call_handle_rate_limit( confluence_client.get_page_child_by_type ) + try: comment_pages = cast( Collection[dict[str, Any]], @@ -355,7 +542,14 @@ def _fetch_attachments( page_id, start=0, limit=500 ) for attachment in attachments_container["results"]: - if attachment["metadata"]["mediaType"] in ["image/jpeg", "image/png"]: + if attachment["metadata"]["mediaType"] in [ + "image/jpeg", + "image/png", + "image/gif", + "image/svg+xml", + "video/mp4", + "video/quicktime", + ]: continue if attachment["title"] not in files_in_used: @@ -366,7 +560,7 @@ def _fetch_attachments( if response.status_code == 200: extract = extract_file_text( - attachment["title"], io.BytesIO(response.content) + attachment["title"], io.BytesIO(response.content), False ) files_attachment_content.append(extract) @@ -386,8 +580,8 @@ def _get_doc_batch( if self.confluence_client is None: raise ConnectorMissingCredentialError("Confluence") - batch = self._fetch_pages(self.confluence_client, start_ind) + for page in batch: last_modified_str = page["version"]["when"] author = cast(str | None, page["version"].get("by", {}).get("email")) @@ -403,15 +597,18 @@ def _get_doc_batch( if time_filter is None or time_filter(last_modified): page_id = page["id"] + if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING: + page_labels = self._fetch_labels(self.confluence_client, page_id) + # check disallowed labels if self.labels_to_skip: - page_labels = self._fetch_labels(self.confluence_client, page_id) label_intersection = self.labels_to_skip.intersection(page_labels) if label_intersection: logger.info( f"Page with ID '{page_id}' has a label which has been " f"designated as disallowed: {label_intersection}. Skipping." ) + continue page_html = ( @@ -432,6 +629,11 @@ def _get_doc_batch( page_text += attachment_text comments_text = self._fetch_comments(self.confluence_client, page_id) page_text += comments_text + doc_metadata: dict[str, str | list[str]] = { + "Wiki Space Name": self.space + } + if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels: + doc_metadata["labels"] = page_labels doc_batch.append( Document( @@ -440,12 +642,10 @@ def _get_doc_batch( source=DocumentSource.CONFLUENCE, semantic_identifier=page["title"], doc_updated_at=last_modified, - primary_owners=[BasicExpertInfo(email=author)] - if author - else None, - metadata={ - "Wiki Space Name": self.space, - }, + primary_owners=( + [BasicExpertInfo(email=author)] if author else None + ), + metadata=doc_metadata, ) ) return doc_batch, len(batch) diff --git a/backend/danswer/connectors/discourse/connector.py b/backend/danswer/connectors/discourse/connector.py index 1bf64a6b369..a637ff78c44 100644 --- a/backend/danswer/connectors/discourse/connector.py +++ b/backend/danswer/connectors/discourse/connector.py @@ -101,7 +101,11 @@ def _get_latest_topics( if end and end < last_time_dt: continue - if valid_categories and topic.get("category_id") not in valid_categories: + if ( + self.categories + and valid_categories + and topic.get("category_id") not in valid_categories + ): continue topic_ids.append(topic["id"]) @@ -138,10 +142,16 @@ def _get_doc_from_topic(self, topic_id: int) -> Document: sections.append( Section(link=topic_url, text=parse_html_page_basic(post["cooked"])) ) + category_name = self.category_id_map.get(topic["category_id"]) + + metadata: dict[str, str | list[str]] = ( + { + "category": category_name, + } + if category_name + else {} + ) - metadata: dict[str, str | list[str]] = { - "category": self.category_id_map[topic["category_id"]], - } if topic.get("tags"): metadata["tags"] = topic["tags"] diff --git a/backend/danswer/connectors/factory.py b/backend/danswer/connectors/factory.py index f70980a8df5..1a3d605d3a5 100644 --- a/backend/danswer/connectors/factory.py +++ b/backend/danswer/connectors/factory.py @@ -5,6 +5,7 @@ from danswer.configs.constants import DocumentSource from danswer.connectors.axero.connector import AxeroConnector +from danswer.connectors.blob.connector import BlobStorageConnector from danswer.connectors.bookstack.connector import BookstackConnector from danswer.connectors.clickup.connector import ClickupConnector from danswer.connectors.confluence.connector import ConfluenceConnector @@ -90,6 +91,10 @@ def identify_connector_class( DocumentSource.CLICKUP: ClickupConnector, DocumentSource.MEDIAWIKI: MediaWikiConnector, DocumentSource.WIKIPEDIA: WikipediaConnector, + DocumentSource.S3: BlobStorageConnector, + DocumentSource.R2: BlobStorageConnector, + DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector, + DocumentSource.OCI_STORAGE: BlobStorageConnector, } connector_by_source = connector_map.get(source, {}) @@ -115,7 +120,6 @@ def identify_connector_class( raise ConnectorMissingException( f"Connector for source={source} does not accept input_type={input_type}" ) - return connector diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index c44d2751551..7fa4aae8b37 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -476,6 +476,11 @@ def _fetch_docs_from_drive( doc_batch = [] for file in files_batch: try: + # Skip files that are shortcuts + if file.get("mimeType") == DRIVE_SHORTCUT_TYPE: + logger.info("Ignoring Drive Shortcut Filetype") + continue + if self.only_org_public: if "permissions" not in file: continue diff --git a/backend/danswer/connectors/notion/connector.py b/backend/danswer/connectors/notion/connector.py index cadf10b328f..fd607e4f97a 100644 --- a/backend/danswer/connectors/notion/connector.py +++ b/backend/danswer/connectors/notion/connector.py @@ -368,7 +368,7 @@ def _filter_pages_by_time( compare_time = time.mktime( time.strptime(page[filter_field], "%Y-%m-%dT%H:%M:%S.000Z") ) - if compare_time <= end or compare_time > start: + if compare_time > start and compare_time <= end: filtered_pages += [NotionPage(**page)] return filtered_pages diff --git a/backend/danswer/connectors/salesforce/connector.py b/backend/danswer/connectors/salesforce/connector.py index 9f10da78ba4..03326df4efd 100644 --- a/backend/danswer/connectors/salesforce/connector.py +++ b/backend/danswer/connectors/salesforce/connector.py @@ -79,8 +79,9 @@ def _convert_object_instance_to_document( if self.sf_client is None: raise ConnectorMissingCredentialError("Salesforce") - extracted_id = f"{ID_PREFIX}{object_dict['Id']}" - extracted_link = f"https://{self.sf_client.sf_instance}/{extracted_id}" + salesforce_id = object_dict["Id"] + danswer_salesforce_id = f"{ID_PREFIX}{salesforce_id}" + extracted_link = f"https://{self.sf_client.sf_instance}/{salesforce_id}" extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"]) extracted_object_text = extract_dict_text(object_dict) extracted_semantic_identifier = object_dict.get("Name", "Unknown Object") @@ -91,7 +92,7 @@ def _convert_object_instance_to_document( ] doc = Document( - id=extracted_id, + id=danswer_salesforce_id, sections=[Section(link=extracted_link, text=extracted_object_text)], source=DocumentSource.SALESFORCE, semantic_identifier=extracted_semantic_identifier, diff --git a/backend/danswer/danswerbot/slack/blocks.py b/backend/danswer/danswerbot/slack/blocks.py index a881819c50d..aaf26318851 100644 --- a/backend/danswer/danswerbot/slack/blocks.py +++ b/backend/danswer/danswerbot/slack/blocks.py @@ -25,6 +25,7 @@ from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID +from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID from danswer.danswerbot.slack.icons import source_to_github_img_link @@ -353,6 +354,22 @@ def build_quotes_block( return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))] +def build_standard_answer_blocks( + answer_message: str, +) -> list[Block]: + generate_button_block = ButtonElement( + action_id=GENERATE_ANSWER_BUTTON_ACTION_ID, + text="Generate Full Answer", + ) + answer_block = SectionBlock(text=answer_message) + return [ + answer_block, + ActionsBlock( + elements=[generate_button_block], + ), + ] + + def build_qa_response_blocks( message_id: int | None, answer: str | None, diff --git a/backend/danswer/danswerbot/slack/constants.py b/backend/danswer/danswerbot/slack/constants.py index 1e524025fc7..5eaf63c8984 100644 --- a/backend/danswer/danswerbot/slack/constants.py +++ b/backend/danswer/danswerbot/slack/constants.py @@ -8,6 +8,7 @@ FOLLOWUP_BUTTON_RESOLVED_ACTION_ID = "followup-resolved-button" SLACK_CHANNEL_ID = "channel_id" VIEW_DOC_FEEDBACK_ID = "view-doc-feedback" +GENERATE_ANSWER_BUTTON_ACTION_ID = "generate-answer-button" class FeedbackVisibility(str, Enum): diff --git a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py index 3a0209b076f..30c7015b963 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py @@ -1,3 +1,4 @@ +import logging from typing import Any from typing import cast @@ -8,6 +9,7 @@ from slack_sdk.socket_mode.request import SocketModeRequest from sqlalchemy.orm import Session +from danswer.configs.constants import MessageType from danswer.configs.constants import SearchFeedbackType from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI from danswer.connectors.slack.utils import make_slack_api_rate_limited @@ -21,12 +23,17 @@ from danswer.danswerbot.slack.handlers.handle_message import ( remove_scheduled_feedback_reminder, ) +from danswer.danswerbot.slack.handlers.handle_regular_answer import ( + handle_regular_answer, +) +from danswer.danswerbot.slack.models import SlackMessageInfo from danswer.danswerbot.slack.utils import build_feedback_id from danswer.danswerbot.slack.utils import decompose_action_id from danswer.danswerbot.slack.utils import fetch_groupids_from_names from danswer.danswerbot.slack.utils import fetch_userids_from_emails from danswer.danswerbot.slack.utils import get_channel_name_from_id from danswer.danswerbot.slack.utils import get_feedback_visibility +from danswer.danswerbot.slack.utils import read_slack_thread from danswer.danswerbot.slack.utils import respond_in_thread from danswer.danswerbot.slack.utils import update_emote_react from danswer.db.engine import get_sqlalchemy_engine @@ -72,6 +79,66 @@ def handle_doc_feedback_button( ) +def handle_generate_answer_button( + req: SocketModeRequest, + client: SocketModeClient, +) -> None: + channel_id = req.payload["channel"]["id"] + channel_name = req.payload["channel"]["name"] + message_ts = req.payload["message"]["ts"] + thread_ts = req.payload["container"]["thread_ts"] + user_id = req.payload["user"]["id"] + + if not thread_ts: + raise ValueError("Missing thread_ts in the payload") + + thread_messages = read_slack_thread( + channel=channel_id, thread=thread_ts, client=client.web_client + ) + # remove all assistant messages till we get to the last user message + # we want the new answer to be generated off of the last "question" in + # the thread + for i in range(len(thread_messages) - 1, -1, -1): + if thread_messages[i].role == MessageType.USER: + break + if thread_messages[i].role == MessageType.ASSISTANT: + thread_messages.pop(i) + + # tell the user that we're working on it + # Send an ephemeral message to the user that we're generating the answer + respond_in_thread( + client=client.web_client, + channel=channel_id, + receiver_ids=[user_id], + text="I'm working on generating a full answer for you. This may take a moment...", + thread_ts=thread_ts, + ) + + with Session(get_sqlalchemy_engine()) as db_session: + slack_bot_config = get_slack_bot_config_for_channel( + channel_name=channel_name, db_session=db_session + ) + + handle_regular_answer( + message_info=SlackMessageInfo( + thread_messages=thread_messages, + channel_to_respond=channel_id, + msg_to_respond=cast(str, message_ts or thread_ts), + thread_to_respond=cast(str, thread_ts or message_ts), + sender=user_id or None, + bypass_filters=True, + is_bot_msg=False, + is_bot_dm=False, + ), + slack_bot_config=slack_bot_config, + receiver_ids=None, + client=client.web_client, + channel=channel_id, + logger=cast(logging.Logger, logger_base), + feedback_reminder_id=None, + ) + + def handle_slack_feedback( feedback_id: str, feedback_type: str, diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index c1335aa77bd..a05006dec1e 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -1,91 +1,34 @@ import datetime -import functools import logging -from collections.abc import Callable -from typing import Any from typing import cast -from typing import Optional -from typing import TypeVar -from retry import retry from slack_sdk import WebClient from slack_sdk.errors import SlackApiError -from slack_sdk.models.blocks import DividerBlock -from slack_sdk.models.blocks import SectionBlock from sqlalchemy.orm import Session -from danswer.configs.app_configs import DISABLE_GENERATIVE_AI -from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT -from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT -from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER -from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS from danswer.configs.danswerbot_configs import DANSWER_BOT_FEEDBACK_REMINDER -from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES -from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE -from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES -from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI -from danswer.configs.danswerbot_configs import DISABLE_DANSWER_BOT_FILTER_DETECT -from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION -from danswer.danswerbot.slack.blocks import build_documents_blocks -from danswer.danswerbot.slack.blocks import build_follow_up_block -from danswer.danswerbot.slack.blocks import build_qa_response_blocks -from danswer.danswerbot.slack.blocks import build_sources_blocks from danswer.danswerbot.slack.blocks import get_feedback_reminder_blocks -from danswer.danswerbot.slack.blocks import get_restate_blocks from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID +from danswer.danswerbot.slack.handlers.handle_regular_answer import ( + handle_regular_answer, +) +from danswer.danswerbot.slack.handlers.handle_standard_answers import ( + handle_standard_answers, +) from danswer.danswerbot.slack.models import SlackMessageInfo from danswer.danswerbot.slack.utils import ChannelIdAdapter from danswer.danswerbot.slack.utils import fetch_userids_from_emails +from danswer.danswerbot.slack.utils import fetch_userids_from_groups from danswer.danswerbot.slack.utils import respond_in_thread from danswer.danswerbot.slack.utils import slack_usage_report -from danswer.danswerbot.slack.utils import SlackRateLimiter from danswer.danswerbot.slack.utils import update_emote_react from danswer.db.engine import get_sqlalchemy_engine -from danswer.db.models import Persona from danswer.db.models import SlackBotConfig -from danswer.db.models import SlackBotResponseType -from danswer.db.persona import fetch_persona_by_id -from danswer.llm.answering.prompts.citations_prompt import ( - compute_max_document_tokens_for_persona, -) -from danswer.llm.factory import get_llm_for_persona -from danswer.llm.utils import check_number_of_tokens -from danswer.llm.utils import get_max_input_tokens -from danswer.one_shot_answer.answer_question import get_search_answer -from danswer.one_shot_answer.models import DirectQARequest -from danswer.one_shot_answer.models import OneShotQAResponse -from danswer.search.models import BaseFilters -from danswer.search.models import OptionalSearchSetting -from danswer.search.models import RetrievalDetails from danswer.utils.logger import setup_logger -from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW logger_base = setup_logger() -srl = SlackRateLimiter() - -RT = TypeVar("RT") # return type - - -def rate_limits( - client: WebClient, channel: str, thread_ts: Optional[str] -) -> Callable[[Callable[..., RT]], Callable[..., RT]]: - def decorator(func: Callable[..., RT]) -> Callable[..., RT]: - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> RT: - if not srl.is_available(): - func_randid, position = srl.init_waiter() - srl.notify(client, channel, position, thread_ts) - while not srl.is_available(): - srl.waiter(func_randid) - srl.acquire_slot() - return func(*args, **kwargs) - - return wrapper - - return decorator - def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None: if details.is_bot_msg and details.sender: @@ -173,17 +116,9 @@ def remove_scheduled_feedback_reminder( def handle_message( message_info: SlackMessageInfo, - channel_config: SlackBotConfig | None, + slack_bot_config: SlackBotConfig | None, client: WebClient, feedback_reminder_id: str | None, - num_retries: int = DANSWER_BOT_NUM_RETRIES, - answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT, - should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS, - disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER, - disable_auto_detect_filters: bool = DISABLE_DANSWER_BOT_FILTER_DETECT, - reflexion: bool = ENABLE_DANSWERBOT_REFLEXION, - disable_cot: bool = DANSWER_BOT_DISABLE_COT, - thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE, ) -> bool: """Potentially respond to the user message depending on filters and if an answer was generated @@ -200,14 +135,22 @@ def handle_message( ) messages = message_info.thread_messages - message_ts_to_respond_to = message_info.msg_to_respond sender_id = message_info.sender bypass_filters = message_info.bypass_filters is_bot_msg = message_info.is_bot_msg is_bot_dm = message_info.is_bot_dm + action = "slack_message" + if is_bot_msg: + action = "slack_slash_message" + elif bypass_filters: + action = "slack_tag_message" + elif is_bot_dm: + action = "slack_dm_message" + slack_usage_report(action=action, sender_id=sender_id, client=client) + document_set_names: list[str] | None = None - persona = channel_config.persona if channel_config else None + persona = slack_bot_config.persona if slack_bot_config else None prompt = None if persona: document_set_names = [ @@ -215,36 +158,16 @@ def handle_message( ] prompt = persona.prompts[0] if persona.prompts else None - should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False - - # figure out if we want to use citations or quotes - use_citations = ( - not DANSWER_BOT_USE_QUOTES - if channel_config is None - else channel_config.response_type == SlackBotResponseType.CITATIONS - ) - # List of user id to send message to, if None, send to everyone in channel send_to: list[str] | None = None respond_tag_only = False respond_team_member_list = None - - bypass_acl = False - if ( - channel_config - and channel_config.persona - and channel_config.persona.document_sets - ): - # For Slack channels, use the full document set, admin will be warned when configuring it - # with non-public document sets - bypass_acl = True + respond_slack_group_list = None channel_conf = None - if channel_config and channel_config.channel_config: - channel_conf = channel_config.channel_config + if slack_bot_config and slack_bot_config.channel_config: + channel_conf = slack_bot_config.channel_config if not bypass_filters and "answer_filters" in channel_conf: - reflexion = "well_answered_postfilter" in channel_conf["answer_filters"] - if ( "questionmark_prefilter" in channel_conf["answer_filters"] and "?" not in messages[-1].message @@ -262,6 +185,7 @@ def handle_message( respond_tag_only = channel_conf.get("respond_tag_only") or False respond_team_member_list = channel_conf.get("respond_team_member_list") or None + respond_slack_group_list = channel_conf.get("respond_slack_group_list") or None if respond_tag_only and not bypass_filters: logger.info( @@ -272,10 +196,15 @@ def handle_message( if respond_team_member_list: send_to, _ = fetch_userids_from_emails(respond_team_member_list, client) + if respond_slack_group_list: + user_ids, _ = fetch_userids_from_groups(respond_slack_group_list, client) + send_to = (send_to + user_ids) if send_to else user_ids + if send_to: + send_to = list(set(send_to)) # remove duplicates # If configured to respond to team members only, then cannot be used with a /DanswerBot command # which would just respond to the sender - if respond_team_member_list and is_bot_msg: + if (respond_team_member_list or respond_slack_group_list) and is_bot_msg: if sender_id: respond_in_thread( client=client, @@ -290,324 +219,28 @@ def handle_message( except SlackApiError as e: logger.error(f"Was not able to react to user message due to: {e}") - @retry( - tries=num_retries, - delay=0.25, - backoff=2, - logger=logger, - ) - @rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to) - def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None: - action = "slack_message" - if is_bot_msg: - action = "slack_slash_message" - elif bypass_filters: - action = "slack_tag_message" - elif is_bot_dm: - action = "slack_dm_message" - - slack_usage_report(action=action, sender_id=sender_id, client=client) - - max_document_tokens: int | None = None - max_history_tokens: int | None = None - - with Session(get_sqlalchemy_engine()) as db_session: - if len(new_message_request.messages) > 1: - persona = cast( - Persona, - fetch_persona_by_id(db_session, new_message_request.persona_id), - ) - llm = get_llm_for_persona(persona) - - # In cases of threads, split the available tokens between docs and thread context - input_tokens = get_max_input_tokens( - model_name=llm.config.model_name, - model_provider=llm.config.model_provider, - ) - max_history_tokens = int(input_tokens * thread_context_percent) - - remaining_tokens = input_tokens - max_history_tokens - - query_text = new_message_request.messages[0].message - if persona: - max_document_tokens = compute_max_document_tokens_for_persona( - persona=persona, - actual_user_input=query_text, - max_llm_token_override=remaining_tokens, - ) - else: - max_document_tokens = ( - remaining_tokens - - 512 # Needs to be more than any of the QA prompts - - check_number_of_tokens(query_text) - ) - - if DISABLE_GENERATIVE_AI: - return None - - # This also handles creating the query event in postgres - answer = get_search_answer( - query_req=new_message_request, - user=None, - max_document_tokens=max_document_tokens, - max_history_tokens=max_history_tokens, - db_session=db_session, - answer_generation_timeout=answer_generation_timeout, - enable_reflexion=reflexion, - bypass_acl=bypass_acl, - use_citations=use_citations, - danswerbot_flow=True, - ) - if not answer.error_msg: - return answer - else: - raise RuntimeError(answer.error_msg) - - try: - # By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters - # it allows the slack flow to extract out filters from the user query - filters = BaseFilters( - source_type=None, - document_set=document_set_names, - time_cutoff=None, - ) - - # Default True because no other ways to apply filters in Slack (no nice UI) - auto_detect_filters = ( - persona.llm_filter_extraction if persona is not None else True - ) - if disable_auto_detect_filters: - auto_detect_filters = False - - retrieval_details = RetrievalDetails( - run_search=OptionalSearchSetting.ALWAYS, - real_time=False, - filters=filters, - enable_auto_detect_filters=auto_detect_filters, - ) - - # This includes throwing out answer via reflexion - answer = _get_answer( - DirectQARequest( - messages=messages, - prompt_id=prompt.id if prompt else None, - persona_id=persona.id if persona is not None else 0, - retrieval_options=retrieval_details, - chain_of_thought=not disable_cot, - skip_rerank=not ENABLE_RERANKING_ASYNC_FLOW, - ) - ) - except Exception as e: - logger.exception( - f"Unable to process message - did not successfully answer " - f"in {num_retries} attempts" - ) - # Optionally, respond in thread with the error message, Used primarily - # for debugging purposes - if should_respond_with_error_msgs: - respond_in_thread( - client=client, - channel=channel, - receiver_ids=None, - text=f"Encountered exception when trying to answer: \n\n```{e}```", - thread_ts=message_ts_to_respond_to, - ) - - # In case of failures, don't keep the reaction there permanently - try: - update_emote_react( - emoji=DANSWER_REACT_EMOJI, - channel=message_info.channel_to_respond, - message_ts=message_info.msg_to_respond, - remove=True, - client=client, - ) - except SlackApiError as e: - logger.error(f"Failed to remove Reaction due to: {e}") - - return True - - # Edge case handling, for tracking down the Slack usage issue - if answer is None: - assert DISABLE_GENERATIVE_AI is True - try: - respond_in_thread( - client=client, - channel=channel, - receiver_ids=send_to, - text="Hello! EVE AI has some results for you!", - blocks=[ - SectionBlock( - text="EVE AI is down for maintenance.\nWe're working hard on recharging the AI!" - ) - ], - thread_ts=message_ts_to_respond_to, - # don't unfurl, since otherwise we will have 5+ previews which makes the message very long - unfurl=False, - ) - - # For DM (ephemeral message), we need to create a thread via a normal message so the user can see - # the ephemeral message. This also will give the user a notification which ephemeral message does not. - if respond_team_member_list: - respond_in_thread( - client=client, - channel=channel, - text=( - "👋 Hi, we've just gathered and forwarded the relevant " - + "information to the team. They'll get back to you shortly!" - ), - thread_ts=message_ts_to_respond_to, - ) - - return False - - except Exception: - logger.exception( - f"Unable to process message - could not respond in slack in {num_retries} attempts" - ) - return True - - # Got an answer at this point, can remove reaction and give results - try: - update_emote_react( - emoji=DANSWER_REACT_EMOJI, - channel=message_info.channel_to_respond, - message_ts=message_info.msg_to_respond, - remove=True, - client=client, - ) - except SlackApiError as e: - logger.error(f"Failed to remove Reaction due to: {e}") - - if answer.answer_valid is False: - logger.info( - "Answer was evaluated to be invalid, throwing it away without responding." - ) - update_emote_react( - emoji=DANSWER_FOLLOWUP_EMOJI, - channel=message_info.channel_to_respond, - message_ts=message_info.msg_to_respond, - remove=False, + with Session(get_sqlalchemy_engine()) as db_session: + # first check if we need to respond with a standard answer + used_standard_answer = handle_standard_answers( + message_info=message_info, + receiver_ids=send_to, + slack_bot_config=slack_bot_config, + prompt=prompt, + logger=logger, client=client, + db_session=db_session, ) + if used_standard_answer: + return False - if answer.answer: - logger.debug(answer.answer) - return True - - retrieval_info = answer.docs - if not retrieval_info: - # This should not happen, even with no docs retrieved, there is still info returned - raise RuntimeError("Failed to retrieve docs, cannot answer question.") - - top_docs = retrieval_info.top_documents - if not top_docs and not should_respond_even_with_no_docs: - logger.error( - f"Unable to answer question: '{answer.rephrase}' - no documents found" - ) - # Optionally, respond in thread with the error message - # Used primarily for debugging purposes - if should_respond_with_error_msgs: - respond_in_thread( - client=client, - channel=channel, - receiver_ids=None, - text="Found no documents when trying to answer. Did you index any documents?", - thread_ts=message_ts_to_respond_to, - ) - return True - - if not answer.answer and disable_docs_only_answer: - logger.info( - "Unable to find answer - not responding since the " - "`DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER` env variable is set" - ) - return True - - # If called with the DanswerBot slash command, the question is lost so we have to reshow it - restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg) - - answer_blocks = build_qa_response_blocks( - message_id=answer.chat_message_id, - answer=answer.answer, - quotes=answer.quotes.quotes if answer.quotes else None, - source_filters=retrieval_info.applied_source_filters, - time_cutoff=retrieval_info.applied_time_cutoff, - favor_recent=retrieval_info.recency_bias_multiplier > 1, - # currently Personas don't support quotes - # if citations are enabled, also don't use quotes - skip_quotes=persona is not None or use_citations, - process_message_for_citations=use_citations, - feedback_reminder_id=feedback_reminder_id, - ) - - # Get the chunks fed to the LLM only, then fill with other docs - llm_doc_inds = answer.llm_chunks_indices or [] - llm_docs = [top_docs[i] for i in llm_doc_inds] - remaining_docs = [ - doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds - ] - priority_ordered_docs = llm_docs + remaining_docs - - document_blocks = [] - citations_block = [] - # if citations are enabled, only show cited documents - if use_citations: - citations = answer.citations or [] - cited_docs = [] - for citation in citations: - matching_doc = next( - (d for d in top_docs if d.document_id == citation.document_id), - None, - ) - if matching_doc: - cited_docs.append((citation.citation_num, matching_doc)) - - cited_docs.sort() - citations_block = build_sources_blocks(cited_documents=cited_docs) - elif priority_ordered_docs: - document_blocks = build_documents_blocks( - documents=priority_ordered_docs, - message_id=answer.chat_message_id, - ) - document_blocks = [DividerBlock()] + document_blocks - - all_blocks = ( - restate_question_block + answer_blocks + citations_block + document_blocks - ) - - if channel_conf and channel_conf.get("follow_up_tags") is not None: - all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id)) - - try: - respond_in_thread( + # if no standard answer applies, try a regular answer + issue_with_regular_answer = handle_regular_answer( + message_info=message_info, + slack_bot_config=slack_bot_config, + receiver_ids=send_to, client=client, channel=channel, - receiver_ids=send_to, - text="Hello! Danswer has some results for you!", - blocks=all_blocks, - thread_ts=message_ts_to_respond_to, - # don't unfurl, since otherwise we will have 5+ previews which makes the message very long - unfurl=False, - ) - - # For DM (ephemeral message), we need to create a thread via a normal message so the user can see - # the ephemeral message. This also will give the user a notification which ephemeral message does not. - if respond_team_member_list: - respond_in_thread( - client=client, - channel=channel, - text=( - "👋 Hi, we've just gathered and forwarded the relevant " - + "information to the team. They'll get back to you shortly!" - ), - thread_ts=message_ts_to_respond_to, - ) - - return False - - except Exception: - logger.exception( - f"Unable to process message - could not respond in slack in {num_retries} attempts" + logger=logger, + feedback_reminder_id=feedback_reminder_id, ) - return True + return issue_with_regular_answer diff --git a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py new file mode 100644 index 00000000000..212c65b13df --- /dev/null +++ b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py @@ -0,0 +1,465 @@ +import functools +import logging +from collections.abc import Callable +from typing import Any +from typing import cast +from typing import Optional +from typing import TypeVar + +from retry import retry +from slack_sdk import WebClient +from slack_sdk.models.blocks import DividerBlock +from slack_sdk.models.blocks import SectionBlock +from sqlalchemy.orm import Session + +from danswer.configs.app_configs import DISABLE_GENERATIVE_AI +from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT +from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT +from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER +from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS +from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES +from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE +from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES +from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI +from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI +from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION +from danswer.danswerbot.slack.blocks import build_documents_blocks +from danswer.danswerbot.slack.blocks import build_follow_up_block +from danswer.danswerbot.slack.blocks import build_qa_response_blocks +from danswer.danswerbot.slack.blocks import build_sources_blocks +from danswer.danswerbot.slack.blocks import get_restate_blocks +from danswer.danswerbot.slack.handlers.utils import send_team_member_message +from danswer.danswerbot.slack.models import SlackMessageInfo +from danswer.danswerbot.slack.utils import respond_in_thread +from danswer.danswerbot.slack.utils import SlackRateLimiter +from danswer.danswerbot.slack.utils import update_emote_react +from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.models import Persona +from danswer.db.models import SlackBotConfig +from danswer.db.models import SlackBotResponseType +from danswer.db.persona import fetch_persona_by_id +from danswer.llm.answering.prompts.citations_prompt import ( + compute_max_document_tokens_for_persona, +) +from danswer.llm.factory import get_llms_for_persona +from danswer.llm.utils import check_number_of_tokens +from danswer.llm.utils import get_max_input_tokens +from danswer.one_shot_answer.answer_question import get_search_answer +from danswer.one_shot_answer.models import DirectQARequest +from danswer.one_shot_answer.models import OneShotQAResponse +from danswer.search.enums import OptionalSearchSetting +from danswer.search.models import BaseFilters +from danswer.search.models import RetrievalDetails +from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW + + +srl = SlackRateLimiter() + +RT = TypeVar("RT") # return type + + +def rate_limits( + client: WebClient, channel: str, thread_ts: Optional[str] +) -> Callable[[Callable[..., RT]], Callable[..., RT]]: + def decorator(func: Callable[..., RT]) -> Callable[..., RT]: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> RT: + if not srl.is_available(): + func_randid, position = srl.init_waiter() + srl.notify(client, channel, position, thread_ts) + while not srl.is_available(): + srl.waiter(func_randid) + srl.acquire_slot() + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def handle_regular_answer( + message_info: SlackMessageInfo, + slack_bot_config: SlackBotConfig | None, + receiver_ids: list[str] | None, + client: WebClient, + channel: str, + logger: logging.Logger, + feedback_reminder_id: str | None, + num_retries: int = DANSWER_BOT_NUM_RETRIES, + answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT, + thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE, + should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS, + disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER, + disable_cot: bool = DANSWER_BOT_DISABLE_COT, + reflexion: bool = ENABLE_DANSWERBOT_REFLEXION, +) -> bool: + channel_conf = slack_bot_config.channel_config if slack_bot_config else None + + messages = message_info.thread_messages + message_ts_to_respond_to = message_info.msg_to_respond + is_bot_msg = message_info.is_bot_msg + + document_set_names: list[str] | None = None + persona = slack_bot_config.persona if slack_bot_config else None + prompt = None + if persona: + document_set_names = [ + document_set.name for document_set in persona.document_sets + ] + prompt = persona.prompts[0] if persona.prompts else None + + should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False + + bypass_acl = False + if ( + slack_bot_config + and slack_bot_config.persona + and slack_bot_config.persona.document_sets + ): + # For Slack channels, use the full document set, admin will be warned when configuring it + # with non-public document sets + bypass_acl = True + + # figure out if we want to use citations or quotes + use_citations = ( + not DANSWER_BOT_USE_QUOTES + if slack_bot_config is None + else slack_bot_config.response_type == SlackBotResponseType.CITATIONS + ) + + if not message_ts_to_respond_to: + raise RuntimeError( + "No message timestamp to respond to in `handle_message`. This should never happen." + ) + + @retry( + tries=num_retries, + delay=0.25, + backoff=2, + logger=logger, + ) + @rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to) + def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None: + max_document_tokens: int | None = None + max_history_tokens: int | None = None + + with Session(get_sqlalchemy_engine()) as db_session: + if len(new_message_request.messages) > 1: + persona = cast( + Persona, + fetch_persona_by_id(db_session, new_message_request.persona_id), + ) + llm, _ = get_llms_for_persona(persona) + + # In cases of threads, split the available tokens between docs and thread context + input_tokens = get_max_input_tokens( + model_name=llm.config.model_name, + model_provider=llm.config.model_provider, + ) + max_history_tokens = int(input_tokens * thread_context_percent) + + remaining_tokens = input_tokens - max_history_tokens + + query_text = new_message_request.messages[0].message + if persona: + max_document_tokens = compute_max_document_tokens_for_persona( + persona=persona, + actual_user_input=query_text, + max_llm_token_override=remaining_tokens, + ) + else: + max_document_tokens = ( + remaining_tokens + - 512 # Needs to be more than any of the QA prompts + - check_number_of_tokens(query_text) + ) + + if DISABLE_GENERATIVE_AI: + return None + + # This also handles creating the query event in postgres + answer = get_search_answer( + query_req=new_message_request, + user=None, + max_document_tokens=max_document_tokens, + max_history_tokens=max_history_tokens, + db_session=db_session, + answer_generation_timeout=answer_generation_timeout, + enable_reflexion=reflexion, + bypass_acl=bypass_acl, + use_citations=use_citations, + danswerbot_flow=True, + ) + if not answer.error_msg: + return answer + else: + raise RuntimeError(answer.error_msg) + + try: + # By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters + # it allows the slack flow to extract out filters from the user query + filters = BaseFilters( + source_type=None, + document_set=document_set_names, + time_cutoff=None, + ) + + # Default True because no other ways to apply filters in Slack (no nice UI) + # Commenting this out because this is only available to the slackbot for now + # later we plan to implement this at the persona level where this will get + # commented back in + # auto_detect_filters = ( + # persona.llm_filter_extraction if persona is not None else True + # ) + auto_detect_filters = ( + slack_bot_config.enable_auto_filters + if slack_bot_config is not None + else False + ) + retrieval_details = RetrievalDetails( + run_search=OptionalSearchSetting.ALWAYS, + real_time=False, + filters=filters, + enable_auto_detect_filters=auto_detect_filters, + ) + + # This includes throwing out answer via reflexion + answer = _get_answer( + DirectQARequest( + messages=messages, + prompt_id=prompt.id if prompt else None, + persona_id=persona.id if persona is not None else 0, + retrieval_options=retrieval_details, + chain_of_thought=not disable_cot, + skip_rerank=not ENABLE_RERANKING_ASYNC_FLOW, + ) + ) + except Exception as e: + logger.exception( + f"Unable to process message - did not successfully answer " + f"in {num_retries} attempts" + ) + # Optionally, respond in thread with the error message, Used primarily + # for debugging purposes + if should_respond_with_error_msgs: + respond_in_thread( + client=client, + channel=channel, + receiver_ids=None, + text=f"Encountered exception when trying to answer: \n\n```{e}```", + thread_ts=message_ts_to_respond_to, + ) + + # In case of failures, don't keep the reaction there permanently + update_emote_react( + emoji=DANSWER_REACT_EMOJI, + channel=message_info.channel_to_respond, + message_ts=message_info.msg_to_respond, + remove=True, + client=client, + ) + + return True + + # Edge case handling, for tracking down the Slack usage issue + if answer is None: + assert DISABLE_GENERATIVE_AI is True + try: + respond_in_thread( + client=client, + channel=channel, + receiver_ids=receiver_ids, + text="Hello! Danswer has some results for you!", + blocks=[ + SectionBlock( + text="Danswer is down for maintenance.\nWe're working hard on recharging the AI!" + ) + ], + thread_ts=message_ts_to_respond_to, + # don't unfurl, since otherwise we will have 5+ previews which makes the message very long + unfurl=False, + ) + + # For DM (ephemeral message), we need to create a thread via a normal message so the user can see + # the ephemeral message. This also will give the user a notification which ephemeral message does not. + if receiver_ids: + respond_in_thread( + client=client, + channel=channel, + text=( + "👋 Hi, we've just gathered and forwarded the relevant " + + "information to the team. They'll get back to you shortly!" + ), + thread_ts=message_ts_to_respond_to, + ) + + return False + + except Exception: + logger.exception( + f"Unable to process message - could not respond in slack in {num_retries} attempts" + ) + return True + + # Got an answer at this point, can remove reaction and give results + update_emote_react( + emoji=DANSWER_REACT_EMOJI, + channel=message_info.channel_to_respond, + message_ts=message_info.msg_to_respond, + remove=True, + client=client, + ) + + if answer.answer_valid is False: + logger.info( + "Answer was evaluated to be invalid, throwing it away without responding." + ) + update_emote_react( + emoji=DANSWER_FOLLOWUP_EMOJI, + channel=message_info.channel_to_respond, + message_ts=message_info.msg_to_respond, + remove=False, + client=client, + ) + + if answer.answer: + logger.debug(answer.answer) + return True + + retrieval_info = answer.docs + if not retrieval_info: + # This should not happen, even with no docs retrieved, there is still info returned + raise RuntimeError("Failed to retrieve docs, cannot answer question.") + + top_docs = retrieval_info.top_documents + if not top_docs and not should_respond_even_with_no_docs: + logger.error( + f"Unable to answer question: '{answer.rephrase}' - no documents found" + ) + # Optionally, respond in thread with the error message + # Used primarily for debugging purposes + if should_respond_with_error_msgs: + respond_in_thread( + client=client, + channel=channel, + receiver_ids=None, + text="Found no documents when trying to answer. Did you index any documents?", + thread_ts=message_ts_to_respond_to, + ) + return True + + if not answer.answer and disable_docs_only_answer: + logger.info( + "Unable to find answer - not responding since the " + "`DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER` env variable is set" + ) + return True + + only_respond_with_citations_or_quotes = ( + channel_conf + and "well_answered_postfilter" in channel_conf.get("answer_filters", []) + ) + has_citations_or_quotes = bool(answer.citations or answer.quotes) + if ( + only_respond_with_citations_or_quotes + and not has_citations_or_quotes + and not message_info.bypass_filters + ): + logger.error( + f"Unable to find citations or quotes to answer: '{answer.rephrase}' - not answering!" + ) + # Optionally, respond in thread with the error message + # Used primarily for debugging purposes + if should_respond_with_error_msgs: + respond_in_thread( + client=client, + channel=channel, + receiver_ids=None, + text="Found no citations or quotes when trying to answer.", + thread_ts=message_ts_to_respond_to, + ) + return True + + # If called with the DanswerBot slash command, the question is lost so we have to reshow it + restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg) + + answer_blocks = build_qa_response_blocks( + message_id=answer.chat_message_id, + answer=answer.answer, + quotes=answer.quotes.quotes if answer.quotes else None, + source_filters=retrieval_info.applied_source_filters, + time_cutoff=retrieval_info.applied_time_cutoff, + favor_recent=retrieval_info.recency_bias_multiplier > 1, + # currently Personas don't support quotes + # if citations are enabled, also don't use quotes + skip_quotes=persona is not None or use_citations, + process_message_for_citations=use_citations, + feedback_reminder_id=feedback_reminder_id, + ) + + # Get the chunks fed to the LLM only, then fill with other docs + llm_doc_inds = answer.llm_chunks_indices or [] + llm_docs = [top_docs[i] for i in llm_doc_inds] + remaining_docs = [ + doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds + ] + priority_ordered_docs = llm_docs + remaining_docs + + document_blocks = [] + citations_block = [] + # if citations are enabled, only show cited documents + if use_citations: + citations = answer.citations or [] + cited_docs = [] + for citation in citations: + matching_doc = next( + (d for d in top_docs if d.document_id == citation.document_id), + None, + ) + if matching_doc: + cited_docs.append((citation.citation_num, matching_doc)) + + cited_docs.sort() + citations_block = build_sources_blocks(cited_documents=cited_docs) + elif priority_ordered_docs: + document_blocks = build_documents_blocks( + documents=priority_ordered_docs, + message_id=answer.chat_message_id, + ) + document_blocks = [DividerBlock()] + document_blocks + + all_blocks = ( + restate_question_block + answer_blocks + citations_block + document_blocks + ) + + if channel_conf and channel_conf.get("follow_up_tags") is not None: + all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id)) + + try: + respond_in_thread( + client=client, + channel=channel, + receiver_ids=receiver_ids, + text="Hello! Danswer has some results for you!", + blocks=all_blocks, + thread_ts=message_ts_to_respond_to, + # don't unfurl, since otherwise we will have 5+ previews which makes the message very long + unfurl=False, + ) + + # For DM (ephemeral message), we need to create a thread via a normal message so the user can see + # the ephemeral message. This also will give the user a notification which ephemeral message does not. + if receiver_ids: + send_team_member_message( + client=client, + channel=channel, + thread_ts=message_ts_to_respond_to, + ) + + return False + + except Exception: + logger.exception( + f"Unable to process message - could not respond in slack in {num_retries} attempts" + ) + return True diff --git a/backend/danswer/danswerbot/slack/handlers/handle_standard_answers.py b/backend/danswer/danswerbot/slack/handlers/handle_standard_answers.py new file mode 100644 index 00000000000..b853b298b70 --- /dev/null +++ b/backend/danswer/danswerbot/slack/handlers/handle_standard_answers.py @@ -0,0 +1,181 @@ +import logging + +from slack_sdk import WebClient +from sqlalchemy.orm import Session + +from danswer.configs.constants import MessageType +from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI +from danswer.danswerbot.slack.blocks import build_standard_answer_blocks +from danswer.danswerbot.slack.blocks import get_restate_blocks +from danswer.danswerbot.slack.handlers.utils import send_team_member_message +from danswer.danswerbot.slack.models import SlackMessageInfo +from danswer.danswerbot.slack.utils import respond_in_thread +from danswer.danswerbot.slack.utils import update_emote_react +from danswer.db.chat import create_chat_session +from danswer.db.chat import create_new_chat_message +from danswer.db.chat import get_chat_messages_by_sessions +from danswer.db.chat import get_chat_sessions_by_slack_thread_id +from danswer.db.chat import get_or_create_root_message +from danswer.db.models import Prompt +from danswer.db.models import SlackBotConfig +from danswer.db.standard_answer import find_matching_standard_answers + + +def handle_standard_answers( + message_info: SlackMessageInfo, + receiver_ids: list[str] | None, + slack_bot_config: SlackBotConfig | None, + prompt: Prompt | None, + logger: logging.Logger, + client: WebClient, + db_session: Session, +) -> bool: + """ + Potentially respond to the user message depending on whether the user's message matches + any of the configured standard answers and also whether those answers have already been + provided in the current thread. + + Returns True if standard answers are found to match the user's message and therefore, + we still need to respond to the users. + """ + # if no channel config, then no standard answers are configured + if not slack_bot_config: + return False + + slack_thread_id = message_info.thread_to_respond + configured_standard_answer_categories = ( + slack_bot_config.standard_answer_categories if slack_bot_config else [] + ) + configured_standard_answers = set( + [ + standard_answer + for standard_answer_category in configured_standard_answer_categories + for standard_answer in standard_answer_category.standard_answers + ] + ) + query_msg = message_info.thread_messages[-1] + + if slack_thread_id is None: + used_standard_answer_ids = set([]) + else: + chat_sessions = get_chat_sessions_by_slack_thread_id( + slack_thread_id=slack_thread_id, + user_id=None, + db_session=db_session, + ) + chat_messages = get_chat_messages_by_sessions( + chat_session_ids=[chat_session.id for chat_session in chat_sessions], + user_id=None, + db_session=db_session, + skip_permission_check=True, + ) + used_standard_answer_ids = set( + [ + standard_answer.id + for chat_message in chat_messages + for standard_answer in chat_message.standard_answers + ] + ) + + usable_standard_answers = configured_standard_answers.difference( + used_standard_answer_ids + ) + if usable_standard_answers: + matching_standard_answers = find_matching_standard_answers( + query=query_msg.message, + id_in=[standard_answer.id for standard_answer in usable_standard_answers], + db_session=db_session, + ) + else: + matching_standard_answers = [] + if matching_standard_answers: + chat_session = create_chat_session( + db_session=db_session, + description="", + user_id=None, + persona_id=slack_bot_config.persona.id if slack_bot_config.persona else 0, + danswerbot_flow=True, + slack_thread_id=slack_thread_id, + one_shot=True, + ) + + root_message = get_or_create_root_message( + chat_session_id=chat_session.id, db_session=db_session + ) + + new_user_message = create_new_chat_message( + chat_session_id=chat_session.id, + parent_message=root_message, + prompt_id=prompt.id if prompt else None, + message=query_msg.message, + token_count=0, + message_type=MessageType.USER, + db_session=db_session, + commit=True, + ) + + formatted_answers = [] + for standard_answer in matching_standard_answers: + block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ") + formatted_answer = ( + f'Since you mentioned _"{standard_answer.keyword}"_, ' + f"I thought this might be useful: \n\n{block_quotified_answer}" + ) + formatted_answers.append(formatted_answer) + answer_message = "\n\n".join(formatted_answers) + + _ = create_new_chat_message( + chat_session_id=chat_session.id, + parent_message=new_user_message, + prompt_id=prompt.id if prompt else None, + message=answer_message, + token_count=0, + message_type=MessageType.ASSISTANT, + error=None, + db_session=db_session, + commit=True, + ) + + update_emote_react( + emoji=DANSWER_REACT_EMOJI, + channel=message_info.channel_to_respond, + message_ts=message_info.msg_to_respond, + remove=True, + client=client, + ) + + restate_question_blocks = get_restate_blocks( + msg=query_msg.message, + is_bot_msg=message_info.is_bot_msg, + ) + + answer_blocks = build_standard_answer_blocks( + answer_message=answer_message, + ) + + all_blocks = restate_question_blocks + answer_blocks + + try: + respond_in_thread( + client=client, + channel=message_info.channel_to_respond, + receiver_ids=receiver_ids, + text="Hello! Danswer has some results for you!", + blocks=all_blocks, + thread_ts=message_info.msg_to_respond, + unfurl=False, + ) + + if receiver_ids and slack_thread_id: + send_team_member_message( + client=client, + channel=message_info.channel_to_respond, + thread_ts=slack_thread_id, + ) + + return True + except Exception as e: + logger.exception(f"Unable to send standard answer message: {e}") + return False + else: + return False diff --git a/backend/danswer/danswerbot/slack/handlers/utils.py b/backend/danswer/danswerbot/slack/handlers/utils.py new file mode 100644 index 00000000000..296b7b90d41 --- /dev/null +++ b/backend/danswer/danswerbot/slack/handlers/utils.py @@ -0,0 +1,19 @@ +from slack_sdk import WebClient + +from danswer.danswerbot.slack.utils import respond_in_thread + + +def send_team_member_message( + client: WebClient, + channel: str, + thread_ts: str, +) -> None: + respond_in_thread( + client=client, + channel=channel, + text=( + "👋 Hi, we've just gathered and forwarded the relevant " + + "information to the team. They'll get back to you shortly!" + ), + thread_ts=thread_ts, + ) diff --git a/backend/danswer/danswerbot/slack/icons.py b/backend/danswer/danswerbot/slack/icons.py index d2e8ea917ee..895fcfdc897 100644 --- a/backend/danswer/danswerbot/slack/icons.py +++ b/backend/danswer/danswerbot/slack/icons.py @@ -4,9 +4,9 @@ def source_to_github_img_link(source: DocumentSource) -> str | None: # TODO: store these images somewhere better if source == DocumentSource.WEB.value: - return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Web.png" + return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Web.png" if source == DocumentSource.FILE.value: - return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/File.png" + return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png" if source == DocumentSource.GOOGLE_SITES.value: return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/GoogleSites.png" if source == DocumentSource.SLACK.value: @@ -20,13 +20,13 @@ def source_to_github_img_link(source: DocumentSource) -> str | None: if source == DocumentSource.GITLAB.value: return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Gitlab.png" if source == DocumentSource.CONFLUENCE.value: - return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Confluence.png" + return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Confluence.png" if source == DocumentSource.JIRA.value: - return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Jira.png" + return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Jira.png" if source == DocumentSource.NOTION.value: return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Notion.png" if source == DocumentSource.ZENDESK.value: - return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Zendesk.png" + return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Zendesk.png" if source == DocumentSource.GONG.value: return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Gong.png" if source == DocumentSource.LINEAR.value: @@ -38,7 +38,7 @@ def source_to_github_img_link(source: DocumentSource) -> str | None: if source == DocumentSource.ZULIP.value: return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Zulip.png" if source == DocumentSource.GURU.value: - return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Guru.png" + return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Guru.png" if source == DocumentSource.HUBSPOT.value: return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/HubSpot.png" if source == DocumentSource.DOCUMENT360.value: @@ -51,8 +51,8 @@ def source_to_github_img_link(source: DocumentSource) -> str | None: return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Sharepoint.png" if source == DocumentSource.REQUESTTRACKER.value: # just use file icon for now - return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/File.png" + return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png" if source == DocumentSource.INGESTION_API.value: - return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/File.png" + return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png" - return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/File.png" + return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png" diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index 56a70fcd151..4f6df76545b 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -18,6 +18,7 @@ from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID +from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID @@ -27,6 +28,9 @@ from danswer.danswerbot.slack.handlers.handle_buttons import ( handle_followup_resolved_button, ) +from danswer.danswerbot.slack.handlers.handle_buttons import ( + handle_generate_answer_button, +) from danswer.danswerbot.slack.handlers.handle_buttons import handle_slack_feedback from danswer.danswerbot.slack.handlers.handle_message import handle_message from danswer.danswerbot.slack.handlers.handle_message import ( @@ -56,7 +60,6 @@ logger = setup_logger() - # In rare cases, some users have been experiencing a massive amount of trivial messages coming through # to the Slack Bot with trivial messages. Adding this to avoid exploding LLM costs while we track down # the cause. @@ -70,6 +73,9 @@ ":wave:", } +# this is always (currently) the user id of Slack's official slackbot +_OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT" + def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool: """True to keep going, False to ignore this Slack request""" @@ -92,6 +98,15 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool channel_specific_logger.error("Cannot respond to empty message - skipping") return False + if ( + req.payload.setdefault("event", {}).get("user", "") + == _OFFICIAL_SLACKBOT_USER_ID + ): + channel_specific_logger.info( + "Ignoring messages from Slack's official Slackbot" + ) + return False + if ( msg in _SLACK_GREETINGS_TO_IGNORE or remove_danswer_bot_tag(msg, client=client.web_client) @@ -255,6 +270,7 @@ def build_request_details( thread_messages=thread_messages, channel_to_respond=channel, msg_to_respond=cast(str, message_ts or thread_ts), + thread_to_respond=cast(str, thread_ts or message_ts), sender=event.get("user") or None, bypass_filters=tagged, is_bot_msg=False, @@ -272,6 +288,7 @@ def build_request_details( thread_messages=[single_msg], channel_to_respond=channel, msg_to_respond=None, + thread_to_respond=None, sender=sender, bypass_filters=True, is_bot_msg=True, @@ -341,7 +358,7 @@ def process_message( failed = handle_message( message_info=details, - channel_config=slack_bot_config, + slack_bot_config=slack_bot_config, client=client.web_client, feedback_reminder_id=feedback_reminder_id, ) @@ -379,6 +396,8 @@ def action_routing(req: SocketModeRequest, client: SocketModeClient) -> None: return handle_followup_resolved_button(req, client, immediate=True) elif action["action_id"] == FOLLOWUP_BUTTON_RESOLVED_ACTION_ID: return handle_followup_resolved_button(req, client, immediate=False) + elif action["action_id"] == GENERATE_ANSWER_BUTTON_ACTION_ID: + return handle_generate_answer_button(req, client) def view_routing(req: SocketModeRequest, client: SocketModeClient) -> None: diff --git a/backend/danswer/danswerbot/slack/models.py b/backend/danswer/danswerbot/slack/models.py index 57a92a29753..e4521a759a7 100644 --- a/backend/danswer/danswerbot/slack/models.py +++ b/backend/danswer/danswerbot/slack/models.py @@ -7,6 +7,7 @@ class SlackMessageInfo(BaseModel): thread_messages: list[ThreadMessage] channel_to_respond: str msg_to_respond: str | None + thread_to_respond: str | None sender: str | None bypass_filters: bool # User has tagged @DanswerBot is_bot_msg: bool # User is using /DanswerBot diff --git a/backend/danswer/danswerbot/slack/utils.py b/backend/danswer/danswerbot/slack/utils.py index 892734e9995..1e5ffcc52f1 100644 --- a/backend/danswer/danswerbot/slack/utils.py +++ b/backend/danswer/danswerbot/slack/utils.py @@ -30,7 +30,7 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.db.users import get_user_by_email from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_default_llms from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_to_string from danswer.one_shot_answer.models import ThreadMessage @@ -58,7 +58,7 @@ def _get_rephrase_message() -> list[dict[str, str]]: return messages try: - llm = get_default_llm(use_fast_llm=False, timeout=5) + llm, _ = get_default_llms(timeout=5) except GenAIDisabledException: logger.warning("Unable to rephrase Slack user message, Gen AI disabled") return msg @@ -77,17 +77,25 @@ def update_emote_react( remove: bool, client: WebClient, ) -> None: - if not message_ts: - logger.error(f"Tried to remove a react in {channel} but no message specified") - return + try: + if not message_ts: + logger.error( + f"Tried to remove a react in {channel} but no message specified" + ) + return - func = client.reactions_remove if remove else client.reactions_add - slack_call = make_slack_api_rate_limited(func) # type: ignore - slack_call( - name=emoji, - channel=channel, - timestamp=message_ts, - ) + func = client.reactions_remove if remove else client.reactions_add + slack_call = make_slack_api_rate_limited(func) # type: ignore + slack_call( + name=emoji, + channel=channel, + timestamp=message_ts, + ) + except SlackApiError as e: + if remove: + logger.error(f"Failed to remove Reaction due to: {e}") + else: + logger.error(f"Was not able to react to user message due to: {e}") def get_danswer_bot_app_id(web_client: WebClient) -> Any: @@ -136,16 +144,13 @@ def respond_in_thread( receiver_ids: list[str] | None = None, metadata: Metadata | None = None, unfurl: bool = True, -) -> None: +) -> list[str]: if not text and not blocks: raise ValueError("One of `text` or `blocks` must be provided") + message_ids: list[str] = [] if not receiver_ids: slack_call = make_slack_api_rate_limited(client.chat_postMessage) - else: - slack_call = make_slack_api_rate_limited(client.chat_postEphemeral) - - if not receiver_ids: response = slack_call( channel=channel, text=text, @@ -157,7 +162,9 @@ def respond_in_thread( ) if not response.get("ok"): raise RuntimeError(f"Failed to post message: {response}") + message_ids.append(response["message_ts"]) else: + slack_call = make_slack_api_rate_limited(client.chat_postEphemeral) for receiver in receiver_ids: response = slack_call( channel=channel, @@ -171,6 +178,9 @@ def respond_in_thread( ) if not response.get("ok"): raise RuntimeError(f"Failed to post message: {response}") + message_ids.append(response["message_ts"]) + + return message_ids def build_feedback_id( @@ -308,6 +318,31 @@ def fetch_userids_from_emails( return user_ids, failed_to_find +def fetch_userids_from_groups( + group_names: list[str], client: WebClient +) -> tuple[list[str], list[str]]: + user_ids: list[str] = [] + failed_to_find: list[str] = [] + for group_name in group_names: + try: + # First, find the group ID from the group name + response = client.usergroups_list() + groups = {group["name"]: group["id"] for group in response["usergroups"]} + group_id = groups.get(group_name) + + if group_id: + # Fetch user IDs for the group + response = client.usergroups_users_list(usergroup=group_id) + user_ids.extend(response["users"]) + else: + failed_to_find.append(group_name) + except Exception as e: + logger.error(f"Error fetching user IDs for group {group_name}: {str(e)}") + failed_to_find.append(group_name) + + return user_ids, failed_to_find + + def fetch_groupids_from_names( names: list[str], client: WebClient ) -> tuple[list[str], list[str]]: diff --git a/backend/danswer/db/auth.py b/backend/danswer/db/auth.py index 6d726c2f92a..161fdc8f10b 100644 --- a/backend/danswer/db/auth.py +++ b/backend/danswer/db/auth.py @@ -1,4 +1,5 @@ from collections.abc import AsyncGenerator +from collections.abc import Callable from typing import Any from typing import Dict @@ -16,6 +17,20 @@ from danswer.db.models import AccessToken from danswer.db.models import OAuthAccount from danswer.db.models import User +from danswer.utils.variable_functionality import ( + fetch_versioned_implementation_with_fallback, +) + + +def get_default_admin_user_emails() -> list[str]: + """Returns a list of emails who should default to Admin role. + Only used in the EE version. For MIT, just return empty list.""" + get_default_admin_user_emails_fn: Callable[ + [], list[str] + ] = fetch_versioned_implementation_with_fallback( + "danswer.auth.users", "get_default_admin_user_emails_", lambda: [] + ) + return get_default_admin_user_emails_fn() async def get_user_count() -> int: @@ -32,7 +47,7 @@ async def get_user_count() -> int: class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase): async def create(self, create_dict: Dict[str, Any]) -> UP: user_count = await get_user_count() - if user_count == 0: + if user_count == 0 or create_dict["email"] in get_default_admin_user_emails(): create_dict["role"] = UserRole.ADMIN else: create_dict["role"] = UserRole.BASIC diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 7f31ccb264e..e2208266c84 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -1,3 +1,6 @@ +from collections.abc import Sequence +from datetime import datetime +from datetime import timedelta from uuid import UUID from sqlalchemy import delete @@ -12,6 +15,7 @@ from danswer.configs.chat_configs import HARD_DELETE_CHATS from danswer.configs.constants import MessageType from danswer.db.models import ChatMessage +from danswer.db.models import ChatMessage__SearchDoc from danswer.db.models import ChatSession from danswer.db.models import ChatSessionSharedStatus from danswer.db.models import Prompt @@ -19,6 +23,7 @@ from danswer.db.models import SearchDoc as DBSearchDoc from danswer.db.models import ToolCall from danswer.db.models import User +from danswer.db.pg_file_store import delete_lobj_by_name from danswer.file_store.models import FileDescriptor from danswer.llm.override_models import LLMOverride from danswer.llm.override_models import PromptOverride @@ -63,6 +68,19 @@ def get_chat_session_by_id( return chat_session +def get_chat_sessions_by_slack_thread_id( + slack_thread_id: str, + user_id: UUID | None, + db_session: Session, +) -> Sequence[ChatSession]: + stmt = select(ChatSession).where(ChatSession.slack_thread_id == slack_thread_id) + if user_id is not None: + stmt = stmt.where( + or_(ChatSession.user_id == user_id, ChatSession.user_id.is_(None)) + ) + return db_session.scalars(stmt).all() + + def get_chat_sessions_by_user( user_id: UUID | None, deleted: bool | None, @@ -83,15 +101,64 @@ def get_chat_sessions_by_user( return list(chat_sessions) +def delete_search_doc_message_relationship( + message_id: int, db_session: Session +) -> None: + db_session.query(ChatMessage__SearchDoc).filter( + ChatMessage__SearchDoc.chat_message_id == message_id + ).delete(synchronize_session=False) + + db_session.commit() + + +def delete_orphaned_search_docs(db_session: Session) -> None: + orphaned_docs = ( + db_session.query(SearchDoc) + .outerjoin(ChatMessage__SearchDoc) + .filter(ChatMessage__SearchDoc.chat_message_id.is_(None)) + .all() + ) + for doc in orphaned_docs: + db_session.delete(doc) + db_session.commit() + + +def delete_messages_and_files_from_chat_session( + chat_session_id: int, db_session: Session +) -> None: + # Select messages older than cutoff_time with files + messages_with_files = db_session.execute( + select(ChatMessage.id, ChatMessage.files).where( + ChatMessage.chat_session_id == chat_session_id, + ) + ).fetchall() + + for id, files in messages_with_files: + delete_search_doc_message_relationship(message_id=id, db_session=db_session) + for file_info in files or {}: + lobj_name = file_info.get("id") + if lobj_name: + logger.info(f"Deleting file with name: {lobj_name}") + delete_lobj_by_name(lobj_name, db_session) + + db_session.execute( + delete(ChatMessage).where(ChatMessage.chat_session_id == chat_session_id) + ) + db_session.commit() + + delete_orphaned_search_docs(db_session) + + def create_chat_session( db_session: Session, description: str, user_id: UUID | None, - persona_id: int | None = None, + persona_id: int, llm_override: LLMOverride | None = None, prompt_override: PromptOverride | None = None, one_shot: bool = False, danswerbot_flow: bool = False, + slack_thread_id: str | None = None, ) -> ChatSession: chat_session = ChatSession( user_id=user_id, @@ -101,6 +168,7 @@ def create_chat_session( prompt_override=prompt_override, one_shot=one_shot, danswerbot_flow=danswerbot_flow, + slack_thread_id=slack_thread_id, ) db_session.add(chat_session) @@ -139,25 +207,30 @@ def delete_chat_session( db_session: Session, hard_delete: bool = HARD_DELETE_CHATS, ) -> None: - chat_session = get_chat_session_by_id( - chat_session_id=chat_session_id, user_id=user_id, db_session=db_session - ) - if hard_delete: - stmt_messages = delete(ChatMessage).where( - ChatMessage.chat_session_id == chat_session_id - ) - db_session.execute(stmt_messages) - - stmt = delete(ChatSession).where(ChatSession.id == chat_session_id) - db_session.execute(stmt) - + delete_messages_and_files_from_chat_session(chat_session_id, db_session) + db_session.execute(delete(ChatSession).where(ChatSession.id == chat_session_id)) else: + chat_session = get_chat_session_by_id( + chat_session_id=chat_session_id, user_id=user_id, db_session=db_session + ) chat_session.deleted = True db_session.commit() +def delete_chat_sessions_older_than(days_old: int, db_session: Session) -> None: + cutoff_time = datetime.utcnow() - timedelta(days=days_old) + old_sessions = db_session.execute( + select(ChatSession.user_id, ChatSession.id).where( + ChatSession.time_created < cutoff_time + ) + ).fetchall() + + for user_id, session_id in old_sessions: + delete_chat_session(user_id, session_id, db_session, hard_delete=True) + + def get_chat_message( chat_message_id: int, user_id: UUID | None, @@ -183,6 +256,25 @@ def get_chat_message( return chat_message +def get_chat_messages_by_sessions( + chat_session_ids: list[int], + user_id: UUID | None, + db_session: Session, + skip_permission_check: bool = False, +) -> Sequence[ChatMessage]: + if not skip_permission_check: + for chat_session_id in chat_session_ids: + get_chat_session_by_id( + chat_session_id=chat_session_id, user_id=user_id, db_session=db_session + ) + stmt = ( + select(ChatMessage) + .where(ChatMessage.chat_session_id.in_(chat_session_ids)) + .order_by(nullsfirst(ChatMessage.parent_message)) + ) + return db_session.execute(stmt).scalars().all() + + def get_chat_messages_by_session( chat_session_id: int, user_id: UUID | None, @@ -259,6 +351,7 @@ def create_new_chat_message( rephrased_query: str | None = None, error: str | None = None, reference_docs: list[DBSearchDoc] | None = None, + alternate_assistant_id: int | None = None, # Maps the citation number [n] to the DB SearchDoc citations: dict[int, int] | None = None, tool_calls: list[ToolCall] | None = None, @@ -277,6 +370,7 @@ def create_new_chat_message( files=files, tool_calls=tool_calls if tool_calls else [], error=error, + alternate_assistant_id=alternate_assistant_id, ) # SQL Alchemy will propagate this to update the reference_docs' foreign keys @@ -371,10 +465,19 @@ def get_doc_query_identifiers_from_model( ) raise ValueError("Docs references do not belong to user") - if any( - [doc.chat_messages[0].chat_session_id != chat_session.id for doc in search_docs] - ): - raise ValueError("Invalid reference doc, not from this chat session.") + try: + if any( + [ + doc.chat_messages[0].chat_session_id != chat_session.id + for doc in search_docs + ] + ): + raise ValueError("Invalid reference doc, not from this chat session.") + except IndexError: + # This happens when the doc has no chat_messages associated with it. + # which happens as an edge case where the chat message failed to save + # This usually happens when the LLM fails either immediately or partially through. + raise RuntimeError("Chat session failed, please start a new session.") doc_query_identifiers = [(doc.document_id, doc.chunk_ind) for doc in search_docs] @@ -401,6 +504,7 @@ def create_db_search_doc( updated_at=server_search_doc.updated_at, primary_owners=server_search_doc.primary_owners, secondary_owners=server_search_doc.secondary_owners, + is_internet=server_search_doc.is_internet, ) db_session.add(db_search_doc) @@ -439,6 +543,7 @@ def translate_db_search_doc_to_server_search_doc( secondary_owners=( db_search_doc.secondary_owners if not remove_doc_content else [] ), + is_internet=db_search_doc.is_internet, ) @@ -479,6 +584,7 @@ def translate_db_message_to_chat_message_detail( ) for tool_call in chat_message.tool_calls ], + alternate_assistant_id=chat_message.alternate_assistant_id, ) return chat_msg_detail diff --git a/backend/danswer/db/connector_credential_pair.py b/backend/danswer/db/connector_credential_pair.py index 31f2982db67..314a31eddcd 100644 --- a/backend/danswer/db/connector_credential_pair.py +++ b/backend/danswer/db/connector_credential_pair.py @@ -151,7 +151,8 @@ def add_credential_to_connector( connector_id: int, credential_id: int, cc_pair_name: str | None, - user: User, + is_public: bool, + user: User | None, db_session: Session, ) -> StatusResponse[int]: connector = fetch_connector_by_id(connector_id, db_session) @@ -185,6 +186,7 @@ def add_credential_to_connector( connector_id=connector_id, credential_id=credential_id, name=cc_pair_name, + is_public=is_public, ) db_session.add(association) db_session.commit() @@ -199,7 +201,7 @@ def add_credential_to_connector( def remove_credential_from_connector( connector_id: int, credential_id: int, - user: User, + user: User | None, db_session: Session, ) -> StatusResponse[int]: connector = fetch_connector_by_id(connector_id, db_session) diff --git a/backend/danswer/db/document.py b/backend/danswer/db/document.py index 1ff3391b77e..befb8675748 100644 --- a/backend/danswer/db/document.py +++ b/backend/danswer/db/document.py @@ -327,6 +327,7 @@ def prepare_to_modify_documents( Multiple commits will result in a sqlalchemy.exc.InvalidRequestError. NOTE: this function will commit any existing transaction. """ + db_session.commit() # ensure that we're not in a transaction lock_acquired = False diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 29ca74bfc89..f7d16743c65 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -246,6 +246,39 @@ class Persona__Tool(Base): tool_id: Mapped[int] = mapped_column(ForeignKey("tool.id"), primary_key=True) +class StandardAnswer__StandardAnswerCategory(Base): + __tablename__ = "standard_answer__standard_answer_category" + + standard_answer_id: Mapped[int] = mapped_column( + ForeignKey("standard_answer.id"), primary_key=True + ) + standard_answer_category_id: Mapped[int] = mapped_column( + ForeignKey("standard_answer_category.id"), primary_key=True + ) + + +class SlackBotConfig__StandardAnswerCategory(Base): + __tablename__ = "slack_bot_config__standard_answer_category" + + slack_bot_config_id: Mapped[int] = mapped_column( + ForeignKey("slack_bot_config.id"), primary_key=True + ) + standard_answer_category_id: Mapped[int] = mapped_column( + ForeignKey("standard_answer_category.id"), primary_key=True + ) + + +class ChatMessage__StandardAnswer(Base): + __tablename__ = "chat_message__standard_answer" + + chat_message_id: Mapped[int] = mapped_column( + ForeignKey("chat_message.id"), primary_key=True + ) + standard_answer_id: Mapped[int] = mapped_column( + ForeignKey("standard_answer.id"), primary_key=True + ) + + """ Documents/Indexing Tables """ @@ -612,6 +645,7 @@ class SearchDoc(Base): secondary_owners: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True ) + is_internet: Mapped[bool] = mapped_column(Boolean, default=False, nullable=True) chat_messages = relationship( "ChatMessage", @@ -661,6 +695,12 @@ class ChatSession(Base): ForeignKey("chat_folder.id"), nullable=True ) + current_alternate_model: Mapped[str | None] = mapped_column(String, default=None) + + slack_thread_id: Mapped[str | None] = mapped_column( + String, nullable=True, default=None + ) + # the latest "overrides" specified by the user. These take precedence over # the attached persona. However, overrides specified directly in the # `send-message` call will take precedence over these. @@ -688,7 +728,7 @@ class ChatSession(Base): "ChatFolder", back_populates="chat_sessions" ) messages: Mapped[list["ChatMessage"]] = relationship( - "ChatMessage", back_populates="chat_session", cascade="delete" + "ChatMessage", back_populates="chat_session" ) persona: Mapped["Persona"] = relationship("Persona") @@ -706,6 +746,11 @@ class ChatMessage(Base): id: Mapped[int] = mapped_column(primary_key=True) chat_session_id: Mapped[int] = mapped_column(ForeignKey("chat_session.id")) + + alternate_assistant_id = mapped_column( + Integer, ForeignKey("persona.id"), nullable=True + ) + parent_message: Mapped[int | None] = mapped_column(Integer, nullable=True) latest_child_message: Mapped[int | None] = mapped_column(Integer, nullable=True) message: Mapped[str] = mapped_column(Text) @@ -734,11 +779,15 @@ class ChatMessage(Base): chat_session: Mapped[ChatSession] = relationship("ChatSession") prompt: Mapped[Optional["Prompt"]] = relationship("Prompt") + chat_message_feedbacks: Mapped[list["ChatMessageFeedback"]] = relationship( - "ChatMessageFeedback", back_populates="chat_message" + "ChatMessageFeedback", + back_populates="chat_message", ) + document_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship( - "DocumentRetrievalFeedback", back_populates="chat_message" + "DocumentRetrievalFeedback", + back_populates="chat_message", ) search_docs: Mapped[list["SearchDoc"]] = relationship( "SearchDoc", @@ -749,6 +798,11 @@ class ChatMessage(Base): "ToolCall", back_populates="message", ) + standard_answers: Mapped[list["StandardAnswer"]] = relationship( + "StandardAnswer", + secondary=ChatMessage__StandardAnswer.__table__, + back_populates="chat_messages", + ) class ChatFolder(Base): @@ -785,7 +839,9 @@ class DocumentRetrievalFeedback(Base): __tablename__ = "document_retrieval_feedback" id: Mapped[int] = mapped_column(primary_key=True) - chat_message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id")) + chat_message_id: Mapped[int | None] = mapped_column( + ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True + ) document_id: Mapped[str] = mapped_column(ForeignKey("document.id")) # How high up this document is in the results, 1 for first document_rank: Mapped[int] = mapped_column(Integer) @@ -795,7 +851,9 @@ class DocumentRetrievalFeedback(Base): ) chat_message: Mapped[ChatMessage] = relationship( - "ChatMessage", back_populates="document_feedbacks" + "ChatMessage", + back_populates="document_feedbacks", + foreign_keys=[chat_message_id], ) document: Mapped[Document] = relationship( "Document", back_populates="retrieval_feedbacks" @@ -806,14 +864,18 @@ class ChatMessageFeedback(Base): __tablename__ = "chat_feedback" id: Mapped[int] = mapped_column(Integer, primary_key=True) - chat_message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id")) + chat_message_id: Mapped[int | None] = mapped_column( + ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True + ) is_positive: Mapped[bool | None] = mapped_column(Boolean, nullable=True) required_followup: Mapped[bool | None] = mapped_column(Boolean, nullable=True) feedback_text: Mapped[str | None] = mapped_column(Text, nullable=True) predefined_feedback: Mapped[str | None] = mapped_column(String, nullable=True) chat_message: Mapped[ChatMessage] = relationship( - "ChatMessage", back_populates="chat_message_feedbacks" + "ChatMessage", + back_populates="chat_message_feedbacks", + foreign_keys=[chat_message_id], ) @@ -929,6 +991,7 @@ class Tool(Base): # ID of the tool in the codebase, only applies for in-code tools. # tools defined via the UI will have this as None in_code_tool_id: Mapped[str | None] = mapped_column(String, nullable=True) + display_name: Mapped[str] = mapped_column(String, nullable=True) # OpenAPI scheme for the tool. Only applies to tools defined via the UI. openapi_schema: Mapped[dict[str, Any] | None] = mapped_column( @@ -1059,12 +1122,60 @@ class ChannelConfig(TypedDict): respond_tag_only: NotRequired[bool] # defaults to False respond_to_bots: NotRequired[bool] # defaults to False respond_team_member_list: NotRequired[list[str]] + respond_slack_group_list: NotRequired[list[str]] answer_filters: NotRequired[list[AllowedAnswerFilters]] # If None then no follow up # If empty list, follow up with no tags follow_up_tags: NotRequired[list[str]] +class StandardAnswerCategory(Base): + __tablename__ = "standard_answer_category" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String, unique=True) + standard_answers: Mapped[list["StandardAnswer"]] = relationship( + "StandardAnswer", + secondary=StandardAnswer__StandardAnswerCategory.__table__, + back_populates="categories", + ) + slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship( + "SlackBotConfig", + secondary=SlackBotConfig__StandardAnswerCategory.__table__, + back_populates="standard_answer_categories", + ) + + +class StandardAnswer(Base): + __tablename__ = "standard_answer" + + id: Mapped[int] = mapped_column(primary_key=True) + keyword: Mapped[str] = mapped_column(String) + answer: Mapped[str] = mapped_column(String) + active: Mapped[bool] = mapped_column(Boolean) + + __table_args__ = ( + Index( + "unique_keyword_active", + keyword, + active, + unique=True, + postgresql_where=(active == True), # noqa: E712 + ), + ) + + categories: Mapped[list[StandardAnswerCategory]] = relationship( + "StandardAnswerCategory", + secondary=StandardAnswer__StandardAnswerCategory.__table__, + back_populates="standard_answers", + ) + chat_messages: Mapped[list[ChatMessage]] = relationship( + "ChatMessage", + secondary=ChatMessage__StandardAnswer.__table__, + back_populates="standard_answers", + ) + + class SlackBotResponseType(str, PyEnum): QUOTES = "quotes" CITATIONS = "citations" @@ -1084,8 +1195,16 @@ class SlackBotConfig(Base): response_type: Mapped[SlackBotResponseType] = mapped_column( Enum(SlackBotResponseType, native_enum=False), nullable=False ) + enable_auto_filters: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False + ) persona: Mapped[Persona | None] = relationship("Persona") + standard_answer_categories: Mapped[list[StandardAnswerCategory]] = relationship( + "StandardAnswerCategory", + secondary=SlackBotConfig__StandardAnswerCategory.__table__, + back_populates="slack_bot_configs", + ) class TaskQueueState(Base): @@ -1365,3 +1484,30 @@ class EmailToExternalUserCache(Base): ) user = relationship("User") + + +class UsageReport(Base): + """This stores metadata about usage reports generated by admin including user who generated + them as well las the period they cover. The actual zip file of the report is stored as a lo + using the PGFileStore + """ + + __tablename__ = "usage_reports" + + id: Mapped[int] = mapped_column(primary_key=True) + report_name: Mapped[str] = mapped_column(ForeignKey("file_store.file_name")) + + # if None, report was auto-generated + requestor_user_id: Mapped[UUID | None] = mapped_column( + ForeignKey("user.id"), nullable=True + ) + time_created: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + period_from: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True) + ) + period_to: Mapped[datetime.datetime | None] = mapped_column(DateTime(timezone=True)) + + requestor = relationship("User") + file = relationship("PGFileStore") diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index 538b8441b86..872bca89b0f 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -12,8 +12,8 @@ from sqlalchemy.orm import Session from danswer.auth.schemas import UserRole +from danswer.configs.chat_configs import BING_API_KEY from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX -from danswer.db.document_set import get_document_sets_by_ids from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import DocumentSet from danswer.db.models import Persona @@ -62,19 +62,6 @@ def create_update_persona( ) -> PersonaSnapshot: """Higher level function than upsert_persona, although either is valid to use.""" # Permission to actually use these is checked later - document_sets = list( - get_document_sets_by_ids( - document_set_ids=create_persona_request.document_set_ids, - db_session=db_session, - ) - ) - prompts = list( - get_prompts_by_ids( - prompt_ids=create_persona_request.prompt_ids, - db_session=db_session, - ) - ) - try: persona = upsert_persona( persona_id=persona_id, @@ -85,9 +72,9 @@ def create_update_persona( llm_relevance_filter=create_persona_request.llm_relevance_filter, llm_filter_extraction=create_persona_request.llm_filter_extraction, recency_bias=create_persona_request.recency_bias, - prompts=prompts, + prompt_ids=create_persona_request.prompt_ids, tool_ids=create_persona_request.tool_ids, - document_sets=document_sets, + document_set_ids=create_persona_request.document_set_ids, llm_model_provider_override=create_persona_request.llm_model_provider_override, llm_model_version_override=create_persona_request.llm_model_version_override, starter_messages=create_persona_request.starter_messages, @@ -330,13 +317,13 @@ def upsert_persona( llm_relevance_filter: bool, llm_filter_extraction: bool, recency_bias: RecencyBiasSetting, - prompts: list[Prompt] | None, - document_sets: list[DocumentSet] | None, llm_model_provider_override: str | None, llm_model_version_override: str | None, starter_messages: list[StarterMessage] | None, is_public: bool, db_session: Session, + prompt_ids: list[int] | None = None, + document_set_ids: list[int] | None = None, tool_ids: list[int] | None = None, persona_id: int | None = None, default_persona: bool = False, @@ -356,6 +343,28 @@ def upsert_persona( if not tools and tool_ids: raise ValueError("Tools not found") + # Fetch and attach document_sets by IDs + document_sets = None + if document_set_ids is not None: + document_sets = ( + db_session.query(DocumentSet) + .filter(DocumentSet.id.in_(document_set_ids)) + .all() + ) + if not document_sets and document_set_ids: + raise ValueError("document_sets not found") + + # Fetch and attach prompts by IDs + prompts = None + if prompt_ids is not None: + prompts = db_session.query(Prompt).filter(Prompt.id.in_(prompt_ids)).all() + if not prompts and prompt_ids: + raise ValueError("prompts not found") + + # ensure all specified tools are valid + if tools: + validate_persona_tools(tools) + if persona: if not default_persona and persona.default_persona: raise ValueError("Cannot update default persona with non-default.") @@ -383,10 +392,10 @@ def upsert_persona( if prompts is not None: persona.prompts.clear() - persona.prompts = prompts + persona.prompts = prompts or [] if tools is not None: - persona.tools = tools + persona.tools = tools or [] else: persona = Persona( @@ -453,6 +462,14 @@ def update_persona_visibility( db_session.commit() +def validate_persona_tools(tools: list[Tool]) -> None: + for tool in tools: + if tool.name == "InternetSearchTool" and not BING_API_KEY: + raise ValueError( + "Bing API key not found, please contact your Danswer admin to get it added!" + ) + + def check_user_can_edit_persona(user: User | None, persona: Persona) -> None: # if user is None, assume that no-auth is turned on if user is None: @@ -537,12 +554,22 @@ def get_persona_by_id( user: User | None, db_session: Session, include_deleted: bool = False, + is_for_edit: bool = True, # NOTE: assume true for safety ) -> Persona: stmt = select(Persona).where(Persona.id == persona_id) + or_conditions = [] + # if user is an admin, they should have access to all Personas if user is not None and user.role != UserRole.ADMIN: - stmt = stmt.where(or_(Persona.user_id == user.id, Persona.user_id.is_(None))) + or_conditions.extend([Persona.user_id == user.id, Persona.user_id.is_(None)]) + + # if we aren't editing, also give access to all public personas + if not is_for_edit: + or_conditions.append(Persona.is_public.is_(True)) + + if or_conditions: + stmt = stmt.where(or_(*or_conditions)) if not include_deleted: stmt = stmt.where(Persona.deleted.is_(False)) diff --git a/backend/danswer/db/pg_file_store.py b/backend/danswer/db/pg_file_store.py index 7146fc75b6b..1333dcd6cee 100644 --- a/backend/danswer/db/pg_file_store.py +++ b/backend/danswer/db/pg_file_store.py @@ -1,3 +1,4 @@ +import tempfile from io import BytesIO from typing import IO @@ -6,6 +7,8 @@ from danswer.configs.constants import FileOrigin from danswer.db.models import PGFileStore +from danswer.file_store.constants import MAX_IN_MEMORY_SIZE +from danswer.file_store.constants import STANDARD_CHUNK_SIZE from danswer.utils.logger import setup_logger logger = setup_logger() @@ -15,6 +18,25 @@ def get_pg_conn_from_session(db_session: Session) -> connection: return db_session.connection().connection.connection # type: ignore +def get_pgfilestore_by_file_name( + file_name: str, + db_session: Session, +) -> PGFileStore: + pgfilestore = db_session.query(PGFileStore).filter_by(file_name=file_name).first() + + if not pgfilestore: + raise RuntimeError(f"File by name {file_name} does not exist or was deleted") + + return pgfilestore + + +def delete_pgfilestore_by_file_name( + file_name: str, + db_session: Session, +) -> None: + db_session.query(PGFileStore).filter_by(file_name=file_name).delete() + + def create_populate_lobj( content: IO, db_session: Session, @@ -26,18 +48,40 @@ def create_populate_lobj( pg_conn = get_pg_conn_from_session(db_session) large_object = pg_conn.lobject() - large_object.write(content.read()) + # write in multiple chunks to avoid loading the whole file into memory + while True: + chunk = content.read(STANDARD_CHUNK_SIZE) + if not chunk: + break + large_object.write(chunk) + large_object.close() return large_object.oid -def read_lobj(lobj_oid: int, db_session: Session, mode: str | None = None) -> IO: +def read_lobj( + lobj_oid: int, + db_session: Session, + mode: str | None = None, + use_tempfile: bool = False, +) -> IO: pg_conn = get_pg_conn_from_session(db_session) large_object = ( pg_conn.lobject(lobj_oid, mode=mode) if mode else pg_conn.lobject(lobj_oid) ) - return BytesIO(large_object.read()) + + if use_tempfile: + temp_file = tempfile.SpooledTemporaryFile(max_size=MAX_IN_MEMORY_SIZE) + while True: + chunk = large_object.read(STANDARD_CHUNK_SIZE) + if not chunk: + break + temp_file.write(chunk) + temp_file.seek(0) + return temp_file + else: + return BytesIO(large_object.read()) def delete_lobj_by_id( @@ -48,6 +92,23 @@ def delete_lobj_by_id( pg_conn.lobject(lobj_oid).unlink() +def delete_lobj_by_name( + lobj_name: str, + db_session: Session, +) -> None: + try: + pgfilestore = get_pgfilestore_by_file_name(lobj_name, db_session) + except RuntimeError: + logger.info(f"no file with name {lobj_name} found") + return + + pg_conn = get_pg_conn_from_session(db_session) + pg_conn.lobject(pgfilestore.lobj_oid).unlink() + + delete_pgfilestore_by_file_name(lobj_name, db_session) + db_session.commit() + + def upsert_pgfilestore( file_name: str, display_name: str | None, @@ -87,22 +148,3 @@ def upsert_pgfilestore( db_session.commit() return pgfilestore - - -def get_pgfilestore_by_file_name( - file_name: str, - db_session: Session, -) -> PGFileStore: - pgfilestore = db_session.query(PGFileStore).filter_by(file_name=file_name).first() - - if not pgfilestore: - raise RuntimeError(f"File by name {file_name} does not exist or was deleted") - - return pgfilestore - - -def delete_pgfilestore_by_file_name( - file_name: str, - db_session: Session, -) -> None: - db_session.query(PGFileStore).filter_by(file_name=file_name).delete() diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py index 8683b3222fb..502d99608f9 100644 --- a/backend/danswer/db/slack_bot_config.py +++ b/backend/danswer/db/slack_bot_config.py @@ -5,7 +5,6 @@ from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX -from danswer.db.document_set import get_document_sets_by_ids from danswer.db.models import ChannelConfig from danswer.db.models import Persona from danswer.db.models import Persona__DocumentSet @@ -15,6 +14,7 @@ from danswer.db.persona import get_default_prompt from danswer.db.persona import mark_persona_as_deleted from danswer.db.persona import upsert_persona +from danswer.db.standard_answer import fetch_standard_answer_categories_by_ids from danswer.search.enums import RecencyBiasSetting @@ -42,12 +42,6 @@ def create_slack_bot_persona( num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, ) -> Persona: """NOTE: does not commit changes""" - document_sets = list( - get_document_sets_by_ids( - document_set_ids=document_set_ids, - db_session=db_session, - ) - ) # create/update persona associated with the slack bot persona_name = _build_persona_name(channel_names) @@ -59,10 +53,10 @@ def create_slack_bot_persona( description="", num_chunks=num_chunks, llm_relevance_filter=True, - llm_filter_extraction=True, + llm_filter_extraction=False, recency_bias=RecencyBiasSetting.AUTO, - prompts=[default_prompt], - document_sets=document_sets, + prompt_ids=[default_prompt.id], + document_set_ids=document_set_ids, llm_model_provider_override=None, llm_model_version_override=None, starter_messages=None, @@ -79,12 +73,25 @@ def insert_slack_bot_config( persona_id: int | None, channel_config: ChannelConfig, response_type: SlackBotResponseType, + standard_answer_category_ids: list[int], + enable_auto_filters: bool, db_session: Session, ) -> SlackBotConfig: + existing_standard_answer_categories = fetch_standard_answer_categories_by_ids( + standard_answer_category_ids=standard_answer_category_ids, + db_session=db_session, + ) + if len(existing_standard_answer_categories) != len(standard_answer_category_ids): + raise ValueError( + f"Some or all categories with ids {standard_answer_category_ids} do not exist" + ) + slack_bot_config = SlackBotConfig( persona_id=persona_id, channel_config=channel_config, response_type=response_type, + standard_answer_categories=existing_standard_answer_categories, + enable_auto_filters=enable_auto_filters, ) db_session.add(slack_bot_config) db_session.commit() @@ -97,6 +104,8 @@ def update_slack_bot_config( persona_id: int | None, channel_config: ChannelConfig, response_type: SlackBotResponseType, + standard_answer_category_ids: list[int], + enable_auto_filters: bool, db_session: Session, ) -> SlackBotConfig: slack_bot_config = db_session.scalar( @@ -106,6 +115,16 @@ def update_slack_bot_config( raise ValueError( f"Unable to find slack bot config with ID {slack_bot_config_id}" ) + + existing_standard_answer_categories = fetch_standard_answer_categories_by_ids( + standard_answer_category_ids=standard_answer_category_ids, + db_session=db_session, + ) + if len(existing_standard_answer_categories) != len(standard_answer_category_ids): + raise ValueError( + f"Some or all categories with ids {standard_answer_category_ids} do not exist" + ) + # get the existing persona id before updating the object existing_persona_id = slack_bot_config.persona_id @@ -115,6 +134,10 @@ def update_slack_bot_config( slack_bot_config.persona_id = persona_id slack_bot_config.channel_config = channel_config slack_bot_config.response_type = response_type + slack_bot_config.standard_answer_categories = list( + existing_standard_answer_categories + ) + slack_bot_config.enable_auto_filters = enable_auto_filters # if the persona has changed, then clean up the old persona if persona_id != existing_persona_id and existing_persona_id: diff --git a/backend/danswer/db/standard_answer.py b/backend/danswer/db/standard_answer.py new file mode 100644 index 00000000000..3b56c52e641 --- /dev/null +++ b/backend/danswer/db/standard_answer.py @@ -0,0 +1,228 @@ +import string +from collections.abc import Sequence + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.db.models import StandardAnswer +from danswer.db.models import StandardAnswerCategory +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def check_category_validity(category_name: str) -> bool: + """If a category name is too long, it should not be used (it will cause an error in Postgres + as the unique constraint can only apply to entries that are less than 2704 bytes). + + Additionally, extremely long categories are not really usable / useful.""" + if len(category_name) > 255: + logger.error( + f"Category with name '{category_name}' is too long, cannot be used" + ) + return False + + return True + + +def insert_standard_answer_category( + category_name: str, db_session: Session +) -> StandardAnswerCategory: + if not check_category_validity(category_name): + raise ValueError(f"Invalid category name: {category_name}") + standard_answer_category = StandardAnswerCategory(name=category_name) + db_session.add(standard_answer_category) + db_session.commit() + + return standard_answer_category + + +def insert_standard_answer( + keyword: str, + answer: str, + category_ids: list[int], + db_session: Session, +) -> StandardAnswer: + existing_categories = fetch_standard_answer_categories_by_ids( + standard_answer_category_ids=category_ids, + db_session=db_session, + ) + if len(existing_categories) != len(category_ids): + raise ValueError(f"Some or all categories with ids {category_ids} do not exist") + + standard_answer = StandardAnswer( + keyword=keyword, + answer=answer, + categories=existing_categories, + active=True, + ) + db_session.add(standard_answer) + db_session.commit() + return standard_answer + + +def update_standard_answer( + standard_answer_id: int, + keyword: str, + answer: str, + category_ids: list[int], + db_session: Session, +) -> StandardAnswer: + standard_answer = db_session.scalar( + select(StandardAnswer).where(StandardAnswer.id == standard_answer_id) + ) + if standard_answer is None: + raise ValueError(f"No standard answer with id {standard_answer_id}") + + existing_categories = fetch_standard_answer_categories_by_ids( + standard_answer_category_ids=category_ids, + db_session=db_session, + ) + if len(existing_categories) != len(category_ids): + raise ValueError(f"Some or all categories with ids {category_ids} do not exist") + + standard_answer.keyword = keyword + standard_answer.answer = answer + standard_answer.categories = list(existing_categories) + + db_session.commit() + + return standard_answer + + +def remove_standard_answer( + standard_answer_id: int, + db_session: Session, +) -> None: + standard_answer = db_session.scalar( + select(StandardAnswer).where(StandardAnswer.id == standard_answer_id) + ) + if standard_answer is None: + raise ValueError(f"No standard answer with id {standard_answer_id}") + + standard_answer.active = False + db_session.commit() + + +def update_standard_answer_category( + standard_answer_category_id: int, + category_name: str, + db_session: Session, +) -> StandardAnswerCategory: + standard_answer_category = db_session.scalar( + select(StandardAnswerCategory).where( + StandardAnswerCategory.id == standard_answer_category_id + ) + ) + if standard_answer_category is None: + raise ValueError( + f"No standard answer category with id {standard_answer_category_id}" + ) + + if not check_category_validity(category_name): + raise ValueError(f"Invalid category name: {category_name}") + + standard_answer_category.name = category_name + + db_session.commit() + + return standard_answer_category + + +def fetch_standard_answer_category( + standard_answer_category_id: int, + db_session: Session, +) -> StandardAnswerCategory | None: + return db_session.scalar( + select(StandardAnswerCategory).where( + StandardAnswerCategory.id == standard_answer_category_id + ) + ) + + +def fetch_standard_answer_categories_by_ids( + standard_answer_category_ids: list[int], + db_session: Session, +) -> Sequence[StandardAnswerCategory]: + return db_session.scalars( + select(StandardAnswerCategory).where( + StandardAnswerCategory.id.in_(standard_answer_category_ids) + ) + ).all() + + +def fetch_standard_answer_categories( + db_session: Session, +) -> Sequence[StandardAnswerCategory]: + return db_session.scalars(select(StandardAnswerCategory)).all() + + +def fetch_standard_answer( + standard_answer_id: int, + db_session: Session, +) -> StandardAnswer | None: + return db_session.scalar( + select(StandardAnswer).where(StandardAnswer.id == standard_answer_id) + ) + + +def find_matching_standard_answers( + id_in: list[int], + query: str, + db_session: Session, +) -> list[StandardAnswer]: + stmt = ( + select(StandardAnswer) + .where(StandardAnswer.active.is_(True)) + .where(StandardAnswer.id.in_(id_in)) + ) + possible_standard_answers = db_session.scalars(stmt).all() + + matching_standard_answers: list[StandardAnswer] = [] + for standard_answer in possible_standard_answers: + # Remove punctuation and split the keyword into individual words + keyword_words = "".join( + char + for char in standard_answer.keyword.lower() + if char not in string.punctuation + ).split() + + # Remove punctuation and split the query into individual words + query_words = "".join( + char for char in query.lower() if char not in string.punctuation + ).split() + + # Check if all of the keyword words are in the query words + if all(word in query_words for word in keyword_words): + matching_standard_answers.append(standard_answer) + + return matching_standard_answers + + +def fetch_standard_answers(db_session: Session) -> Sequence[StandardAnswer]: + return db_session.scalars( + select(StandardAnswer).where(StandardAnswer.active.is_(True)) + ).all() + + +def create_initial_default_standard_answer_category(db_session: Session) -> None: + default_category_id = 0 + default_category_name = "General" + default_category = fetch_standard_answer_category( + standard_answer_category_id=default_category_id, + db_session=db_session, + ) + if default_category is not None: + if default_category.name != default_category_name: + raise ValueError( + "DB is not in a valid initial state. " + "Default standard answer category does not have expected name." + ) + return + + standard_answer_category = StandardAnswerCategory( + id=default_category_id, + name=default_category_name, + ) + db_session.add(standard_answer_category) + db_session.commit() diff --git a/backend/danswer/db/tag.py b/backend/danswer/db/tag.py index 240065a17d9..66418b948e7 100644 --- a/backend/danswer/db/tag.py +++ b/backend/danswer/db/tag.py @@ -91,9 +91,10 @@ def create_or_add_document_tag_list( new_tags.append(new_tag) existing_tag_values.add(tag_value) - logger.debug( - f"Created new tags: {', '.join([f'{tag.tag_key}:{tag.tag_value}' for tag in new_tags])}" - ) + if new_tags: + logger.debug( + f"Created new tags: {', '.join([f'{tag.tag_key}:{tag.tag_value}' for tag in new_tags])}" + ) all_tags = existing_tags + new_tags diff --git a/backend/danswer/db/tasks.py b/backend/danswer/db/tasks.py index a12f988613c..23a7edc9882 100644 --- a/backend/danswer/db/tasks.py +++ b/backend/danswer/db/tasks.py @@ -26,6 +26,23 @@ def get_latest_task( return latest_task +def get_latest_task_by_type( + task_name: str, + db_session: Session, +) -> TaskQueueState | None: + stmt = ( + select(TaskQueueState) + .where(TaskQueueState.task_name.like(f"%{task_name}%")) + .order_by(desc(TaskQueueState.id)) + .limit(1) + ) + + result = db_session.execute(stmt) + latest_task = result.scalars().first() + + return latest_task + + def register_task( task_id: str, task_name: str, @@ -66,7 +83,7 @@ def mark_task_finished( db_session.commit() -def check_live_task_not_timed_out( +def check_task_is_live_and_not_timed_out( task: TaskQueueState, db_session: Session, timeout: int = JOB_TIMEOUT, diff --git a/backend/danswer/document_index/vespa/index.py b/backend/danswer/document_index/vespa/index.py index d96350199e9..e29d2a010fa 100644 --- a/backend/danswer/document_index/vespa/index.py +++ b/backend/danswer/document_index/vespa/index.py @@ -119,6 +119,7 @@ def _does_document_exist( chunk. This checks for whether the chunk exists already in the index""" doc_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}" doc_fetch_response = http_client.get(doc_url) + if doc_fetch_response.status_code == 404: return False @@ -559,7 +560,9 @@ def _process_dynamic_summary( return processed_summary -def _vespa_hit_to_inference_chunk(hit: dict[str, Any]) -> InferenceChunk: +def _vespa_hit_to_inference_chunk( + hit: dict[str, Any], null_score: bool = False +) -> InferenceChunk: fields = cast(dict[str, Any], hit["fields"]) # parse fields that are stored as strings, but are really json / datetime @@ -615,7 +618,7 @@ def _vespa_hit_to_inference_chunk(hit: dict[str, Any]) -> InferenceChunk: semantic_identifier=fields[SEMANTIC_IDENTIFIER], boost=fields.get(BOOST, 1), recency_bias=fields.get("matchfeatures", {}).get(RECENCY_BIAS, 1.0), - score=hit.get("relevance", 0), + score=None if null_score else hit.get("relevance", 0), hidden=fields.get(HIDDEN, False), primary_owners=fields.get(PRIMARY_OWNERS), secondary_owners=fields.get(SECONDARY_OWNERS), @@ -992,7 +995,8 @@ def id_based_retrieval( return [] inference_chunks = [ - _vespa_hit_to_inference_chunk(chunk) for chunk in vespa_chunks + _vespa_hit_to_inference_chunk(chunk, null_score=True) + for chunk in vespa_chunks ] inference_chunks.sort(key=lambda chunk: chunk.chunk_id) return inference_chunks diff --git a/backend/danswer/file_processing/enums.py b/backend/danswer/file_processing/enums.py new file mode 100644 index 00000000000..f532d0ebfcc --- /dev/null +++ b/backend/danswer/file_processing/enums.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class HtmlBasedConnectorTransformLinksStrategy(str, Enum): + # remove links entirely + STRIP = "strip" + # turn HTML links into markdown links + MARKDOWN = "markdown" diff --git a/backend/danswer/file_processing/extract_file_text.py b/backend/danswer/file_processing/extract_file_text.py index 14cca7f6f9b..f96d4a4153d 100644 --- a/backend/danswer/file_processing/extract_file_text.py +++ b/backend/danswer/file_processing/extract_file_text.py @@ -3,6 +3,7 @@ import os import re import zipfile +from collections.abc import Callable from collections.abc import Iterator from email.parser import Parser as EmailParser from pathlib import Path @@ -65,6 +66,16 @@ def check_file_ext_is_valid(ext: str) -> bool: return ext in VALID_FILE_EXTENSIONS +def is_text_file(file: IO[bytes]) -> bool: + """ + checks if the first 1024 bytes only contain printable or whitespace characters + if it does, then we say its a plaintext file + """ + raw_data = file.read(1024) + text_chars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F}) + return all(c in text_chars for c in raw_data) + + def detect_encoding(file: IO[bytes]) -> str: raw_data = file.read(50000) encoding = chardet.detect(raw_data)["encoding"] or "utf-8" @@ -261,37 +272,32 @@ def extract_file_text( file: IO[Any], break_on_unprocessable: bool = True, ) -> str: - if not file_name: - return file_io_to_text(file) + extension_to_function: dict[str, Callable[[IO[Any]], str]] = { + ".pdf": pdf_to_text, + ".docx": docx_to_text, + ".pptx": pptx_to_text, + ".xlsx": xlsx_to_text, + ".eml": eml_to_text, + ".epub": epub_to_text, + ".html": parse_html_page_basic, + } + + def _process_file() -> str: + if file_name: + extension = get_file_ext(file_name) + if check_file_ext_is_valid(extension): + return extension_to_function.get(extension, file_io_to_text)(file) + + # Either the file somehow has no name or the extension is not one that we are familiar with + if is_text_file(file): + return file_io_to_text(file) + + raise ValueError("Unknown file extension and unknown text encoding") - extension = get_file_ext(file_name) - if not check_file_ext_is_valid(extension): + try: + return _process_file() + except Exception as e: if break_on_unprocessable: - raise RuntimeError(f"Unprocessable file type: {file_name}") - else: - logger.warning(f"Unprocessable file type: {file_name}") - return "" - - if extension == ".pdf": - return pdf_to_text(file=file) - - elif extension == ".docx": - return docx_to_text(file) - - elif extension == ".pptx": - return pptx_to_text(file) - - elif extension == ".xlsx": - return xlsx_to_text(file) - - elif extension == ".eml": - return eml_to_text(file) - - elif extension == ".epub": - return epub_to_text(file) - - elif extension == ".html": - return parse_html_page_basic(file) - - else: - return file_io_to_text(file) + raise RuntimeError(f"Failed to process file: {str(e)}") from e + logger.warning(f"Failed to process file: {str(e)}") + return "" diff --git a/backend/danswer/file_processing/html_utils.py b/backend/danswer/file_processing/html_utils.py index 9b5875227a0..48782981f89 100644 --- a/backend/danswer/file_processing/html_utils.py +++ b/backend/danswer/file_processing/html_utils.py @@ -5,8 +5,10 @@ import bs4 +from danswer.configs.app_configs import HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY from danswer.configs.app_configs import WEB_CONNECTOR_IGNORED_CLASSES from danswer.configs.app_configs import WEB_CONNECTOR_IGNORED_ELEMENTS +from danswer.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy MINTLIFY_UNWANTED = ["sticky", "hidden"] @@ -32,6 +34,19 @@ def strip_newlines(document: str) -> str: return re.sub(r"[\n\r]+", " ", document) +def format_element_text(element_text: str, link_href: str | None) -> str: + element_text_no_newlines = strip_newlines(element_text) + + if ( + not link_href + or HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY + == HtmlBasedConnectorTransformLinksStrategy.STRIP + ): + return element_text_no_newlines + + return f"[{element_text_no_newlines}]({link_href})" + + def format_document_soup( document: bs4.BeautifulSoup, table_cell_separator: str = "\t" ) -> str: @@ -49,6 +64,8 @@ def format_document_soup( verbatim_output = 0 in_table = False last_added_newline = False + link_href: str | None = None + for e in document.descendants: verbatim_output -= 1 if isinstance(e, bs4.element.NavigableString): @@ -71,7 +88,7 @@ def format_document_soup( content_to_add = ( element_text if verbatim_output > 0 - else strip_newlines(element_text) + else format_element_text(element_text, link_href) ) # Don't join separate elements without any spacing @@ -98,7 +115,14 @@ def format_document_soup( elif in_table: # don't handle other cases while in table pass - + elif e.name == "a": + href_value = e.get("href", None) + # mostly for typing, having multiple hrefs is not valid HTML + link_href = ( + href_value[0] if isinstance(href_value, list) else href_value + ) + elif e.name == "/a": + link_href = None elif e.name in ["p", "div"]: if not list_element_start: text += "\n" diff --git a/backend/danswer/file_store/constants.py b/backend/danswer/file_store/constants.py new file mode 100644 index 00000000000..a0845d35ef7 --- /dev/null +++ b/backend/danswer/file_store/constants.py @@ -0,0 +1,2 @@ +MAX_IN_MEMORY_SIZE = 30 * 1024 * 1024 # 30MB +STANDARD_CHUNK_SIZE = 10 * 1024 * 1024 # 10MB chunks diff --git a/backend/danswer/file_store/file_store.py b/backend/danswer/file_store/file_store.py index 9e131d38c9d..9bc4c41d361 100644 --- a/backend/danswer/file_store/file_store.py +++ b/backend/danswer/file_store/file_store.py @@ -5,6 +5,7 @@ from sqlalchemy.orm import Session from danswer.configs.constants import FileOrigin +from danswer.db.models import PGFileStore from danswer.db.pg_file_store import create_populate_lobj from danswer.db.pg_file_store import delete_lobj_by_id from danswer.db.pg_file_store import delete_pgfilestore_by_file_name @@ -26,6 +27,7 @@ def save_file( display_name: str | None, file_origin: FileOrigin, file_type: str, + file_metadata: dict | None = None, ) -> None: """ Save a file to the blob store @@ -41,12 +43,17 @@ def save_file( raise NotImplementedError @abstractmethod - def read_file(self, file_name: str, mode: str | None) -> IO: + def read_file( + self, file_name: str, mode: str | None, use_tempfile: bool = False + ) -> IO: """ Read the content of a given file by the name Parameters: - file_name: Name of file to read + - mode: Mode to open the file (e.g. 'b' for binary) + - use_tempfile: Whether to use a temporary file to store the contents + in order to avoid loading the entire file into memory Returns: Contents of the file and metadata dict @@ -73,6 +80,7 @@ def save_file( display_name: str | None, file_origin: FileOrigin, file_type: str, + file_metadata: dict | None = None, ) -> None: try: # The large objects in postgres are saved as special objects can be listed with @@ -85,20 +93,33 @@ def save_file( file_type=file_type, lobj_oid=obj_id, db_session=self.db_session, + file_metadata=file_metadata, ) self.db_session.commit() except Exception: self.db_session.rollback() raise - def read_file(self, file_name: str, mode: str | None = None) -> IO: + def read_file( + self, file_name: str, mode: str | None = None, use_tempfile: bool = False + ) -> IO: file_record = get_pgfilestore_by_file_name( file_name=file_name, db_session=self.db_session ) return read_lobj( - lobj_oid=file_record.lobj_oid, db_session=self.db_session, mode=mode + lobj_oid=file_record.lobj_oid, + db_session=self.db_session, + mode=mode, + use_tempfile=use_tempfile, ) + def read_file_record(self, file_name: str) -> PGFileStore: + file_record = get_pgfilestore_by_file_name( + file_name=file_name, db_session=self.db_session + ) + + return file_record + def delete_file(self, file_name: str) -> None: try: file_record = get_pgfilestore_by_file_name( diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index 55b2be36002..2ce36d0746d 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -31,6 +31,8 @@ from danswer.llm.answering.stream_processing.quotes_processing import ( build_quotes_processor, ) +from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping +from danswer.llm.answering.stream_processing.utils import map_document_id_order from danswer.llm.interfaces import LLM from danswer.llm.utils import get_default_llm_tokenizer from danswer.llm.utils import message_generator_to_string_generator @@ -43,9 +45,11 @@ from danswer.tools.images.image_generation_tool import ImageGenerationResponse from danswer.tools.images.image_generation_tool import ImageGenerationTool from danswer.tools.images.prompt import build_image_generation_user_prompt +from danswer.tools.internet_search.internet_search_tool import InternetSearchTool from danswer.tools.message import build_tool_message from danswer.tools.message import ToolCallSummary from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS +from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID from danswer.tools.search.search_tool import SearchResponseSummary from danswer.tools.search.search_tool import SearchTool @@ -57,17 +61,22 @@ from danswer.tools.tool_runner import ToolCallFinalResult from danswer.tools.tool_runner import ToolCallKickoff from danswer.tools.tool_runner import ToolRunner +from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm from danswer.tools.utils import explicit_tool_calling_supported +from danswer.utils.logger import setup_logger + + +logger = setup_logger() def _get_answer_stream_processor( context_docs: list[LlmDoc], - search_order_docs: list[LlmDoc], + doc_id_to_rank_map: DocumentIdOrderMapping, answer_style_configs: AnswerStyleConfig, ) -> StreamProcessor: if answer_style_configs.citation_config: return build_citation_processor( - context_docs=context_docs, search_order_docs=search_order_docs + context_docs=context_docs, doc_id_to_rank_map=doc_id_to_rank_map ) if answer_style_configs.quotes_config: return build_quotes_processor( @@ -101,6 +110,8 @@ def __init__( force_use_tool: ForceUseTool | None = None, # if set to True, then never use the LLMs provided tool-calling functonality skip_explicit_tool_calling: bool = False, + # Returns the full document sections text from the search tool + return_contexts: bool = False, ) -> None: if single_message_history and message_history: raise ValueError( @@ -134,6 +145,8 @@ def __init__( AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff ] | None = None + self._return_contexts = return_contexts + def _update_prompt_builder_for_search_tool( self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc] ) -> None: @@ -224,7 +237,7 @@ def _raw_output_for_explicit_tool_calling_llms( tool_call_requests = tool_call_chunk.tool_calls for tool_call_request in tool_call_requests: tool = [ - tool for tool in self.tools if tool.name() == tool_call_request["name"] + tool for tool in self.tools if tool.name == tool_call_request["name"] ][0] tool_args = ( self.force_use_tool.args @@ -243,15 +256,14 @@ def _raw_output_for_explicit_tool_calling_llms( ), ) - if tool.name() == SearchTool.NAME: + if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}: self._update_prompt_builder_for_search_tool(prompt_builder, []) - elif tool.name() == ImageGenerationTool.NAME: + elif tool.name == ImageGenerationTool._NAME: prompt_builder.update_user_prompt( build_image_generation_user_prompt( query=self.question, ) ) - yield tool_runner.tool_final_result() prompt = prompt_builder.build(tool_call_summary=tool_call_summary) @@ -277,7 +289,7 @@ def _raw_output_for_non_explicit_tool_calling_llms( [ tool for tool in self.tools - if tool.name() == self.force_use_tool.tool_name + if tool.name == self.force_use_tool.tool_name ] ), None, @@ -297,21 +309,39 @@ def _raw_output_for_non_explicit_tool_calling_llms( ) if tool_args is None: - raise RuntimeError(f"Tool '{tool.name()}' did not return args") + raise RuntimeError(f"Tool '{tool.name}' did not return args") chosen_tool_and_args = (tool, tool_args) else: - all_tool_args = check_which_tools_should_run_for_non_tool_calling_llm( + tool_options = check_which_tools_should_run_for_non_tool_calling_llm( tools=self.tools, query=self.question, history=self.message_history, llm=self.llm, ) - for ind, args in enumerate(all_tool_args): - if args is not None: - chosen_tool_and_args = (self.tools[ind], args) - # for now, just pick the first tool selected - break + + available_tools_and_args = [ + (self.tools[ind], args) + for ind, args in enumerate(tool_options) + if args is not None + ] + + logger.info( + f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}" + ) + + chosen_tool_and_args = ( + select_single_tool_for_non_tool_calling_llm( + tools_and_args=available_tools_and_args, + history=self.message_history, + query=self.question, + llm=self.llm, + ) + if available_tools_and_args + else None + ) + + logger.info(f"Chosen tool: {chosen_tool_and_args}") if not chosen_tool_and_args: prompt_builder.update_system_prompt( @@ -332,7 +362,7 @@ def _raw_output_for_non_explicit_tool_calling_llms( tool_runner = ToolRunner(tool, tool_args) yield tool_runner.kickoff() - if tool.name() == SearchTool.NAME: + if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}: final_context_documents = None for response in tool_runner.tool_responses(): if response.id == FINAL_CONTEXT_DOCUMENTS: @@ -340,12 +370,14 @@ def _raw_output_for_non_explicit_tool_calling_llms( yield response if final_context_documents is None: - raise RuntimeError("SearchTool did not return final context documents") + raise RuntimeError( + f"{tool.name} did not return final context documents" + ) self._update_prompt_builder_for_search_tool( prompt_builder, final_context_documents ) - elif tool.name() == ImageGenerationTool.NAME: + elif tool.name == ImageGenerationTool._NAME: img_urls = [] for response in tool_runner.tool_responses(): if response.id == IMAGE_GENERATION_RESPONSE_ID: @@ -353,7 +385,7 @@ def _raw_output_for_non_explicit_tool_calling_llms( list[ImageGenerationResponse], response.response ) img_urls = [img.url for img in img_generation_response] - break + yield response prompt_builder.update_user_prompt( @@ -367,7 +399,7 @@ def _raw_output_for_non_explicit_tool_calling_llms( HumanMessage( content=build_user_message_for_custom_tool_for_non_tool_calling_llm( self.question, - tool.name(), + tool.name, *tool_runner.tool_responses(), ) ) @@ -413,6 +445,10 @@ def _process_stream( yield message elif isinstance(message, ToolResponse): if message.id == SEARCH_RESPONSE_SUMMARY_ID: + # We don't need to run section merging in this flow, this variable is only used + # below to specify the ordering of the documents for the purpose of matching + # citations to the right search documents. The deduplication logic is more lightweight + # there and we don't need to do it twice search_results = [ llm_doc_from_inference_section(section) for section in cast( @@ -421,6 +457,12 @@ def _process_stream( ] elif message.id == FINAL_CONTEXT_DOCUMENTS: final_context_docs = cast(list[LlmDoc], message.response) + elif ( + message.id == SEARCH_DOC_CONTENT_ID + and not self._return_contexts + ): + continue + yield message else: # assumes all tool responses will come first, then the final answer @@ -430,7 +472,9 @@ def _process_stream( context_docs=final_context_docs or [], # if doc selection is enabled, then search_results will be None, # so we need to use the final_context_docs - search_order_docs=search_results or final_context_docs or [], + doc_id_to_rank_map=map_document_id_order( + search_results or final_context_docs or [] + ), answer_style_configs=self.answer_style_config, ) diff --git a/backend/danswer/llm/answering/doc_pruning.py b/backend/danswer/llm/answering/doc_pruning.py deleted file mode 100644 index 5a43ab3c6fd..00000000000 --- a/backend/danswer/llm/answering/doc_pruning.py +++ /dev/null @@ -1,230 +0,0 @@ -import json -from copy import deepcopy -from typing import TypeVar - -from danswer.chat.models import ( - LlmDoc, -) -from danswer.configs.constants import IGNORE_FOR_QA -from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE -from danswer.llm.answering.models import DocumentPruningConfig -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens -from danswer.llm.interfaces import LLMConfig -from danswer.llm.utils import get_default_llm_tokenizer -from danswer.llm.utils import tokenizer_trim_content -from danswer.prompts.prompt_utils import build_doc_context_str -from danswer.search.models import InferenceChunk -from danswer.tools.search.search_utils import llm_doc_to_dict -from danswer.utils.logger import setup_logger - - -logger = setup_logger() - -T = TypeVar("T", bound=LlmDoc | InferenceChunk) - -_METADATA_TOKEN_ESTIMATE = 75 - - -class PruningError(Exception): - pass - - -def _compute_limit( - prompt_config: PromptConfig, - llm_config: LLMConfig, - question: str, - max_chunks: int | None, - max_window_percentage: float | None, - max_tokens: int | None, - tool_token_count: int, -) -> int: - llm_max_document_tokens = compute_max_document_tokens( - prompt_config=prompt_config, - llm_config=llm_config, - tool_token_count=tool_token_count, - actual_user_input=question, - ) - - window_percentage_based_limit = ( - max_window_percentage * llm_max_document_tokens - if max_window_percentage - else None - ) - chunk_count_based_limit = ( - max_chunks * DOC_EMBEDDING_CONTEXT_SIZE if max_chunks else None - ) - - limit_options = [ - lim - for lim in [ - window_percentage_based_limit, - chunk_count_based_limit, - max_tokens, - llm_max_document_tokens, - ] - if lim - ] - return int(min(limit_options)) - - -def reorder_docs( - docs: list[T], - doc_relevance_list: list[bool] | None, -) -> list[T]: - if doc_relevance_list is None: - return docs - - reordered_docs: list[T] = [] - if doc_relevance_list is not None: - for selection_target in [True, False]: - for doc, is_relevant in zip(docs, doc_relevance_list): - if is_relevant == selection_target: - reordered_docs.append(doc) - return reordered_docs - - -def _remove_docs_to_ignore(docs: list[LlmDoc]) -> list[LlmDoc]: - return [doc for doc in docs if not doc.metadata.get(IGNORE_FOR_QA)] - - -def _apply_pruning( - docs: list[LlmDoc], - doc_relevance_list: list[bool] | None, - token_limit: int, - is_manually_selected_docs: bool, - use_sections: bool, - using_tool_message: bool, -) -> list[LlmDoc]: - llm_tokenizer = get_default_llm_tokenizer() - docs = deepcopy(docs) # don't modify in place - - # re-order docs with all the "relevant" docs at the front - docs = reorder_docs(docs=docs, doc_relevance_list=doc_relevance_list) - # remove docs that are explicitly marked as not for QA - docs = _remove_docs_to_ignore(docs=docs) - - tokens_per_doc: list[int] = [] - final_doc_ind = None - total_tokens = 0 - for ind, llm_doc in enumerate(docs): - doc_str = ( - json.dumps(llm_doc_to_dict(llm_doc, ind)) - if using_tool_message - else build_doc_context_str( - semantic_identifier=llm_doc.semantic_identifier, - source_type=llm_doc.source_type, - content=llm_doc.content, - metadata_dict=llm_doc.metadata, - updated_at=llm_doc.updated_at, - ind=ind, - ) - ) - - doc_tokens = len(llm_tokenizer.encode(doc_str)) - # if chunks, truncate chunks that are way too long - # this can happen if the embedding model tokenizer is different - # than the LLM tokenizer - if ( - not is_manually_selected_docs - and not use_sections - and doc_tokens > DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE - ): - logger.warning( - "Found more tokens in chunk than expected, " - "likely mismatch between embedding and LLM tokenizers. Trimming content..." - ) - llm_doc.content = tokenizer_trim_content( - content=llm_doc.content, - desired_length=DOC_EMBEDDING_CONTEXT_SIZE, - tokenizer=llm_tokenizer, - ) - doc_tokens = DOC_EMBEDDING_CONTEXT_SIZE - tokens_per_doc.append(doc_tokens) - total_tokens += doc_tokens - if total_tokens > token_limit: - final_doc_ind = ind - break - - if final_doc_ind is not None: - if is_manually_selected_docs or use_sections: - # for document selection, only allow the final document to get truncated - # if more than that, then the user message is too long - if final_doc_ind != len(docs) - 1: - if use_sections: - # Truncate the rest of the list since we're over the token limit - # for the last one, trim it. In this case, the Sections can be rather long - # so better to trim the back than throw away the whole thing. - docs = docs[: final_doc_ind + 1] - else: - raise PruningError( - "LLM context window exceeded. Please de-select some documents or shorten your query." - ) - - amount_to_truncate = total_tokens - token_limit - # NOTE: need to recalculate the length here, since the previous calculation included - # overhead from JSON-fying the doc / the metadata - final_doc_content_length = len( - llm_tokenizer.encode(docs[final_doc_ind].content) - ) - (amount_to_truncate) - # this could occur if we only have space for the title / metadata - # not ideal, but it's the most reasonable thing to do - # NOTE: the frontend prevents documents from being selected if - # less than 75 tokens are available to try and avoid this situation - # from occurring in the first place - if final_doc_content_length <= 0: - logger.error( - f"Final doc ({docs[final_doc_ind].semantic_identifier}) content " - "length is less than 0. Removing this doc from the final prompt." - ) - docs.pop() - else: - docs[final_doc_ind].content = tokenizer_trim_content( - content=docs[final_doc_ind].content, - desired_length=final_doc_content_length, - tokenizer=llm_tokenizer, - ) - else: - # For regular search, don't truncate the final document unless it's the only one - # If it's not the only one, we can throw it away, if it's the only one, we have to truncate - if final_doc_ind != 0: - docs = docs[:final_doc_ind] - else: - docs[0].content = tokenizer_trim_content( - content=docs[0].content, - desired_length=token_limit - _METADATA_TOKEN_ESTIMATE, - tokenizer=llm_tokenizer, - ) - docs = [docs[0]] - - return docs - - -def prune_documents( - docs: list[LlmDoc], - doc_relevance_list: list[bool] | None, - prompt_config: PromptConfig, - llm_config: LLMConfig, - question: str, - document_pruning_config: DocumentPruningConfig, -) -> list[LlmDoc]: - if doc_relevance_list is not None: - assert len(docs) == len(doc_relevance_list) - - doc_token_limit = _compute_limit( - prompt_config=prompt_config, - llm_config=llm_config, - question=question, - max_chunks=document_pruning_config.max_chunks, - max_window_percentage=document_pruning_config.max_window_percentage, - max_tokens=document_pruning_config.max_tokens, - tool_token_count=document_pruning_config.tool_num_tokens, - ) - return _apply_pruning( - docs=docs, - doc_relevance_list=doc_relevance_list, - token_limit=doc_token_limit, - is_manually_selected_docs=document_pruning_config.is_manually_selected_docs, - use_sections=document_pruning_config.use_sections, - using_tool_message=document_pruning_config.using_tool_message, - ) diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py index a5248fac27a..94ca91703ab 100644 --- a/backend/danswer/llm/answering/models.py +++ b/backend/danswer/llm/answering/models.py @@ -70,9 +70,11 @@ class DocumentPruningConfig(BaseModel): # e.g. we don't want to truncate each document to be no more # than one chunk long is_manually_selected_docs: bool = False - # If user specifies to include additional context chunks for each match, then different pruning + # If user specifies to include additional context Chunks for each match, then different pruning # is used. As many Sections as possible are included, and the last Section is truncated - use_sections: bool = False + # If this is false, all of the Sections are truncated if they are longer than the expected Chunk size. + # Sections are often expected to be longer than the maximum Chunk size but Chunks should not be. + use_sections: bool = True # If using tools, then we need to consider the tool length tool_num_tokens: int = 0 # If using a tool message to represent the docs, then we have to JSON serialize diff --git a/backend/danswer/llm/answering/prompts/build.py b/backend/danswer/llm/answering/prompts/build.py index 4907bde88ea..7656e729246 100644 --- a/backend/danswer/llm/answering/prompts/build.py +++ b/backend/danswer/llm/answering/prompts/build.py @@ -15,7 +15,7 @@ from danswer.llm.utils import get_default_llm_tokenizer from danswer.llm.utils import translate_history_to_basemessages from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT -from danswer.prompts.prompt_utils import add_time_to_system_prompt +from danswer.prompts.prompt_utils import add_date_time_to_prompt from danswer.prompts.prompt_utils import drop_messages_history_overflow from danswer.tools.message import ToolCallSummary @@ -25,7 +25,7 @@ def default_build_system_message( ) -> SystemMessage | None: system_prompt = prompt_config.system_prompt.strip() if prompt_config.datetime_aware: - system_prompt = add_time_to_system_prompt(system_prompt=system_prompt) + system_prompt = add_date_time_to_prompt(prompt_str=system_prompt) if not system_prompt: return None diff --git a/backend/danswer/llm/answering/prompts/citations_prompt.py b/backend/danswer/llm/answering/prompts/citations_prompt.py index 6c4ebfc40d2..69f727318d0 100644 --- a/backend/danswer/llm/answering/prompts/citations_prompt.py +++ b/backend/danswer/llm/answering/prompts/citations_prompt.py @@ -8,7 +8,8 @@ from danswer.db.persona import get_default_prompt__read_only from danswer.file_store.utils import InMemoryChatFile from danswer.llm.answering.models import PromptConfig -from danswer.llm.factory import get_llm_for_persona +from danswer.llm.factory import get_llms_for_persona +from danswer.llm.factory import get_main_llm_from_tuple from danswer.llm.interfaces import LLMConfig from danswer.llm.utils import build_content_with_imgs from danswer.llm.utils import check_number_of_tokens @@ -17,7 +18,7 @@ from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT_FOR_TOOL_CALLING -from danswer.prompts.prompt_utils import add_time_to_system_prompt +from danswer.prompts.prompt_utils import add_date_time_to_prompt from danswer.prompts.prompt_utils import build_complete_context_str from danswer.prompts.prompt_utils import build_task_prompt_reminders from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT @@ -99,7 +100,7 @@ def compute_max_document_tokens_for_persona( prompt = persona.prompts[0] if persona.prompts else get_default_prompt__read_only() return compute_max_document_tokens( prompt_config=PromptConfig.from_model(prompt), - llm_config=get_llm_for_persona(persona).config, + llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config, actual_user_input=actual_user_input, max_llm_token_override=max_llm_token_override, ) @@ -121,7 +122,7 @@ def build_citations_system_message( if prompt_config.include_citations: system_prompt += REQUIRE_CITATION_STATEMENT if prompt_config.datetime_aware: - system_prompt = add_time_to_system_prompt(system_prompt=system_prompt) + system_prompt = add_date_time_to_prompt(prompt_str=system_prompt) return SystemMessage(content=system_prompt) diff --git a/backend/danswer/llm/answering/prompts/quotes_prompt.py b/backend/danswer/llm/answering/prompts/quotes_prompt.py index c0a36b10ec9..b2b67c65b37 100644 --- a/backend/danswer/llm/answering/prompts/quotes_prompt.py +++ b/backend/danswer/llm/answering/prompts/quotes_prompt.py @@ -1,14 +1,15 @@ from langchain.schema.messages import HumanMessage from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import LANGUAGE_HINT from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE from danswer.llm.answering.models import PromptConfig from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK from danswer.prompts.direct_qa_prompts import JSON_PROMPT -from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT +from danswer.prompts.prompt_utils import add_date_time_to_prompt from danswer.prompts.prompt_utils import build_complete_context_str from danswer.search.models import InferenceChunk @@ -35,6 +36,10 @@ def _build_weak_llm_quotes_prompt( task_prompt=prompt.task_prompt, user_query=question, ) + + if prompt.datetime_aware: + prompt_str = add_date_time_to_prompt(prompt_str=prompt_str) + return HumanMessage(content=prompt_str) @@ -62,6 +67,10 @@ def _build_strong_llm_quotes_prompt( user_query=question, language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "", ).strip() + + if prompt.datetime_aware: + full_prompt = add_date_time_to_prompt(prompt_str=full_prompt) + return HumanMessage(content=full_prompt) diff --git a/backend/danswer/llm/answering/prune_and_merge.py b/backend/danswer/llm/answering/prune_and_merge.py new file mode 100644 index 00000000000..3fee5266d8d --- /dev/null +++ b/backend/danswer/llm/answering/prune_and_merge.py @@ -0,0 +1,359 @@ +import json +from collections import defaultdict +from copy import deepcopy +from typing import TypeVar + +from pydantic import BaseModel + +from danswer.chat.models import ( + LlmDoc, +) +from danswer.configs.constants import IGNORE_FOR_QA +from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE +from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.models import PromptConfig +from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens +from danswer.llm.interfaces import LLMConfig +from danswer.llm.utils import get_default_llm_tokenizer +from danswer.llm.utils import tokenizer_trim_content +from danswer.prompts.prompt_utils import build_doc_context_str +from danswer.search.models import InferenceChunk +from danswer.search.models import InferenceSection +from danswer.tools.search.search_utils import section_to_dict +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + +T = TypeVar("T", bound=LlmDoc | InferenceChunk | InferenceSection) + +_METADATA_TOKEN_ESTIMATE = 75 + + +class PruningError(Exception): + pass + + +class ChunkRange(BaseModel): + chunks: list[InferenceChunk] + start: int + end: int + + +def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]: + """ + This acts on a single document to merge the overlapping ranges of chunks + Algo explained here for easy understanding: https://leetcode.com/problems/merge-intervals + + NOTE: this is used to merge chunk ranges for retrieving the right chunk_ids against the + document index, this does not merge the actual contents so it should not be used to actually + merge chunks post retrieval. + """ + sorted_ranges = sorted(chunk_ranges, key=lambda x: x.start) + + combined_ranges: list[ChunkRange] = [] + + for new_chunk_range in sorted_ranges: + if not combined_ranges or combined_ranges[-1].end < new_chunk_range.start - 1: + combined_ranges.append(new_chunk_range) + else: + current_range = combined_ranges[-1] + current_range.end = max(current_range.end, new_chunk_range.end) + current_range.chunks.extend(new_chunk_range.chunks) + + return combined_ranges + + +def _compute_limit( + prompt_config: PromptConfig, + llm_config: LLMConfig, + question: str, + max_chunks: int | None, + max_window_percentage: float | None, + max_tokens: int | None, + tool_token_count: int, +) -> int: + llm_max_document_tokens = compute_max_document_tokens( + prompt_config=prompt_config, + llm_config=llm_config, + tool_token_count=tool_token_count, + actual_user_input=question, + ) + + window_percentage_based_limit = ( + max_window_percentage * llm_max_document_tokens + if max_window_percentage + else None + ) + chunk_count_based_limit = ( + max_chunks * DOC_EMBEDDING_CONTEXT_SIZE if max_chunks else None + ) + + limit_options = [ + lim + for lim in [ + window_percentage_based_limit, + chunk_count_based_limit, + max_tokens, + llm_max_document_tokens, + ] + if lim + ] + return int(min(limit_options)) + + +def reorder_sections( + sections: list[InferenceSection], + section_relevance_list: list[bool] | None, +) -> list[InferenceSection]: + if section_relevance_list is None: + return sections + + reordered_sections: list[InferenceSection] = [] + if section_relevance_list is not None: + for selection_target in [True, False]: + for section, is_relevant in zip(sections, section_relevance_list): + if is_relevant == selection_target: + reordered_sections.append(section) + return reordered_sections + + +def _remove_sections_to_ignore( + sections: list[InferenceSection], +) -> list[InferenceSection]: + return [ + section + for section in sections + if not section.center_chunk.metadata.get(IGNORE_FOR_QA) + ] + + +def _apply_pruning( + sections: list[InferenceSection], + section_relevance_list: list[bool] | None, + token_limit: int, + is_manually_selected_docs: bool, + use_sections: bool, + using_tool_message: bool, +) -> list[InferenceSection]: + llm_tokenizer = get_default_llm_tokenizer() + sections = deepcopy(sections) # don't modify in place + + # re-order docs with all the "relevant" docs at the front + sections = reorder_sections( + sections=sections, section_relevance_list=section_relevance_list + ) + # remove docs that are explicitly marked as not for QA + sections = _remove_sections_to_ignore(sections=sections) + + final_section_ind = None + total_tokens = 0 + for ind, section in enumerate(sections): + section_str = ( + # If using tool message, it will be a bit of an overestimate as the extra json text around the section + # will be counted towards the token count. However, once the Sections are merged, the extra json parts + # that overlap will not be counted multiple times like it is in the pruning step. + json.dumps(section_to_dict(section, ind)) + if using_tool_message + else build_doc_context_str( + semantic_identifier=section.center_chunk.semantic_identifier, + source_type=section.center_chunk.source_type, + content=section.combined_content, + metadata_dict=section.center_chunk.metadata, + updated_at=section.center_chunk.updated_at, + ind=ind, + ) + ) + + section_tokens = len(llm_tokenizer.encode(section_str)) + # if not using sections (specifically, using Sections where each section maps exactly to the one center chunk), + # truncate chunks that are way too long. This can happen if the embedding model tokenizer is different + # than the LLM tokenizer + if ( + not is_manually_selected_docs + and not use_sections + and section_tokens > DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE + ): + logger.warning( + "Found more tokens in Section than expected, " + "likely mismatch between embedding and LLM tokenizers. Trimming content..." + ) + section.combined_content = tokenizer_trim_content( + content=section.combined_content, + desired_length=DOC_EMBEDDING_CONTEXT_SIZE, + tokenizer=llm_tokenizer, + ) + section_tokens = DOC_EMBEDDING_CONTEXT_SIZE + + total_tokens += section_tokens + if total_tokens > token_limit: + final_section_ind = ind + break + + if final_section_ind is not None: + if is_manually_selected_docs or use_sections: + if final_section_ind != len(sections) - 1: + # If using Sections, then the final section could be more than we need, in this case we are willing to + # truncate the final section to fit the specified context window + sections = sections[: final_section_ind + 1] + + if is_manually_selected_docs: + # For document selection flow, only allow the final document/section to get truncated + # if more than that needs to be throw away then some documents are completely thrown away in which + # case this should be reported to the user as an error + raise PruningError( + "LLM context window exceeded. Please de-select some documents or shorten your query." + ) + + amount_to_truncate = total_tokens - token_limit + # NOTE: need to recalculate the length here, since the previous calculation included + # overhead from JSON-fying the doc / the metadata + final_doc_content_length = len( + llm_tokenizer.encode(sections[final_section_ind].combined_content) + ) - (amount_to_truncate) + # this could occur if we only have space for the title / metadata + # not ideal, but it's the most reasonable thing to do + # NOTE: the frontend prevents documents from being selected if + # less than 75 tokens are available to try and avoid this situation + # from occurring in the first place + if final_doc_content_length <= 0: + logger.error( + f"Final section ({sections[final_section_ind].center_chunk.semantic_identifier}) content " + "length is less than 0. Removing this section from the final prompt." + ) + sections.pop() + else: + sections[final_section_ind].combined_content = tokenizer_trim_content( + content=sections[final_section_ind].combined_content, + desired_length=final_doc_content_length, + tokenizer=llm_tokenizer, + ) + else: + # For search on chunk level (Section is just a chunk), don't truncate the final Chunk/Section unless it's the only one + # If it's not the only one, we can throw it away, if it's the only one, we have to truncate + if final_section_ind != 0: + sections = sections[:final_section_ind] + else: + sections[0].combined_content = tokenizer_trim_content( + content=sections[0].combined_content, + desired_length=token_limit - _METADATA_TOKEN_ESTIMATE, + tokenizer=llm_tokenizer, + ) + sections = [sections[0]] + + return sections + + +def prune_sections( + sections: list[InferenceSection], + section_relevance_list: list[bool] | None, + prompt_config: PromptConfig, + llm_config: LLMConfig, + question: str, + document_pruning_config: DocumentPruningConfig, +) -> list[InferenceSection]: + # Assumes the sections are score ordered with highest first + if section_relevance_list is not None: + assert len(sections) == len(section_relevance_list) + + token_limit = _compute_limit( + prompt_config=prompt_config, + llm_config=llm_config, + question=question, + max_chunks=document_pruning_config.max_chunks, + max_window_percentage=document_pruning_config.max_window_percentage, + max_tokens=document_pruning_config.max_tokens, + tool_token_count=document_pruning_config.tool_num_tokens, + ) + return _apply_pruning( + sections=sections, + section_relevance_list=section_relevance_list, + token_limit=token_limit, + is_manually_selected_docs=document_pruning_config.is_manually_selected_docs, + use_sections=document_pruning_config.use_sections, # Now default True + using_tool_message=document_pruning_config.using_tool_message, + ) + + +def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection: + # Assuming there are no duplicates by this point + sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id) + + center_chunk = max( + chunks, key=lambda x: x.score if x.score is not None else float("-inf") + ) + + merged_content = [] + for i, chunk in enumerate(sorted_chunks): + if i > 0: + prev_chunk_id = sorted_chunks[i - 1].chunk_id + if chunk.chunk_id == prev_chunk_id + 1: + merged_content.append("\n") + else: + merged_content.append("\n\n...\n\n") + merged_content.append(chunk.content) + + combined_content = "".join(merged_content) + + return InferenceSection( + center_chunk=center_chunk, + chunks=sorted_chunks, + combined_content=combined_content, + ) + + +def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]: + docs_map: dict[str, dict[int, InferenceChunk]] = defaultdict(dict) + doc_order: dict[str, int] = {} + for index, section in enumerate(sections): + if section.center_chunk.document_id not in doc_order: + doc_order[section.center_chunk.document_id] = index + for chunk in [section.center_chunk] + section.chunks: + chunks_map = docs_map[section.center_chunk.document_id] + existing_chunk = chunks_map.get(chunk.chunk_id) + if ( + existing_chunk is None + or existing_chunk.score is None + or chunk.score is not None + and chunk.score > existing_chunk.score + ): + chunks_map[chunk.chunk_id] = chunk + + new_sections = [] + for section_chunks in docs_map.values(): + new_sections.append(_merge_doc_chunks(chunks=list(section_chunks.values()))) + + # Sort by highest score, then by original document order + # It is now 1 large section per doc, the center chunk being the one with the highest score + new_sections.sort( + key=lambda x: ( + x.center_chunk.score if x.center_chunk.score is not None else 0, + -1 * doc_order[x.center_chunk.document_id], + ), + reverse=True, + ) + + return new_sections + + +def prune_and_merge_sections( + sections: list[InferenceSection], + section_relevance_list: list[bool] | None, + prompt_config: PromptConfig, + llm_config: LLMConfig, + question: str, + document_pruning_config: DocumentPruningConfig, +) -> list[InferenceSection]: + # Assumes the sections are score ordered with highest first + remaining_sections = prune_sections( + sections=sections, + section_relevance_list=section_relevance_list, + prompt_config=prompt_config, + llm_config=llm_config, + question=question, + document_pruning_config=document_pruning_config, + ) + + merged_sections = _merge_sections(sections=remaining_sections) + + return merged_sections diff --git a/backend/danswer/llm/answering/stream_processing/citation_processing.py b/backend/danswer/llm/answering/stream_processing/citation_processing.py index fa774660cde..d1d1eb0783a 100644 --- a/backend/danswer/llm/answering/stream_processing/citation_processing.py +++ b/backend/danswer/llm/answering/stream_processing/citation_processing.py @@ -7,7 +7,7 @@ from danswer.chat.models import LlmDoc from danswer.configs.chat_configs import STOP_STREAM_PAT from danswer.llm.answering.models import StreamProcessor -from danswer.llm.answering.stream_processing.utils import map_document_id_order +from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping from danswer.prompts.constants import TRIPLE_BACKTICK from danswer.utils.logger import setup_logger @@ -23,104 +23,167 @@ def in_code_block(llm_text: str) -> bool: def extract_citations_from_stream( tokens: Iterator[str], context_docs: list[LlmDoc], - doc_id_to_rank_map: dict[str, int], + doc_id_to_rank_map: DocumentIdOrderMapping, stop_stream: str | None = STOP_STREAM_PAT, ) -> Iterator[DanswerAnswerPiece | CitationInfo]: + """ + Key aspects: + + 1. Stream Processing: + - Processes tokens one by one, allowing for real-time handling of large texts. + + 2. Citation Detection: + - Uses regex to find citations in the format [number]. + - Example: [1], [2], etc. + + 3. Citation Mapping: + - Maps detected citation numbers to actual document ranks using doc_id_to_rank_map. + - Example: [1] might become [3] if doc_id_to_rank_map maps it to 3. + + 4. Citation Formatting: + - Replaces citations with properly formatted versions. + - Adds links if available: [[1]](https://example.com) + - Handles cases where links are not available: [[1]]() + + 5. Duplicate Handling: + - Skips consecutive citations of the same document to avoid redundancy. + + 6. Output Generation: + - Yields DanswerAnswerPiece objects for regular text. + - Yields CitationInfo objects for each unique citation encountered. + + 7. Context Awareness: + - Uses context_docs to access document information for citations. + + This function effectively processes a stream of text, identifies and reformats citations, + and provides both the processed text and citation information as output. + """ + order_mapping = doc_id_to_rank_map.order_mapping llm_out = "" max_citation_num = len(context_docs) + citation_order = [] curr_segment = "" - prepend_bracket = False cited_inds = set() hold = "" + + raw_out = "" + current_citations: list[int] = [] + past_cite_count = 0 for raw_token in tokens: + raw_out += raw_token if stop_stream: next_hold = hold + raw_token - if stop_stream in next_hold: break - if next_hold == stop_stream[: len(next_hold)]: hold = next_hold continue - token = next_hold hold = "" else: token = raw_token - # Special case of [1][ where ][ is a single token - # This is where the model attempts to do consecutive citations like [1][2] - if prepend_bracket: - curr_segment += "[" + curr_segment - prepend_bracket = False - curr_segment += token llm_out += token + citation_pattern = r"\[(\d+)\]" + + citations_found = list(re.finditer(citation_pattern, curr_segment)) possible_citation_pattern = r"(\[\d*$)" # [1, [, etc possible_citation_found = re.search(possible_citation_pattern, curr_segment) - citation_pattern = r"\[(\d+)\]" # [1], [2] etc - citation_found = re.search(citation_pattern, curr_segment) - - if citation_found and not in_code_block(llm_out): - numerical_value = int(citation_found.group(1)) - if 1 <= numerical_value <= max_citation_num: - context_llm_doc = context_docs[ - numerical_value - 1 - ] # remove 1 index offset - - link = context_llm_doc.link - target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id] - - # Use the citation number for the document's rank in - # the search (or selected docs) results - curr_segment = re.sub( - rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment - ) - - if target_citation_num not in cited_inds: - cited_inds.add(target_citation_num) - yield CitationInfo( - citation_num=target_citation_num, - document_id=context_llm_doc.document_id, + # `past_cite_count`: number of characters since past citation + # 5 to ensure a citation hasn't occured + if len(citations_found) == 0 and len(llm_out) - past_cite_count > 5: + current_citations = [] + + if citations_found and not in_code_block(llm_out): + last_citation_end = 0 + length_to_add = 0 + while len(citations_found) > 0: + citation = citations_found.pop(0) + numerical_value = int(citation.group(1)) + + if 1 <= numerical_value <= max_citation_num: + context_llm_doc = context_docs[numerical_value - 1] + real_citation_num = order_mapping[context_llm_doc.document_id] + + if real_citation_num not in citation_order: + citation_order.append(real_citation_num) + + target_citation_num = citation_order.index(real_citation_num) + 1 + + # Skip consecutive citations of the same work + if target_citation_num in current_citations: + start, end = citation.span() + real_start = length_to_add + start + diff = end - start + curr_segment = ( + curr_segment[: length_to_add + start] + + curr_segment[real_start + diff :] + ) + length_to_add -= diff + continue + + link = context_llm_doc.link + + # Replace the citation in the current segment + start, end = citation.span() + curr_segment = ( + curr_segment[: start + length_to_add] + + f"[{target_citation_num}]" + + curr_segment[end + length_to_add :] ) - if link: - curr_segment = re.sub(r"\[", "[[", curr_segment, count=1) - curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1) - - # In case there's another open bracket like [1][, don't want to match this - possible_citation_found = None - - # if we see "[", but haven't seen the right side, hold back - this may be a - # citation that needs to be replaced with a link + past_cite_count = len(llm_out) + current_citations.append(target_citation_num) + + if target_citation_num not in cited_inds: + cited_inds.add(target_citation_num) + yield CitationInfo( + citation_num=target_citation_num, + document_id=context_llm_doc.document_id, + ) + + if link: + prev_length = len(curr_segment) + curr_segment = ( + curr_segment[: start + length_to_add] + + f"[[{target_citation_num}]]({link})" + + curr_segment[end + length_to_add :] + ) + length_to_add += len(curr_segment) - prev_length + + else: + prev_length = len(curr_segment) + curr_segment = ( + curr_segment[: start + length_to_add] + + f"[[{target_citation_num}]]()" + + curr_segment[end + length_to_add :] + ) + length_to_add += len(curr_segment) - prev_length + last_citation_end = end + length_to_add + + if last_citation_end > 0: + yield DanswerAnswerPiece(answer_piece=curr_segment[:last_citation_end]) + curr_segment = curr_segment[last_citation_end:] if possible_citation_found: continue - - # Special case with back to back citations [1][2] - if curr_segment and curr_segment[-1] == "[": - curr_segment = curr_segment[:-1] - prepend_bracket = True - yield DanswerAnswerPiece(answer_piece=curr_segment) curr_segment = "" if curr_segment: - if prepend_bracket: - yield DanswerAnswerPiece(answer_piece="[" + curr_segment) - else: - yield DanswerAnswerPiece(answer_piece=curr_segment) + yield DanswerAnswerPiece(answer_piece=curr_segment) def build_citation_processor( - context_docs: list[LlmDoc], search_order_docs: list[LlmDoc] + context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping ) -> StreamProcessor: def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn: yield from extract_citations_from_stream( tokens=tokens, context_docs=context_docs, - doc_id_to_rank_map=map_document_id_order(search_order_docs), + doc_id_to_rank_map=doc_id_to_rank_map, ) return stream_processor diff --git a/backend/danswer/llm/answering/stream_processing/utils.py b/backend/danswer/llm/answering/stream_processing/utils.py index 9f21e6a34e2..b4fb83747de 100644 --- a/backend/danswer/llm/answering/stream_processing/utils.py +++ b/backend/danswer/llm/answering/stream_processing/utils.py @@ -1,12 +1,18 @@ from collections.abc import Sequence +from pydantic import BaseModel + from danswer.chat.models import LlmDoc from danswer.search.models import InferenceChunk +class DocumentIdOrderMapping(BaseModel): + order_mapping: dict[str, int] + + def map_document_id_order( chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True -) -> dict[str, int]: +) -> DocumentIdOrderMapping: order_mapping = {} current = 1 if one_indexed else 0 for chunk in chunks: @@ -14,4 +20,4 @@ def map_document_id_order( order_mapping[chunk.document_id] = current current += 1 - return order_mapping + return DocumentIdOrderMapping(order_mapping=order_mapping) diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index d450efa9107..90ad481d453 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -5,6 +5,7 @@ from typing import cast import litellm # type: ignore +from httpx import RemoteProtocolError from langchain.schema.language_model import LanguageModelInput from langchain_core.messages import AIMessage from langchain_core.messages import AIMessageChunk @@ -22,6 +23,7 @@ from langchain_core.messages.tool import ToolMessage from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS +from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING from danswer.configs.model_configs import GEN_AI_API_ENDPOINT from danswer.configs.model_configs import GEN_AI_API_VERSION @@ -41,6 +43,8 @@ litellm.drop_params = True litellm.telemetry = False +litellm.set_verbose = LOG_ALL_MODEL_INTERACTIONS + def _base_msg_to_role(msg: BaseMessage) -> str: if isinstance(msg, HumanMessage) or isinstance(msg, HumanMessageChunk): @@ -303,6 +307,8 @@ def config(self) -> LLMConfig: model_name=self._model_version, temperature=self._temperature, api_key=self._api_key, + api_base=self._api_base, + api_version=self._api_version, ) def invoke( @@ -311,7 +317,7 @@ def invoke( tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, ) -> BaseMessage: - if LOG_ALL_MODEL_INTERACTIONS: + if LOG_DANSWER_MODEL_INTERACTIONS: self.log_model_configs() self._log_prompt(prompt) @@ -328,7 +334,7 @@ def stream( tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, ) -> Iterator[BaseMessage]: - if LOG_ALL_MODEL_INTERACTIONS: + if LOG_DANSWER_MODEL_INTERACTIONS: self.log_model_configs() self._log_prompt(prompt) @@ -338,19 +344,25 @@ def stream( output = None response = self._completion(prompt, tools, tool_choice, True) - for part in response: - if len(part["choices"]) == 0: - continue - delta = part["choices"][0]["delta"] - message_chunk = _convert_delta_to_message_chunk(delta, output) - if output is None: - output = message_chunk - else: - output += message_chunk + try: + for part in response: + if len(part["choices"]) == 0: + continue + delta = part["choices"][0]["delta"] + message_chunk = _convert_delta_to_message_chunk(delta, output) + if output is None: + output = message_chunk + else: + output += message_chunk + + yield message_chunk - yield message_chunk + except RemoteProtocolError: + raise RuntimeError( + "The AI model failed partway through generation, please try again." + ) - if LOG_ALL_MODEL_INTERACTIONS and output: + if LOG_DANSWER_MODEL_INTERACTIONS and output: content = output.content or "" if isinstance(output, AIMessage): if content: diff --git a/backend/danswer/llm/factory.py b/backend/danswer/llm/factory.py index d85dbdc9f61..f57bfb524b9 100644 --- a/backend/danswer/llm/factory.py +++ b/backend/danswer/llm/factory.py @@ -1,76 +1,103 @@ from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.chat_configs import QA_TIMEOUT from danswer.configs.model_configs import GEN_AI_TEMPERATURE -from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS from danswer.db.engine import get_session_context_manager from danswer.db.llm import fetch_default_provider from danswer.db.llm import fetch_provider from danswer.db.models import Persona from danswer.llm.chat_llm import DefaultMultiLLM from danswer.llm.exceptions import GenAIDisabledException +from danswer.llm.headers import build_llm_extra_headers from danswer.llm.interfaces import LLM from danswer.llm.override_models import LLMOverride -def get_llm_for_persona( +def get_main_llm_from_tuple( + llms: tuple[LLM, LLM], +) -> LLM: + return llms[0] + + +def get_llms_for_persona( persona: Persona, llm_override: LLMOverride | None = None, additional_headers: dict[str, str] | None = None, -) -> LLM: +) -> tuple[LLM, LLM]: model_provider_override = llm_override.model_provider if llm_override else None model_version_override = llm_override.model_version if llm_override else None temperature_override = llm_override.temperature if llm_override else None - return get_default_llm( - model_provider_name=( - model_provider_override or persona.llm_model_provider_override - ), - model_version=(model_version_override or persona.llm_model_version_override), - temperature=temperature_override or GEN_AI_TEMPERATURE, - additional_headers=additional_headers, - ) + provider_name = model_provider_override or persona.llm_model_provider_override + if not provider_name: + return get_default_llms( + temperature=temperature_override or GEN_AI_TEMPERATURE, + additional_headers=additional_headers, + ) + + with get_session_context_manager() as db_session: + llm_provider = fetch_provider(db_session, provider_name) + + if not llm_provider: + raise ValueError("No LLM provider found") + + model = model_version_override or persona.llm_model_version_override + fast_model = llm_provider.fast_default_model_name or llm_provider.default_model_name + if not model: + raise ValueError("No model name found") + if not fast_model: + raise ValueError("No fast model name found") + def _create_llm(model: str) -> LLM: + return get_llm( + provider=llm_provider.provider, + model=model, + api_key=llm_provider.api_key, + api_base=llm_provider.api_base, + api_version=llm_provider.api_version, + custom_config=llm_provider.custom_config, + additional_headers=additional_headers, + ) -def get_default_llm( + return _create_llm(model), _create_llm(fast_model) + + +def get_default_llms( timeout: int = QA_TIMEOUT, temperature: float = GEN_AI_TEMPERATURE, - use_fast_llm: bool = False, - model_provider_name: str | None = None, - model_version: str | None = None, additional_headers: dict[str, str] | None = None, -) -> LLM: +) -> tuple[LLM, LLM]: if DISABLE_GENERATIVE_AI: raise GenAIDisabledException() - # TODO: pass this in - with get_session_context_manager() as session: - if model_provider_name is None: - llm_provider = fetch_default_provider(session) - else: - llm_provider = fetch_provider(session, model_provider_name) + with get_session_context_manager() as db_session: + llm_provider = fetch_default_provider(db_session) if not llm_provider: raise ValueError("No default LLM provider found") - model_name = model_version or ( - (llm_provider.fast_default_model_name or llm_provider.default_model_name) - if use_fast_llm - else llm_provider.default_model_name + model_name = llm_provider.default_model_name + fast_model_name = ( + llm_provider.fast_default_model_name or llm_provider.default_model_name ) if not model_name: raise ValueError("No default model name found") + if not fast_model_name: + raise ValueError("No fast default model name found") - return get_llm( - provider=llm_provider.provider, - model=model_name, - api_key=llm_provider.api_key, - api_base=llm_provider.api_base, - api_version=llm_provider.api_version, - custom_config=llm_provider.custom_config, - timeout=timeout, - temperature=temperature, - additional_headers=additional_headers, - ) + def _create_llm(model: str) -> LLM: + return get_llm( + provider=llm_provider.provider, + model=model, + api_key=llm_provider.api_key, + api_base=llm_provider.api_base, + api_version=llm_provider.api_version, + custom_config=llm_provider.custom_config, + timeout=timeout, + temperature=temperature, + additional_headers=additional_headers, + ) + + return _create_llm(model_name), _create_llm(fast_model_name) def get_llm( @@ -84,12 +111,6 @@ def get_llm( timeout: int = QA_TIMEOUT, additional_headers: dict[str, str] | None = None, ) -> LLM: - extra_headers = {} - if additional_headers: - extra_headers.update(additional_headers) - if LITELLM_EXTRA_HEADERS: - extra_headers.update(LITELLM_EXTRA_HEADERS) - return DefaultMultiLLM( model_provider=provider, model_name=model, @@ -99,5 +120,5 @@ def get_llm( timeout=timeout, temperature=temperature, custom_config=custom_config, - extra_headers=extra_headers, + extra_headers=build_llm_extra_headers(additional_headers), ) diff --git a/backend/danswer/llm/headers.py b/backend/danswer/llm/headers.py index f7ae7436fb4..b43c83e141e 100644 --- a/backend/danswer/llm/headers.py +++ b/backend/danswer/llm/headers.py @@ -1,5 +1,6 @@ from fastapi.datastructures import Headers +from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS @@ -20,3 +21,14 @@ def get_litellm_additional_request_headers( pass_through_headers[lowercase_key] = headers[lowercase_key] return pass_through_headers + + +def build_llm_extra_headers( + additional_headers: dict[str, str] | None = None +) -> dict[str, str]: + extra_headers: dict[str, str] = {} + if additional_headers: + extra_headers.update(additional_headers) + if LITELLM_EXTRA_HEADERS: + extra_headers.update(LITELLM_EXTRA_HEADERS) + return extra_headers diff --git a/backend/danswer/llm/interfaces.py b/backend/danswer/llm/interfaces.py index 1f99383fae8..e876403c421 100644 --- a/backend/danswer/llm/interfaces.py +++ b/backend/danswer/llm/interfaces.py @@ -19,6 +19,8 @@ class LLMConfig(BaseModel): model_name: str temperature: float api_key: str | None + api_base: str | None + api_version: str | None class LLM(abc.ABC): diff --git a/backend/danswer/main.py b/backend/danswer/main.py index c363961d642..88064a2e8df 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -46,13 +46,13 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model from danswer.db.index_attempt import expire_index_attempts from danswer.db.persona import delete_old_default_personas +from danswer.db.standard_answer import create_initial_default_standard_answer_category from danswer.db.swap_index import check_index_swap from danswer.document_index.factory import get_default_document_index from danswer.llm.llm_initialization import load_llm_providers from danswer.search.retrieval.search_runner import download_nltk_data from danswer.search.search_nlp_models import warm_up_encoders from danswer.server.auth_check import check_router_auth -from danswer.server.danswer_api.ingestion import get_danswer_api_key from danswer.server.danswer_api.ingestion import router as danswer_api_router from danswer.server.documents.cc_pair import router as cc_pair_router from danswer.server.documents.connector import router as connector_router @@ -72,6 +72,7 @@ from danswer.server.manage.llm.api import basic_router as llm_router from danswer.server.manage.secondary_index import router as secondary_index_router from danswer.server.manage.slack_bot import router as slack_bot_management_router +from danswer.server.manage.standard_answer import router as standard_answer_router from danswer.server.manage.users import router as user_router from danswer.server.middleware.latency_logging import add_latency_logging_middleware from danswer.server.query_and_chat.chat_backend import router as chat_router @@ -81,6 +82,9 @@ from danswer.server.query_and_chat.query_backend import basic_router as query_router from danswer.server.settings.api import admin_router as settings_admin_router from danswer.server.settings.api import basic_router as settings_router +from danswer.server.token_rate_limits.api import ( + router as token_rate_limit_settings_router, +) from danswer.tools.built_in_tools import auto_add_search_tool_to_personas from danswer.tools.built_in_tools import load_builtin_tools from danswer.tools.built_in_tools import refresh_built_in_tools_cache @@ -88,6 +92,8 @@ from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType from danswer.utils.variable_functionality import fetch_versioned_implementation +from danswer.utils.variable_functionality import global_version +from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW from shared_configs.configs import MODEL_SERVER_HOST from shared_configs.configs import MODEL_SERVER_PORT @@ -154,10 +160,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: # Will throw exception if an issue is found verify_auth() - # Danswer APIs key - api_key = get_danswer_api_key() - logger.info(f"Danswer API Key: {api_key}") - if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET: logger.info("Both OAuth Client ID and Secret are configured.") @@ -207,6 +209,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: create_initial_default_connector(db_session) associate_default_cc_pair(db_session) + logger.info("Verifying default standard answer category exists.") + create_initial_default_standard_answer_category(db_session) + logger.info("Loading LLM providers from env variables") load_llm_providers(db_session) @@ -254,7 +259,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: ) optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__}) - yield @@ -278,6 +282,7 @@ def get_application() -> FastAPI: include_router_with_global_prefix_prepended( application, slack_bot_management_router ) + include_router_with_global_prefix_prepended(application, standard_answer_router) include_router_with_global_prefix_prepended(application, persona_router) include_router_with_global_prefix_prepended(application, admin_persona_router) include_router_with_global_prefix_prepended(application, prompt_router) @@ -290,6 +295,9 @@ def get_application() -> FastAPI: include_router_with_global_prefix_prepended(application, settings_admin_router) include_router_with_global_prefix_prepended(application, llm_admin_router) include_router_with_global_prefix_prepended(application, llm_router) + include_router_with_global_prefix_prepended( + application, token_rate_limit_settings_router + ) if AUTH_TYPE == AuthType.DISABLED: # Server logs this during auth setup verification step @@ -373,11 +381,18 @@ def get_application() -> FastAPI: return application -app = get_application() +# NOTE: needs to be outside of the `if __name__ == "__main__"` block so that the +# app is exportable +set_is_ee_based_on_env_variable() +app = fetch_versioned_implementation(module="danswer.main", attribute="get_application") if __name__ == "__main__": logger.info( f"Starting Danswer Backend version {__version__} on http://{APP_HOST}:{str(APP_PORT)}/" ) + + if global_version.get_is_ee_version(): + logger.info("Running Enterprise Edition") + uvicorn.run(app, host=APP_HOST, port=APP_PORT) diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index c131c0c05ca..0b66aae166b 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -30,7 +30,8 @@ from danswer.llm.answering.models import DocumentPruningConfig from danswer.llm.answering.models import PromptConfig from danswer.llm.answering.models import QuotesConfig -from danswer.llm.factory import get_llm_for_persona +from danswer.llm.factory import get_llms_for_persona +from danswer.llm.factory import get_main_llm_from_tuple from danswer.llm.utils import get_default_llm_token_encode from danswer.one_shot_answer.models import DirectQARequest from danswer.one_shot_answer.models import OneShotQAResponse @@ -46,6 +47,7 @@ from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.utils import get_json_line from danswer.tools.force import ForceUseTool +from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID from danswer.tools.search.search_tool import SearchResponseSummary from danswer.tools.search.search_tool import SearchTool @@ -156,7 +158,7 @@ def stream_answer_objects( commit=True, ) - llm = get_llm_for_persona(persona=chat_session.persona) + llm, fast_llm = get_llms_for_persona(persona=chat_session.persona) prompt_config = PromptConfig.from_model(prompt) document_pruning_config = DocumentPruningConfig( max_chunks=int( @@ -174,7 +176,11 @@ def stream_answer_objects( retrieval_options=query_req.retrieval_options, prompt_config=prompt_config, llm=llm, + fast_llm=fast_llm, pruning_config=document_pruning_config, + chunks_above=query_req.chunks_above, + chunks_below=query_req.chunks_below, + full_doc=query_req.full_doc, bypass_acl=bypass_acl, ) @@ -187,16 +193,17 @@ def stream_answer_objects( question=query_msg.message, answer_style_config=answer_config, prompt_config=PromptConfig.from_model(prompt), - llm=get_llm_for_persona(persona=chat_session.persona), + llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona)), single_message_history=history_str, tools=[search_tool], force_use_tool=ForceUseTool( - tool_name=search_tool.name(), + tool_name=search_tool.name, args={"query": rephrased_query}, ), # for now, don't use tool calling for this flow, as we haven't # tested quotes with tool calling too much yet skip_explicit_tool_calling=True, + return_contexts=query_req.return_contexts, ) # won't be any ImageGenerationDisplay responses since that tool is never passed in dropped_inds: list[int] = [] @@ -246,6 +253,8 @@ def stream_answer_objects( ) yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response) + elif packet.id == SEARCH_DOC_CONTENT_ID: + yield packet.response else: yield packet diff --git a/backend/danswer/prompts/chat_prompts.py b/backend/danswer/prompts/chat_prompts.py index e0b20243bc5..a5fa973f37c 100644 --- a/backend/danswer/prompts/chat_prompts.py +++ b/backend/danswer/prompts/chat_prompts.py @@ -144,6 +144,23 @@ Standalone question (Respond with only the short combined query): """.strip() +INTERNET_SEARCH_QUERY_REPHRASE = f""" +Given the following conversation and a follow up input, rephrase the follow up into a SHORT, \ +standalone query suitable for an internet search engine. +IMPORTANT: If a specific query might limit results, keep it broad. \ +If a broad query might yield too many results, make it detailed. +If there is a clear change in topic, ensure the query reflects the new topic accurately. +Strip out any information that is not relevant for the internet search. + +{GENERAL_SEP_PAT} +Chat History: +{{chat_history}} +{GENERAL_SEP_PAT} + +Follow Up Input: {{question}} +Internet Search Query (Respond with a detailed and specific query): +""".strip() + # The below prompts are retired NO_SEARCH = "No Search" @@ -188,7 +205,7 @@ CHAT_NAMING = f""" -Given the following conversation, provide a SHORT name for the conversation. +Given the following conversation, provide a SHORT name for the conversation.{{language_hint_or_empty}} IMPORTANT: TRY NOT TO USE MORE THAN 5 WORDS, MAKE IT AS CONCISE AS POSSIBLE. Focus the name on the important keywords to convey the topic of the conversation. diff --git a/backend/danswer/prompts/direct_qa_prompts.py b/backend/danswer/prompts/direct_qa_prompts.py index ee1b492be8a..64a704fa693 100644 --- a/backend/danswer/prompts/direct_qa_prompts.py +++ b/backend/danswer/prompts/direct_qa_prompts.py @@ -41,12 +41,6 @@ Quotes MUST be EXACT substrings from provided documents! """.strip() - -LANGUAGE_HINT = """ -IMPORTANT: Respond in the same language as my query! -""" - - CONTEXT_BLOCK = f""" REFERENCE DOCUMENTS: {GENERAL_SEP_PAT} diff --git a/backend/danswer/prompts/llm_chunk_filter.py b/backend/danswer/prompts/llm_chunk_filter.py index 623ae587703..fe0b8d398db 100644 --- a/backend/danswer/prompts/llm_chunk_filter.py +++ b/backend/danswer/prompts/llm_chunk_filter.py @@ -4,7 +4,7 @@ USEFUL_PAT = "Yes useful" NONUSEFUL_PAT = "Not useful" -CHUNK_FILTER_PROMPT = f""" +SECTION_FILTER_PROMPT = f""" Determine if the reference section is USEFUL for answering the user query. It is NOT enough for the section to be related to the query, \ it must contain information that is USEFUL for answering the query. @@ -27,4 +27,4 @@ # Use the following for easy viewing of prompts if __name__ == "__main__": - print(CHUNK_FILTER_PROMPT) + print(SECTION_FILTER_PROMPT) diff --git a/backend/danswer/prompts/miscellaneous_prompts.py b/backend/danswer/prompts/miscellaneous_prompts.py index 876f0f3591a..81ae5164329 100644 --- a/backend/danswer/prompts/miscellaneous_prompts.py +++ b/backend/danswer/prompts/miscellaneous_prompts.py @@ -2,9 +2,7 @@ LANGUAGE_REPHRASE_PROMPT = """ Translate query to {target_language}. -If the query at the end is already in {target_language}, \ -simply repeat the ORIGINAL query back to me, EXACTLY as is with no edits. - +If the query at the end is already in {target_language}, simply repeat the ORIGINAL query back to me, EXACTLY as is with no edits. If the query below is not in {target_language}, translate it into {target_language}. Query: diff --git a/backend/danswer/prompts/prompt_utils.py b/backend/danswer/prompts/prompt_utils.py index 9dc939eb170..6d7bddeec95 100644 --- a/backend/danswer/prompts/prompt_utils.py +++ b/backend/danswer/prompts/prompt_utils.py @@ -5,6 +5,7 @@ from langchain_core.messages import BaseMessage from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import LANGUAGE_HINT from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.configs.constants import DocumentSource from danswer.db.models import Prompt @@ -12,7 +13,6 @@ from danswer.prompts.chat_prompts import ADDITIONAL_INFO from danswer.prompts.chat_prompts import CITATION_REMINDER from danswer.prompts.constants import CODE_BLOCK_PAT -from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT from danswer.search.models import InferenceChunk @@ -35,15 +35,15 @@ def get_current_llm_day_time( return f"{formatted_datetime}" -def add_time_to_system_prompt(system_prompt: str) -> str: - if DANSWER_DATETIME_REPLACEMENT in system_prompt: - return system_prompt.replace( +def add_date_time_to_prompt(prompt_str: str) -> str: + if DANSWER_DATETIME_REPLACEMENT in prompt_str: + return prompt_str.replace( DANSWER_DATETIME_REPLACEMENT, get_current_llm_day_time(full_sentence=False, include_day_of_week=True), ) - if system_prompt: - return system_prompt + ADDITIONAL_INFO.format( + if prompt_str: + return prompt_str + ADDITIONAL_INFO.format( datetime_info=get_current_llm_day_time() ) else: diff --git a/backend/danswer/prompts/token_counts.py b/backend/danswer/prompts/token_counts.py index de77d0b8d11..c5bb7b2881e 100644 --- a/backend/danswer/prompts/token_counts.py +++ b/backend/danswer/prompts/token_counts.py @@ -1,14 +1,13 @@ +from danswer.configs.chat_configs import LANGUAGE_HINT from danswer.llm.utils import check_number_of_tokens from danswer.prompts.chat_prompts import ADDITIONAL_INFO from danswer.prompts.chat_prompts import CHAT_USER_PROMPT from danswer.prompts.chat_prompts import CITATION_REMINDER from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT -from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT from danswer.prompts.prompt_utils import get_current_llm_day_time -# tokens outside of the actual persona's "user_prompt" that make up the end -# user message +# tokens outside of the actual persona's "user_prompt" that make up the end user message CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT = check_number_of_tokens( CHAT_USER_PROMPT.format( context_docs_str="", diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index f0069d5c750..6e16de2c75e 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -4,6 +4,8 @@ from pydantic import BaseModel from pydantic import validator +from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE +from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER from danswer.configs.chat_configs import HYBRID_ALPHA from danswer.configs.chat_configs import NUM_RERANKED_RESULTS @@ -47,8 +49,8 @@ class ChunkMetric(BaseModel): class ChunkContext(BaseModel): # Additional surrounding context options, if full doc, then chunks are deduped # If surrounding context overlap, it is combined into one - chunks_above: int = 0 - chunks_below: int = 0 + chunks_above: int = CONTEXT_CHUNKS_ABOVE + chunks_below: int = CONTEXT_CHUNKS_BELOW full_doc: bool = False @validator("chunks_above", "chunks_below", pre=True, each_item=False) @@ -94,7 +96,7 @@ class SearchQuery(ChunkContext): # Only used if not skip_rerank num_rerank: int | None = NUM_RERANKED_RESULTS # Only used if not skip_llm_chunk_filter - max_llm_filter_chunks: int = NUM_RERANKED_RESULTS + max_llm_filter_sections: int = NUM_RERANKED_RESULTS class Config: frozen = True @@ -162,19 +164,38 @@ def __eq__(self, other: Any) -> bool: def __hash__(self) -> int: return hash((self.document_id, self.chunk_id)) + def __lt__(self, other: Any) -> bool: + if not isinstance(other, InferenceChunk): + return NotImplemented + if self.score is None: + if other.score is None: + return self.chunk_id > other.chunk_id + return True + if other.score is None: + return False + if self.score == other.score: + return self.chunk_id > other.chunk_id + return self.score < other.score -class InferenceSection(InferenceChunk): - """Section is a combination of chunks. A section could be a single chunk, several consecutive - chunks or the entire document""" + def __gt__(self, other: Any) -> bool: + if not isinstance(other, InferenceChunk): + return NotImplemented + if self.score is None: + return False + if other.score is None: + return True + if self.score == other.score: + return self.chunk_id < other.chunk_id + return self.score > other.score - combined_content: str - @classmethod - def from_chunk( - cls, inf_chunk: InferenceChunk, content: str | None = None - ) -> "InferenceSection": - inf_chunk_data = inf_chunk.dict() - return cls(**inf_chunk_data, combined_content=content or inf_chunk.content) +class InferenceSection(BaseModel): + """Section list of chunks with a combined content. A section could be a single chunk, several + chunks from the same document or the entire document.""" + + center_chunk: InferenceChunk + chunks: list[InferenceChunk] + combined_content: str class SearchDoc(BaseModel): @@ -199,6 +220,7 @@ class SearchDoc(BaseModel): updated_at: datetime | None primary_owners: list[str] | None secondary_owners: list[str] | None + is_internet: bool = False def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore initial_dict = super().dict(*args, **kwargs) # type: ignore @@ -229,6 +251,13 @@ def __lt__(self, other: Any) -> bool: return self.score < other.score +class SavedSearchDocWithContent(SavedSearchDoc): + """Used for endpoints that need to return the actual contents of the retrieved + section in addition to the match_highlights.""" + + content: str + + class RetrievalDocs(BaseModel): top_documents: list[SavedSearchDoc] diff --git a/backend/danswer/search/pipeline.py b/backend/danswer/search/pipeline.py index 7f68178bf08..2d990c15a3f 100644 --- a/backend/danswer/search/pipeline.py +++ b/backend/danswer/search/pipeline.py @@ -1,15 +1,16 @@ from collections import defaultdict from collections.abc import Callable -from collections.abc import Generator +from collections.abc import Iterator from typing import cast -from pydantic import BaseModel from sqlalchemy.orm import Session from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.models import User from danswer.document_index.factory import get_default_document_index +from danswer.llm.answering.prune_and_merge import ChunkRange +from danswer.llm.answering.prune_and_merge import merge_chunk_intervals from danswer.llm.interfaces import LLM from danswer.search.enums import QueryFlow from danswer.search.enums import SearchType @@ -23,31 +24,11 @@ from danswer.search.postprocessing.postprocessing import search_postprocessing from danswer.search.preprocessing.preprocessing import retrieval_preprocessing from danswer.search.retrieval.search_runner import retrieve_chunks +from danswer.search.utils import inference_section_from_chunks +from danswer.utils.logger import setup_logger from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel - -class ChunkRange(BaseModel): - chunk: InferenceChunk - start: int - end: int - combined_content: str | None = None - - -def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]: - """This acts on a single document to merge the overlapping ranges of sections - Algo explained here for easy understanding: https://leetcode.com/problems/merge-intervals - """ - sorted_ranges = sorted(chunk_ranges, key=lambda x: x.start) - - ans: list[ChunkRange] = [] - - for chunk_range in sorted_ranges: - if not ans or ans[-1].end < chunk_range.start: - ans.append(chunk_range) - else: - ans[-1].end = max(ans[-1].end, chunk_range.end) - - return ans +logger = setup_logger() class SearchPipeline: @@ -56,6 +37,7 @@ def __init__( search_request: SearchRequest, user: User | None, llm: LLM, + fast_llm: LLM, db_session: Session, bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] @@ -65,6 +47,7 @@ def __init__( self.search_request = search_request self.user = user self.llm = llm + self.fast_llm = fast_llm self.db_session = db_session self.bypass_acl = bypass_acl self.retrieval_metrics_callback = retrieval_metrics_callback @@ -76,60 +59,113 @@ def __init__( secondary_index_name=None, ) + # Preprocessing steps generate this self._search_query: SearchQuery | None = None self._predicted_search_type: SearchType | None = None self._predicted_flow: QueryFlow | None = None + # Initial document index retrieval chunks self._retrieved_chunks: list[InferenceChunk] | None = None + # Another call made to the document index to get surrounding sections self._retrieved_sections: list[InferenceSection] | None = None - self._reranked_chunks: list[InferenceChunk] | None = None + # Reranking and LLM section selection can be run together + # If only LLM selection is on, the reranked chunks are yielded immediatly self._reranked_sections: list[InferenceSection] | None = None - self._relevant_chunk_indices: list[int] | None = None - - # If chunks have been merged, the LLM filter flow no longer applies - # as the indices no longer match. Can be implemented later as needed - self.ran_merge_chunk = False + self._relevant_section_indices: list[int] | None = None - # generator state - self._postprocessing_generator: Generator[ - list[InferenceChunk] | list[str], None, None + # Generates reranked chunks and LLM selections + self._postprocessing_generator: Iterator[ + list[InferenceSection] | list[int] ] | None = None - def _combine_chunks(self, post_rerank: bool) -> list[InferenceSection]: - if not post_rerank and self._retrieved_sections: - return self._retrieved_sections - if post_rerank and self._reranked_sections: - return self._reranked_sections + """Pre-processing""" - if not post_rerank: - chunks = self.retrieved_chunks - else: - chunks = self.reranked_chunks + def _run_preprocessing(self) -> None: + ( + final_search_query, + predicted_search_type, + predicted_flow, + ) = retrieval_preprocessing( + search_request=self.search_request, + user=self.user, + llm=self.llm, + db_session=self.db_session, + bypass_acl=self.bypass_acl, + ) + self._search_query = final_search_query + self._predicted_search_type = predicted_search_type + self._predicted_flow = predicted_flow - if self._search_query is None: - # Should never happen - raise RuntimeError("Failed in Query Preprocessing") + @property + def search_query(self) -> SearchQuery: + if self._search_query is not None: + return self._search_query - functions_with_args: list[tuple[Callable, tuple]] = [] - final_inference_sections = [] + self._run_preprocessing() - # Nothing to combine, just return the chunks - if ( - not self._search_query.chunks_above - and not self._search_query.chunks_below - and not self._search_query.full_doc - ): - return [InferenceSection.from_chunk(chunk) for chunk in chunks] + return cast(SearchQuery, self._search_query) - # If chunk merges have been run, LLM reranking loses meaning - # Needs reimplementation, out of scope for now - self.ran_merge_chunk = True + @property + def predicted_search_type(self) -> SearchType: + if self._predicted_search_type is not None: + return self._predicted_search_type + + self._run_preprocessing() + return cast(SearchType, self._predicted_search_type) + + @property + def predicted_flow(self) -> QueryFlow: + if self._predicted_flow is not None: + return self._predicted_flow + + self._run_preprocessing() + return cast(QueryFlow, self._predicted_flow) + + """Retrieval and Postprocessing""" + + def _get_chunks(self) -> list[InferenceChunk]: + """TODO as a future extension: + If large chunks (above 512 tokens) are used which cannot be directly fed to the LLM, + This step should run the two retrievals to get all of the base size chunks + """ + if self._retrieved_chunks is not None: + return self._retrieved_chunks + + self._retrieved_chunks = retrieve_chunks( + query=self.search_query, + document_index=self.document_index, + db_session=self.db_session, + hybrid_alpha=self.search_request.hybrid_alpha, + multilingual_expansion_str=MULTILINGUAL_QUERY_EXPANSION, + retrieval_metrics_callback=self.retrieval_metrics_callback, + ) + + return cast(list[InferenceChunk], self._retrieved_chunks) + + def _get_sections(self) -> list[InferenceSection]: + """Returns an expanded section from each of the chunks. + If whole docs (instead of above/below context) is specified then it will give back all of the whole docs + that have a corresponding chunk. + + This step should be fast for any document index implementation. + """ + if self._retrieved_sections is not None: + return self._retrieved_sections + + retrieved_chunks = self._get_chunks() + + above = self.search_query.chunks_above + below = self.search_query.chunks_below + + functions_with_args: list[tuple[Callable, tuple]] = [] + expanded_inference_sections = [] # Full doc setting takes priority - if self._search_query.full_doc: + if self.search_query.full_doc: seen_document_ids = set() unique_chunks = [] - for chunk in chunks: + # This preserves the ordering since the chunks are retrieved in score order + for chunk in retrieved_chunks: if chunk.document_id not in seen_document_ids: seen_document_ids.add(chunk.document_id) unique_chunks.append(chunk) @@ -154,43 +190,54 @@ def _combine_chunks(self, post_rerank: bool) -> list[InferenceSection]: for ind, chunk in enumerate(unique_chunks): inf_chunks = list_inference_chunks[ind] - combined_content = "\n".join([chunk.content for chunk in inf_chunks]) - final_inference_sections.append( - InferenceSection.from_chunk(chunk, content=combined_content) + + inference_section = inference_section_from_chunks( + center_chunk=chunk, + chunks=inf_chunks, ) - return final_inference_sections + if inference_section is not None: + expanded_inference_sections.append(inference_section) + else: + logger.warning("Skipped creation of section, no chunks found") + + self._retrieved_sections = expanded_inference_sections + return expanded_inference_sections # General flow: # - Combine chunks into lists by document_id # - For each document, run merge-intervals to get combined ranges + # - This allows for less queries to the document index # - Fetch all of the new chunks with contents for the combined ranges - # - Map it back to the combined ranges (which each know their "center" chunk) # - Reiterate the chunks again and map to the results above based on the chunk. # This maintains the original chunks ordering. Note, we cannot simply sort by score here # as reranking flow may wipe the scores for a lot of the chunks. doc_chunk_ranges_map = defaultdict(list) - for chunk in chunks: + for chunk in retrieved_chunks: + # The list of ranges for each document is ordered by score doc_chunk_ranges_map[chunk.document_id].append( ChunkRange( - chunk=chunk, - start=max(0, chunk.chunk_id - self._search_query.chunks_above), + chunks=[chunk], + start=max(0, chunk.chunk_id - above), # No max known ahead of time, filter will handle this anyway - end=chunk.chunk_id + self._search_query.chunks_below, + end=chunk.chunk_id + below, ) ) + # List of ranges, outside list represents documents, inner list represents ranges merged_ranges = [ merge_chunk_intervals(ranges) for ranges in doc_chunk_ranges_map.values() ] - reverse_map = {r.chunk: r for doc_ranges in merged_ranges for r in doc_ranges} + flat_ranges = [r for ranges in merged_ranges for r in ranges] - for chunk_range in reverse_map.values(): + for chunk_range in flat_ranges: functions_with_args.append( ( + # If Large Chunks are introduced, additional filters need to be added here self.document_index.id_based_retrieval, ( - chunk_range.chunk.document_id, + # Only need the document_id here, just use any chunk in the range is fine + chunk_range.chunks[0].document_id, chunk_range.start, chunk_range.end, # There is no chunk level permissioning, this expansion around chunks @@ -204,151 +251,81 @@ def _combine_chunks(self, post_rerank: bool) -> list[InferenceSection]: list_inference_chunks = run_functions_tuples_in_parallel( functions_with_args, allow_failures=False ) + flattened_inference_chunks = [ + chunk for sublist in list_inference_chunks for chunk in sublist + ] - for ind, chunk_range in enumerate(reverse_map.values()): - inf_chunks = list_inference_chunks[ind] - combined_content = "\n".join([chunk.content for chunk in inf_chunks]) - chunk_range.combined_content = combined_content - - for chunk in chunks: - if chunk not in reverse_map: - continue - chunk_range = reverse_map[chunk] - final_inference_sections.append( - InferenceSection.from_chunk( - chunk_range.chunk, content=chunk_range.combined_content - ) + doc_chunk_ind_to_chunk = { + (chunk.document_id, chunk.chunk_id): chunk + for chunk in flattened_inference_chunks + } + + # Build the surroundings for all of the initial retrieved chunks + for chunk in retrieved_chunks: + start_ind = max(0, chunk.chunk_id - above) + end_ind = chunk.chunk_id + below + + # Since the index of the max_chunk is unknown, just allow it to be None and filter after + surrounding_chunks_or_none = [ + doc_chunk_ind_to_chunk.get((chunk.document_id, chunk_ind)) + for chunk_ind in range(start_ind, end_ind + 1) # end_ind is inclusive + ] + # The None will apply to the would be "chunks" that are larger than the index of the last chunk + # of the document + surrounding_chunks = [ + chunk for chunk in surrounding_chunks_or_none if chunk is not None + ] + + inference_section = inference_section_from_chunks( + center_chunk=chunk, + chunks=surrounding_chunks, ) + if inference_section is not None: + expanded_inference_sections.append(inference_section) + else: + logger.warning("Skipped creation of section, no chunks found") - return final_inference_sections - - """Pre-processing""" - - def _run_preprocessing(self) -> None: - ( - final_search_query, - predicted_search_type, - predicted_flow, - ) = retrieval_preprocessing( - search_request=self.search_request, - user=self.user, - llm=self.llm, - db_session=self.db_session, - bypass_acl=self.bypass_acl, - ) - self._predicted_search_type = predicted_search_type - self._predicted_flow = predicted_flow - self._search_query = final_search_query + self._retrieved_sections = expanded_inference_sections + return expanded_inference_sections @property - def search_query(self) -> SearchQuery: - if self._search_query is not None: - return self._search_query - - self._run_preprocessing() - return cast(SearchQuery, self._search_query) - - @property - def predicted_search_type(self) -> SearchType: - if self._predicted_search_type is not None: - return self._predicted_search_type - - self._run_preprocessing() - return cast(SearchType, self._predicted_search_type) - - @property - def predicted_flow(self) -> QueryFlow: - if self._predicted_flow is not None: - return self._predicted_flow - - self._run_preprocessing() - return cast(QueryFlow, self._predicted_flow) - - """Retrieval""" - - @property - def retrieved_chunks(self) -> list[InferenceChunk]: - if self._retrieved_chunks is not None: - return self._retrieved_chunks - - self._retrieved_chunks = retrieve_chunks( - query=self.search_query, - document_index=self.document_index, - db_session=self.db_session, - hybrid_alpha=self.search_request.hybrid_alpha, - multilingual_expansion_str=MULTILINGUAL_QUERY_EXPANSION, - retrieval_metrics_callback=self.retrieval_metrics_callback, - ) - - return cast(list[InferenceChunk], self._retrieved_chunks) - - @property - def retrieved_sections(self) -> list[InferenceSection]: - # Calls retrieved_chunks inside - self._retrieved_sections = self._combine_chunks(post_rerank=False) - return self._retrieved_sections - - """Post-Processing""" - - @property - def reranked_chunks(self) -> list[InferenceChunk]: - if self._reranked_chunks is not None: - return self._reranked_chunks + def reranked_sections(self) -> list[InferenceSection]: + """Reranking is always done at the chunk level since section merging could create arbitrarily + long sections which could be: + 1. Longer than the maximum context limit of even large rerankers + 2. Slow to calculate due to the quadratic scaling laws of Transformers + + See implementation in search_postprocessing for details + """ + if self._reranked_sections is not None: + return self._reranked_sections self._postprocessing_generator = search_postprocessing( search_query=self.search_query, - retrieved_chunks=self.retrieved_chunks, + retrieved_sections=self._get_sections(), + llm=self.fast_llm, rerank_metrics_callback=self.rerank_metrics_callback, ) - self._reranked_chunks = cast( - list[InferenceChunk], next(self._postprocessing_generator) + + self._reranked_sections = cast( + list[InferenceSection], next(self._postprocessing_generator) ) - return self._reranked_chunks - @property - def reranked_sections(self) -> list[InferenceSection]: - # Calls reranked_chunks inside - self._reranked_sections = self._combine_chunks(post_rerank=True) return self._reranked_sections @property - def relevant_chunk_indices(self) -> list[int]: - # If chunks have been merged, then we cannot simply rely on the leading chunk - # relevance, there is no way to get the full relevance of the Section now - # without running a more token heavy pass. This can be an option but not - # implementing now. - if self.ran_merge_chunk: - return [] - - if self._relevant_chunk_indices is not None: - return self._relevant_chunk_indices - - # run first step of postprocessing generator if not already done - reranked_docs = self.reranked_chunks - - relevant_chunk_ids = next( - cast(Generator[list[str], None, None], self._postprocessing_generator) - ) - self._relevant_chunk_indices = [ - ind - for ind, chunk in enumerate(reranked_docs) - if chunk.unique_id in relevant_chunk_ids - ] - return self._relevant_chunk_indices + def relevant_section_indices(self) -> list[int]: + if self._relevant_section_indices is not None: + return self._relevant_section_indices - @property - def chunk_relevance_list(self) -> list[bool]: - return [ - True if ind in self.relevant_chunk_indices else False - for ind in range(len(self.reranked_chunks)) - ] + self._relevant_section_indices = next( + cast(Iterator[list[int]], self._postprocessing_generator) + ) + return self._relevant_section_indices @property def section_relevance_list(self) -> list[bool]: - if self.ran_merge_chunk: - return [False] * len(self.reranked_sections) - return [ - True if ind in self.relevant_chunk_indices else False - for ind in range(len(self.reranked_chunks)) + True if ind in self.relevant_section_indices else False + for ind in range(len(self.reranked_sections)) ] diff --git a/backend/danswer/search/postprocessing/postprocessing.py b/backend/danswer/search/postprocessing/postprocessing.py index f7c750eaf3b..b457549917f 100644 --- a/backend/danswer/search/postprocessing/postprocessing.py +++ b/backend/danswer/search/postprocessing/postprocessing.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from collections.abc import Generator +from collections.abc import Iterator from typing import cast import numpy @@ -9,14 +9,16 @@ from danswer.document_index.document_index_utils import ( translate_boost_count_to_multiplier, ) +from danswer.llm.interfaces import LLM from danswer.search.models import ChunkMetric from danswer.search.models import InferenceChunk +from danswer.search.models import InferenceSection from danswer.search.models import MAX_METRICS_CONTENT from danswer.search.models import RerankMetricsContainer from danswer.search.models import SearchQuery from danswer.search.models import SearchType from danswer.search.search_nlp_models import CrossEncoderEnsembleModel -from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks +from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections from danswer.utils.logger import setup_logger from danswer.utils.threadpool_concurrency import FunctionCall from danswer.utils.threadpool_concurrency import run_functions_in_parallel @@ -26,9 +28,12 @@ logger = setup_logger() -def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None: +def _log_top_section_links(search_flow: str, sections: list[InferenceSection]) -> None: top_links = [ - c.source_links[0] if c.source_links is not None else "No Link" for c in chunks + section.center_chunk.source_links[0] + if section.center_chunk.source_links is not None + else "No Link" + for section in sections ] logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}") @@ -112,79 +117,115 @@ def semantic_reranking( return list(ranked_chunks), list(ranked_indices) -def rerank_chunks( +def rerank_sections( query: SearchQuery, - chunks_to_rerank: list[InferenceChunk], + sections_to_rerank: list[InferenceSection], rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> list[InferenceChunk]: +) -> list[InferenceSection]: + """Chunks are reranked rather than the containing sections, this is because of speed + implications, if reranking models have lower latency for long inputs in the future + we may rerank on the combined context of the section instead + + Making the assumption here that often times we want larger Sections to provide context + for the LLM to determine if a section is useful but for reranking, we don't need to be + as stringent. If the Section is relevant, we assume that the chunk rerank score will + also be high. + """ + chunks_to_rerank = [section.center_chunk for section in sections_to_rerank] + ranked_chunks, _ = semantic_reranking( query=query.query, chunks=chunks_to_rerank[: query.num_rerank], rerank_metrics_callback=rerank_metrics_callback, ) lower_chunks = chunks_to_rerank[query.num_rerank :] + # Scores from rerank cannot be meaningfully combined with scores without rerank + # However the ordering is still important for lower_chunk in lower_chunks: lower_chunk.score = None ranked_chunks.extend(lower_chunks) - return ranked_chunks + + chunk_id_to_section = { + section.center_chunk.unique_id: section for section in sections_to_rerank + } + ordered_sections = [chunk_id_to_section[chunk.unique_id] for chunk in ranked_chunks] + return ordered_sections @log_function_time(print_only=True) -def filter_chunks( +def filter_sections( query: SearchQuery, - chunks_to_filter: list[InferenceChunk], -) -> list[str]: - """Filters chunks based on whether the LLM thought they were relevant to the query. + sections_to_filter: list[InferenceSection], + llm: LLM, + # For cost saving, we may turn this on + use_chunk: bool = False, +) -> list[InferenceSection]: + """Filters sections based on whether the LLM thought they were relevant to the query. + This applies on the section which has more context than the chunk. Hopefully this yields more accurate LLM evaluations. + + Returns a list of the unique chunk IDs that were marked as relevant + """ + sections_to_filter = sections_to_filter[: query.max_llm_filter_sections] - Returns a list of the unique chunk IDs that were marked as relevant""" - chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks] - llm_chunk_selection = llm_batch_eval_chunks( + contents = [ + section.center_chunk.content if use_chunk else section.combined_content + for section in sections_to_filter + ] + + llm_chunk_selection = llm_batch_eval_sections( query=query.query, - chunk_contents=[chunk.content for chunk in chunks_to_filter], + section_contents=contents, + llm=llm, ) + return [ - chunk.unique_id - for ind, chunk in enumerate(chunks_to_filter) + section + for ind, section in enumerate(sections_to_filter) if llm_chunk_selection[ind] ] def search_postprocessing( search_query: SearchQuery, - retrieved_chunks: list[InferenceChunk], + retrieved_sections: list[InferenceSection], + llm: LLM, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> Generator[list[InferenceChunk] | list[str], None, None]: +) -> Iterator[list[InferenceSection] | list[int]]: post_processing_tasks: list[FunctionCall] = [] rerank_task_id = None - chunks_yielded = False + sections_yielded = False if should_rerank(search_query): post_processing_tasks.append( FunctionCall( - rerank_chunks, + rerank_sections, ( search_query, - retrieved_chunks, + retrieved_sections, rerank_metrics_callback, ), ) ) rerank_task_id = post_processing_tasks[-1].result_id else: - final_chunks = retrieved_chunks # NOTE: if we don't rerank, we can return the chunks immediately - # since we know this is the final order - _log_top_chunk_links(search_query.search_type.value, final_chunks) - yield final_chunks - chunks_yielded = True + # since we know this is the final order. + # This way the user experience isn't delayed by the LLM step + _log_top_section_links(search_query.search_type.value, retrieved_sections) + yield retrieved_sections + sections_yielded = True llm_filter_task_id = None if should_apply_llm_based_relevance_filter(search_query): post_processing_tasks.append( FunctionCall( - filter_chunks, - (search_query, retrieved_chunks[: search_query.max_llm_filter_chunks]), + filter_sections, + ( + search_query, + retrieved_sections[: search_query.max_llm_filter_sections], + llm, + ), ) ) llm_filter_task_id = post_processing_tasks[-1].result_id @@ -194,30 +235,30 @@ def search_postprocessing( if post_processing_tasks else {} ) - reranked_chunks = cast( - list[InferenceChunk] | None, + reranked_sections = cast( + list[InferenceSection] | None, post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None, ) - if reranked_chunks: - if chunks_yielded: + if reranked_sections: + if sections_yielded: logger.error( - "Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen." + "Trying to yield re-ranked sections, but sections were already yielded. This should never happen." ) else: - _log_top_chunk_links(search_query.search_type.value, reranked_chunks) - yield reranked_chunks + _log_top_section_links(search_query.search_type.value, reranked_sections) + yield reranked_sections - llm_chunk_selection = cast( + llm_section_selection = cast( list[str] | None, post_processing_results.get(str(llm_filter_task_id)) if llm_filter_task_id else None, ) - if llm_chunk_selection is not None: + if llm_section_selection is not None: yield [ - chunk.unique_id - for chunk in reranked_chunks or retrieved_chunks - if chunk.unique_id in llm_chunk_selection + index + for index, section in enumerate(reranked_sections or retrieved_sections) + if section.center_chunk.unique_id in llm_section_selection ] else: - yield cast(list[str], []) + yield cast(list[int], []) diff --git a/backend/danswer/search/preprocessing/access_filters.py b/backend/danswer/search/preprocessing/access_filters.py index 5ea2132028a..e8141864d11 100644 --- a/backend/danswer/search/preprocessing/access_filters.py +++ b/backend/danswer/search/preprocessing/access_filters.py @@ -16,5 +16,6 @@ def build_user_only_filters(user: User | None, db_session: Session) -> IndexFilt source_type=None, document_set=None, time_cutoff=None, + tags=None, access_control_list=user_acl_filters, ) diff --git a/backend/danswer/search/preprocessing/preprocessing.py b/backend/danswer/search/preprocessing/preprocessing.py index b4be8dca685..bb2449efe52 100644 --- a/backend/danswer/search/preprocessing/preprocessing.py +++ b/backend/danswer/search/preprocessing/preprocessing.py @@ -2,7 +2,6 @@ from danswer.configs.chat_configs import BASE_RECENCY_DECAY from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER -from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER from danswer.configs.chat_configs import NUM_RETURNED_HITS from danswer.db.models import User @@ -36,8 +35,6 @@ def retrieval_preprocessing( db_session: Session, bypass_acl: bool = False, include_query_intent: bool = True, - enable_auto_detect_filters: bool = False, - disable_llm_filter_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION, disable_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER, base_recency_decay: float = BASE_RECENCY_DECAY, favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER, @@ -63,10 +60,7 @@ def retrieval_preprocessing( auto_detect_time_filter = True auto_detect_source_filter = True - if disable_llm_filter_extraction: - auto_detect_time_filter = False - auto_detect_source_filter = False - elif enable_auto_detect_filters is False: + if not search_request.enable_auto_detect_filters: logger.debug("Retrieval details disables auto detect filters") auto_detect_time_filter = False auto_detect_source_filter = False diff --git a/backend/danswer/search/retrieval/search_runner.py b/backend/danswer/search/retrieval/search_runner.py index 411db5b0f56..3313d243942 100644 --- a/backend/danswer/search/retrieval/search_runner.py +++ b/backend/danswer/search/retrieval/search_runner.py @@ -7,7 +7,6 @@ from nltk.tokenize import word_tokenize # type:ignore from sqlalchemy.orm import Session -from danswer.chat.models import LlmDoc from danswer.configs.chat_configs import HYBRID_ALPHA from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.db.embedding_model import get_current_db_embedding_model @@ -16,11 +15,13 @@ from danswer.search.models import ChunkMetric from danswer.search.models import IndexFilters from danswer.search.models import InferenceChunk +from danswer.search.models import InferenceSection from danswer.search.models import MAX_METRICS_CONTENT from danswer.search.models import RetrievalMetricsContainer from danswer.search.models import SearchQuery from danswer.search.models import SearchType from danswer.search.search_nlp_models import EmbeddingModel +from danswer.search.utils import inference_section_from_chunks from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion from danswer.utils.logger import setup_logger from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel @@ -240,30 +241,10 @@ def retrieve_chunks( return top_chunks -def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc: - if not inf_chunks: - raise ValueError("Cannot combine empty list of chunks") - - # Use the first link of the document - first_chunk = inf_chunks[0] - chunk_texts = [chunk.content for chunk in inf_chunks] - return LlmDoc( - document_id=first_chunk.document_id, - content="\n".join(chunk_texts), - blurb=first_chunk.blurb, - semantic_identifier=first_chunk.semantic_identifier, - source_type=first_chunk.source_type, - metadata=first_chunk.metadata, - updated_at=first_chunk.updated_at, - link=first_chunk.source_links[0] if first_chunk.source_links else None, - source_links=first_chunk.source_links, - ) - - -def inference_documents_from_ids( +def inference_sections_from_ids( doc_identifiers: list[tuple[str, int]], document_index: DocumentIndex, -) -> list[LlmDoc]: +) -> list[InferenceSection]: # Currently only fetches whole docs doc_ids_set = set(doc_id for doc_id, chunk_id in doc_identifiers) @@ -282,4 +263,17 @@ def inference_documents_from_ids( # Any failures to retrieve would give a None, drop the Nones and empty lists inference_chunks_sets = [res for res in parallel_results if res] - return [combine_inference_chunks(chunk_set) for chunk_set in inference_chunks_sets] + return [ + inference_section + for inference_section in [ + inference_section_from_chunks( + # The scores will always be 0 because the fetching by id gives back + # no search scores. This is not needed though if the user is explicitly + # selecting a document. + center_chunk=chunk_set[0], + chunks=chunk_set, + ) + for chunk_set in inference_chunks_sets + ] + if inference_section is not None + ] diff --git a/backend/danswer/search/utils.py b/backend/danswer/search/utils.py index 5b5a6464aad..8b138d2e9b8 100644 --- a/backend/danswer/search/utils.py +++ b/backend/danswer/search/utils.py @@ -4,10 +4,19 @@ from danswer.db.models import SearchDoc as DBSearchDoc from danswer.search.models import InferenceChunk from danswer.search.models import InferenceSection +from danswer.search.models import SavedSearchDoc +from danswer.search.models import SavedSearchDocWithContent from danswer.search.models import SearchDoc -T = TypeVar("T", InferenceSection, InferenceChunk, SearchDoc) +T = TypeVar( + "T", + InferenceSection, + InferenceChunk, + SearchDoc, + SavedSearchDoc, + SavedSearchDocWithContent, +) def dedupe_documents(items: list[T]) -> tuple[list[T], list[int]]: @@ -15,8 +24,13 @@ def dedupe_documents(items: list[T]) -> tuple[list[T], list[int]]: deduped_items = [] dropped_indices = [] for index, item in enumerate(items): - if item.document_id not in seen_ids: - seen_ids.add(item.document_id) + if isinstance(item, InferenceSection): + document_id = item.center_chunk.document_id + else: + document_id = item.document_id + + if document_id not in seen_ids: + seen_ids.add(document_id) deduped_items.append(item) else: dropped_indices.append(index) @@ -24,7 +38,9 @@ def dedupe_documents(items: list[T]) -> tuple[list[T], list[int]]: def drop_llm_indices( - llm_indices: list[int], search_docs: list[DBSearchDoc], dropped_indices: list[int] + llm_indices: list[int], + search_docs: Sequence[DBSearchDoc | SavedSearchDoc], + dropped_indices: list[int], ) -> list[int]: llm_bools = [True if i in llm_indices else False for i in range(len(search_docs))] if dropped_indices: @@ -34,30 +50,51 @@ def drop_llm_indices( return [i for i, val in enumerate(llm_bools) if val] +def inference_section_from_chunks( + center_chunk: InferenceChunk, + chunks: list[InferenceChunk], +) -> InferenceSection | None: + if not chunks: + return None + + combined_content = "\n".join([chunk.content for chunk in chunks]) + + return InferenceSection( + center_chunk=center_chunk, + chunks=chunks, + combined_content=combined_content, + ) + + def chunks_or_sections_to_search_docs( - chunks: Sequence[InferenceChunk | InferenceSection] | None, + items: Sequence[InferenceChunk | InferenceSection] | None, ) -> list[SearchDoc]: - search_docs = ( - [ - SearchDoc( - document_id=chunk.document_id, - chunk_ind=chunk.chunk_id, - semantic_identifier=chunk.semantic_identifier or "Unknown", - link=chunk.source_links.get(0) if chunk.source_links else None, - blurb=chunk.blurb, - source_type=chunk.source_type, - boost=chunk.boost, - hidden=chunk.hidden, - metadata=chunk.metadata, - score=chunk.score, - match_highlights=chunk.match_highlights, - updated_at=chunk.updated_at, - primary_owners=chunk.primary_owners, - secondary_owners=chunk.secondary_owners, - ) - for chunk in chunks - ] - if chunks - else [] - ) + if not items: + return [] + + search_docs = [ + SearchDoc( + document_id=( + chunk := item.center_chunk + if isinstance(item, InferenceSection) + else item + ).document_id, + chunk_ind=chunk.chunk_id, + semantic_identifier=chunk.semantic_identifier or "Unknown", + link=chunk.source_links[0] if chunk.source_links else None, + blurb=chunk.blurb, + source_type=chunk.source_type, + boost=chunk.boost, + hidden=chunk.hidden, + metadata=chunk.metadata, + score=chunk.score, + match_highlights=chunk.match_highlights, + updated_at=chunk.updated_at, + primary_owners=chunk.primary_owners, + secondary_owners=chunk.secondary_owners, + is_internet=False, + ) + for item in items + ] + return search_docs diff --git a/backend/danswer/secondary_llm_flows/answer_validation.py b/backend/danswer/secondary_llm_flows/answer_validation.py index 2ef3787c11b..68587109538 100644 --- a/backend/danswer/secondary_llm_flows/answer_validation.py +++ b/backend/danswer/secondary_llm_flows/answer_validation.py @@ -1,5 +1,5 @@ from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_default_llms from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_to_string from danswer.prompts.answer_validation import ANSWER_VALIDITY_PROMPT @@ -44,7 +44,7 @@ def _extract_validity(model_output: str) -> bool: return True # If something is wrong, let's not toss away the answer try: - llm = get_default_llm() + llm, _ = get_default_llms() except GenAIDisabledException: return True diff --git a/backend/danswer/secondary_llm_flows/chat_session_naming.py b/backend/danswer/secondary_llm_flows/chat_session_naming.py index 7a48cac80b8..9449eaded7a 100644 --- a/backend/danswer/secondary_llm_flows/chat_session_naming.py +++ b/backend/danswer/secondary_llm_flows/chat_session_naming.py @@ -1,4 +1,6 @@ from danswer.chat.chat_utils import combine_message_chain +from danswer.configs.chat_configs import LANGUAGE_CHAT_NAMING_HINT +from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.db.models import ChatMessage from danswer.llm.interfaces import LLM @@ -18,10 +20,18 @@ def get_renamed_conversation_name( messages=full_history, token_limit=GEN_AI_HISTORY_CUTOFF ) + language_hint = ( + f"\n{LANGUAGE_CHAT_NAMING_HINT.strip()}" + if bool(MULTILINGUAL_QUERY_EXPANSION) + else "" + ) + prompt_msgs = [ { "role": "user", - "content": CHAT_NAMING.format(chat_history=history_str), + "content": CHAT_NAMING.format( + language_hint_or_empty=language_hint, chat_history=history_str + ), }, ] diff --git a/backend/danswer/secondary_llm_flows/chunk_usefulness.py b/backend/danswer/secondary_llm_flows/chunk_usefulness.py index d37feb0c0b1..b672a563d0e 100644 --- a/backend/danswer/secondary_llm_flows/chunk_usefulness.py +++ b/backend/danswer/secondary_llm_flows/chunk_usefulness.py @@ -1,24 +1,23 @@ from collections.abc import Callable -from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm +from danswer.llm.interfaces import LLM from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_to_string -from danswer.prompts.llm_chunk_filter import CHUNK_FILTER_PROMPT from danswer.prompts.llm_chunk_filter import NONUSEFUL_PAT +from danswer.prompts.llm_chunk_filter import SECTION_FILTER_PROMPT from danswer.utils.logger import setup_logger from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel logger = setup_logger() -def llm_eval_chunk(query: str, chunk_content: str) -> bool: +def llm_eval_section(query: str, section_content: str, llm: LLM) -> bool: def _get_usefulness_messages() -> list[dict[str, str]]: messages = [ { "role": "user", - "content": CHUNK_FILTER_PROMPT.format( - chunk_text=chunk_content, user_query=query + "content": SECTION_FILTER_PROMPT.format( + chunk_text=section_content, user_query=query ), }, ] @@ -32,14 +31,6 @@ def _extract_usefulness(model_output: str) -> bool: return False return True - # If Gen AI is disabled, none of the messages are more "useful" than any other - # All are marked not useful (False) so that the icon for Gen AI likes this answer - # is not shown for any result - try: - llm = get_default_llm(use_fast_llm=True, timeout=5) - except GenAIDisabledException: - return False - messages = _get_usefulness_messages() filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) # When running in a batch, it takes as long as the longest thread @@ -51,12 +42,13 @@ def _extract_usefulness(model_output: str) -> bool: return _extract_usefulness(model_output) -def llm_batch_eval_chunks( - query: str, chunk_contents: list[str], use_threads: bool = True +def llm_batch_eval_sections( + query: str, section_contents: list[str], llm: LLM, use_threads: bool = True ) -> list[bool]: if use_threads: functions_with_args: list[tuple[Callable, tuple]] = [ - (llm_eval_chunk, (query, chunk_content)) for chunk_content in chunk_contents + (llm_eval_section, (query, section_content, llm)) + for section_content in section_contents ] logger.debug( @@ -66,10 +58,11 @@ def llm_batch_eval_chunks( functions_with_args, allow_failures=True ) - # In case of failure/timeout, don't throw out the chunk + # In case of failure/timeout, don't throw out the section return [True if item is None else item for item in parallel_results] else: return [ - llm_eval_chunk(query, chunk_content) for chunk_content in chunk_contents + llm_eval_section(query, section_content, llm) + for section_content in section_contents ] diff --git a/backend/danswer/secondary_llm_flows/query_expansion.py b/backend/danswer/secondary_llm_flows/query_expansion.py index 2f221bfa903..e5a2b67e1fa 100644 --- a/backend/danswer/secondary_llm_flows/query_expansion.py +++ b/backend/danswer/secondary_llm_flows/query_expansion.py @@ -6,7 +6,7 @@ from danswer.db.models import ChatMessage from danswer.llm.answering.models import PreviousMessage from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_default_llms from danswer.llm.interfaces import LLM from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_to_string @@ -33,7 +33,7 @@ def _get_rephrase_messages() -> list[dict[str, str]]: return messages try: - llm = get_default_llm(use_fast_llm=True, timeout=5) + _, fast_llm = get_default_llms(timeout=5) except GenAIDisabledException: logger.warning( "Unable to perform multilingual query expansion, Gen AI disabled" @@ -42,7 +42,7 @@ def _get_rephrase_messages() -> list[dict[str, str]]: messages = _get_rephrase_messages() filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - model_output = message_to_string(llm.invoke(filled_llm_prompt)) + model_output = message_to_string(fast_llm.invoke(filled_llm_prompt)) logger.debug(model_output) return model_output @@ -74,11 +74,12 @@ def multilingual_query_expansion( def get_contextual_rephrase_messages( question: str, history_str: str, + prompt_template: str = HISTORY_QUERY_REPHRASE, ) -> list[dict[str, str]]: messages = [ { "role": "user", - "content": HISTORY_QUERY_REPHRASE.format( + "content": prompt_template.format( question=question, chat_history=history_str ), }, @@ -94,6 +95,7 @@ def history_based_query_rephrase( size_heuristic: int = 200, punctuation_heuristic: int = 10, skip_first_rephrase: bool = False, + prompt_template: str = HISTORY_QUERY_REPHRASE, ) -> str: # Globally disabled, just use the exact user query if DISABLE_LLM_QUERY_REPHRASE: @@ -119,7 +121,7 @@ def history_based_query_rephrase( ) prompt_msgs = get_contextual_rephrase_messages( - question=query, history_str=history_str + question=query, history_str=history_str, prompt_template=prompt_template ) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs) @@ -148,7 +150,7 @@ def thread_based_query_rephrase( if llm is None: try: - llm = get_default_llm() + llm, _ = get_default_llms() except GenAIDisabledException: # If Generative AI is turned off, just return the original query return user_query diff --git a/backend/danswer/secondary_llm_flows/query_validation.py b/backend/danswer/secondary_llm_flows/query_validation.py index 4130b7ee356..bbc1ef412b9 100644 --- a/backend/danswer/secondary_llm_flows/query_validation.py +++ b/backend/danswer/secondary_llm_flows/query_validation.py @@ -5,7 +5,7 @@ from danswer.chat.models import StreamingError from danswer.configs.chat_configs import DISABLE_LLM_QUERY_ANSWERABILITY from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_default_llms from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_generator_to_string_generator from danswer.llm.utils import message_to_string @@ -52,7 +52,7 @@ def get_query_answerability( return "Query Answerability Evaluation feature is turned off", True try: - llm = get_default_llm() + llm, _ = get_default_llms() except GenAIDisabledException: return "Generative AI is turned off - skipping check", True @@ -79,7 +79,7 @@ def stream_query_answerability( return try: - llm = get_default_llm() + llm, _ = get_default_llms() except GenAIDisabledException: yield get_json_line( QueryValidationResponse( diff --git a/backend/danswer/secondary_llm_flows/source_filter.py b/backend/danswer/secondary_llm_flows/source_filter.py index 78dc504abdf..802a14f42fa 100644 --- a/backend/danswer/secondary_llm_flows/source_filter.py +++ b/backend/danswer/secondary_llm_flows/source_filter.py @@ -159,11 +159,13 @@ def _extract_source_filters_from_llm_out( if __name__ == "__main__": - from danswer.llm.factory import get_default_llm + from danswer.llm.factory import get_default_llms, get_main_llm_from_tuple # Just for testing purposes with Session(get_sqlalchemy_engine()) as db_session: while True: user_input = input("Query to Extract Sources: ") - sources = extract_source_filter(user_input, get_default_llm(), db_session) + sources = extract_source_filter( + user_input, get_main_llm_from_tuple(get_default_llms()), db_session + ) print(sources) diff --git a/backend/danswer/secondary_llm_flows/time_filter.py b/backend/danswer/secondary_llm_flows/time_filter.py index 7de6efb3dd2..aef32d7bdb2 100644 --- a/backend/danswer/secondary_llm_flows/time_filter.py +++ b/backend/danswer/secondary_llm_flows/time_filter.py @@ -156,10 +156,12 @@ def _extract_time_filter_from_llm_out( if __name__ == "__main__": # Just for testing purposes, too tedious to unit test as it relies on an LLM - from danswer.llm.factory import get_default_llm + from danswer.llm.factory import get_default_llms, get_main_llm_from_tuple while True: user_input = input("Query to Extract Time: ") - cutoff, recency_bias = extract_time_filter(user_input, get_default_llm()) + cutoff, recency_bias = extract_time_filter( + user_input, get_main_llm_from_tuple(get_default_llms()) + ) print(f"Time Cutoff: {cutoff}") print(f"Favor Recent: {recency_bias}") diff --git a/backend/danswer/server/danswer_api/ingestion.py b/backend/danswer/server/danswer_api/ingestion.py index 33a5b4f52f8..1b6e6d9852f 100644 --- a/backend/danswer/server/danswer_api/ingestion.py +++ b/backend/danswer/server/danswer_api/ingestion.py @@ -1,8 +1,5 @@ -import secrets - from fastapi import APIRouter from fastapi import Depends -from fastapi import Header from fastapi import HTTPException from sqlalchemy.orm import Session @@ -17,55 +14,19 @@ from danswer.db.engine import get_session from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index -from danswer.dynamic_configs.factory import get_dynamic_config_store -from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.indexing.embedder import DefaultIndexingEmbedder from danswer.indexing.indexing_pipeline import build_indexing_pipeline from danswer.server.danswer_api.models import DocMinimalInfo from danswer.server.danswer_api.models import IngestionDocument from danswer.server.danswer_api.models import IngestionResult from danswer.utils.logger import setup_logger +from ee.danswer.auth.users import api_key_dep logger = setup_logger() # not using /api to avoid confusion with nginx api path routing router = APIRouter(prefix="/danswer-api") -# Assumes this gives admin privileges, basic users should not be allowed to call any Danswer apis -_DANSWER_API_KEY = "danswer_api_key" - - -def get_danswer_api_key(key_len: int = 30, dont_regenerate: bool = False) -> str | None: - kv_store = get_dynamic_config_store() - try: - return str(kv_store.load(_DANSWER_API_KEY)) - except ConfigNotFoundError: - if dont_regenerate: - return None - - logger.info("Generating Danswer API Key") - - api_key = "dn_" + secrets.token_urlsafe(key_len) - kv_store.store(_DANSWER_API_KEY, api_key, encrypt=True) - - return api_key - - -def delete_danswer_api_key() -> None: - kv_store = get_dynamic_config_store() - try: - kv_store.delete(_DANSWER_API_KEY) - except ConfigNotFoundError: - pass - - -def api_key_dep(authorization: str = Header(...)) -> str: - saved_key = get_danswer_api_key(dont_regenerate=True) - token = authorization.removeprefix("Bearer ").strip() - if token != saved_key or not saved_key: - raise HTTPException(status_code=401, detail="Invalid API key") - return token - @router.get("/connector-docs/{cc_pair_id}") def get_docs_by_connector_credential_pair( diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index 28981c716b7..861657a43ea 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -77,7 +77,7 @@ def associate_credential_to_connector( connector_id: int, credential_id: int, metadata: ConnectorCredentialPairMetadata, - user: User = Depends(current_user), + user: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> StatusResponse[int]: try: @@ -85,6 +85,7 @@ def associate_credential_to_connector( connector_id=connector_id, credential_id=credential_id, cc_pair_name=metadata.name, + is_public=metadata.is_public, user=user, db_session=db_session, ) @@ -96,7 +97,7 @@ def associate_credential_to_connector( def dissociate_credential_from_connector( connector_id: int, credential_id: int, - user: User = Depends(current_user), + user: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> StatusResponse[int]: return remove_credential_from_connector( diff --git a/backend/danswer/server/documents/models.py b/backend/danswer/server/documents/models.py index f1eb71dcea2..e02132e741a 100644 --- a/backend/danswer/server/documents/models.py +++ b/backend/danswer/server/documents/models.py @@ -183,6 +183,7 @@ class ConnectorCredentialPairIdentifier(BaseModel): class ConnectorCredentialPairMetadata(BaseModel): name: str | None + is_public: bool class ConnectorCredentialPairDescriptor(BaseModel): diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index 006f7506894..6739da46606 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -180,6 +180,7 @@ def get_persona( persona_id=persona_id, user=user, db_session=db_session, + is_for_edit=False, ) ) diff --git a/backend/danswer/server/features/tool/models.py b/backend/danswer/server/features/tool/models.py index feb3ba68269..0c1da965d4f 100644 --- a/backend/danswer/server/features/tool/models.py +++ b/backend/danswer/server/features/tool/models.py @@ -10,6 +10,7 @@ class ToolSnapshot(BaseModel): name: str description: str definition: dict[str, Any] | None + display_name: str in_code_tool_id: str | None @classmethod @@ -19,5 +20,6 @@ def from_model(cls, tool: Tool) -> "ToolSnapshot": name=tool.name, description=tool.description, definition=tool.openapi_schema, + display_name=tool.display_name or tool.name, in_code_tool_id=tool.in_code_tool_id, ) diff --git a/backend/danswer/server/gpts/api.py b/backend/danswer/server/gpts/api.py index cefb923410b..a3ce59edc37 100644 --- a/backend/danswer/server/gpts/api.py +++ b/backend/danswer/server/gpts/api.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import Session from danswer.db.engine import get_session -from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_default_llms from danswer.search.models import SearchRequest from danswer.search.pipeline import SearchPipeline from danswer.server.danswer_api.ingestion import api_key_dep @@ -67,27 +67,31 @@ def gpt_search( _: str | None = Depends(api_key_dep), db_session: Session = Depends(get_session), ) -> GptSearchResponse: - top_chunks = SearchPipeline( + llm, fast_llm = get_default_llms() + top_sections = SearchPipeline( search_request=SearchRequest( query=search_request.query, ), user=None, - llm=get_default_llm(), + llm=llm, + fast_llm=fast_llm, db_session=db_session, - ).reranked_chunks + ).reranked_sections return GptSearchResponse( matching_document_chunks=[ GptDocChunk( - title=chunk.semantic_identifier, - content=chunk.content, - source_type=chunk.source_type, - link=chunk.source_links.get(0, "") if chunk.source_links else "", - metadata=chunk.metadata, - document_age=time_ago(chunk.updated_at) - if chunk.updated_at + title=section.center_chunk.semantic_identifier, + content=section.center_chunk.content, + source_type=section.center_chunk.source_type, + link=section.center_chunk.source_links.get(0, "") + if section.center_chunk.source_links + else "", + metadata=section.center_chunk.metadata, + document_age=time_ago(section.center_chunk.updated_at) + if section.center_chunk.updated_at else "Unknown", ) - for chunk in top_chunks + for section in top_sections ], ) diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index c60206ca3f3..d6a52917f3b 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -1,23 +1,16 @@ -import json from datetime import datetime from datetime import timedelta from datetime import timezone from typing import cast from fastapi import APIRouter -from fastapi import Body from fastapi import Depends from fastapi import HTTPException from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ -from danswer.configs.app_configs import TOKEN_BUDGET_GLOBALLY_ENABLED from danswer.configs.constants import DocumentSource -from danswer.configs.constants import ENABLE_TOKEN_BUDGET -from danswer.configs.constants import TOKEN_BUDGET -from danswer.configs.constants import TOKEN_BUDGET_SETTINGS -from danswer.configs.constants import TOKEN_BUDGET_TIME_PERIOD from danswer.db.connector_credential_pair import get_connector_credential_pair from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed from danswer.db.engine import get_session @@ -31,7 +24,7 @@ from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.file_store.file_store import get_default_file_store -from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_default_llms from danswer.llm.utils import test_llm from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.manage.models import BoostDoc @@ -133,7 +126,7 @@ def validate_existing_genai_api_key( pass try: - llm = get_default_llm(timeout=10) + llm, __ = get_default_llms(timeout=10) except ValueError: raise HTTPException(status_code=404, detail="LLM not setup") @@ -152,7 +145,9 @@ def create_deletion_attempt_for_connector_id( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> None: - from danswer.background.celery.celery import cleanup_connector_credential_pair_task + from danswer.background.celery.celery_app import ( + cleanup_connector_credential_pair_task, + ) connector_id = connector_credential_pair_identifier.connector_id credential_id = connector_credential_pair_identifier.credential_id @@ -193,41 +188,3 @@ def create_deletion_attempt_for_connector_id( file_store = get_default_file_store(db_session) for file_name in connector.connector_specific_config["file_locations"]: file_store.delete_file(file_name) - - -@router.get("/admin/token-budget-settings") -def get_token_budget_settings(_: User = Depends(current_admin_user)) -> dict: - if not TOKEN_BUDGET_GLOBALLY_ENABLED: - raise HTTPException( - status_code=400, detail="Token budget is not enabled in the application." - ) - - try: - settings_json = cast( - str, get_dynamic_config_store().load(TOKEN_BUDGET_SETTINGS) - ) - settings = json.loads(settings_json) - return settings - except ConfigNotFoundError: - raise HTTPException(status_code=404, detail="Token budget settings not found.") - - -@router.put("/admin/token-budget-settings") -def update_token_budget_settings( - _: User = Depends(current_admin_user), - enable_token_budget: bool = Body(..., embed=True), - token_budget: int = Body(..., ge=0, embed=True), # Ensure non-negative - token_budget_time_period: int = Body(..., ge=1, embed=True), # Ensure positive -) -> dict[str, str]: - # Prepare the settings as a JSON string - settings_json = json.dumps( - { - ENABLE_TOKEN_BUDGET: enable_token_budget, - TOKEN_BUDGET: token_budget, - TOKEN_BUDGET_TIME_PERIOD: token_budget_time_period, - } - ) - - # Store the settings in the dynamic config store - get_dynamic_config_store().store(TOKEN_BUDGET_SETTINGS, settings_json) - return {"message": "Token budget settings updated successfully."} diff --git a/backend/danswer/server/manage/llm/api.py b/backend/danswer/server/manage/llm/api.py index 6e21a29b12b..4df00b529af 100644 --- a/backend/danswer/server/manage/llm/api.py +++ b/backend/danswer/server/manage/llm/api.py @@ -13,7 +13,7 @@ from danswer.db.llm import update_default_provider from danswer.db.llm import upsert_llm_provider from danswer.db.models import User -from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_default_llms from danswer.llm.factory import get_llm from danswer.llm.llm_provider_options import fetch_available_well_known_llms from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor @@ -85,8 +85,7 @@ def test_default_provider( _: User | None = Depends(current_admin_user), ) -> None: try: - llm = get_default_llm() - fast_llm = get_default_llm(use_fast_llm=True) + llm, fast_llm = get_default_llms() except ValueError: logger.exception("Failed to fetch default LLM Provider") raise HTTPException(status_code=400, detail="No LLM Provider setup") diff --git a/backend/danswer/server/manage/models.py b/backend/danswer/server/manage/models.py index cbc3251bfd3..f544df0f24e 100644 --- a/backend/danswer/server/manage/models.py +++ b/backend/danswer/server/manage/models.py @@ -12,6 +12,8 @@ from danswer.db.models import ChannelConfig from danswer.db.models import SlackBotConfig as SlackBotConfigModel from danswer.db.models import SlackBotResponseType +from danswer.db.models import StandardAnswer as StandardAnswerModel +from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel from danswer.indexing.models import EmbeddingModelDetail from danswer.server.features.persona.models import PersonaSnapshot from danswer.server.models import FullUserSnapshot @@ -84,6 +86,57 @@ class HiddenUpdateRequest(BaseModel): hidden: bool +class StandardAnswerCategoryCreationRequest(BaseModel): + name: str + + +class StandardAnswerCategory(BaseModel): + id: int + name: str + + @classmethod + def from_model( + cls, standard_answer_category: StandardAnswerCategoryModel + ) -> "StandardAnswerCategory": + return cls( + id=standard_answer_category.id, + name=standard_answer_category.name, + ) + + +class StandardAnswer(BaseModel): + id: int + keyword: str + answer: str + categories: list[StandardAnswerCategory] + + @classmethod + def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer": + return cls( + id=standard_answer_model.id, + keyword=standard_answer_model.keyword, + answer=standard_answer_model.answer, + categories=[ + StandardAnswerCategory.from_model(standard_answer_category_model) + for standard_answer_category_model in standard_answer_model.categories + ], + ) + + +class StandardAnswerCreationRequest(BaseModel): + keyword: str + answer: str + categories: list[int] + + @validator("categories", pre=True) + def validate_categories(cls, value: list[int]) -> list[int]: + if len(value) < 1: + raise ValueError( + "At least one category must be attached to a standard answer" + ) + return value + + class SlackBotTokens(BaseModel): bot_token: str app_token: str @@ -102,12 +155,15 @@ class SlackBotConfigCreationRequest(BaseModel): channel_names: list[str] respond_tag_only: bool = False respond_to_bots: bool = False + enable_auto_filters: bool = False # If no team members, assume respond in the channel to everyone respond_team_member_list: list[str] = [] + respond_slack_group_list: list[str] = [] answer_filters: list[AllowedAnswerFilters] = [] # list of user emails follow_up_tags: list[str] | None = None response_type: SlackBotResponseType + standard_answer_categories: list[int] = [] @validator("answer_filters", pre=True) def validate_filters(cls, value: list[str]) -> list[str]: @@ -132,6 +188,8 @@ class SlackBotConfig(BaseModel): persona: PersonaSnapshot | None channel_config: ChannelConfig response_type: SlackBotResponseType + standard_answer_categories: list[StandardAnswerCategory] + enable_auto_filters: bool @classmethod def from_model( @@ -148,6 +206,11 @@ def from_model( ), channel_config=slack_bot_config_model.channel_config, response_type=slack_bot_config_model.response_type, + standard_answer_categories=[ + StandardAnswerCategory.from_model(standard_answer_category_model) + for standard_answer_category_model in slack_bot_config_model.standard_answer_categories + ], + enable_auto_filters=slack_bot_config_model.enable_auto_filters, ) diff --git a/backend/danswer/server/manage/slack_bot.py b/backend/danswer/server/manage/slack_bot.py index aade420adeb..d5f08e2694a 100644 --- a/backend/danswer/server/manage/slack_bot.py +++ b/backend/danswer/server/manage/slack_bot.py @@ -37,6 +37,9 @@ def _form_channel_config( respond_team_member_list = ( slack_bot_config_creation_request.respond_team_member_list ) + respond_slack_group_list = ( + slack_bot_config_creation_request.respond_slack_group_list + ) answer_filters = slack_bot_config_creation_request.answer_filters follow_up_tags = slack_bot_config_creation_request.follow_up_tags @@ -58,7 +61,7 @@ def _form_channel_config( detail=str(e), ) - if respond_tag_only and respond_team_member_list: + if respond_tag_only and (respond_team_member_list or respond_slack_group_list): raise ValueError( "Cannot set DanswerBot to only respond to tags only and " "also respond to a predetermined set of users." @@ -71,6 +74,8 @@ def _form_channel_config( channel_config["respond_tag_only"] = respond_tag_only if respond_team_member_list: channel_config["respond_team_member_list"] = respond_team_member_list + if respond_slack_group_list: + channel_config["respond_slack_group_list"] = respond_slack_group_list if answer_filters: channel_config["answer_filters"] = answer_filters if follow_up_tags is not None: @@ -108,7 +113,9 @@ def create_slack_bot_config( persona_id=persona_id, channel_config=channel_config, response_type=slack_bot_config_creation_request.response_type, + standard_answer_category_ids=slack_bot_config_creation_request.standard_answer_categories, db_session=db_session, + enable_auto_filters=slack_bot_config_creation_request.enable_auto_filters, ) return SlackBotConfig.from_model(slack_bot_config_model) @@ -140,7 +147,10 @@ def patch_slack_bot_config( existing_persona_id = existing_slack_bot_config.persona_id if existing_persona_id is not None: persona = get_persona_by_id( - persona_id=existing_persona_id, user=None, db_session=db_session + persona_id=existing_persona_id, + user=None, + db_session=db_session, + is_for_edit=False, ) if not persona.name.startswith(SLACK_BOT_PERSONA_PREFIX): @@ -163,7 +173,9 @@ def patch_slack_bot_config( persona_id=persona_id, channel_config=channel_config, response_type=slack_bot_config_creation_request.response_type, + standard_answer_category_ids=slack_bot_config_creation_request.standard_answer_categories, db_session=db_session, + enable_auto_filters=slack_bot_config_creation_request.enable_auto_filters, ) return SlackBotConfig.from_model(slack_bot_config_model) diff --git a/backend/danswer/server/manage/standard_answer.py b/backend/danswer/server/manage/standard_answer.py new file mode 100644 index 00000000000..69f9e8146df --- /dev/null +++ b/backend/danswer/server/manage/standard_answer.py @@ -0,0 +1,139 @@ +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from sqlalchemy.orm import Session + +from danswer.auth.users import current_admin_user +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.db.standard_answer import fetch_standard_answer +from danswer.db.standard_answer import fetch_standard_answer_categories +from danswer.db.standard_answer import fetch_standard_answer_category +from danswer.db.standard_answer import fetch_standard_answers +from danswer.db.standard_answer import insert_standard_answer +from danswer.db.standard_answer import insert_standard_answer_category +from danswer.db.standard_answer import remove_standard_answer +from danswer.db.standard_answer import update_standard_answer +from danswer.db.standard_answer import update_standard_answer_category +from danswer.server.manage.models import StandardAnswer +from danswer.server.manage.models import StandardAnswerCategory +from danswer.server.manage.models import StandardAnswerCategoryCreationRequest +from danswer.server.manage.models import StandardAnswerCreationRequest + +router = APIRouter(prefix="/manage") + + +@router.post("/admin/standard-answer") +def create_standard_answer( + standard_answer_creation_request: StandardAnswerCreationRequest, + db_session: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> StandardAnswer: + standard_answer_model = insert_standard_answer( + keyword=standard_answer_creation_request.keyword, + answer=standard_answer_creation_request.answer, + category_ids=standard_answer_creation_request.categories, + db_session=db_session, + ) + return StandardAnswer.from_model(standard_answer_model) + + +@router.get("/admin/standard-answer") +def list_standard_answers( + db_session: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> list[StandardAnswer]: + standard_answer_models = fetch_standard_answers(db_session=db_session) + return [ + StandardAnswer.from_model(standard_answer_model) + for standard_answer_model in standard_answer_models + ] + + +@router.patch("/admin/standard-answer/{standard_answer_id}") +def patch_standard_answer( + standard_answer_id: int, + standard_answer_creation_request: StandardAnswerCreationRequest, + db_session: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> StandardAnswer: + existing_standard_answer = fetch_standard_answer( + standard_answer_id=standard_answer_id, + db_session=db_session, + ) + + if existing_standard_answer is None: + raise HTTPException(status_code=404, detail="Standard answer not found") + + standard_answer_model = update_standard_answer( + standard_answer_id=standard_answer_id, + keyword=standard_answer_creation_request.keyword, + answer=standard_answer_creation_request.answer, + category_ids=standard_answer_creation_request.categories, + db_session=db_session, + ) + return StandardAnswer.from_model(standard_answer_model) + + +@router.delete("/admin/standard-answer/{standard_answer_id}") +def delete_standard_answer( + standard_answer_id: int, + db_session: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> None: + return remove_standard_answer( + standard_answer_id=standard_answer_id, + db_session=db_session, + ) + + +@router.post("/admin/standard-answer/category") +def create_standard_answer_category( + standard_answer_category_creation_request: StandardAnswerCategoryCreationRequest, + db_session: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> StandardAnswerCategory: + standard_answer_category_model = insert_standard_answer_category( + category_name=standard_answer_category_creation_request.name, + db_session=db_session, + ) + return StandardAnswerCategory.from_model(standard_answer_category_model) + + +@router.get("/admin/standard-answer/category") +def list_standard_answer_categories( + db_session: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> list[StandardAnswerCategory]: + standard_answer_category_models = fetch_standard_answer_categories( + db_session=db_session + ) + return [ + StandardAnswerCategory.from_model(standard_answer_category_model) + for standard_answer_category_model in standard_answer_category_models + ] + + +@router.patch("/admin/standard-answer/category/{standard_answer_category_id}") +def patch_standard_answer_category( + standard_answer_category_id: int, + standard_answer_category_creation_request: StandardAnswerCategoryCreationRequest, + db_session: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> StandardAnswerCategory: + existing_standard_answer_category = fetch_standard_answer_category( + standard_answer_category_id=standard_answer_category_id, + db_session=db_session, + ) + + if existing_standard_answer_category is None: + raise HTTPException( + status_code=404, detail="Standard answer category not found" + ) + + standard_answer_category_model = update_standard_answer_category( + standard_answer_category_id=standard_answer_category_id, + category_name=standard_answer_category_creation_request.name, + db_session=db_session, + ) + return StandardAnswerCategory.from_model(standard_answer_category_model) diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index 3fbd15c7e98..c635469919e 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -34,10 +34,10 @@ from danswer.server.models import InvitedUserSnapshot from danswer.server.models import MinimalUserSnapshot from danswer.utils.logger import setup_logger +from ee.danswer.db.api_key import is_api_key_email_address logger = setup_logger() - router = APIRouter() @@ -85,13 +85,20 @@ async def demote_admin( @router.get("/manage/users") def list_all_users( - q: str, - accepted_page: int, - invited_page: int, + q: str | None = None, + accepted_page: int | None = None, + invited_page: int | None = None, _: User | None = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> AllUsersResponse: - users = list_users(db_session, q=q) + if not q: + q = "" + + users = [ + user + for user in list_users(db_session, q=q) + if not is_api_key_email_address(user.email) + ] accepted_emails = {user.email for user in users} invited_emails = get_invited_users() if q: @@ -102,6 +109,26 @@ def list_all_users( accepted_count = len(accepted_emails) invited_count = len(invited_emails) + # If any of q, accepted_page, or invited_page is None, return all users + if accepted_page is None or invited_page is None: + return AllUsersResponse( + accepted=[ + FullUserSnapshot( + id=user.id, + email=user.email, + role=user.role, + status=UserStatus.LIVE + if user.is_active + else UserStatus.DEACTIVATED, + ) + for user in users + ], + invited=[InvitedUserSnapshot(email=email) for email in invited_emails], + accepted_pages=1, + invited_pages=1, + ) + + # Otherwise, return paginated results return AllUsersResponse( accepted=[ FullUserSnapshot( diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 834453e6e2d..646660c9fac 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -43,7 +43,7 @@ compute_max_document_tokens_for_persona, ) from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_default_llms from danswer.llm.headers import get_litellm_additional_request_headers from danswer.llm.utils import get_default_llm_tokenizer from danswer.secondary_llm_flows.chat_session_naming import ( @@ -63,6 +63,8 @@ from danswer.server.query_and_chat.models import PromptOverride from danswer.server.query_and_chat.models import RenameChatSessionResponse from danswer.server.query_and_chat.models import SearchFeedbackRequest +from danswer.server.query_and_chat.models import UpdateChatSessionThreadRequest +from danswer.server.query_and_chat.token_limit import check_token_rate_limits from danswer.utils.logger import setup_logger logger = setup_logger() @@ -77,9 +79,13 @@ def get_user_chat_sessions( ) -> ChatSessionsResponse: user_id = user.id if user is not None else None - chat_sessions = get_chat_sessions_by_user( - user_id=user_id, deleted=False, db_session=db_session - ) + try: + chat_sessions = get_chat_sessions_by_user( + user_id=user_id, deleted=False, db_session=db_session + ) + + except ValueError: + raise ValueError("Chat session does not exist or has been deleted") return ChatSessionsResponse( sessions=[ @@ -90,12 +96,30 @@ def get_user_chat_sessions( time_created=chat.time_created.isoformat(), shared_status=chat.shared_status, folder_id=chat.folder_id, + current_alternate_model=chat.current_alternate_model, ) for chat in chat_sessions ] ) +@router.put("/update-chat-session-model") +def update_chat_session_model( + update_thread_req: UpdateChatSessionThreadRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + chat_session = get_chat_session_by_id( + chat_session_id=update_thread_req.chat_session_id, + user_id=user.id if user is not None else None, + db_session=db_session, + ) + chat_session.current_alternate_model = update_thread_req.new_alternate_model + + db_session.add(chat_session) + db_session.commit() + + @router.get("/get-chat-session/{session_id}") def get_chat_session( session_id: int, @@ -138,6 +162,7 @@ def get_chat_session( description=chat_session.description, persona_id=chat_session.persona_id, persona_name=chat_session.persona.name, + current_alternate_model=chat_session.current_alternate_model, messages=[ translate_db_message_to_chat_message_detail( msg, remove_doc_content=is_shared # if shared, don't leak doc content @@ -199,7 +224,7 @@ def rename_chat_session( full_history = history_msgs + [final_msg] try: - llm = get_default_llm( + llm, _ = get_default_llms( additional_headers=get_litellm_additional_request_headers(request.headers) ) except GenAIDisabledException: @@ -251,6 +276,7 @@ def handle_new_chat_message( chat_message_req: CreateChatMessageRequest, request: Request, user: User | None = Depends(current_user), + _: None = Depends(check_token_rate_limits), ) -> StreamingResponse: """This endpoint is both used for all the following purposes: - Sending a new message in the session @@ -361,6 +387,7 @@ def get_max_document_tokens( persona_id=persona_id, user=user, db_session=db_session, + is_for_edit=False, ) except ValueError: raise HTTPException(status_code=404, detail="Persona not found") diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 09561bf24f8..ea1ce1ff680 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -32,6 +32,12 @@ class SimpleQueryRequest(BaseModel): query: str +class UpdateChatSessionThreadRequest(BaseModel): + # If not specified, use Danswer default persona + chat_session_id: int + new_alternate_model: str + + class ChatSessionCreationRequest(BaseModel): # If not specified, use Danswer default persona persona_id: int = 0 @@ -101,6 +107,9 @@ class CreateChatMessageRequest(ChunkContext): llm_override: LLMOverride | None = None prompt_override: PromptOverride | None = None + # allow user to specify an alternate assistnat + alternate_assistant_id: int | None = None + # used for seeded chats to kick off the generation of an AI answer use_existing_user_message: bool = False @@ -142,6 +151,7 @@ class ChatSessionDetails(BaseModel): time_created: str shared_status: ChatSessionSharedStatus folder_id: int | None + current_alternate_model: str | None = None class ChatSessionsResponse(BaseModel): @@ -174,6 +184,7 @@ class ChatMessageDetail(BaseModel): context_docs: RetrievalDocs | None message_type: MessageType time_sent: datetime + alternate_assistant_id: str | None # Dict mapping citation number to db_doc_id citations: dict[int, int] | None files: list[FileDescriptor] @@ -193,6 +204,7 @@ class ChatSessionDetailResponse(BaseModel): messages: list[ChatMessageDetail] time_created: datetime shared_status: ChatSessionSharedStatus + current_alternate_model: str | None class QueryValidationResponse(BaseModel): diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index b8c6945dcab..43192211b79 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -29,7 +29,7 @@ from danswer.server.query_and_chat.models import SimpleQueryRequest from danswer.server.query_and_chat.models import SourceTag from danswer.server.query_and_chat.models import TagResponse -from danswer.server.query_and_chat.token_budget import check_token_budget +from danswer.server.query_and_chat.token_limit import check_token_rate_limits from danswer.utils.logger import setup_logger logger = setup_logger() @@ -52,6 +52,7 @@ def admin_search( source_type=question.filters.source_type, document_set=question.filters.document_set, time_cutoff=question.filters.time_cutoff, + tags=question.filters.tags, access_control_list=user_acl_filters, ) @@ -149,7 +150,7 @@ def stream_query_validation( def get_answer_with_quote( query_request: DirectQARequest, user: User = Depends(current_user), - _: bool = Depends(check_token_budget), + _: None = Depends(check_token_rate_limits), ) -> StreamingResponse: query = query_request.messages[0].message logger.info(f"Received query for one shot answer with quotes: {query}") diff --git a/backend/danswer/server/query_and_chat/token_budget.py b/backend/danswer/server/query_and_chat/token_budget.py deleted file mode 100644 index 49a84f5b0ec..00000000000 --- a/backend/danswer/server/query_and_chat/token_budget.py +++ /dev/null @@ -1,79 +0,0 @@ -import json -from datetime import datetime -from datetime import timedelta -from typing import cast - -from fastapi import HTTPException -from sqlalchemy import func -from sqlalchemy.orm import Session - -from danswer.configs.app_configs import TOKEN_BUDGET_GLOBALLY_ENABLED -from danswer.configs.constants import ENABLE_TOKEN_BUDGET -from danswer.configs.constants import TOKEN_BUDGET -from danswer.configs.constants import TOKEN_BUDGET_SETTINGS -from danswer.configs.constants import TOKEN_BUDGET_TIME_PERIOD -from danswer.db.engine import get_session_context_manager -from danswer.db.models import ChatMessage -from danswer.dynamic_configs.factory import get_dynamic_config_store - -BUDGET_LIMIT_DEFAULT = -1 # Default to no limit -TIME_PERIOD_HOURS_DEFAULT = 12 - - -def is_under_token_budget(db_session: Session) -> bool: - try: - settings_json = cast( - str, get_dynamic_config_store().load(TOKEN_BUDGET_SETTINGS) - ) - except Exception: - return True - - settings = json.loads(settings_json) - - is_enabled = settings.get(ENABLE_TOKEN_BUDGET, False) - - if not is_enabled: - return True - - budget_limit = settings.get(TOKEN_BUDGET, -1) - - if budget_limit < 0: - return True - - period_hours = settings.get(TOKEN_BUDGET_TIME_PERIOD, TIME_PERIOD_HOURS_DEFAULT) - period_start_time = datetime.now() - timedelta(hours=period_hours) - - # Fetch the sum of all tokens used within the period - token_sum = ( - db_session.query(func.sum(ChatMessage.token_count)) - .filter(ChatMessage.time_sent >= period_start_time) - .scalar() - or 0 - ) - - print( - "token_sum:", - token_sum, - "budget_limit:", - budget_limit, - "period_hours:", - period_hours, - "period_start_time:", - period_start_time, - ) - - return token_sum < ( - budget_limit * 1000 - ) # Budget limit is expressed in thousands of tokens - - -def check_token_budget() -> None: - if not TOKEN_BUDGET_GLOBALLY_ENABLED: - return None - - with get_session_context_manager() as db_session: - # Perform the token budget check here, possibly using `user` and `db_session` for database access if needed - if not is_under_token_budget(db_session): - raise HTTPException( - status_code=429, detail="Sorry, token budget exceeded. Try again later." - ) diff --git a/backend/danswer/server/query_and_chat/token_limit.py b/backend/danswer/server/query_and_chat/token_limit.py new file mode 100644 index 00000000000..b44eec3a64c --- /dev/null +++ b/backend/danswer/server/query_and_chat/token_limit.py @@ -0,0 +1,135 @@ +from collections.abc import Sequence +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from functools import lru_cache + +from dateutil import tz +from fastapi import Depends +from fastapi import HTTPException +from sqlalchemy import func +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.auth.users import current_user +from danswer.db.engine import get_session_context_manager +from danswer.db.models import ChatMessage +from danswer.db.models import ChatSession +from danswer.db.models import TokenRateLimit +from danswer.db.models import User +from danswer.utils.logger import setup_logger +from danswer.utils.variable_functionality import fetch_versioned_implementation +from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits + + +logger = setup_logger() + + +TOKEN_BUDGET_UNIT = 1_000 + + +def check_token_rate_limits( + user: User | None = Depends(current_user), +) -> None: + # short circuit if no rate limits are set up + # NOTE: result of `any_rate_limit_exists` is cached, so this call is fast 99% of the time + if not any_rate_limit_exists(): + return + + versioned_rate_limit_strategy = fetch_versioned_implementation( + "danswer.server.query_and_chat.token_limit", "_check_token_rate_limits" + ) + return versioned_rate_limit_strategy(user) + + +def _check_token_rate_limits(_: User | None) -> None: + _user_is_rate_limited_by_global() + + +""" +Global rate limits +""" + + +def _user_is_rate_limited_by_global() -> None: + with get_session_context_manager() as db_session: + global_rate_limits = fetch_all_global_token_rate_limits( + db_session=db_session, enabled_only=True, ordered=False + ) + + if global_rate_limits: + global_cutoff_time = _get_cutoff_time(global_rate_limits) + global_usage = _fetch_global_usage(global_cutoff_time, db_session) + + if _is_rate_limited(global_rate_limits, global_usage): + raise HTTPException( + status_code=429, + detail="Token budget exceeded for organization. Try again later.", + ) + + +def _fetch_global_usage( + cutoff_time: datetime, db_session: Session +) -> Sequence[tuple[datetime, int]]: + """ + Fetch global token usage within the cutoff time, grouped by minute + """ + result = db_session.execute( + select( + func.date_trunc("minute", ChatMessage.time_sent), + func.sum(ChatMessage.token_count), + ) + .join(ChatSession, ChatMessage.chat_session_id == ChatSession.id) + .filter( + ChatMessage.time_sent >= cutoff_time, + ) + .group_by(func.date_trunc("minute", ChatMessage.time_sent)) + ).all() + + return [(row[0], row[1]) for row in result] + + +""" +Common functions +""" + + +def _get_cutoff_time(rate_limits: Sequence[TokenRateLimit]) -> datetime: + max_period_hours = max(rate_limit.period_hours for rate_limit in rate_limits) + return datetime.now(tz=timezone.utc) - timedelta(hours=max_period_hours) + + +def _is_rate_limited( + rate_limits: Sequence[TokenRateLimit], usage: Sequence[tuple[datetime, int]] +) -> bool: + """ + If at least one rate limit is exceeded, return True + """ + for rate_limit in rate_limits: + tokens_used = sum( + u_token_count + for u_date, u_token_count in usage + if u_date + >= datetime.now(tz=tz.UTC) - timedelta(hours=rate_limit.period_hours) + ) + + if tokens_used >= rate_limit.token_budget * TOKEN_BUDGET_UNIT: + return True + + return False + + +@lru_cache() +def any_rate_limit_exists() -> bool: + """Checks if any rate limit exists in the database. Is cached, so that if no rate limits + are setup, we don't have any effect on average query latency.""" + logger.info("Checking for any rate limits...") + with get_session_context_manager() as db_session: + return ( + db_session.scalar( + select(TokenRateLimit.id).where( + TokenRateLimit.enabled == True # noqa: E712 + ) + ) + is not None + ) diff --git a/backend/danswer/server/settings/models.py b/backend/danswer/server/settings/models.py index 041e360d72d..9afacf5add2 100644 --- a/backend/danswer/server/settings/models.py +++ b/backend/danswer/server/settings/models.py @@ -14,6 +14,7 @@ class Settings(BaseModel): chat_page_enabled: bool = True search_page_enabled: bool = True default_page: PageType = PageType.SEARCH + maximum_chat_retention_days: int | None = None def check_validity(self) -> None: chat_page_enabled = self.chat_page_enabled diff --git a/backend/danswer/server/token_rate_limits/api.py b/backend/danswer/server/token_rate_limits/api.py new file mode 100644 index 00000000000..245e3391410 --- /dev/null +++ b/backend/danswer/server/token_rate_limits/api.py @@ -0,0 +1,79 @@ +from fastapi import APIRouter +from fastapi import Depends +from sqlalchemy.orm import Session + +from danswer.auth.users import current_admin_user +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.server.query_and_chat.token_limit import any_rate_limit_exists +from danswer.server.token_rate_limits.models import TokenRateLimitArgs +from danswer.server.token_rate_limits.models import TokenRateLimitDisplay +from ee.danswer.db.token_limit import delete_token_rate_limit +from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits +from ee.danswer.db.token_limit import insert_global_token_rate_limit +from ee.danswer.db.token_limit import update_token_rate_limit + +router = APIRouter(prefix="/admin/token-rate-limits") + + +""" +Global Token Limit Settings +""" + + +@router.get("/global") +def get_global_token_limit_settings( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[TokenRateLimitDisplay]: + return [ + TokenRateLimitDisplay.from_db(token_rate_limit) + for token_rate_limit in fetch_all_global_token_rate_limits(db_session) + ] + + +@router.post("/global") +def create_global_token_limit_settings( + token_limit_settings: TokenRateLimitArgs, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> TokenRateLimitDisplay: + rate_limit_display = TokenRateLimitDisplay.from_db( + insert_global_token_rate_limit(db_session, token_limit_settings) + ) + # clear cache in case this was the first rate limit created + any_rate_limit_exists.cache_clear() + return rate_limit_display + + +""" +General Token Limit Settings +""" + + +@router.put("/rate-limit/{token_rate_limit_id}") +def update_token_limit_settings( + token_rate_limit_id: int, + token_limit_settings: TokenRateLimitArgs, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> TokenRateLimitDisplay: + return TokenRateLimitDisplay.from_db( + update_token_rate_limit( + db_session=db_session, + token_rate_limit_id=token_rate_limit_id, + token_rate_limit_settings=token_limit_settings, + ) + ) + + +@router.delete("/rate-limit/{token_rate_limit_id}") +def delete_token_limit_settings( + token_rate_limit_id: int, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + return delete_token_rate_limit( + db_session=db_session, + token_rate_limit_id=token_rate_limit_id, + ) diff --git a/backend/danswer/server/token_rate_limits/models.py b/backend/danswer/server/token_rate_limits/models.py new file mode 100644 index 00000000000..351abe92e7e --- /dev/null +++ b/backend/danswer/server/token_rate_limits/models.py @@ -0,0 +1,25 @@ +from pydantic import BaseModel + +from danswer.db.models import TokenRateLimit + + +class TokenRateLimitArgs(BaseModel): + enabled: bool + token_budget: int + period_hours: int + + +class TokenRateLimitDisplay(BaseModel): + token_id: int + enabled: bool + token_budget: int + period_hours: int + + @classmethod + def from_db(cls, token_rate_limit: TokenRateLimit) -> "TokenRateLimitDisplay": + return cls( + token_id=token_rate_limit.id, + enabled=token_rate_limit.enabled, + token_budget=token_rate_limit.token_budget, + period_hours=token_rate_limit.period_hours, + ) diff --git a/backend/danswer/tools/built_in_tools.py b/backend/danswer/tools/built_in_tools.py index 94ffb085748..68fae10c061 100644 --- a/backend/danswer/tools/built_in_tools.py +++ b/backend/danswer/tools/built_in_tools.py @@ -1,3 +1,4 @@ +import os from typing import Type from typing import TypedDict @@ -9,6 +10,7 @@ from danswer.db.models import Persona from danswer.db.models import Tool as ToolDBModel from danswer.tools.images.image_generation_tool import ImageGenerationTool +from danswer.tools.internet_search.internet_search_tool import InternetSearchTool from danswer.tools.search.search_tool import SearchTool from danswer.tools.tool import Tool from danswer.utils.logger import setup_logger @@ -20,22 +22,41 @@ class InCodeToolInfo(TypedDict): cls: Type[Tool] description: str in_code_tool_id: str + display_name: str BUILT_IN_TOOLS: list[InCodeToolInfo] = [ - { - "cls": SearchTool, - "description": "The Search Tool allows the Assistant to search through connected knowledge to help build an answer.", - "in_code_tool_id": SearchTool.__name__, - }, - { - "cls": ImageGenerationTool, - "description": ( + InCodeToolInfo( + cls=SearchTool, + description="The Search Tool allows the Assistant to search through connected knowledge to help build an answer.", + in_code_tool_id=SearchTool.__name__, + display_name=SearchTool._DISPLAY_NAME, + ), + InCodeToolInfo( + cls=ImageGenerationTool, + description=( "The Image Generation Tool allows the assistant to use DALL-E 3 to generate images. " "The tool will be used when the user asks the assistant to generate an image." ), - "in_code_tool_id": ImageGenerationTool.__name__, - }, + in_code_tool_id=ImageGenerationTool.__name__, + display_name=ImageGenerationTool._DISPLAY_NAME, + ), + # don't show the InternetSearchTool as an option if BING_API_KEY is not available + *( + [ + InCodeToolInfo( + cls=InternetSearchTool, + description=( + "The Internet Search Tool allows the assistant " + "to perform internet searches for up-to-date information." + ), + in_code_tool_id=InternetSearchTool.__name__, + display_name=InternetSearchTool._DISPLAY_NAME, + ) + ] + if os.environ.get("BING_API_KEY") + else [] + ), ] @@ -55,12 +76,14 @@ def load_builtin_tools(db_session: Session) -> None: # Update existing tool tool.name = tool_name tool.description = tool_info["description"] + tool.display_name = tool_info["display_name"] logger.info(f"Updated tool: {tool_name}") else: # Add new tool new_tool = ToolDBModel( name=tool_name, description=tool_info["description"], + display_name=tool_info["display_name"], in_code_tool_id=tool_info["in_code_tool_id"], ) db_session.add(new_tool) diff --git a/backend/danswer/tools/custom/custom_tool.py b/backend/danswer/tools/custom/custom_tool.py index ea232fc5a74..84f10c3ec05 100644 --- a/backend/danswer/tools/custom/custom_tool.py +++ b/backend/danswer/tools/custom/custom_tool.py @@ -44,11 +44,20 @@ def __init__(self, method_spec: MethodSpec, base_url: str) -> None: self._tool_definition = self._method_spec.to_tool_definition() self._name = self._method_spec.name - self.description = self._method_spec.summary + self._description = self._method_spec.summary + @property def name(self) -> str: return self._name + @property + def description(self) -> str: + return self._description + + @property + def display_name(self) -> str: + return self._name + """For LLMs which support explicit tool calling""" def tool_definition(self) -> dict: @@ -77,7 +86,7 @@ def get_args_for_non_tool_calling_llm( content=SHOULD_USE_CUSTOM_TOOL_USER_PROMPT.format( history=history, query=query, - tool_name=self.name(), + tool_name=self.name, tool_description=self.description, ) ), @@ -93,7 +102,7 @@ def get_args_for_non_tool_calling_llm( content=TOOL_ARG_USER_PROMPT.format( history=history, query=query, - tool_name=self.name(), + tool_name=self.name, tool_description=self.description, tool_args=self.tool_definition()["function"]["parameters"], ) @@ -121,7 +130,7 @@ def get_args_for_non_tool_calling_llm( # pretend like nothing happened if not parse-able logger.error( - f"Failed to parse args for '{self.name()}' tool. Recieved: {args_result_str}" + f"Failed to parse args for '{self.name}' tool. Recieved: {args_result_str}" ) return None diff --git a/backend/danswer/tools/force.py b/backend/danswer/tools/force.py index 1c3f0a220ed..53f7997d6ed 100644 --- a/backend/danswer/tools/force.py +++ b/backend/danswer/tools/force.py @@ -37,4 +37,4 @@ def filter_tools_for_force_tool_use( if not force_use_tool: return tools - return [tool for tool in tools if tool.name() == force_use_tool.tool_name] + return [tool for tool in tools if tool.name == force_use_tool.tool_name] diff --git a/backend/danswer/tools/images/image_generation_tool.py b/backend/danswer/tools/images/image_generation_tool.py index 22aa40993b6..ed145a55cf2 100644 --- a/backend/danswer/tools/images/image_generation_tool.py +++ b/backend/danswer/tools/images/image_generation_tool.py @@ -10,6 +10,7 @@ from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.dynamic_configs.interface import JSON_ro from danswer.llm.answering.models import PreviousMessage +from danswer.llm.headers import build_llm_extra_headers from danswer.llm.interfaces import LLM from danswer.llm.utils import build_content_with_imgs from danswer.llm.utils import message_to_string @@ -54,24 +55,46 @@ class ImageGenerationResponse(BaseModel): class ImageGenerationTool(Tool): - NAME = "run_image_generation" + _NAME = "run_image_generation" + _DESCRIPTION = "Generate an image from a prompt." + _DISPLAY_NAME = "Image Generation Tool" def __init__( - self, api_key: str, model: str = "dall-e-3", num_imgs: int = 2 + self, + api_key: str, + api_base: str | None, + api_version: str | None, + model: str = "dall-e-3", + num_imgs: int = 2, + additional_headers: dict[str, str] | None = None, ) -> None: self.api_key = api_key + self.api_base = api_base + self.api_version = api_version + self.model = model self.num_imgs = num_imgs + self.additional_headers = additional_headers + + @property def name(self) -> str: - return self.NAME + return self._NAME + + @property + def description(self) -> str: + return self._DESCRIPTION + + @property + def display_name(self) -> str: + return self._DISPLAY_NAME def tool_definition(self) -> dict: return { "type": "function", "function": { - "name": self.name(), - "description": "Generate an image from a prompt", + "name": self.name, + "description": self.description, "parameters": { "type": "object", "properties": { @@ -141,7 +164,11 @@ def _generate_image(self, prompt: str) -> ImageGenerationResponse: prompt=prompt, model=self.model, api_key=self.api_key, + # need to pass in None rather than empty str + api_base=self.api_base or None, + api_version=self.api_version or None, n=1, + extra_headers=build_llm_extra_headers(self.additional_headers), ) return ImageGenerationResponse( revised_prompt=response.data[0]["revised_prompt"], diff --git a/backend/danswer/tools/internet_search/internet_search_tool.py b/backend/danswer/tools/internet_search/internet_search_tool.py new file mode 100644 index 00000000000..0f92deadd04 --- /dev/null +++ b/backend/danswer/tools/internet_search/internet_search_tool.py @@ -0,0 +1,233 @@ +import json +from collections.abc import Generator +from datetime import datetime +from typing import Any +from typing import cast + +import httpx + +from danswer.chat.chat_utils import combine_message_chain +from danswer.chat.models import LlmDoc +from danswer.configs.constants import DocumentSource +from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF +from danswer.dynamic_configs.interface import JSON_ro +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.interfaces import LLM +from danswer.llm.utils import message_to_string +from danswer.prompts.chat_prompts import INTERNET_SEARCH_QUERY_REPHRASE +from danswer.prompts.constants import GENERAL_SEP_PAT +from danswer.search.models import SearchDoc +from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase +from danswer.tools.internet_search.models import InternetSearchResponse +from danswer.tools.internet_search.models import InternetSearchResult +from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS +from danswer.tools.tool import Tool +from danswer.tools.tool import ToolResponse +from danswer.utils.logger import setup_logger + +logger = setup_logger() + +INTERNET_SEARCH_RESPONSE_ID = "internet_search_response" + +YES_INTERNET_SEARCH = "Yes Internet Search" +SKIP_INTERNET_SEARCH = "Skip Internet Search" + +INTERNET_SEARCH_TEMPLATE = f""" +Given the conversation history and a follow up query, determine if the system should call \ +an external internet search tool to better answer the latest user input. +Your default response is {SKIP_INTERNET_SEARCH}. + +Respond "{YES_INTERNET_SEARCH}" if: +- The user is asking for information that requires an internet search. + +Conversation History: +{GENERAL_SEP_PAT} +{{chat_history}} +{GENERAL_SEP_PAT} + +If you are at all unsure, respond with {SKIP_INTERNET_SEARCH}. +Respond with EXACTLY and ONLY "{YES_INTERNET_SEARCH}" or "{SKIP_INTERNET_SEARCH}" + +Follow Up Input: +{{final_query}} +""".strip() + + +def llm_doc_from_internet_search_result(result: InternetSearchResult) -> LlmDoc: + return LlmDoc( + document_id=result.link, + content=result.snippet, + blurb=result.snippet, + semantic_identifier=result.link, + source_type=DocumentSource.WEB, + metadata={}, + updated_at=datetime.now(), + link=result.link, + source_links={0: result.link}, + ) + + +def internet_search_response_to_search_docs( + internet_search_response: InternetSearchResponse, +) -> list[SearchDoc]: + return [ + SearchDoc( + document_id=doc.link, + chunk_ind=-1, + semantic_identifier=doc.title, + link=doc.link, + blurb=doc.snippet, + source_type=DocumentSource.NOT_APPLICABLE, + boost=0, + hidden=False, + metadata={}, + score=None, + match_highlights=[], + updated_at=None, + primary_owners=[], + secondary_owners=[], + is_internet=True, + ) + for doc in internet_search_response.internet_results + ] + + +class InternetSearchTool(Tool): + _NAME = "run_internet_search" + _DISPLAY_NAME = "[Beta] Internet Search Tool" + _DESCRIPTION = "Perform an internet search for up-to-date information." + + def __init__(self, api_key: str, num_results: int = 10) -> None: + self.api_key = api_key + self.host = "https://api.bing.microsoft.com/v7.0" + self.headers = { + "Ocp-Apim-Subscription-Key": api_key, + "Content-Type": "application/json", + } + self.num_results = num_results + self.client = httpx.Client() + + @property + def name(self) -> str: + return self._NAME + + @property + def description(self) -> str: + return self._DESCRIPTION + + @property + def display_name(self) -> str: + return self._DISPLAY_NAME + + def tool_definition(self) -> dict: + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": { + "internet_search_query": { + "type": "string", + "description": "Query to search on the internet", + }, + }, + "required": ["internet_search_query"], + }, + }, + } + + def check_if_needs_internet_search( + self, + query: str, + history: list[PreviousMessage], + llm: LLM, + ) -> bool: + history_str = combine_message_chain( + messages=history, token_limit=GEN_AI_HISTORY_CUTOFF + ) + prompt = INTERNET_SEARCH_TEMPLATE.format( + chat_history=history_str, + final_query=query, + ) + use_internet_search_output = message_to_string(llm.invoke(prompt)) + + logger.debug( + f"Evaluated if should use internet search: {use_internet_search_output}" + ) + + return ( + YES_INTERNET_SEARCH.split()[0] + ).lower() in use_internet_search_output.lower() + + def get_args_for_non_tool_calling_llm( + self, + query: str, + history: list[PreviousMessage], + llm: LLM, + force_run: bool = False, + ) -> dict[str, Any] | None: + if not force_run and not self.check_if_needs_internet_search( + query, history, llm + ): + return None + + rephrased_query = history_based_query_rephrase( + query=query, + history=history, + llm=llm, + prompt_template=INTERNET_SEARCH_QUERY_REPHRASE, + ) + return { + "internet_search_query": rephrased_query, + } + + def build_tool_message_content( + self, *args: ToolResponse + ) -> str | list[str | dict[str, Any]]: + search_response = cast(InternetSearchResponse, args[0].response) + return json.dumps(search_response.dict()) + + def _perform_search(self, query: str) -> InternetSearchResponse: + response = self.client.get( + f"{self.host}/search", + headers=self.headers, + params={"q": query, "count": self.num_results}, + ) + results = response.json() + + return InternetSearchResponse( + revised_query=query, + internet_results=[ + InternetSearchResult( + title=result["name"], + link=result["url"], + snippet=result["snippet"], + ) + for result in results["webPages"]["value"][: self.num_results] + ], + ) + + def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: + query = cast(str, kwargs["internet_search_query"]) + + results = self._perform_search(query) + yield ToolResponse( + id=INTERNET_SEARCH_RESPONSE_ID, + response=results, + ) + + llm_docs = [ + llm_doc_from_internet_search_result(result) + for result in results.internet_results + ] + + yield ToolResponse( + id=FINAL_CONTEXT_DOCUMENTS, + response=llm_docs, + ) + + def final_result(self, *args: ToolResponse) -> JSON_ro: + search_response = cast(InternetSearchResponse, args[0].response) + return search_response.dict() diff --git a/backend/danswer/tools/internet_search/models.py b/backend/danswer/tools/internet_search/models.py new file mode 100644 index 00000000000..e6db4179b7f --- /dev/null +++ b/backend/danswer/tools/internet_search/models.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel + + +class InternetSearchResult(BaseModel): + title: str + link: str + snippet: str + + +class InternetSearchResponse(BaseModel): + revised_query: str + internet_results: list[InternetSearchResult] diff --git a/backend/danswer/tools/search/search_tool.py b/backend/danswer/tools/search/search_tool.py index b0b45bd8f40..44c5001d6cf 100644 --- a/backend/danswer/tools/search/search_tool.py +++ b/backend/danswer/tools/search/search_tool.py @@ -7,14 +7,16 @@ from sqlalchemy.orm import Session from danswer.chat.chat_utils import llm_doc_from_inference_section +from danswer.chat.models import DanswerContext +from danswer.chat.models import DanswerContexts from danswer.chat.models import LlmDoc from danswer.db.models import Persona from danswer.db.models import User from danswer.dynamic_configs.interface import JSON_ro -from danswer.llm.answering.doc_pruning import prune_documents from danswer.llm.answering.models import DocumentPruningConfig from danswer.llm.answering.models import PreviousMessage from danswer.llm.answering.models import PromptConfig +from danswer.llm.answering.prune_and_merge import prune_and_merge_sections from danswer.llm.interfaces import LLM from danswer.search.enums import QueryFlow from danswer.search.enums import SearchType @@ -30,6 +32,7 @@ from danswer.tools.tool import ToolResponse SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary" +SEARCH_DOC_CONTENT_ID = "search_doc_content" SECTION_RELEVANCE_LIST_ID = "section_relevance_list" FINAL_CONTEXT_DOCUMENTS = "final_context_documents" @@ -43,7 +46,7 @@ class SearchResponseSummary(BaseModel): recency_bias_multiplier: float -search_tool_description = """ +SEARCH_TOOL_DESCRIPTION = """ Runs a semantic search over the user's knowledge base. The default behavior is to use this tool. \ The only scenario where you should not use this tool is if: @@ -56,7 +59,9 @@ class SearchResponseSummary(BaseModel): class SearchTool(Tool): - NAME = "run_search" + _NAME = "run_search" + _DISPLAY_NAME = "Search Tool" + _DESCRIPTION = SEARCH_TOOL_DESCRIPTION def __init__( self, @@ -66,10 +71,11 @@ def __init__( retrieval_options: RetrievalDetails | None, prompt_config: PromptConfig, llm: LLM, + fast_llm: LLM, pruning_config: DocumentPruningConfig, # if specified, will not actually run a search and will instead return these # sections. Used when the user selects specific docs to talk to - selected_docs: list[LlmDoc] | None = None, + selected_sections: list[InferenceSection] | None = None, chunks_above: int = 0, chunks_below: int = 0, full_doc: bool = False, @@ -80,9 +86,10 @@ def __init__( self.retrieval_options = retrieval_options self.prompt_config = prompt_config self.llm = llm + self.fast_llm = fast_llm self.pruning_config = pruning_config - self.selected_docs = selected_docs + self.selected_sections = selected_sections self.chunks_above = chunks_above self.chunks_below = chunks_below @@ -90,8 +97,17 @@ def __init__( self.bypass_acl = bypass_acl self.db_session = db_session + @property def name(self) -> str: - return self.NAME + return self._NAME + + @property + def description(self) -> str: + return self._DESCRIPTION + + @property + def display_name(self) -> str: + return self._DISPLAY_NAME """For explicit tool calling""" @@ -99,8 +115,8 @@ def tool_definition(self) -> dict: return { "type": "function", "function": { - "name": self.name(), - "description": search_tool_description, + "name": self.name, + "description": self.description, "parameters": { "type": "object", "properties": { @@ -117,7 +133,9 @@ def tool_definition(self) -> dict: def build_tool_message_content( self, *args: ToolResponse ) -> str | list[str | dict[str, Any]]: - final_context_docs_response = args[2] + final_context_docs_response = next( + response for response in args if response.id == FINAL_CONTEXT_DOCUMENTS + ) final_context_docs = cast(list[LlmDoc], final_context_docs_response.response) return json.dumps( @@ -153,7 +171,7 @@ def get_args_for_non_tool_calling_llm( def _build_response_for_specified_sections( self, query: str ) -> Generator[ToolResponse, None, None]: - if self.selected_docs is None: + if self.selected_sections is None: raise ValueError("sections must be specified") yield ToolResponse( @@ -169,24 +187,27 @@ def _build_response_for_specified_sections( ) yield ToolResponse( id=SECTION_RELEVANCE_LIST_ID, - response=[i for i in range(len(self.selected_docs))], + response=[i for i in range(len(self.selected_sections))], ) - yield ToolResponse( - id=FINAL_CONTEXT_DOCUMENTS, - response=prune_documents( - docs=self.selected_docs, - doc_relevance_list=None, - prompt_config=self.prompt_config, - llm_config=self.llm.config, - question=query, - document_pruning_config=self.pruning_config, - ), + + final_context_sections = prune_and_merge_sections( + sections=self.selected_sections, + section_relevance_list=None, + prompt_config=self.prompt_config, + llm_config=self.llm.config, + question=query, + document_pruning_config=self.pruning_config, ) + llm_docs = [ + llm_doc_from_inference_section(section) + for section in final_context_sections + ] + yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS, response=llm_docs) def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: query = cast(str, kwargs["query"]) - if self.selected_docs: + if self.selected_sections: yield from self._build_response_for_specified_sections(query) return @@ -204,9 +225,13 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: chunks_above=self.chunks_above, chunks_below=self.chunks_below, full_doc=self.full_doc, + enable_auto_detect_filters=self.retrieval_options.enable_auto_detect_filters + if self.retrieval_options + else None, ), user=self.user, llm=self.llm, + fast_llm=self.fast_llm, bypass_acl=self.bypass_acl, db_session=self.db_session, ) @@ -221,27 +246,40 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier, ), ) + yield ToolResponse( + id=SEARCH_DOC_CONTENT_ID, + response=DanswerContexts( + contexts=[ + DanswerContext( + content=section.combined_content, + document_id=section.center_chunk.document_id, + semantic_identifier=section.center_chunk.semantic_identifier, + blurb=section.center_chunk.blurb, + ) + for section in search_pipeline.reranked_sections + ] + ), + ) yield ToolResponse( id=SECTION_RELEVANCE_LIST_ID, - response=search_pipeline.relevant_chunk_indices, + response=search_pipeline.relevant_section_indices, ) - llm_docs = [ - llm_doc_from_inference_section(section) - for section in search_pipeline.reranked_sections - ] - final_context_documents = prune_documents( - docs=llm_docs, - doc_relevance_list=[ - True if ind in search_pipeline.relevant_chunk_indices else False - for ind in range(len(llm_docs)) - ], + final_context_sections = prune_and_merge_sections( + sections=search_pipeline.reranked_sections, + section_relevance_list=search_pipeline.section_relevance_list, prompt_config=self.prompt_config, llm_config=self.llm.config, question=query, document_pruning_config=self.pruning_config, ) - yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS, response=final_context_documents) + + llm_docs = [ + llm_doc_from_inference_section(section) + for section in final_context_sections + ] + + yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS, response=llm_docs) def final_result(self, *args: ToolResponse) -> JSON_ro: final_docs = cast( diff --git a/backend/danswer/tools/search/search_utils.py b/backend/danswer/tools/search/search_utils.py index 7e5151bb582..5076632a694 100644 --- a/backend/danswer/tools/search/search_utils.py +++ b/backend/danswer/tools/search/search_utils.py @@ -1,5 +1,6 @@ from danswer.chat.models import LlmDoc from danswer.prompts.prompt_utils import clean_up_source +from danswer.search.models import InferenceSection def llm_doc_to_dict(llm_doc: LlmDoc, doc_num: int) -> dict: @@ -13,3 +14,18 @@ def llm_doc_to_dict(llm_doc: LlmDoc, doc_num: int) -> dict: if llm_doc.updated_at: doc_dict["updated_at"] = llm_doc.updated_at.strftime("%B %d, %Y %H:%M") return doc_dict + + +def section_to_dict(section: InferenceSection, section_num: int) -> dict: + doc_dict = { + "document_number": section_num + 1, + "title": section.center_chunk.semantic_identifier, + "content": section.combined_content, + "source": clean_up_source(section.center_chunk.source_type), + "metadata": section.center_chunk.metadata, + } + if section.center_chunk.updated_at: + doc_dict["updated_at"] = section.center_chunk.updated_at.strftime( + "%B %d, %Y %H:%M" + ) + return doc_dict diff --git a/backend/danswer/tools/tool.py b/backend/danswer/tools/tool.py index e335a049838..81b9b457178 100644 --- a/backend/danswer/tools/tool.py +++ b/backend/danswer/tools/tool.py @@ -9,10 +9,21 @@ class Tool(abc.ABC): + @property @abc.abstractmethod def name(self) -> str: raise NotImplementedError + @property + @abc.abstractmethod + def description(self) -> str: + raise NotImplementedError + + @property + @abc.abstractmethod + def display_name(self) -> str: + raise NotImplementedError + """For LLMs which support explicit tool calling""" @abc.abstractmethod diff --git a/backend/danswer/tools/tool_runner.py b/backend/danswer/tools/tool_runner.py index a4367d865d5..f962c214a03 100644 --- a/backend/danswer/tools/tool_runner.py +++ b/backend/danswer/tools/tool_runner.py @@ -18,11 +18,12 @@ def __init__(self, tool: Tool, args: dict[str, Any]): self._tool_responses: list[ToolResponse] | None = None def kickoff(self) -> ToolCallKickoff: - return ToolCallKickoff(tool_name=self.tool.name(), tool_args=self.args) + return ToolCallKickoff(tool_name=self.tool.name, tool_args=self.args) def tool_responses(self) -> Generator[ToolResponse, None, None]: if self._tool_responses is not None: yield from self._tool_responses + return tool_responses: list[ToolResponse] = [] for tool_response in self.tool.run(**self.args): @@ -37,7 +38,7 @@ def tool_message_content(self) -> str | list[str | dict[str, Any]]: def tool_final_result(self) -> ToolCallFinalResult: return ToolCallFinalResult( - tool_name=self.tool.name(), + tool_name=self.tool.name, tool_args=self.args, tool_result=self.tool.final_result(*self.tool_responses()), ) diff --git a/backend/danswer/tools/tool_selection.py b/backend/danswer/tools/tool_selection.py new file mode 100644 index 00000000000..dc8d697c2ad --- /dev/null +++ b/backend/danswer/tools/tool_selection.py @@ -0,0 +1,78 @@ +import re +from typing import Any + +from danswer.chat.chat_utils import combine_message_chain +from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.interfaces import LLM +from danswer.llm.utils import message_to_string +from danswer.prompts.constants import GENERAL_SEP_PAT +from danswer.tools.tool import Tool +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +SINGLE_TOOL_SELECTION_PROMPT = f""" +You are an expert at selecting the most useful tool to run for answering the query. +You will be given a numbered list of tools and their arguments, a message history, and a query. +You will select a single tool that will be most useful for answering the query. +Respond with only the number corresponding to the tool you want to use. + +Conversation History: +{GENERAL_SEP_PAT} +{{chat_history}} +{GENERAL_SEP_PAT} + +Query: +{{query}} + +Tools: +{{tool_list}} + +Respond with EXACTLY and ONLY the number corresponding to the tool you want to use. + +Your selection: +""" + + +def select_single_tool_for_non_tool_calling_llm( + tools_and_args: list[tuple[Tool, dict[str, Any]]], + history: list[PreviousMessage], + query: str, + llm: LLM, +) -> tuple[Tool, dict[str, Any]] | None: + if len(tools_and_args) == 1: + return tools_and_args[0] + + tool_list_str = "\n".join( + f"""```{ind}: {tool.name} ({args}) - {tool.description}```""" + for ind, (tool, args) in enumerate(tools_and_args) + ).lstrip() + + history_str = combine_message_chain( + messages=history, + token_limit=GEN_AI_HISTORY_CUTOFF, + ) + prompt = SINGLE_TOOL_SELECTION_PROMPT.format( + tool_list=tool_list_str, chat_history=history_str, query=query + ) + output = message_to_string(llm.invoke(prompt)) + try: + # First try to match the number + number_match = re.search(r"\d+", output) + if number_match: + tool_ind = int(number_match.group()) + return tools_and_args[tool_ind] + + # If that fails, try to match the tool name + for tool, args in tools_and_args: + if tool.name.lower() in output.lower(): + return tool, args + + # If that fails, return the first tool + return tools_and_args[0] + + except Exception: + logger.error(f"Failed to select single tool for non-tool-calling LLM: {output}") + return None diff --git a/backend/danswer/utils/acl.py b/backend/danswer/utils/acl.py index 8fbadb30000..5608530fac6 100644 --- a/backend/danswer/utils/acl.py +++ b/backend/danswer/utils/acl.py @@ -3,8 +3,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session -from danswer.access.models import DocumentAccess -from danswer.db.document import get_acccess_info_for_documents +from danswer.access.access import get_access_for_documents from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import Document from danswer.document_index.document_index_utils import get_both_index_names @@ -37,7 +36,7 @@ def set_acl_for_vespa(should_check_if_already_done: bool = False) -> None: # for all documents, set the `access_control_list` field appropriately # based on the state of Postgres documents = db_session.scalars(select(Document)).all() - document_access_info = get_acccess_info_for_documents( + document_access_dict = get_access_for_documents( db_session=db_session, document_ids=[document.id for document in documents], ) @@ -46,15 +45,16 @@ def set_acl_for_vespa(should_check_if_already_done: bool = False) -> None: vespa_index = get_default_document_index( primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name ) + if not isinstance(vespa_index, VespaIndex): raise ValueError("This script is only for Vespa indexes") update_requests = [ UpdateRequest( document_ids=[document_id], - access=DocumentAccess.build(user_ids, is_public), + access=access, ) - for document_id, user_ids, is_public in document_access_info + for document_id, access in document_access_dict.items() ] vespa_index.update(update_requests=update_requests) diff --git a/backend/danswer/utils/variable_functionality.py b/backend/danswer/utils/variable_functionality.py index 61414effdb1..2348c3271a7 100644 --- a/backend/danswer/utils/variable_functionality.py +++ b/backend/danswer/utils/variable_functionality.py @@ -1,10 +1,11 @@ import functools import importlib from typing import Any +from typing import TypeVar +from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED from danswer.utils.logger import setup_logger - logger = setup_logger() @@ -22,6 +23,12 @@ def get_is_ee_version(self) -> bool: global_version = DanswerVersion() +def set_is_ee_based_on_env_variable() -> None: + if ENTERPRISE_EDITION_ENABLED and not global_version.get_is_ee_version(): + logger.info("Enterprise Edition enabled") + global_version.set_ee() + + @functools.lru_cache(maxsize=128) def fetch_versioned_implementation(module: str, attribute: str) -> Any: logger.debug("Fetching versioned implementation for %s.%s", module, attribute) @@ -36,3 +43,21 @@ def fetch_versioned_implementation(module: str, attribute: str) -> Any: return getattr(importlib.import_module(module), attribute) raise + + +T = TypeVar("T") + + +def fetch_versioned_implementation_with_fallback( + module: str, attribute: str, fallback: T +) -> T: + try: + return fetch_versioned_implementation(module, attribute) + except Exception as e: + logger.warning( + "Failed to fetch versioned implementation for %s.%s: %s", + module, + attribute, + e, + ) + return fallback diff --git a/backend/ee/LICENSE b/backend/ee/LICENSE new file mode 100644 index 00000000000..1f2d7c15faa --- /dev/null +++ b/backend/ee/LICENSE @@ -0,0 +1,36 @@ +The DanswerAI Enterprise license (the “Enterprise License”) +Copyright (c) 2023-present DanswerAI, Inc. + +With regard to the Danswer Software: + +This software and associated documentation files (the "Software") may only be +used in production, if you (and any entity that you represent) have agreed to, +and are in compliance with, the DanswerAI Subscription Terms of Service, available +at https://danswer.ai/terms (the “Enterprise Terms”), or other +agreement governing the use of the Software, as agreed by you and DanswerAI, +and otherwise have a valid Danswer Enterprise license for the +correct number of user seats. Subject to the foregoing sentence, you are free to +modify this Software and publish patches to the Software. You agree that DanswerAI +and/or its licensors (as applicable) retain all right, title and interest in and +to all such modifications and/or patches, and all such modifications and/or +patches may only be used, copied, modified, displayed, distributed, or otherwise +exploited with a valid Danswer Enterprise license for the correct +number of user seats. Notwithstanding the foregoing, you may copy and modify +the Software for development and testing purposes, without requiring a +subscription. You agree that DanswerAI and/or its licensors (as applicable) retain +all right, title and interest in and to all such modifications. You are not +granted any other rights beyond what is expressly stated herein. Subject to the +foregoing, it is forbidden to copy, merge, publish, distribute, sublicense, +and/or sell the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +For all third party components incorporated into the Danswer Software, those +components are licensed under the original license provided by the owner of the +applicable component. diff --git a/backend/ee/__init__.py b/backend/ee/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/ee/danswer/__init__.py b/backend/ee/danswer/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/ee/danswer/access/access.py b/backend/ee/danswer/access/access.py new file mode 100644 index 00000000000..254e76e6681 --- /dev/null +++ b/backend/ee/danswer/access/access.py @@ -0,0 +1,55 @@ +from sqlalchemy.orm import Session + +from danswer.access.access import ( + _get_access_for_documents as get_access_for_documents_without_groups, +) +from danswer.access.access import _get_acl_for_user as get_acl_for_user_without_groups +from danswer.access.models import DocumentAccess +from danswer.access.utils import prefix_user_group +from danswer.db.models import User +from danswer.server.documents.models import ConnectorCredentialPairIdentifier +from ee.danswer.db.user_group import fetch_user_groups_for_documents +from ee.danswer.db.user_group import fetch_user_groups_for_user + + +def _get_access_for_documents( + document_ids: list[str], + db_session: Session, + cc_pair_to_delete: ConnectorCredentialPairIdentifier | None, +) -> dict[str, DocumentAccess]: + non_ee_access_dict = get_access_for_documents_without_groups( + document_ids=document_ids, + db_session=db_session, + cc_pair_to_delete=cc_pair_to_delete, + ) + user_group_info = { + document_id: group_names + for document_id, group_names in fetch_user_groups_for_documents( + db_session=db_session, + document_ids=document_ids, + cc_pair_to_delete=cc_pair_to_delete, + ) + } + + return { + document_id: DocumentAccess( + user_ids=non_ee_access.user_ids, + user_groups=user_group_info.get(document_id, []), # type: ignore + is_public=non_ee_access.is_public, + ) + for document_id, non_ee_access in non_ee_access_dict.items() + } + + +def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]: + """Returns a list of ACL entries that the user has access to. This is meant to be + used downstream to filter out documents that the user does not have access to. The + user should have access to a document if at least one entry in the document's ACL + matches one entry in the returned set. + + NOTE: is imported in danswer.access.access by `fetch_versioned_implementation` + DO NOT REMOVE.""" + user_groups = fetch_user_groups_for_user(db_session, user.id) if user else [] + return set( + [prefix_user_group(user_group.name) for user_group in user_groups] + ).union(get_acl_for_user_without_groups(user, db_session)) diff --git a/backend/ee/danswer/auth/__init__.py b/backend/ee/danswer/auth/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/ee/danswer/auth/api_key.py b/backend/ee/danswer/auth/api_key.py new file mode 100644 index 00000000000..9f9d2bac4bf --- /dev/null +++ b/backend/ee/danswer/auth/api_key.py @@ -0,0 +1,51 @@ +import secrets +import uuid + +from fastapi import Request +from passlib.hash import sha256_crypt +from pydantic import BaseModel + +from ee.danswer.configs.app_configs import API_KEY_HASH_ROUNDS + + +_API_KEY_HEADER_NAME = "Authorization" +_BEARER_PREFIX = "Bearer " +_API_KEY_PREFIX = "dn_" +_API_KEY_LEN = 192 + + +class ApiKeyDescriptor(BaseModel): + api_key_id: int + api_key_display: str + api_key: str | None = None # only present on initial creation + api_key_name: str | None = None + + user_id: uuid.UUID + + +def generate_api_key() -> str: + return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN) + + +def hash_api_key(api_key: str) -> str: + # NOTE: no salt is needed, as the API key is randomly generated + # and overlaps are impossible + return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS) + + +def build_displayable_api_key(api_key: str) -> str: + if api_key.startswith(_API_KEY_PREFIX): + api_key = api_key[len(_API_KEY_PREFIX) :] + + return _API_KEY_PREFIX + api_key[:4] + "********" + api_key[-4:] + + +def get_hashed_api_key_from_request(request: Request) -> str | None: + raw_api_key_header = request.headers.get(_API_KEY_HEADER_NAME) + if raw_api_key_header is None: + return None + + if raw_api_key_header.startswith(_BEARER_PREFIX): + raw_api_key_header = raw_api_key_header[len(_BEARER_PREFIX) :].strip() + + return hash_api_key(raw_api_key_header) diff --git a/backend/ee/danswer/auth/users.py b/backend/ee/danswer/auth/users.py new file mode 100644 index 00000000000..f5f5dbd58f3 --- /dev/null +++ b/backend/ee/danswer/auth/users.py @@ -0,0 +1,65 @@ +from fastapi import Depends +from fastapi import HTTPException +from fastapi import Request +from sqlalchemy.orm import Session + +from danswer.configs.app_configs import AUTH_TYPE +from danswer.configs.constants import AuthType +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.utils.logger import setup_logger +from ee.danswer.auth.api_key import get_hashed_api_key_from_request +from ee.danswer.db.api_key import fetch_user_for_api_key +from ee.danswer.db.saml import get_saml_account +from ee.danswer.server.seeding import get_seed_config +from ee.danswer.utils.secrets import extract_hashed_cookie + +logger = setup_logger() + + +def verify_auth_setting() -> None: + # All the Auth flows are valid for EE version + logger.info(f"Using Auth Type: {AUTH_TYPE.value}") + + +async def optional_user_( + request: Request, + user: User | None, + db_session: Session, +) -> User | None: + # Check if the user has a session cookie from SAML + if AUTH_TYPE == AuthType.SAML: + saved_cookie = extract_hashed_cookie(request) + + if saved_cookie: + saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session) + user = saml_account.user if saml_account else None + + # check if an API key is present + if user is None: + hashed_api_key = get_hashed_api_key_from_request(request) + if hashed_api_key: + user = fetch_user_for_api_key(hashed_api_key, db_session) + + return user + + +def api_key_dep(request: Request, db_session: Session = Depends(get_session)) -> User: + hashed_api_key = get_hashed_api_key_from_request(request) + if not hashed_api_key: + raise HTTPException(status_code=401, detail="Missing API key") + + if hashed_api_key: + user = fetch_user_for_api_key(hashed_api_key, db_session) + + if user is None: + raise HTTPException(status_code=401, detail="Invalid API key") + + return user + + +def get_default_admin_user_emails_() -> list[str]: + seed_config = get_seed_config() + if seed_config and seed_config.admin_user_emails: + return seed_config.admin_user_emails + return [] diff --git a/backend/ee/danswer/background/celery/celery_app.py b/backend/ee/danswer/background/celery/celery_app.py new file mode 100644 index 00000000000..2dd3ecb472e --- /dev/null +++ b/backend/ee/danswer/background/celery/celery_app.py @@ -0,0 +1,115 @@ +from datetime import timedelta + +from sqlalchemy.orm import Session + +from danswer.background.celery.celery_app import celery_app +from danswer.background.task_utils import build_celery_task_wrapper +from danswer.configs.app_configs import JOB_TIMEOUT +from danswer.db.chat import delete_chat_sessions_older_than +from danswer.db.engine import get_sqlalchemy_engine +from danswer.server.settings.store import load_settings +from danswer.utils.logger import setup_logger +from danswer.utils.variable_functionality import global_version +from ee.danswer.background.celery_utils import should_perform_chat_ttl_check +from ee.danswer.background.celery_utils import should_sync_user_groups +from ee.danswer.background.task_name_builders import name_chat_ttl_task +from ee.danswer.background.task_name_builders import name_user_group_sync_task +from ee.danswer.db.user_group import fetch_user_groups +from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report +from ee.danswer.user_groups.sync import sync_user_groups + +logger = setup_logger() + +# mark as EE for all tasks in this file +global_version.set_ee() + + +@build_celery_task_wrapper(name_user_group_sync_task) +@celery_app.task(soft_time_limit=JOB_TIMEOUT) +def sync_user_group_task(user_group_id: int) -> None: + with Session(get_sqlalchemy_engine()) as db_session: + # actual sync logic + try: + sync_user_groups(user_group_id=user_group_id, db_session=db_session) + except Exception as e: + logger.exception(f"Failed to sync user group - {e}") + + +@build_celery_task_wrapper(name_chat_ttl_task) +@celery_app.task(soft_time_limit=JOB_TIMEOUT) +def perform_ttl_management_task(retention_limit_days: int) -> None: + with Session(get_sqlalchemy_engine()) as db_session: + delete_chat_sessions_older_than(retention_limit_days, db_session) + + +##### +# Periodic Tasks +##### + + +@celery_app.task( + name="check_ttl_management_task", + soft_time_limit=JOB_TIMEOUT, +) +def check_ttl_management_task() -> None: + """Runs periodically to check if any ttl tasks should be run and adds them + to the queue""" + settings = load_settings() + retention_limit_days = settings.maximum_chat_retention_days + with Session(get_sqlalchemy_engine()) as db_session: + if should_perform_chat_ttl_check(retention_limit_days, db_session): + perform_ttl_management_task.apply_async( + kwargs=dict(retention_limit_days=retention_limit_days), + ) + + +@celery_app.task( + name="check_for_user_groups_sync_task", + soft_time_limit=JOB_TIMEOUT, +) +def check_for_user_groups_sync_task() -> None: + """Runs periodically to check if any user groups are out of sync + Creates a task to sync the user group if needed""" + with Session(get_sqlalchemy_engine()) as db_session: + # check if any document sets are not synced + user_groups = fetch_user_groups(db_session=db_session, only_current=False) + for user_group in user_groups: + if should_sync_user_groups(user_group, db_session): + logger.info(f"User Group {user_group.id} is not synced. Syncing now!") + sync_user_group_task.apply_async( + kwargs=dict(user_group_id=user_group.id), + ) + + +@celery_app.task( + name="autogenerate_usage_report_task", + soft_time_limit=JOB_TIMEOUT, +) +def autogenerate_usage_report_task() -> None: + """This generates usage report under the /admin/generate-usage/report endpoint""" + with Session(get_sqlalchemy_engine()) as db_session: + create_new_usage_report( + db_session=db_session, + user_id=None, + period=None, + ) + + +##### +# Celery Beat (Periodic Tasks) Settings +##### +celery_app.conf.beat_schedule = { + "check-for-user-group-sync": { + "task": "check_for_user_groups_sync_task", + "schedule": timedelta(seconds=5), + }, + "autogenerate_usage_report": { + "task": "autogenerate_usage_report_task", + "schedule": timedelta(days=30), # TODO: change this to config flag + }, + "check-ttl-management": { + "task": "check_ttl_management_task", + "schedule": timedelta(hours=1), + }, + **(celery_app.conf.beat_schedule or {}), +} diff --git a/backend/ee/danswer/background/celery_utils.py b/backend/ee/danswer/background/celery_utils.py new file mode 100644 index 00000000000..0134f6642f7 --- /dev/null +++ b/backend/ee/danswer/background/celery_utils.py @@ -0,0 +1,40 @@ +from sqlalchemy.orm import Session + +from danswer.db.models import UserGroup +from danswer.db.tasks import check_task_is_live_and_not_timed_out +from danswer.db.tasks import get_latest_task +from danswer.utils.logger import setup_logger +from ee.danswer.background.task_name_builders import name_chat_ttl_task +from ee.danswer.background.task_name_builders import name_user_group_sync_task + +logger = setup_logger() + + +def should_sync_user_groups(user_group: UserGroup, db_session: Session) -> bool: + if user_group.is_up_to_date: + return False + task_name = name_user_group_sync_task(user_group.id) + latest_sync = get_latest_task(task_name, db_session) + + if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session): + logger.info("TTL check is already being performed. Skipping.") + return False + return True + + +def should_perform_chat_ttl_check( + retention_limit_days: int | None, db_session: Session +) -> bool: + # TODO: make this a check for None and add behavior for 0 day TTL + if not retention_limit_days: + return False + + task_name = name_chat_ttl_task(retention_limit_days) + latest_task = get_latest_task(task_name, db_session) + if not latest_task: + return True + + if latest_task and check_task_is_live_and_not_timed_out(latest_task, db_session): + logger.info("TTL check is already being performed. Skipping.") + return False + return True diff --git a/backend/ee/danswer/background/permission_sync.py b/backend/ee/danswer/background/permission_sync.py new file mode 100644 index 00000000000..b3e8845ab3b --- /dev/null +++ b/backend/ee/danswer/background/permission_sync.py @@ -0,0 +1,221 @@ +import logging +import time +from datetime import datetime + +import dask +from dask.distributed import Client +from dask.distributed import Future +from distributed import LocalCluster +from sqlalchemy.orm import Session + +from danswer.background.indexing.dask_utils import ResourceLogger +from danswer.background.indexing.job_client import SimpleJob +from danswer.background.indexing.job_client import SimpleJobClient +from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT +from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED +from danswer.configs.constants import DocumentSource +from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.models import PermissionSyncStatus +from danswer.utils.logger import setup_logger +from ee.danswer.configs.app_configs import NUM_PERMISSION_WORKERS +from ee.danswer.connectors.factory import CONNECTOR_PERMISSION_FUNC_MAP +from ee.danswer.db.connector import fetch_sources_with_connectors +from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source +from ee.danswer.db.permission_sync import create_perm_sync +from ee.danswer.db.permission_sync import expire_perm_sync_timed_out +from ee.danswer.db.permission_sync import get_perm_sync_attempt +from ee.danswer.db.permission_sync import mark_all_inprogress_permission_sync_failed +from shared_configs.configs import LOG_LEVEL + +logger = setup_logger() + +# If the indexing dies, it's most likely due to resource constraints, +# restarting just delays the eventual failure, not useful to the user +dask.config.set({"distributed.scheduler.allowed-failures": 0}) + + +def cleanup_perm_sync_jobs( + existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob], + # Just reusing the same timeout, fine for now + timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT, +) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]: + existing_jobs_copy = existing_jobs.copy() + + with Session(get_sqlalchemy_engine()) as db_session: + # clean up completed jobs + for (attempt_id, details), job in existing_jobs.items(): + perm_sync_attempt = get_perm_sync_attempt( + attempt_id=attempt_id, db_session=db_session + ) + + # do nothing for ongoing jobs that haven't been stopped + if ( + not job.done() + and perm_sync_attempt.status == PermissionSyncStatus.IN_PROGRESS + ): + continue + + if job.status == "error": + logger.error(job.exception()) + + job.release() + del existing_jobs_copy[(attempt_id, details)] + + # clean up in-progress jobs that were never completed + expire_perm_sync_timed_out( + timeout_hours=timeout_hours, + db_session=db_session, + ) + + return existing_jobs_copy + + +def create_group_sync_jobs( + existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob], + client: Client | SimpleJobClient, +) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]: + """Creates new relational DB group permission sync job for each source that: + - has permission sync enabled + - has at least 1 connector (enabled or paused) + - has no sync already running + """ + existing_jobs_copy = existing_jobs.copy() + sources_w_runs = [ + key[1] + for key in existing_jobs_copy.keys() + if isinstance(key[1], DocumentSource) + ] + with Session(get_sqlalchemy_engine()) as db_session: + sources_w_connector = fetch_sources_with_connectors(db_session) + for source_type in sources_w_connector: + if source_type not in CONNECTOR_PERMISSION_FUNC_MAP: + continue + if source_type in sources_w_runs: + continue + + db_group_fnc, _ = CONNECTOR_PERMISSION_FUNC_MAP[source_type] + perm_sync = create_perm_sync( + source_type=source_type, + group_update=True, + cc_pair_id=None, + db_session=db_session, + ) + + run = client.submit(db_group_fnc, pure=False) + + logger.info( + f"Kicked off group permission sync for source type {source_type}" + ) + + if run: + existing_jobs_copy[(perm_sync.id, source_type)] = run + + return existing_jobs_copy + + +def create_connector_perm_sync_jobs( + existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob], + client: Client | SimpleJobClient, +) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]: + """Update Document Index ACL sync job for each cc-pair where: + - source type has permission sync enabled + - has no sync already running + """ + existing_jobs_copy = existing_jobs.copy() + cc_pairs_w_runs = [ + key[1] + for key in existing_jobs_copy.keys() + if isinstance(key[1], DocumentSource) + ] + with Session(get_sqlalchemy_engine()) as db_session: + sources_w_connector = fetch_sources_with_connectors(db_session) + for source_type in sources_w_connector: + if source_type not in CONNECTOR_PERMISSION_FUNC_MAP: + continue + + _, index_sync_fnc = CONNECTOR_PERMISSION_FUNC_MAP[source_type] + + cc_pairs = get_cc_pairs_by_source(source_type, db_session) + + for cc_pair in cc_pairs: + if cc_pair.id in cc_pairs_w_runs: + continue + + perm_sync = create_perm_sync( + source_type=source_type, + group_update=False, + cc_pair_id=cc_pair.id, + db_session=db_session, + ) + + run = client.submit(index_sync_fnc, cc_pair.id, pure=False) + + logger.info(f"Kicked off ACL sync for cc-pair {cc_pair.id}") + + if run: + existing_jobs_copy[(perm_sync.id, cc_pair.id)] = run + + return existing_jobs_copy + + +def permission_loop(delay: int = 60, num_workers: int = NUM_PERMISSION_WORKERS) -> None: + client: Client | SimpleJobClient + if DASK_JOB_CLIENT_ENABLED: + cluster_primary = LocalCluster( + n_workers=num_workers, + threads_per_worker=1, + # there are warning about high memory usage + "Event loop unresponsive" + # which are not relevant to us since our workers are expected to use a + # lot of memory + involve CPU intensive tasks that will not relinquish + # the event loop + silence_logs=logging.ERROR, + ) + client = Client(cluster_primary) + if LOG_LEVEL.lower() == "debug": + client.register_worker_plugin(ResourceLogger()) + else: + client = SimpleJobClient(n_workers=num_workers) + + existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob] = {} + engine = get_sqlalchemy_engine() + + with Session(engine) as db_session: + # Any jobs still in progress on restart must have died + mark_all_inprogress_permission_sync_failed(db_session) + + while True: + start = time.time() + start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S") + logger.info(f"Running Permission Sync, current UTC time: {start_time_utc}") + + if existing_jobs: + logger.debug( + "Found existing permission sync jobs: " + f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}" + ) + + try: + # TODO turn this on when it works + """ + existing_jobs = cleanup_perm_sync_jobs(existing_jobs=existing_jobs) + existing_jobs = create_group_sync_jobs( + existing_jobs=existing_jobs, client=client + ) + existing_jobs = create_connector_perm_sync_jobs( + existing_jobs=existing_jobs, client=client + ) + """ + except Exception as e: + logger.exception(f"Failed to run update due to {e}") + sleep_time = delay - (time.time() - start) + if sleep_time > 0: + time.sleep(sleep_time) + + +def update__main() -> None: + logger.info("Starting Permission Syncing Loop") + permission_loop() + + +if __name__ == "__main__": + update__main() diff --git a/backend/ee/danswer/background/task_name_builders.py b/backend/ee/danswer/background/task_name_builders.py new file mode 100644 index 00000000000..4f1046adbbb --- /dev/null +++ b/backend/ee/danswer/background/task_name_builders.py @@ -0,0 +1,6 @@ +def name_user_group_sync_task(user_group_id: int) -> str: + return f"user_group_sync_task__{user_group_id}" + + +def name_chat_ttl_task(retention_limit_days: int) -> str: + return f"chat_ttl_{retention_limit_days}_days" diff --git a/backend/ee/danswer/configs/__init__.py b/backend/ee/danswer/configs/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/ee/danswer/configs/app_configs.py b/backend/ee/danswer/configs/app_configs.py new file mode 100644 index 00000000000..1430a499136 --- /dev/null +++ b/backend/ee/danswer/configs/app_configs.py @@ -0,0 +1,23 @@ +import os + +# Applicable for OIDC Auth +OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "") + +# Applicable for SAML Auth +SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/danswer/configs/saml_config" + + +##### +# API Key Configs +##### +# refers to the rounds described here: https://passlib.readthedocs.io/en/stable/lib/passlib.hash.sha256_crypt.html +_API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS") +API_KEY_HASH_ROUNDS = ( + int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None +) + + +##### +# Auto Permission Sync +##### +NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2) diff --git a/backend/ee/danswer/configs/saml_config/template.settings.json b/backend/ee/danswer/configs/saml_config/template.settings.json new file mode 100644 index 00000000000..e3c828944af --- /dev/null +++ b/backend/ee/danswer/configs/saml_config/template.settings.json @@ -0,0 +1,20 @@ +{ + "strict": true, + "debug": false, + "idp": { + "entityId": "", + "singleSignOnService": { + "url": " https://trial-1234567.okta.com/home/trial-1234567_danswer/somevalues/somevalues", + "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" + }, + "x509cert": "" + }, + "sp": { + "entityId": "", + "assertionConsumerService": { + "url": "http://127.0.0.1:3000/auth/saml/callback", + "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" + }, + "x509cert": "" + } +} diff --git a/backend/ee/danswer/connectors/__init__.py b/backend/ee/danswer/connectors/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/ee/danswer/connectors/confluence/__init__.py b/backend/ee/danswer/connectors/confluence/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/ee/danswer/connectors/confluence/perm_sync.py b/backend/ee/danswer/connectors/confluence/perm_sync.py new file mode 100644 index 00000000000..2985b47b0d1 --- /dev/null +++ b/backend/ee/danswer/connectors/confluence/perm_sync.py @@ -0,0 +1,12 @@ +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + + +def confluence_update_db_group() -> None: + logger.debug("Not yet implemented group sync for confluence, no-op") + + +def confluence_update_index_acl(cc_pair_id: int) -> None: + logger.debug("Not yet implemented ACL sync for confluence, no-op") diff --git a/backend/ee/danswer/connectors/factory.py b/backend/ee/danswer/connectors/factory.py new file mode 100644 index 00000000000..52f9324948b --- /dev/null +++ b/backend/ee/danswer/connectors/factory.py @@ -0,0 +1,8 @@ +from danswer.configs.constants import DocumentSource +from ee.danswer.connectors.confluence.perm_sync import confluence_update_db_group +from ee.danswer.connectors.confluence.perm_sync import confluence_update_index_acl + + +CONNECTOR_PERMISSION_FUNC_MAP = { + DocumentSource.CONFLUENCE: (confluence_update_db_group, confluence_update_index_acl) +} diff --git a/backend/ee/danswer/db/__init__.py b/backend/ee/danswer/db/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/ee/danswer/db/analytics.py b/backend/ee/danswer/db/analytics.py new file mode 100644 index 00000000000..e0eff7850e4 --- /dev/null +++ b/backend/ee/danswer/db/analytics.py @@ -0,0 +1,172 @@ +import datetime +from collections.abc import Sequence +from uuid import UUID + +from sqlalchemy import case +from sqlalchemy import cast +from sqlalchemy import Date +from sqlalchemy import func +from sqlalchemy import or_ +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.configs.constants import MessageType +from danswer.db.models import ChatMessage +from danswer.db.models import ChatMessageFeedback +from danswer.db.models import ChatSession + + +def fetch_query_analytics( + start: datetime.datetime, + end: datetime.datetime, + db_session: Session, +) -> Sequence[tuple[int, int, int, datetime.date]]: + stmt = ( + select( + func.count(ChatMessage.id), + func.sum(case((ChatMessageFeedback.is_positive, 1), else_=0)), + func.sum( + case( + (ChatMessageFeedback.is_positive == False, 1), else_=0 # noqa: E712 + ) + ), + cast(ChatMessage.time_sent, Date), + ) + .join( + ChatMessageFeedback, + ChatMessageFeedback.chat_message_id == ChatMessage.id, + isouter=True, + ) + .where( + ChatMessage.time_sent >= start, + ) + .where( + ChatMessage.time_sent <= end, + ) + .where(ChatMessage.message_type == MessageType.ASSISTANT) + .group_by(cast(ChatMessage.time_sent, Date)) + .order_by(cast(ChatMessage.time_sent, Date)) + ) + + return db_session.execute(stmt).all() # type: ignore + + +def fetch_per_user_query_analytics( + start: datetime.datetime, + end: datetime.datetime, + db_session: Session, +) -> Sequence[tuple[int, int, int, datetime.date, UUID]]: + stmt = ( + select( + func.count(ChatMessage.id), + func.sum(case((ChatMessageFeedback.is_positive, 1), else_=0)), + func.sum( + case( + (ChatMessageFeedback.is_positive == False, 1), else_=0 # noqa: E712 + ) + ), + cast(ChatMessage.time_sent, Date), + ChatSession.user_id, + ) + .join(ChatSession, ChatSession.id == ChatMessage.chat_session_id) + .where( + ChatMessage.time_sent >= start, + ) + .where( + ChatMessage.time_sent <= end, + ) + .where(ChatMessage.message_type == MessageType.ASSISTANT) + .group_by(cast(ChatMessage.time_sent, Date), ChatSession.user_id) + .order_by(cast(ChatMessage.time_sent, Date), ChatSession.user_id) + ) + + return db_session.execute(stmt).all() # type: ignore + + +def fetch_danswerbot_analytics( + start: datetime.datetime, + end: datetime.datetime, + db_session: Session, +) -> Sequence[tuple[int, int, datetime.date]]: + """Gets the: + Date of each set of aggregated statistics + Number of DanswerBot Queries (Chat Sessions) + Number of instances of Negative feedback OR Needing additional help + (only counting the last feedback) + """ + # Get every chat session in the time range which is a Danswerbot flow + # along with the first Assistant message which is the response to the user question. + # Generally there should not be more than one AI message per chat session of this type + subquery_first_ai_response = ( + db_session.query( + ChatMessage.chat_session_id.label("chat_session_id"), + func.min(ChatMessage.id).label("chat_message_id"), + ) + .join(ChatSession, ChatSession.id == ChatMessage.chat_session_id) + .where( + ChatSession.time_created >= start, + ChatSession.time_created <= end, + ChatSession.danswerbot_flow.is_(True), + ) + .where( + ChatMessage.message_type == MessageType.ASSISTANT, + ) + .group_by(ChatMessage.chat_session_id) + .subquery() + ) + + # Get the chat message ids and most recent feedback for each of those chat messages, + # not including the messages that have no feedback + subquery_last_feedback = ( + db_session.query( + ChatMessageFeedback.chat_message_id.label("chat_message_id"), + func.max(ChatMessageFeedback.id).label("max_feedback_id"), + ) + .group_by(ChatMessageFeedback.chat_message_id) + .subquery() + ) + + results = ( + db_session.query( + func.count(ChatSession.id).label("total_sessions"), + # Need to explicitly specify this as False to handle the NULL case so the cases without + # feedback aren't counted against Danswerbot + func.sum( + case( + ( + or_( + ChatMessageFeedback.is_positive.is_(False), + ChatMessageFeedback.required_followup, + ), + 1, + ), + else_=0, + ) + ).label("negative_answer"), + cast(ChatSession.time_created, Date).label("session_date"), + ) + .join( + subquery_first_ai_response, + ChatSession.id == subquery_first_ai_response.c.chat_session_id, + ) + # Combine the chat sessions with latest feedback to get the latest feedback for the first AI + # message of the chat session where the chat session is Danswerbot type and within the time + # range specified. Left/outer join used here to ensure that if no feedback, a null is used + # for the feedback id + .outerjoin( + subquery_last_feedback, + subquery_first_ai_response.c.chat_message_id + == subquery_last_feedback.c.chat_message_id, + ) + # Join the actual feedback table to get the feedback info for the sums + # Outer join because the "last feedback" may be null + .outerjoin( + ChatMessageFeedback, + ChatMessageFeedback.id == subquery_last_feedback.c.max_feedback_id, + ) + .group_by(cast(ChatSession.time_created, Date)) + .order_by(cast(ChatSession.time_created, Date)) + .all() + ) + + return results diff --git a/backend/ee/danswer/db/api_key.py b/backend/ee/danswer/db/api_key.py new file mode 100644 index 00000000000..c2555bb1127 --- /dev/null +++ b/backend/ee/danswer/db/api_key.py @@ -0,0 +1,154 @@ +import uuid + +from fastapi_users.password import PasswordHelper +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.auth.schemas import UserRole +from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN +from danswer.configs.constants import DANSWER_API_KEY_PREFIX +from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER +from danswer.db.models import ApiKey +from danswer.db.models import User +from ee.danswer.auth.api_key import ApiKeyDescriptor +from ee.danswer.auth.api_key import build_displayable_api_key +from ee.danswer.auth.api_key import generate_api_key +from ee.danswer.auth.api_key import hash_api_key +from ee.danswer.server.api_key.models import APIKeyArgs + + +def is_api_key_email_address(email: str) -> bool: + return email.endswith(f"{DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN}") + + +def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]: + api_keys = db_session.scalars(select(ApiKey)).all() + return [ + ApiKeyDescriptor( + api_key_id=api_key.id, + api_key_display=api_key.api_key_display, + api_key_name=api_key.name, + user_id=api_key.user_id, + ) + for api_key in api_keys + ] + + +def fetch_user_for_api_key(hashed_api_key: str, db_session: Session) -> User | None: + api_key = db_session.scalar( + select(ApiKey).where(ApiKey.hashed_api_key == hashed_api_key) + ) + if api_key is None: + return None + + return db_session.scalar(select(User).where(User.id == api_key.user_id)) # type: ignore + + +def get_api_key_fake_email( + name: str, + unique_id: str, +) -> str: + return f"{DANSWER_API_KEY_PREFIX}{name}@{unique_id}{DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN}" + + +def insert_api_key( + db_session: Session, api_key_args: APIKeyArgs, user_id: uuid.UUID | None +) -> ApiKeyDescriptor: + std_password_helper = PasswordHelper() + api_key = generate_api_key() + api_key_user_id = uuid.uuid4() + + display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER + api_key_user_row = User( + id=api_key_user_id, + email=get_api_key_fake_email(display_name, str(api_key_user_id)), + # a random password for the "user" + hashed_password=std_password_helper.hash(std_password_helper.generate()), + is_active=True, + is_superuser=False, + is_verified=True, + role=UserRole.BASIC, + ) + db_session.add(api_key_user_row) + + api_key_row = ApiKey( + name=api_key_args.name, + hashed_api_key=hash_api_key(api_key), + api_key_display=build_displayable_api_key(api_key), + user_id=api_key_user_id, + owner_id=user_id, + ) + db_session.add(api_key_row) + + db_session.commit() + return ApiKeyDescriptor( + api_key_id=api_key_row.id, + api_key_display=api_key_row.api_key_display, + api_key=api_key, + api_key_name=api_key_args.name, + user_id=api_key_user_id, + ) + + +def update_api_key( + db_session: Session, api_key_id: int, api_key_args: APIKeyArgs +) -> ApiKeyDescriptor: + existing_api_key = db_session.scalar(select(ApiKey).where(ApiKey.id == api_key_id)) + if existing_api_key is None: + raise ValueError(f"API key with id {api_key_id} does not exist") + + existing_api_key.name = api_key_args.name + api_key_user = db_session.scalar( + select(User).where(User.id == existing_api_key.user_id) # type: ignore + ) + if api_key_user is None: + raise RuntimeError("API Key does not have associated user.") + + email_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER + api_key_user.email = get_api_key_fake_email(email_name, str(api_key_user.id)) + db_session.commit() + + return ApiKeyDescriptor( + api_key_id=existing_api_key.id, + api_key_display=existing_api_key.api_key_display, + api_key_name=api_key_args.name, + user_id=existing_api_key.user_id, + ) + + +def regenerate_api_key(db_session: Session, api_key_id: int) -> ApiKeyDescriptor: + """NOTE: currently, any admin can regenerate any API key.""" + existing_api_key = db_session.scalar(select(ApiKey).where(ApiKey.id == api_key_id)) + if existing_api_key is None: + raise ValueError(f"API key with id {api_key_id} does not exist") + + new_api_key = generate_api_key() + existing_api_key.hashed_api_key = hash_api_key(new_api_key) + existing_api_key.api_key_display = build_displayable_api_key(new_api_key) + db_session.commit() + + return ApiKeyDescriptor( + api_key_id=existing_api_key.id, + api_key_display=existing_api_key.api_key_display, + api_key=new_api_key, + api_key_name=existing_api_key.name, + user_id=existing_api_key.user_id, + ) + + +def remove_api_key(db_session: Session, api_key_id: int) -> None: + existing_api_key = db_session.scalar(select(ApiKey).where(ApiKey.id == api_key_id)) + if existing_api_key is None: + raise ValueError(f"API key with id {api_key_id} does not exist") + + user_associated_with_key = db_session.scalar( + select(User).where(User.id == existing_api_key.user_id) # type: ignore + ) + if user_associated_with_key is None: + raise ValueError( + f"User associated with API key with id {api_key_id} does not exist. This should not happen." + ) + + db_session.delete(existing_api_key) + db_session.delete(user_associated_with_key) + db_session.commit() diff --git a/backend/ee/danswer/db/connector.py b/backend/ee/danswer/db/connector.py new file mode 100644 index 00000000000..44505f51510 --- /dev/null +++ b/backend/ee/danswer/db/connector.py @@ -0,0 +1,16 @@ +from sqlalchemy import distinct +from sqlalchemy.orm import Session + +from danswer.configs.constants import DocumentSource +from danswer.db.models import Connector +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def fetch_sources_with_connectors(db_session: Session) -> list[DocumentSource]: + sources = db_session.query(distinct(Connector.source)).all() # type: ignore + + document_sources = [source[0] for source in sources] + + return document_sources diff --git a/backend/ee/danswer/db/connector_credential_pair.py b/backend/ee/danswer/db/connector_credential_pair.py new file mode 100644 index 00000000000..a4938138592 --- /dev/null +++ b/backend/ee/danswer/db/connector_credential_pair.py @@ -0,0 +1,22 @@ +from sqlalchemy.orm import Session + +from danswer.configs.constants import DocumentSource +from danswer.db.models import Connector +from danswer.db.models import ConnectorCredentialPair +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def get_cc_pairs_by_source( + source_type: DocumentSource, + db_session: Session, +) -> list[ConnectorCredentialPair]: + cc_pairs = ( + db_session.query(ConnectorCredentialPair) + .join(ConnectorCredentialPair.connector) + .filter(Connector.source == source_type) + .all() + ) + + return cc_pairs diff --git a/backend/ee/danswer/db/document.py b/backend/ee/danswer/db/document.py new file mode 100644 index 00000000000..5a368ea170e --- /dev/null +++ b/backend/ee/danswer/db/document.py @@ -0,0 +1,14 @@ +from collections.abc import Sequence + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.db.models import Document + + +def fetch_documents_from_ids( + db_session: Session, document_ids: list[str] +) -> Sequence[Document]: + return db_session.scalars( + select(Document).where(Document.id.in_(document_ids)) + ).all() diff --git a/backend/ee/danswer/db/document_set.py b/backend/ee/danswer/db/document_set.py new file mode 100644 index 00000000000..bcbd06874da --- /dev/null +++ b/backend/ee/danswer/db/document_set.py @@ -0,0 +1,123 @@ +from uuid import UUID + +from sqlalchemy.orm import Session + +from danswer.db.models import ConnectorCredentialPair +from danswer.db.models import DocumentSet +from danswer.db.models import DocumentSet__ConnectorCredentialPair +from danswer.db.models import DocumentSet__User +from danswer.db.models import DocumentSet__UserGroup +from danswer.db.models import User__UserGroup +from danswer.db.models import UserGroup + + +def make_doc_set_private( + document_set_id: int, + user_ids: list[UUID] | None, + group_ids: list[int] | None, + db_session: Session, +) -> None: + db_session.query(DocumentSet__User).filter( + DocumentSet__User.document_set_id == document_set_id + ).delete(synchronize_session="fetch") + db_session.query(DocumentSet__UserGroup).filter( + DocumentSet__UserGroup.document_set_id == document_set_id + ).delete(synchronize_session="fetch") + + if user_ids: + for user_uuid in user_ids: + db_session.add( + DocumentSet__User(document_set_id=document_set_id, user_id=user_uuid) + ) + + if group_ids: + for group_id in group_ids: + db_session.add( + DocumentSet__UserGroup( + document_set_id=document_set_id, user_group_id=group_id + ) + ) + + +def delete_document_set_privacy__no_commit( + document_set_id: int, db_session: Session +) -> None: + db_session.query(DocumentSet__User).filter( + DocumentSet__User.document_set_id == document_set_id + ).delete(synchronize_session="fetch") + + db_session.query(DocumentSet__UserGroup).filter( + DocumentSet__UserGroup.document_set_id == document_set_id + ).delete(synchronize_session="fetch") + + +def fetch_document_sets( + user_id: UUID | None, + db_session: Session, + include_outdated: bool = True, # Parameter only for versioned implementation, unused +) -> list[tuple[DocumentSet, list[ConnectorCredentialPair]]]: + assert user_id is not None + + # Public document sets + public_document_sets = ( + db_session.query(DocumentSet) + .filter(DocumentSet.is_public == True) # noqa + .all() + ) + + # Document sets via shared user relationships + shared_document_sets = ( + db_session.query(DocumentSet) + .join(DocumentSet__User, DocumentSet.id == DocumentSet__User.document_set_id) + .filter(DocumentSet__User.user_id == user_id) + .all() + ) + + # Document sets via groups + # First, find the user groups the user belongs to + user_groups = ( + db_session.query(UserGroup) + .join(User__UserGroup, UserGroup.id == User__UserGroup.user_group_id) + .filter(User__UserGroup.user_id == user_id) + .all() + ) + + group_document_sets = [] + for group in user_groups: + group_document_sets.extend( + db_session.query(DocumentSet) + .join( + DocumentSet__UserGroup, + DocumentSet.id == DocumentSet__UserGroup.document_set_id, + ) + .filter(DocumentSet__UserGroup.user_group_id == group.id) + .all() + ) + + # Combine and deduplicate document sets from all sources + all_document_sets = list( + set(public_document_sets + shared_document_sets + group_document_sets) + ) + + document_set_with_cc_pairs: list[ + tuple[DocumentSet, list[ConnectorCredentialPair]] + ] = [] + + for document_set in all_document_sets: + # Fetch the associated ConnectorCredentialPairs + cc_pairs = ( + db_session.query(ConnectorCredentialPair) + .join( + DocumentSet__ConnectorCredentialPair, + ConnectorCredentialPair.id + == DocumentSet__ConnectorCredentialPair.connector_credential_pair_id, + ) + .filter( + DocumentSet__ConnectorCredentialPair.document_set_id == document_set.id, + ) + .all() + ) + + document_set_with_cc_pairs.append((document_set, cc_pairs)) # type: ignore + + return document_set_with_cc_pairs diff --git a/backend/ee/danswer/db/permission_sync.py b/backend/ee/danswer/db/permission_sync.py new file mode 100644 index 00000000000..7642bb65321 --- /dev/null +++ b/backend/ee/danswer/db/permission_sync.py @@ -0,0 +1,72 @@ +from datetime import timedelta + +from sqlalchemy import func +from sqlalchemy import select +from sqlalchemy import update +from sqlalchemy.exc import NoResultFound +from sqlalchemy.orm import Session + +from danswer.configs.constants import DocumentSource +from danswer.db.models import PermissionSyncRun +from danswer.db.models import PermissionSyncStatus +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def mark_all_inprogress_permission_sync_failed( + db_session: Session, +) -> None: + stmt = ( + update(PermissionSyncRun) + .where(PermissionSyncRun.status == PermissionSyncStatus.IN_PROGRESS) + .values(status=PermissionSyncStatus.FAILED) + ) + db_session.execute(stmt) + db_session.commit() + + +def get_perm_sync_attempt(attempt_id: int, db_session: Session) -> PermissionSyncRun: + stmt = select(PermissionSyncRun).where(PermissionSyncRun.id == attempt_id) + try: + return db_session.scalars(stmt).one() + except NoResultFound: + raise ValueError(f"No PermissionSyncRun found with id {attempt_id}") + + +def expire_perm_sync_timed_out( + timeout_hours: int, + db_session: Session, +) -> None: + cutoff_time = func.now() - timedelta(hours=timeout_hours) + + update_stmt = ( + update(PermissionSyncRun) + .where( + PermissionSyncRun.status == PermissionSyncStatus.IN_PROGRESS, + PermissionSyncRun.updated_at < cutoff_time, + ) + .values(status=PermissionSyncStatus.FAILED, error_msg="timed out") + ) + + db_session.execute(update_stmt) + db_session.commit() + + +def create_perm_sync( + source_type: DocumentSource, + group_update: bool, + cc_pair_id: int | None, + db_session: Session, +) -> PermissionSyncRun: + new_run = PermissionSyncRun( + source_type=source_type, + status=PermissionSyncStatus.IN_PROGRESS, + group_update=group_update, + cc_pair_id=cc_pair_id, + ) + + db_session.add(new_run) + db_session.commit() + + return new_run diff --git a/backend/ee/danswer/db/persona.py b/backend/ee/danswer/db/persona.py new file mode 100644 index 00000000000..a7e257278c9 --- /dev/null +++ b/backend/ee/danswer/db/persona.py @@ -0,0 +1,32 @@ +from uuid import UUID + +from sqlalchemy.orm import Session + +from danswer.db.models import Persona__User +from danswer.db.models import Persona__UserGroup + + +def make_persona_private( + persona_id: int, + user_ids: list[UUID] | None, + group_ids: list[int] | None, + db_session: Session, +) -> None: + db_session.query(Persona__User).filter( + Persona__User.persona_id == persona_id + ).delete(synchronize_session="fetch") + db_session.query(Persona__UserGroup).filter( + Persona__UserGroup.persona_id == persona_id + ).delete(synchronize_session="fetch") + + if user_ids: + for user_uuid in user_ids: + db_session.add(Persona__User(persona_id=persona_id, user_id=user_uuid)) + + if group_ids: + for group_id in group_ids: + db_session.add( + Persona__UserGroup(persona_id=persona_id, user_group_id=group_id) + ) + + db_session.commit() diff --git a/backend/ee/danswer/db/query_history.py b/backend/ee/danswer/db/query_history.py new file mode 100644 index 00000000000..868afef23ce --- /dev/null +++ b/backend/ee/danswer/db/query_history.py @@ -0,0 +1,59 @@ +import datetime +from typing import Literal + +from sqlalchemy import asc +from sqlalchemy import BinaryExpression +from sqlalchemy import ColumnElement +from sqlalchemy import desc +from sqlalchemy.orm import contains_eager +from sqlalchemy.orm import joinedload +from sqlalchemy.orm import Session + +from danswer.db.models import ChatMessage +from danswer.db.models import ChatSession + +SortByOptions = Literal["time_sent"] + + +def fetch_chat_sessions_eagerly_by_time( + start: datetime.datetime, + end: datetime.datetime, + db_session: Session, + limit: int | None = 500, + initial_id: int | None = None, +) -> list[ChatSession]: + id_order = desc(ChatSession.id) # type: ignore + time_order = desc(ChatSession.time_created) # type: ignore + message_order = asc(ChatMessage.id) # type: ignore + + filters: list[ColumnElement | BinaryExpression] = [ + ChatSession.time_created.between(start, end) + ] + if initial_id: + filters.append(ChatSession.id < initial_id) + subquery = ( + db_session.query(ChatSession.id, ChatSession.time_created) + .filter(*filters) + .order_by(id_order, time_order) + .distinct(ChatSession.id) + .limit(limit) + .subquery() + ) + + query = ( + db_session.query(ChatSession) + .join(subquery, ChatSession.id == subquery.c.id) # type: ignore + .outerjoin(ChatMessage, ChatSession.id == ChatMessage.chat_session_id) + .options( + joinedload(ChatSession.user), + joinedload(ChatSession.persona), + contains_eager(ChatSession.messages).joinedload( + ChatMessage.chat_message_feedbacks + ), + ) + .order_by(time_order, message_order) + ) + + chat_sessions = query.all() + + return chat_sessions diff --git a/backend/ee/danswer/db/saml.py b/backend/ee/danswer/db/saml.py new file mode 100644 index 00000000000..6689a7a7e14 --- /dev/null +++ b/backend/ee/danswer/db/saml.py @@ -0,0 +1,65 @@ +import datetime +from typing import cast +from uuid import UUID + +from sqlalchemy import and_ +from sqlalchemy import func +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS +from danswer.db.models import SamlAccount +from danswer.db.models import User + + +def upsert_saml_account( + user_id: UUID, + cookie: str, + db_session: Session, + expiration_offset: int = SESSION_EXPIRE_TIME_SECONDS, +) -> datetime.datetime: + expires_at = func.now() + datetime.timedelta(seconds=expiration_offset) + + existing_saml_acc = ( + db_session.query(SamlAccount) + .filter(SamlAccount.user_id == user_id) + .one_or_none() + ) + + if existing_saml_acc: + existing_saml_acc.encrypted_cookie = cookie + existing_saml_acc.expires_at = cast(datetime.datetime, expires_at) + existing_saml_acc.updated_at = func.now() + saml_acc = existing_saml_acc + else: + saml_acc = SamlAccount( + user_id=user_id, + encrypted_cookie=cookie, + expires_at=expires_at, + ) + db_session.add(saml_acc) + + db_session.commit() + + return saml_acc.expires_at + + +def get_saml_account(cookie: str, db_session: Session) -> SamlAccount | None: + stmt = ( + select(SamlAccount) + .join(User, User.id == SamlAccount.user_id) # type: ignore + .where( + and_( + SamlAccount.encrypted_cookie == cookie, + SamlAccount.expires_at > func.now(), + ) + ) + ) + + result = db_session.execute(stmt) + return result.scalar_one_or_none() + + +def expire_saml_account(saml_account: SamlAccount, db_session: Session) -> None: + saml_account.expires_at = func.now() + db_session.commit() diff --git a/backend/ee/danswer/db/token_limit.py b/backend/ee/danswer/db/token_limit.py new file mode 100644 index 00000000000..9b153811635 --- /dev/null +++ b/backend/ee/danswer/db/token_limit.py @@ -0,0 +1,176 @@ +from collections.abc import Sequence + +from sqlalchemy import Row +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.configs.constants import TokenRateLimitScope +from danswer.db.models import TokenRateLimit +from danswer.db.models import TokenRateLimit__UserGroup +from danswer.db.models import UserGroup +from danswer.server.token_rate_limits.models import TokenRateLimitArgs + + +def fetch_all_user_token_rate_limits( + db_session: Session, + enabled_only: bool = False, + ordered: bool = True, +) -> Sequence[TokenRateLimit]: + query = select(TokenRateLimit).where( + TokenRateLimit.scope == TokenRateLimitScope.USER + ) + + if enabled_only: + query = query.where(TokenRateLimit.enabled.is_(True)) + + if ordered: + query = query.order_by(TokenRateLimit.created_at.desc()) + + return db_session.scalars(query).all() + + +def fetch_all_global_token_rate_limits( + db_session: Session, + enabled_only: bool = False, + ordered: bool = True, +) -> Sequence[TokenRateLimit]: + query = select(TokenRateLimit).where( + TokenRateLimit.scope == TokenRateLimitScope.GLOBAL + ) + + if enabled_only: + query = query.where(TokenRateLimit.enabled.is_(True)) + + if ordered: + query = query.order_by(TokenRateLimit.created_at.desc()) + + token_rate_limits = db_session.scalars(query).all() + return token_rate_limits + + +def fetch_all_user_group_token_rate_limits( + db_session: Session, group_id: int, enabled_only: bool = False, ordered: bool = True +) -> Sequence[TokenRateLimit]: + query = ( + select(TokenRateLimit) + .join( + TokenRateLimit__UserGroup, + TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id, + ) + .where( + TokenRateLimit__UserGroup.user_group_id == group_id, + TokenRateLimit.scope == TokenRateLimitScope.USER_GROUP, + ) + ) + + if enabled_only: + query = query.where(TokenRateLimit.enabled.is_(True)) + + if ordered: + query = query.order_by(TokenRateLimit.created_at.desc()) + + token_rate_limits = db_session.scalars(query).all() + return token_rate_limits + + +def fetch_all_user_group_token_rate_limits_by_group( + db_session: Session, +) -> Sequence[Row[tuple[TokenRateLimit, str]]]: + query = ( + select(TokenRateLimit, UserGroup.name) + .join( + TokenRateLimit__UserGroup, + TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id, + ) + .join(UserGroup, UserGroup.id == TokenRateLimit__UserGroup.user_group_id) + ) + + return db_session.execute(query).all() + + +def insert_user_token_rate_limit( + db_session: Session, + token_rate_limit_settings: TokenRateLimitArgs, +) -> TokenRateLimit: + token_limit = TokenRateLimit( + enabled=token_rate_limit_settings.enabled, + token_budget=token_rate_limit_settings.token_budget, + period_hours=token_rate_limit_settings.period_hours, + scope=TokenRateLimitScope.USER, + ) + db_session.add(token_limit) + db_session.commit() + + return token_limit + + +def insert_global_token_rate_limit( + db_session: Session, + token_rate_limit_settings: TokenRateLimitArgs, +) -> TokenRateLimit: + token_limit = TokenRateLimit( + enabled=token_rate_limit_settings.enabled, + token_budget=token_rate_limit_settings.token_budget, + period_hours=token_rate_limit_settings.period_hours, + scope=TokenRateLimitScope.GLOBAL, + ) + db_session.add(token_limit) + db_session.commit() + + return token_limit + + +def insert_user_group_token_rate_limit( + db_session: Session, + token_rate_limit_settings: TokenRateLimitArgs, + group_id: int, +) -> TokenRateLimit: + token_limit = TokenRateLimit( + enabled=token_rate_limit_settings.enabled, + token_budget=token_rate_limit_settings.token_budget, + period_hours=token_rate_limit_settings.period_hours, + scope=TokenRateLimitScope.USER_GROUP, + ) + db_session.add(token_limit) + db_session.flush() + + rate_limit = TokenRateLimit__UserGroup( + rate_limit_id=token_limit.id, user_group_id=group_id + ) + db_session.add(rate_limit) + db_session.commit() + + return token_limit + + +def update_token_rate_limit( + db_session: Session, + token_rate_limit_id: int, + token_rate_limit_settings: TokenRateLimitArgs, +) -> TokenRateLimit: + token_limit = db_session.get(TokenRateLimit, token_rate_limit_id) + if token_limit is None: + raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found") + + token_limit.enabled = token_rate_limit_settings.enabled + token_limit.token_budget = token_rate_limit_settings.token_budget + token_limit.period_hours = token_rate_limit_settings.period_hours + db_session.commit() + + return token_limit + + +def delete_token_rate_limit( + db_session: Session, + token_rate_limit_id: int, +) -> None: + token_limit = db_session.get(TokenRateLimit, token_rate_limit_id) + if token_limit is None: + raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found") + + db_session.query(TokenRateLimit__UserGroup).filter( + TokenRateLimit__UserGroup.rate_limit_id == token_rate_limit_id + ).delete() + + db_session.delete(token_limit) + db_session.commit() diff --git a/backend/ee/danswer/db/usage_export.py b/backend/ee/danswer/db/usage_export.py new file mode 100644 index 00000000000..bf53362e97e --- /dev/null +++ b/backend/ee/danswer/db/usage_export.py @@ -0,0 +1,108 @@ +import uuid +from collections.abc import Generator +from datetime import datetime +from typing import IO + +from fastapi_users_db_sqlalchemy import UUID_ID +from sqlalchemy.orm import Session + +from danswer.configs.constants import MessageType +from danswer.db.models import UsageReport +from danswer.file_store.file_store import get_default_file_store +from ee.danswer.db.query_history import fetch_chat_sessions_eagerly_by_time +from ee.danswer.server.reporting.usage_export_models import ChatMessageSkeleton +from ee.danswer.server.reporting.usage_export_models import FlowType +from ee.danswer.server.reporting.usage_export_models import UsageReportMetadata + + +# Gets skeletons of all message +def get_empty_chat_messages_entries__paginated( + db_session: Session, + period: tuple[datetime, datetime], + limit: int | None = 1, + initial_id: int | None = None, +) -> list[ChatMessageSkeleton]: + chat_sessions = fetch_chat_sessions_eagerly_by_time( + period[0], period[1], db_session, limit=limit, initial_id=initial_id + ) + + message_skeletons: list[ChatMessageSkeleton] = [] + for chat_session in chat_sessions: + if chat_session.one_shot: + flow_type = FlowType.SEARCH + elif chat_session.danswerbot_flow: + flow_type = FlowType.SLACK + else: + flow_type = FlowType.CHAT + + for message in chat_session.messages: + # only count user messages + if message.message_type != MessageType.USER: + continue + + message_skeletons.append( + ChatMessageSkeleton( + message_id=chat_session.id, + chat_session_id=chat_session.id, + user_id=str(chat_session.user_id) if chat_session.user_id else None, + flow_type=flow_type, + time_sent=message.time_sent, + ) + ) + + return message_skeletons + + +def get_all_empty_chat_message_entries( + db_session: Session, + period: tuple[datetime, datetime], +) -> Generator[list[ChatMessageSkeleton], None, None]: + initial_id = None + while True: + message_skeletons = get_empty_chat_messages_entries__paginated( + db_session, period, initial_id=initial_id + ) + if not message_skeletons: + return + + yield message_skeletons + initial_id = message_skeletons[-1].message_id + + +def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]: + return [ + UsageReportMetadata( + report_name=r.report_name, + requestor=str(r.requestor_user_id) if r.requestor_user_id else None, + time_created=r.time_created, + period_from=r.period_from, + period_to=r.period_to, + ) + for r in db_session.query(UsageReport).all() + ] + + +def get_usage_report_data( + db_session: Session, + report_name: str, +) -> IO: + file_store = get_default_file_store(db_session) + # usage report may be very large, so don't load it all into memory + return file_store.read_file(file_name=report_name, mode="b", use_tempfile=True) + + +def write_usage_report( + db_session: Session, + report_name: str, + user_id: uuid.UUID | UUID_ID | None, + period: tuple[datetime, datetime] | None, +) -> UsageReport: + new_report = UsageReport( + report_name=report_name, + requestor_user_id=user_id, + period_from=period[0] if period else None, + period_to=period[1] if period else None, + ) + db_session.add(new_report) + db_session.commit() + return new_report diff --git a/backend/ee/danswer/db/user_group.py b/backend/ee/danswer/db/user_group.py new file mode 100644 index 00000000000..0451db9b633 --- /dev/null +++ b/backend/ee/danswer/db/user_group.py @@ -0,0 +1,332 @@ +from collections.abc import Sequence +from operator import and_ +from uuid import UUID + +from sqlalchemy import func +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.db.models import ConnectorCredentialPair +from danswer.db.models import Document +from danswer.db.models import DocumentByConnectorCredentialPair +from danswer.db.models import TokenRateLimit__UserGroup +from danswer.db.models import User +from danswer.db.models import User__UserGroup +from danswer.db.models import UserGroup +from danswer.db.models import UserGroup__ConnectorCredentialPair +from danswer.server.documents.models import ConnectorCredentialPairIdentifier +from ee.danswer.server.user_group.models import UserGroupCreate +from ee.danswer.server.user_group.models import UserGroupUpdate + + +def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | None: + stmt = select(UserGroup).where(UserGroup.id == user_group_id) + return db_session.scalar(stmt) + + +def fetch_user_groups( + db_session: Session, only_current: bool = True +) -> Sequence[UserGroup]: + stmt = select(UserGroup) + if only_current: + stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712 + return db_session.scalars(stmt).all() + + +def fetch_user_groups_for_user( + db_session: Session, user_id: UUID +) -> Sequence[UserGroup]: + stmt = ( + select(UserGroup) + .join(User__UserGroup, User__UserGroup.user_group_id == UserGroup.id) + .join(User, User.id == User__UserGroup.user_id) # type: ignore + .where(User.id == user_id) # type: ignore + ) + return db_session.scalars(stmt).all() + + +def fetch_documents_for_user_group_paginated( + db_session: Session, + user_group_id: int, + last_document_id: str | None = None, + limit: int = 100, +) -> tuple[Sequence[Document], str | None]: + stmt = ( + select(Document) + .join( + DocumentByConnectorCredentialPair, + Document.id == DocumentByConnectorCredentialPair.id, + ) + .join( + ConnectorCredentialPair, + and_( + DocumentByConnectorCredentialPair.connector_id + == ConnectorCredentialPair.connector_id, + DocumentByConnectorCredentialPair.credential_id + == ConnectorCredentialPair.credential_id, + ), + ) + .join( + UserGroup__ConnectorCredentialPair, + UserGroup__ConnectorCredentialPair.cc_pair_id == ConnectorCredentialPair.id, + ) + .join( + UserGroup, + UserGroup__ConnectorCredentialPair.user_group_id == UserGroup.id, + ) + .where(UserGroup.id == user_group_id) + .order_by(Document.id) + .limit(limit) + ) + if last_document_id is not None: + stmt = stmt.where(Document.id > last_document_id) + stmt = stmt.distinct() + + documents = db_session.scalars(stmt).all() + return documents, documents[-1].id if documents else None + + +def fetch_user_groups_for_documents( + db_session: Session, + document_ids: list[str], + cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None, +) -> Sequence[tuple[int, list[str]]]: + stmt = ( + select(Document.id, func.array_agg(UserGroup.name)) + .join( + UserGroup__ConnectorCredentialPair, + UserGroup.id == UserGroup__ConnectorCredentialPair.user_group_id, + ) + .join( + ConnectorCredentialPair, + ConnectorCredentialPair.id == UserGroup__ConnectorCredentialPair.cc_pair_id, + ) + .join( + DocumentByConnectorCredentialPair, + and_( + DocumentByConnectorCredentialPair.connector_id + == ConnectorCredentialPair.connector_id, + DocumentByConnectorCredentialPair.credential_id + == ConnectorCredentialPair.credential_id, + ), + ) + .join(Document, Document.id == DocumentByConnectorCredentialPair.id) + .where(Document.id.in_(document_ids)) + .where(UserGroup__ConnectorCredentialPair.is_current == True) # noqa: E712 + .group_by(Document.id) + ) + + # pretend that the specified cc pair doesn't exist + if cc_pair_to_delete is not None: + stmt = stmt.where( + and_( + ConnectorCredentialPair.connector_id != cc_pair_to_delete.connector_id, + ConnectorCredentialPair.credential_id + != cc_pair_to_delete.credential_id, + ) + ) + + return db_session.execute(stmt).all() # type: ignore + + +def _check_user_group_is_modifiable(user_group: UserGroup) -> None: + if not user_group.is_up_to_date: + raise ValueError( + "Specified user group is currently syncing. Wait until the current " + "sync has finished before editing." + ) + + +def _add_user__user_group_relationships__no_commit( + db_session: Session, user_group_id: int, user_ids: list[UUID] +) -> list[User__UserGroup]: + """NOTE: does not commit the transaction.""" + relationships = [ + User__UserGroup(user_id=user_id, user_group_id=user_group_id) + for user_id in user_ids + ] + db_session.add_all(relationships) + return relationships + + +def _add_user_group__cc_pair_relationships__no_commit( + db_session: Session, user_group_id: int, cc_pair_ids: list[int] +) -> list[UserGroup__ConnectorCredentialPair]: + """NOTE: does not commit the transaction.""" + relationships = [ + UserGroup__ConnectorCredentialPair( + user_group_id=user_group_id, cc_pair_id=cc_pair_id + ) + for cc_pair_id in cc_pair_ids + ] + db_session.add_all(relationships) + return relationships + + +def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup: + db_user_group = UserGroup(name=user_group.name) + db_session.add(db_user_group) + db_session.flush() # give the group an ID + + _add_user__user_group_relationships__no_commit( + db_session=db_session, + user_group_id=db_user_group.id, + user_ids=user_group.user_ids, + ) + _add_user_group__cc_pair_relationships__no_commit( + db_session=db_session, + user_group_id=db_user_group.id, + cc_pair_ids=user_group.cc_pair_ids, + ) + + db_session.commit() + return db_user_group + + +def _cleanup_user__user_group_relationships__no_commit( + db_session: Session, user_group_id: int +) -> None: + """NOTE: does not commit the transaction.""" + user__user_group_relationships = db_session.scalars( + select(User__UserGroup).where(User__UserGroup.user_group_id == user_group_id) + ).all() + for user__user_group_relationship in user__user_group_relationships: + db_session.delete(user__user_group_relationship) + + +def _mark_user_group__cc_pair_relationships_outdated__no_commit( + db_session: Session, user_group_id: int +) -> None: + """NOTE: does not commit the transaction.""" + user_group__cc_pair_relationships = db_session.scalars( + select(UserGroup__ConnectorCredentialPair).where( + UserGroup__ConnectorCredentialPair.user_group_id == user_group_id + ) + ) + for user_group__cc_pair_relationship in user_group__cc_pair_relationships: + user_group__cc_pair_relationship.is_current = False + + +def update_user_group( + db_session: Session, user_group_id: int, user_group: UserGroupUpdate +) -> UserGroup: + stmt = select(UserGroup).where(UserGroup.id == user_group_id) + db_user_group = db_session.scalar(stmt) + if db_user_group is None: + raise ValueError(f"UserGroup with id '{user_group_id}' not found") + + _check_user_group_is_modifiable(db_user_group) + + existing_cc_pairs = db_user_group.cc_pairs + cc_pairs_updated = set([cc_pair.id for cc_pair in existing_cc_pairs]) != set( + user_group.cc_pair_ids + ) + users_updated = set([user.id for user in db_user_group.users]) != set( + user_group.user_ids + ) + + if users_updated: + _cleanup_user__user_group_relationships__no_commit( + db_session=db_session, user_group_id=user_group_id + ) + _add_user__user_group_relationships__no_commit( + db_session=db_session, + user_group_id=user_group_id, + user_ids=user_group.user_ids, + ) + if cc_pairs_updated: + _mark_user_group__cc_pair_relationships_outdated__no_commit( + db_session=db_session, user_group_id=user_group_id + ) + _add_user_group__cc_pair_relationships__no_commit( + db_session=db_session, + user_group_id=db_user_group.id, + cc_pair_ids=user_group.cc_pair_ids, + ) + + # only needs to sync with Vespa if the cc_pairs have been updated + if cc_pairs_updated: + db_user_group.is_up_to_date = False + + db_session.commit() + return db_user_group + + +def _cleanup_token_rate_limit__user_group_relationships__no_commit( + db_session: Session, user_group_id: int +) -> None: + """NOTE: does not commit the transaction.""" + token_rate_limit__user_group_relationships = db_session.scalars( + select(TokenRateLimit__UserGroup).where( + TokenRateLimit__UserGroup.user_group_id == user_group_id + ) + ).all() + for ( + token_rate_limit__user_group_relationship + ) in token_rate_limit__user_group_relationships: + db_session.delete(token_rate_limit__user_group_relationship) + + +def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None: + stmt = select(UserGroup).where(UserGroup.id == user_group_id) + db_user_group = db_session.scalar(stmt) + if db_user_group is None: + raise ValueError(f"UserGroup with id '{user_group_id}' not found") + + _check_user_group_is_modifiable(db_user_group) + + _cleanup_user__user_group_relationships__no_commit( + db_session=db_session, user_group_id=user_group_id + ) + _mark_user_group__cc_pair_relationships_outdated__no_commit( + db_session=db_session, user_group_id=user_group_id + ) + _cleanup_token_rate_limit__user_group_relationships__no_commit( + db_session=db_session, user_group_id=user_group_id + ) + + db_user_group.is_up_to_date = False + db_user_group.is_up_for_deletion = True + db_session.commit() + + +def _cleanup_user_group__cc_pair_relationships__no_commit( + db_session: Session, user_group_id: int, outdated_only: bool +) -> None: + """NOTE: does not commit the transaction.""" + stmt = select(UserGroup__ConnectorCredentialPair).where( + UserGroup__ConnectorCredentialPair.user_group_id == user_group_id + ) + if outdated_only: + stmt = stmt.where( + UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712 + ) + user_group__cc_pair_relationships = db_session.scalars(stmt) + for user_group__cc_pair_relationship in user_group__cc_pair_relationships: + db_session.delete(user_group__cc_pair_relationship) + + +def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> None: + # cleanup outdated relationships + _cleanup_user_group__cc_pair_relationships__no_commit( + db_session=db_session, user_group_id=user_group.id, outdated_only=True + ) + user_group.is_up_to_date = True + db_session.commit() + + +def delete_user_group(db_session: Session, user_group: UserGroup) -> None: + _cleanup_user__user_group_relationships__no_commit( + db_session=db_session, user_group_id=user_group.id + ) + _cleanup_user_group__cc_pair_relationships__no_commit( + db_session=db_session, + user_group_id=user_group.id, + outdated_only=False, + ) + + # need to flush so that we don't get a foreign key error when deleting the user group row + db_session.flush() + + db_session.delete(user_group) + db_session.commit() diff --git a/backend/ee/danswer/main.py b/backend/ee/danswer/main.py new file mode 100644 index 00000000000..d7d1d6406a3 --- /dev/null +++ b/backend/ee/danswer/main.py @@ -0,0 +1,107 @@ +from fastapi import FastAPI +from httpx_oauth.clients.openid import OpenID + +from danswer.auth.users import auth_backend +from danswer.auth.users import fastapi_users +from danswer.configs.app_configs import AUTH_TYPE +from danswer.configs.app_configs import OAUTH_CLIENT_ID +from danswer.configs.app_configs import OAUTH_CLIENT_SECRET +from danswer.configs.app_configs import USER_AUTH_SECRET +from danswer.configs.app_configs import WEB_DOMAIN +from danswer.configs.constants import AuthType +from danswer.main import get_application as get_application_base +from danswer.main import include_router_with_global_prefix_prepended +from danswer.utils.logger import setup_logger +from danswer.utils.variable_functionality import global_version +from ee.danswer.configs.app_configs import OPENID_CONFIG_URL +from ee.danswer.server.analytics.api import router as analytics_router +from ee.danswer.server.api_key.api import router as api_key_router +from ee.danswer.server.auth_check import check_ee_router_auth +from ee.danswer.server.enterprise_settings.api import ( + admin_router as enterprise_settings_admin_router, +) +from ee.danswer.server.enterprise_settings.api import ( + basic_router as enterprise_settings_router, +) +from ee.danswer.server.query_and_chat.chat_backend import ( + router as chat_router, +) +from ee.danswer.server.query_and_chat.query_backend import ( + basic_router as query_router, +) +from ee.danswer.server.query_history.api import router as query_history_router +from ee.danswer.server.reporting.usage_export_api import router as usage_export_router +from ee.danswer.server.saml import router as saml_router +from ee.danswer.server.seeding import seed_db +from ee.danswer.server.token_rate_limits.api import ( + router as token_rate_limit_settings_router, +) +from ee.danswer.server.user_group.api import router as user_group_router +from ee.danswer.utils.encryption import test_encryption + +logger = setup_logger() + + +def get_application() -> FastAPI: + # Anything that happens at import time is not guaranteed to be running ee-version + # Anything after the server startup will be running ee version + global_version.set_ee() + + test_encryption() + + application = get_application_base() + + if AUTH_TYPE == AuthType.OIDC: + include_router_with_global_prefix_prepended( + application, + fastapi_users.get_oauth_router( + OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL), + auth_backend, + USER_AUTH_SECRET, + associate_by_email=True, + is_verified_by_default=True, + redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback", + ), + prefix="/auth/oidc", + tags=["auth"], + ) + # need basic auth router for `logout` endpoint + include_router_with_global_prefix_prepended( + application, + fastapi_users.get_auth_router(auth_backend), + prefix="/auth", + tags=["auth"], + ) + + elif AUTH_TYPE == AuthType.SAML: + include_router_with_global_prefix_prepended(application, saml_router) + + # RBAC / group access control + include_router_with_global_prefix_prepended(application, user_group_router) + # Analytics endpoints + include_router_with_global_prefix_prepended(application, analytics_router) + include_router_with_global_prefix_prepended(application, query_history_router) + # Api key management + include_router_with_global_prefix_prepended(application, api_key_router) + # EE only backend APIs + include_router_with_global_prefix_prepended(application, query_router) + include_router_with_global_prefix_prepended(application, chat_router) + # Enterprise-only global settings + include_router_with_global_prefix_prepended( + application, enterprise_settings_admin_router + ) + # Token rate limit settings + include_router_with_global_prefix_prepended( + application, token_rate_limit_settings_router + ) + include_router_with_global_prefix_prepended(application, enterprise_settings_router) + include_router_with_global_prefix_prepended(application, usage_export_router) + + # Ensure all routes have auth enabled or are explicitly marked as public + check_ee_router_auth(application) + + # seed the Danswer environment with LLMs, Assistants, etc. based on an optional + # environment variable. Used to automate deployment for multiple environments. + seed_db() + + return application diff --git a/backend/ee/danswer/server/__init__.py b/backend/ee/danswer/server/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/ee/danswer/server/analytics/api.py b/backend/ee/danswer/server/analytics/api.py new file mode 100644 index 00000000000..19415e506d3 --- /dev/null +++ b/backend/ee/danswer/server/analytics/api.py @@ -0,0 +1,117 @@ +import datetime +from collections import defaultdict + +from fastapi import APIRouter +from fastapi import Depends +from pydantic import BaseModel +from sqlalchemy.orm import Session + +import danswer.db.models as db_models +from danswer.auth.users import current_admin_user +from danswer.db.engine import get_session +from ee.danswer.db.analytics import fetch_danswerbot_analytics +from ee.danswer.db.analytics import fetch_per_user_query_analytics +from ee.danswer.db.analytics import fetch_query_analytics + +router = APIRouter(prefix="/analytics") + + +class QueryAnalyticsResponse(BaseModel): + total_queries: int + total_likes: int + total_dislikes: int + date: datetime.date + + +@router.get("/admin/query") +def get_query_analytics( + start: datetime.datetime | None = None, + end: datetime.datetime | None = None, + _: db_models.User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[QueryAnalyticsResponse]: + daily_query_usage_info = fetch_query_analytics( + start=start + or ( + datetime.datetime.utcnow() - datetime.timedelta(days=30) + ), # default is 30d lookback + end=end or datetime.datetime.utcnow(), + db_session=db_session, + ) + return [ + QueryAnalyticsResponse( + total_queries=total_queries, + total_likes=total_likes, + total_dislikes=total_dislikes, + date=date, + ) + for total_queries, total_likes, total_dislikes, date in daily_query_usage_info + ] + + +class UserAnalyticsResponse(BaseModel): + total_active_users: int + date: datetime.date + + +@router.get("/admin/user") +def get_user_analytics( + start: datetime.datetime | None = None, + end: datetime.datetime | None = None, + _: db_models.User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[UserAnalyticsResponse]: + daily_query_usage_info_per_user = fetch_per_user_query_analytics( + start=start + or ( + datetime.datetime.utcnow() - datetime.timedelta(days=30) + ), # default is 30d lookback + end=end or datetime.datetime.utcnow(), + db_session=db_session, + ) + + user_analytics: dict[datetime.date, int] = defaultdict(int) + for __, ___, ____, date, _____ in daily_query_usage_info_per_user: + user_analytics[date] += 1 + return [ + UserAnalyticsResponse( + total_active_users=cnt, + date=date, + ) + for date, cnt in user_analytics.items() + ] + + +class DanswerbotAnalyticsResponse(BaseModel): + total_queries: int + auto_resolved: int + date: datetime.date + + +@router.get("/admin/danswerbot") +def get_danswerbot_analytics( + start: datetime.datetime | None = None, + end: datetime.datetime | None = None, + _: db_models.User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[DanswerbotAnalyticsResponse]: + daily_danswerbot_info = fetch_danswerbot_analytics( + start=start + or ( + datetime.datetime.utcnow() - datetime.timedelta(days=30) + ), # default is 30d lookback + end=end or datetime.datetime.utcnow(), + db_session=db_session, + ) + + resolution_results = [ + DanswerbotAnalyticsResponse( + total_queries=total_queries, + # If it hits negatives, something has gone wrong... + auto_resolved=max(0, total_queries - total_negatives), + date=date, + ) + for total_queries, total_negatives, date in daily_danswerbot_info + ] + + return resolution_results diff --git a/backend/ee/danswer/server/api_key/api.py b/backend/ee/danswer/server/api_key/api.py new file mode 100644 index 00000000000..d9a62d0655a --- /dev/null +++ b/backend/ee/danswer/server/api_key/api.py @@ -0,0 +1,61 @@ +from fastapi import APIRouter +from fastapi import Depends +from sqlalchemy.orm import Session + +import danswer.db.models as db_models +from danswer.auth.users import current_admin_user +from danswer.db.engine import get_session +from ee.danswer.db.api_key import ApiKeyDescriptor +from ee.danswer.db.api_key import fetch_api_keys +from ee.danswer.db.api_key import insert_api_key +from ee.danswer.db.api_key import regenerate_api_key +from ee.danswer.db.api_key import remove_api_key +from ee.danswer.db.api_key import update_api_key +from ee.danswer.server.api_key.models import APIKeyArgs + +router = APIRouter(prefix="/admin/api-key") + + +@router.get("") +def list_api_keys( + _: db_models.User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[ApiKeyDescriptor]: + return fetch_api_keys(db_session) + + +@router.post("") +def create_api_key( + api_key_args: APIKeyArgs, + user: db_models.User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> ApiKeyDescriptor: + return insert_api_key(db_session, api_key_args, user.id if user else None) + + +@router.post("/{api_key_id}/regenerate") +def regenerate_existing_api_key( + api_key_id: int, + _: db_models.User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> ApiKeyDescriptor: + return regenerate_api_key(db_session, api_key_id) + + +@router.patch("/{api_key_id}") +def update_existing_api_key( + api_key_id: int, + api_key_args: APIKeyArgs, + _: db_models.User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> ApiKeyDescriptor: + return update_api_key(db_session, api_key_id, api_key_args) + + +@router.delete("/{api_key_id}") +def delete_api_key( + api_key_id: int, + _: db_models.User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + remove_api_key(db_session, api_key_id) diff --git a/backend/ee/danswer/server/api_key/models.py b/backend/ee/danswer/server/api_key/models.py new file mode 100644 index 00000000000..a26e27512e4 --- /dev/null +++ b/backend/ee/danswer/server/api_key/models.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class APIKeyArgs(BaseModel): + name: str | None = None diff --git a/backend/ee/danswer/server/auth_check.py b/backend/ee/danswer/server/auth_check.py new file mode 100644 index 00000000000..d0ba3ffe46c --- /dev/null +++ b/backend/ee/danswer/server/auth_check.py @@ -0,0 +1,28 @@ +from fastapi import FastAPI + +from danswer.server.auth_check import check_router_auth +from danswer.server.auth_check import PUBLIC_ENDPOINT_SPECS + + +EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [ + # needs to be accessible prior to user login + ("/enterprise-settings", {"GET"}), + ("/enterprise-settings/logo", {"GET"}), + ("/enterprise-settings/custom-analytics-script", {"GET"}), + # oidc + ("/auth/oidc/authorize", {"GET"}), + ("/auth/oidc/callback", {"GET"}), + # saml + ("/auth/saml/authorize", {"GET"}), + ("/auth/saml/callback", {"POST"}), + ("/auth/saml/logout", {"POST"}), +] + + +def check_ee_router_auth( + application: FastAPI, + public_endpoint_specs: list[tuple[str, set[str]]] = EE_PUBLIC_ENDPOINT_SPECS, +) -> None: + # similar to the open source version of this function, but checking for the EE-only + # endpoints as well + check_router_auth(application, public_endpoint_specs) diff --git a/backend/ee/danswer/server/enterprise_settings/api.py b/backend/ee/danswer/server/enterprise_settings/api.py new file mode 100644 index 00000000000..85f43ca5541 --- /dev/null +++ b/backend/ee/danswer/server/enterprise_settings/api.py @@ -0,0 +1,74 @@ +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from fastapi import Response +from fastapi import UploadFile +from sqlalchemy.orm import Session + +from danswer.auth.users import current_admin_user +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.file_store.file_store import get_default_file_store +from ee.danswer.server.enterprise_settings.models import AnalyticsScriptUpload +from ee.danswer.server.enterprise_settings.models import EnterpriseSettings +from ee.danswer.server.enterprise_settings.store import _LOGO_FILENAME +from ee.danswer.server.enterprise_settings.store import load_analytics_script +from ee.danswer.server.enterprise_settings.store import load_settings +from ee.danswer.server.enterprise_settings.store import store_analytics_script +from ee.danswer.server.enterprise_settings.store import store_settings +from ee.danswer.server.enterprise_settings.store import upload_logo + +admin_router = APIRouter(prefix="/admin/enterprise-settings") +basic_router = APIRouter(prefix="/enterprise-settings") + + +@admin_router.put("") +def put_settings( + settings: EnterpriseSettings, _: User | None = Depends(current_admin_user) +) -> None: + try: + settings.check_validity() + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + store_settings(settings) + + +@basic_router.get("") +def fetch_settings() -> EnterpriseSettings: + return load_settings() + + +@admin_router.put("/logo") +def put_logo( + file: UploadFile, + db_session: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> None: + upload_logo(file=file, db_session=db_session) + + +@basic_router.get("/logo") +def fetch_logo(db_session: Session = Depends(get_session)) -> Response: + try: + file_store = get_default_file_store(db_session) + file_io = file_store.read_file(_LOGO_FILENAME, mode="b") + # NOTE: specifying "image/jpeg" here, but it still works for pngs + # TODO: do this properly + return Response(content=file_io.read(), media_type="image/jpeg") + except Exception: + raise HTTPException(status_code=404, detail="No logo file found") + + +@admin_router.put("/custom-analytics-script") +def upload_custom_analytics_script( + script_upload: AnalyticsScriptUpload, _: User | None = Depends(current_admin_user) +) -> None: + try: + store_analytics_script(script_upload) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@basic_router.get("/custom-analytics-script") +def fetch_custom_analytics_script() -> str | None: + return load_analytics_script() diff --git a/backend/ee/danswer/server/enterprise_settings/models.py b/backend/ee/danswer/server/enterprise_settings/models.py new file mode 100644 index 00000000000..c2142c64162 --- /dev/null +++ b/backend/ee/danswer/server/enterprise_settings/models.py @@ -0,0 +1,23 @@ +from pydantic import BaseModel + + +class EnterpriseSettings(BaseModel): + """General settings that only apply to the Enterprise Edition of Danswer + + NOTE: don't put anything sensitive in here, as this is accessible without auth.""" + + application_name: str | None = None + use_custom_logo: bool = False + + # custom Chat components + custom_header_content: str | None = None + custom_popup_header: str | None = None + custom_popup_content: str | None = None + + def check_validity(self) -> None: + return + + +class AnalyticsScriptUpload(BaseModel): + script: str + secret_key: str diff --git a/backend/ee/danswer/server/enterprise_settings/store.py b/backend/ee/danswer/server/enterprise_settings/store.py new file mode 100644 index 00000000000..99fb1cc90d6 --- /dev/null +++ b/backend/ee/danswer/server/enterprise_settings/store.py @@ -0,0 +1,120 @@ +import os +from io import BytesIO +from typing import Any +from typing import cast +from typing import IO + +from fastapi import HTTPException +from fastapi import UploadFile +from sqlalchemy.orm import Session + +from danswer.configs.constants import FileOrigin +from danswer.dynamic_configs.factory import get_dynamic_config_store +from danswer.dynamic_configs.interface import ConfigNotFoundError +from danswer.file_store.file_store import get_default_file_store +from danswer.utils.logger import setup_logger +from ee.danswer.server.enterprise_settings.models import AnalyticsScriptUpload +from ee.danswer.server.enterprise_settings.models import EnterpriseSettings + + +_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings" +logger = setup_logger() + + +def load_settings() -> EnterpriseSettings: + dynamic_config_store = get_dynamic_config_store() + try: + settings = EnterpriseSettings( + **cast(dict, dynamic_config_store.load(_ENTERPRISE_SETTINGS_KEY)) + ) + except ConfigNotFoundError: + settings = EnterpriseSettings() + dynamic_config_store.store(_ENTERPRISE_SETTINGS_KEY, settings.dict()) + + return settings + + +def store_settings(settings: EnterpriseSettings) -> None: + get_dynamic_config_store().store(_ENTERPRISE_SETTINGS_KEY, settings.dict()) + + +_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__" +_CUSTOM_ANALYTICS_SECRET_KEY = os.environ.get("CUSTOM_ANALYTICS_SECRET_KEY") + + +def load_analytics_script() -> str | None: + dynamic_config_store = get_dynamic_config_store() + try: + return cast(str, dynamic_config_store.load(_CUSTOM_ANALYTICS_SCRIPT_KEY)) + except ConfigNotFoundError: + return None + + +def store_analytics_script(analytics_script_upload: AnalyticsScriptUpload) -> None: + if ( + not _CUSTOM_ANALYTICS_SECRET_KEY + or analytics_script_upload.secret_key != _CUSTOM_ANALYTICS_SECRET_KEY + ): + raise ValueError("Invalid secret key") + + get_dynamic_config_store().store( + _CUSTOM_ANALYTICS_SCRIPT_KEY, analytics_script_upload.script + ) + + +_LOGO_FILENAME = "__logo__" + + +def is_valid_file_type(filename: str) -> bool: + valid_extensions = (".png", ".jpg", ".jpeg") + return filename.endswith(valid_extensions) + + +def guess_file_type(filename: str) -> str: + if filename.lower().endswith(".png"): + return "image/png" + elif filename.lower().endswith(".jpg") or filename.lower().endswith(".jpeg"): + return "image/jpeg" + return "application/octet-stream" + + +def upload_logo( + db_session: Session, + file: UploadFile | str, +) -> bool: + content: IO[Any] + + if isinstance(file, str): + logger.info(f"Uploading logo from local path {file}") + if not os.path.isfile(file) or not is_valid_file_type(file): + logger.error( + "Invalid file type- only .png, .jpg, and .jpeg files are allowed" + ) + return False + + with open(file, "rb") as file_handle: + file_content = file_handle.read() + content = BytesIO(file_content) + display_name = file + file_type = guess_file_type(file) + + else: + logger.info("Uploading logo from uploaded file") + if not file.filename or not is_valid_file_type(file.filename): + raise HTTPException( + status_code=400, + detail="Invalid file type- only .png, .jpg, and .jpeg files are allowed", + ) + content = file.file + display_name = file.filename + file_type = file.content_type or "image/jpeg" + + file_store = get_default_file_store(db_session) + file_store.save_file( + file_name=_LOGO_FILENAME, + content=content, + display_name=display_name, + file_origin=FileOrigin.OTHER, + file_type=file_type, + ) + return True diff --git a/backend/ee/danswer/server/query_and_chat/__init__.py b/backend/ee/danswer/server/query_and_chat/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/ee/danswer/server/query_and_chat/chat_backend.py b/backend/ee/danswer/server/query_and_chat/chat_backend.py new file mode 100644 index 00000000000..1d24026003f --- /dev/null +++ b/backend/ee/danswer/server/query_and_chat/chat_backend.py @@ -0,0 +1,123 @@ +import re + +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from sqlalchemy.orm import Session + +from danswer.auth.users import current_user +from danswer.chat.chat_utils import create_chat_chain +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import QADocsResponse +from danswer.chat.models import StreamingError +from danswer.chat.process_message import stream_chat_message_objects +from danswer.db.chat import get_or_create_root_message +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.search.models import OptionalSearchSetting +from danswer.search.models import RetrievalDetails +from danswer.server.query_and_chat.models import ChatMessageDetail +from danswer.server.query_and_chat.models import CreateChatMessageRequest +from danswer.utils.logger import setup_logger +from ee.danswer.server.query_and_chat.models import BasicCreateChatMessageRequest +from ee.danswer.server.query_and_chat.models import ChatBasicResponse +from ee.danswer.server.query_and_chat.models import SimpleDoc + +logger = setup_logger() + +router = APIRouter(prefix="/chat") + + +def translate_doc_response_to_simple_doc( + doc_response: QADocsResponse, +) -> list[SimpleDoc]: + return [ + SimpleDoc( + semantic_identifier=doc.semantic_identifier, + link=doc.link, + blurb=doc.blurb, + match_highlights=[ + highlight for highlight in doc.match_highlights if highlight + ], + source_type=doc.source_type, + ) + for doc in doc_response.top_documents + ] + + +def remove_answer_citations(answer: str) -> str: + pattern = r"\s*\[\[\d+\]\]\(http[s]?://[^\s]+\)" + + return re.sub(pattern, "", answer) + + +@router.post("/send-message-simple-api") +def handle_simplified_chat_message( + chat_message_req: BasicCreateChatMessageRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> ChatBasicResponse: + """This is a Non-Streaming version that only gives back a minimal set of information""" + logger.info(f"Received new simple api chat message: {chat_message_req.message}") + + if not chat_message_req.message: + raise HTTPException(status_code=400, detail="Empty chat message is invalid") + + try: + parent_message, _ = create_chat_chain( + chat_session_id=chat_message_req.chat_session_id, db_session=db_session + ) + except Exception: + parent_message = get_or_create_root_message( + chat_session_id=chat_message_req.chat_session_id, db_session=db_session + ) + + if ( + chat_message_req.retrieval_options is None + and chat_message_req.search_doc_ids is None + ): + retrieval_options: RetrievalDetails | None = RetrievalDetails( + run_search=OptionalSearchSetting.ALWAYS, + real_time=False, + ) + else: + retrieval_options = chat_message_req.retrieval_options + + full_chat_msg_info = CreateChatMessageRequest( + chat_session_id=chat_message_req.chat_session_id, + parent_message_id=parent_message.id, + message=chat_message_req.message, + file_descriptors=[], + prompt_id=None, + search_doc_ids=chat_message_req.search_doc_ids, + retrieval_options=retrieval_options, + query_override=chat_message_req.query_override, + chunks_above=chat_message_req.chunks_above, + chunks_below=chat_message_req.chunks_below, + full_doc=chat_message_req.full_doc, + ) + + packets = stream_chat_message_objects( + new_msg_req=full_chat_msg_info, + user=user, + db_session=db_session, + ) + + response = ChatBasicResponse() + + answer = "" + for packet in packets: + if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece: + answer += packet.answer_piece + elif isinstance(packet, QADocsResponse): + response.simple_search_docs = translate_doc_response_to_simple_doc(packet) + elif isinstance(packet, StreamingError): + response.error_msg = packet.error + elif isinstance(packet, ChatMessageDetail): + response.message_id = packet.message_id + + response.answer = answer + if answer: + response.answer_citationless = remove_answer_citations(answer) + + return response diff --git a/backend/ee/danswer/server/query_and_chat/models.py b/backend/ee/danswer/server/query_and_chat/models.py new file mode 100644 index 00000000000..ec5ec5ac411 --- /dev/null +++ b/backend/ee/danswer/server/query_and_chat/models.py @@ -0,0 +1,50 @@ +from pydantic import BaseModel + +from danswer.configs.constants import DocumentSource +from danswer.search.enums import SearchType +from danswer.search.models import ChunkContext +from danswer.search.models import RetrievalDetails + + +class DocumentSearchRequest(ChunkContext): + message: str + search_type: SearchType + retrieval_options: RetrievalDetails + recency_bias_multiplier: float = 1.0 + # This is to forcibly skip (or run) the step, if None it uses the system defaults + skip_rerank: bool | None = None + skip_llm_chunk_filter: bool | None = None + + +class BasicCreateChatMessageRequest(ChunkContext): + """Before creating messages, be sure to create a chat_session and get an id + Note, for simplicity this option only allows for a single linear chain of messages + """ + + chat_session_id: int + # New message contents + message: str + # Defaults to using retrieval with no additional filters + retrieval_options: RetrievalDetails | None = None + # Allows the caller to specify the exact search query they want to use + # will disable Query Rewording if specified + query_override: str | None = None + # If search_doc_ids provided, then retrieval options are unused + search_doc_ids: list[int] | None = None + + +class SimpleDoc(BaseModel): + semantic_identifier: str + link: str | None + blurb: str + match_highlights: list[str] + source_type: DocumentSource + + +class ChatBasicResponse(BaseModel): + # This is built piece by piece, any of these can be None as the flow could break + answer: str | None = None + answer_citationless: str | None = None + simple_search_docs: list[SimpleDoc] | None = None + error_msg: str | None = None + message_id: int | None = None diff --git a/backend/ee/danswer/server/query_and_chat/query_backend.py b/backend/ee/danswer/server/query_and_chat/query_backend.py new file mode 100644 index 00000000000..f6cf7297a4f --- /dev/null +++ b/backend/ee/danswer/server/query_and_chat/query_backend.py @@ -0,0 +1,157 @@ +from fastapi import APIRouter +from fastapi import Depends +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from danswer.auth.users import current_user +from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.db.persona import get_persona_by_id +from danswer.llm.answering.prompts.citations_prompt import ( + compute_max_document_tokens_for_persona, +) +from danswer.llm.factory import get_default_llms +from danswer.llm.factory import get_llms_for_persona +from danswer.llm.factory import get_main_llm_from_tuple +from danswer.llm.utils import get_max_input_tokens +from danswer.one_shot_answer.answer_question import get_search_answer +from danswer.one_shot_answer.models import DirectQARequest +from danswer.one_shot_answer.models import OneShotQAResponse +from danswer.search.models import SavedSearchDocWithContent +from danswer.search.models import SearchRequest +from danswer.search.pipeline import SearchPipeline +from danswer.search.utils import dedupe_documents +from danswer.search.utils import drop_llm_indices +from danswer.utils.logger import setup_logger +from ee.danswer.server.query_and_chat.models import DocumentSearchRequest + + +logger = setup_logger() +basic_router = APIRouter(prefix="/query") + + +class DocumentSearchResponse(BaseModel): + top_documents: list[SavedSearchDocWithContent] + llm_indices: list[int] + + +@basic_router.post("/document-search") +def handle_search_request( + search_request: DocumentSearchRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> DocumentSearchResponse: + """Simple search endpoint, does not create a new message or records in the DB""" + query = search_request.message + logger.info(f"Received document search query: {query}") + + llm, fast_llm = get_default_llms() + search_pipeline = SearchPipeline( + search_request=SearchRequest( + query=query, + search_type=search_request.search_type, + human_selected_filters=search_request.retrieval_options.filters, + enable_auto_detect_filters=search_request.retrieval_options.enable_auto_detect_filters, + persona=None, # For simplicity, default settings should be good for this search + offset=search_request.retrieval_options.offset, + limit=search_request.retrieval_options.limit, + skip_rerank=search_request.skip_rerank, + skip_llm_chunk_filter=search_request.skip_llm_chunk_filter, + chunks_above=search_request.chunks_above, + chunks_below=search_request.chunks_below, + full_doc=search_request.full_doc, + ), + user=user, + llm=llm, + fast_llm=fast_llm, + db_session=db_session, + bypass_acl=False, + ) + top_sections = search_pipeline.reranked_sections + # If using surrounding context or full doc, this will be empty + relevant_section_indices = search_pipeline.relevant_section_indices + top_docs = [ + SavedSearchDocWithContent( + document_id=section.center_chunk.document_id, + chunk_ind=section.center_chunk.chunk_id, + content=section.center_chunk.content, + semantic_identifier=section.center_chunk.semantic_identifier or "Unknown", + link=section.center_chunk.source_links.get(0) + if section.center_chunk.source_links + else None, + blurb=section.center_chunk.blurb, + source_type=section.center_chunk.source_type, + boost=section.center_chunk.boost, + hidden=section.center_chunk.hidden, + metadata=section.center_chunk.metadata, + score=section.center_chunk.score or 0.0, + match_highlights=section.center_chunk.match_highlights, + updated_at=section.center_chunk.updated_at, + primary_owners=section.center_chunk.primary_owners, + secondary_owners=section.center_chunk.secondary_owners, + is_internet=False, + db_doc_id=0, + ) + for section in top_sections + ] + + # Deduping happens at the last step to avoid harming quality by dropping content early on + deduped_docs = top_docs + dropped_inds = None + if search_request.retrieval_options.dedupe_docs: + deduped_docs, dropped_inds = dedupe_documents(top_docs) + + if dropped_inds: + relevant_section_indices = drop_llm_indices( + llm_indices=relevant_section_indices, + search_docs=deduped_docs, + dropped_indices=dropped_inds, + ) + + return DocumentSearchResponse( + top_documents=deduped_docs, llm_indices=relevant_section_indices + ) + + +@basic_router.post("/answer-with-quote") +def get_answer_with_quote( + query_request: DirectQARequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> OneShotQAResponse: + query = query_request.messages[0].message + logger.info(f"Received query for one shot answer API with quotes: {query}") + + persona = get_persona_by_id( + persona_id=query_request.persona_id, + user=user, + db_session=db_session, + is_for_edit=False, + ) + + llm = get_main_llm_from_tuple( + get_default_llms() if not persona else get_llms_for_persona(persona) + ) + input_tokens = get_max_input_tokens( + model_name=llm.config.model_name, model_provider=llm.config.model_provider + ) + max_history_tokens = int(input_tokens * DANSWER_BOT_TARGET_CHUNK_PERCENTAGE) + + remaining_tokens = input_tokens - max_history_tokens + + max_document_tokens = compute_max_document_tokens_for_persona( + persona=persona, + actual_user_input=query, + max_llm_token_override=remaining_tokens, + ) + + answer_details = get_search_answer( + query_req=query_request, + user=user, + max_document_tokens=max_document_tokens, + max_history_tokens=max_history_tokens, + db_session=db_session, + ) + + return answer_details diff --git a/backend/ee/danswer/server/query_and_chat/token_limit.py b/backend/ee/danswer/server/query_and_chat/token_limit.py new file mode 100644 index 00000000000..538458fb63f --- /dev/null +++ b/backend/ee/danswer/server/query_and_chat/token_limit.py @@ -0,0 +1,184 @@ +from collections import defaultdict +from collections.abc import Sequence +from datetime import datetime +from itertools import groupby +from typing import Dict +from typing import List +from typing import Tuple +from uuid import UUID + +from fastapi import HTTPException +from sqlalchemy import func +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.db.engine import get_session_context_manager +from danswer.db.models import ChatMessage +from danswer.db.models import ChatSession +from danswer.db.models import TokenRateLimit +from danswer.db.models import TokenRateLimit__UserGroup +from danswer.db.models import User +from danswer.db.models import User__UserGroup +from danswer.db.models import UserGroup +from danswer.server.query_and_chat.token_limit import _get_cutoff_time +from danswer.server.query_and_chat.token_limit import _is_rate_limited +from danswer.server.query_and_chat.token_limit import _user_is_rate_limited_by_global +from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel +from ee.danswer.db.api_key import is_api_key_email_address +from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits + + +def _check_token_rate_limits(user: User | None) -> None: + if user is None: + # Unauthenticated users are only rate limited by global settings + _user_is_rate_limited_by_global() + + elif is_api_key_email_address(user.email): + # API keys are only rate limited by global settings + _user_is_rate_limited_by_global() + + else: + run_functions_tuples_in_parallel( + [ + (_user_is_rate_limited, (user.id,)), + (_user_is_rate_limited_by_group, (user.id,)), + (_user_is_rate_limited_by_global, ()), + ] + ) + + +""" +User rate limits +""" + + +def _user_is_rate_limited(user_id: UUID) -> None: + with get_session_context_manager() as db_session: + user_rate_limits = fetch_all_user_token_rate_limits( + db_session=db_session, enabled_only=True, ordered=False + ) + + if user_rate_limits: + user_cutoff_time = _get_cutoff_time(user_rate_limits) + user_usage = _fetch_user_usage(user_id, user_cutoff_time, db_session) + + if _is_rate_limited(user_rate_limits, user_usage): + raise HTTPException( + status_code=429, + detail="Token budget exceeded for user. Try again later.", + ) + + +def _fetch_user_usage( + user_id: UUID, cutoff_time: datetime, db_session: Session +) -> Sequence[tuple[datetime, int]]: + """ + Fetch user usage within the cutoff time, grouped by minute + """ + result = db_session.execute( + select( + func.date_trunc("minute", ChatMessage.time_sent), + func.sum(ChatMessage.token_count), + ) + .join(ChatSession, ChatMessage.chat_session_id == ChatSession.id) + .where(ChatSession.user_id == user_id, ChatMessage.time_sent >= cutoff_time) + .group_by(func.date_trunc("minute", ChatMessage.time_sent)) + ).all() + + return [(row[0], row[1]) for row in result] + + +""" +User Group rate limits +""" + + +def _user_is_rate_limited_by_group(user_id: UUID) -> None: + with get_session_context_manager() as db_session: + group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session) + + if group_rate_limits: + # Group cutoff time is the same for all groups. + # This could be optimized to only fetch the maximum cutoff time for + # a specific group, but seems unnecessary for now. + group_cutoff_time = _get_cutoff_time( + [e for sublist in group_rate_limits.values() for e in sublist] + ) + + user_group_ids = list(group_rate_limits.keys()) + group_usage = _fetch_user_group_usage( + user_group_ids, group_cutoff_time, db_session + ) + + has_at_least_one_untriggered_limit = False + for user_group_id, rate_limits in group_rate_limits.items(): + usage = group_usage.get(user_group_id, []) + + if not _is_rate_limited(rate_limits, usage): + has_at_least_one_untriggered_limit = True + break + + if not has_at_least_one_untriggered_limit: + raise HTTPException( + status_code=429, + detail="Token budget exceeded for user's groups. Try again later.", + ) + + +def _fetch_all_user_group_rate_limits( + user_id: UUID, db_session: Session +) -> Dict[int, List[TokenRateLimit]]: + group_limits = ( + select(TokenRateLimit, User__UserGroup.user_group_id) + .join( + TokenRateLimit__UserGroup, + TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id, + ) + .join( + UserGroup, + UserGroup.id == TokenRateLimit__UserGroup.user_group_id, + ) + .join( + User__UserGroup, + User__UserGroup.user_group_id == UserGroup.id, + ) + .where( + User__UserGroup.user_id == user_id, + TokenRateLimit.enabled.is_(True), + ) + ) + + raw_rate_limits = db_session.execute(group_limits).all() + + group_rate_limits = defaultdict(list) + for rate_limit, user_group_id in raw_rate_limits: + group_rate_limits[user_group_id].append(rate_limit) + + return group_rate_limits + + +def _fetch_user_group_usage( + user_group_ids: list[int], cutoff_time: datetime, db_session: Session +) -> dict[int, list[Tuple[datetime, int]]]: + """ + Fetch user group usage within the cutoff time, grouped by minute + """ + user_group_usage = db_session.execute( + select( + func.sum(ChatMessage.token_count), + func.date_trunc("minute", ChatMessage.time_sent), + UserGroup.id, + ) + .join(ChatSession, ChatMessage.chat_session_id == ChatSession.id) + .join(User__UserGroup, User__UserGroup.user_id == ChatSession.user_id) + .join(UserGroup, UserGroup.id == User__UserGroup.user_group_id) + .filter(UserGroup.id.in_(user_group_ids), ChatMessage.time_sent >= cutoff_time) + .group_by(func.date_trunc("minute", ChatMessage.time_sent), UserGroup.id) + ).all() + + return { + user_group_id: [(usage, time_sent) for time_sent, usage, _ in group_usage] + for user_group_id, group_usage in groupby( + user_group_usage, key=lambda row: row[2] + ) + } diff --git a/backend/ee/danswer/server/query_history/api.py b/backend/ee/danswer/server/query_history/api.py new file mode 100644 index 00000000000..c01c82794e0 --- /dev/null +++ b/backend/ee/danswer/server/query_history/api.py @@ -0,0 +1,387 @@ +import csv +import io +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from typing import Literal + +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from fastapi.responses import StreamingResponse +from pydantic import BaseModel +from sqlalchemy.orm import Session + +import danswer.db.models as db_models +from danswer.auth.users import current_admin_user +from danswer.auth.users import get_display_email +from danswer.chat.chat_utils import create_chat_chain +from danswer.configs.constants import MessageType +from danswer.configs.constants import QAFeedbackType +from danswer.db.chat import get_chat_session_by_id +from danswer.db.engine import get_session +from danswer.db.models import ChatMessage +from danswer.db.models import ChatSession +from ee.danswer.db.query_history import fetch_chat_sessions_eagerly_by_time + + +router = APIRouter() + + +class AbridgedSearchDoc(BaseModel): + """A subset of the info present in `SearchDoc`""" + + document_id: str + semantic_identifier: str + link: str | None + + +class MessageSnapshot(BaseModel): + message: str + message_type: MessageType + documents: list[AbridgedSearchDoc] + feedback_type: QAFeedbackType | None + feedback_text: str | None + time_created: datetime + + @classmethod + def build(cls, message: ChatMessage) -> "MessageSnapshot": + latest_messages_feedback_obj = ( + message.chat_message_feedbacks[-1] + if len(message.chat_message_feedbacks) > 0 + else None + ) + feedback_type = ( + ( + QAFeedbackType.LIKE + if latest_messages_feedback_obj.is_positive + else QAFeedbackType.DISLIKE + ) + if latest_messages_feedback_obj + else None + ) + feedback_text = ( + latest_messages_feedback_obj.feedback_text + if latest_messages_feedback_obj + else None + ) + return cls( + message=message.message, + message_type=message.message_type, + documents=[ + AbridgedSearchDoc( + document_id=document.document_id, + semantic_identifier=document.semantic_id, + link=document.link, + ) + for document in message.search_docs + ], + feedback_type=feedback_type, + feedback_text=feedback_text, + time_created=message.time_sent, + ) + + +class ChatSessionMinimal(BaseModel): + id: int + user_email: str + name: str | None + first_user_message: str + first_ai_message: str + persona_name: str + time_created: datetime + feedback_type: QAFeedbackType | Literal["mixed"] | None + + +class ChatSessionSnapshot(BaseModel): + id: int + user_email: str + name: str | None + messages: list[MessageSnapshot] + persona_name: str + time_created: datetime + + +class QuestionAnswerPairSnapshot(BaseModel): + user_message: str + ai_response: str + retrieved_documents: list[AbridgedSearchDoc] + feedback_type: QAFeedbackType | None + feedback_text: str | None + persona_name: str + user_email: str + time_created: datetime + + @classmethod + def from_chat_session_snapshot( + cls, + chat_session_snapshot: ChatSessionSnapshot, + ) -> list["QuestionAnswerPairSnapshot"]: + message_pairs: list[tuple[MessageSnapshot, MessageSnapshot]] = [] + for ind in range(1, len(chat_session_snapshot.messages), 2): + message_pairs.append( + ( + chat_session_snapshot.messages[ind - 1], + chat_session_snapshot.messages[ind], + ) + ) + + return [ + cls( + user_message=user_message.message, + ai_response=ai_message.message, + retrieved_documents=ai_message.documents, + feedback_type=ai_message.feedback_type, + feedback_text=ai_message.feedback_text, + persona_name=chat_session_snapshot.persona_name, + user_email=get_display_email(chat_session_snapshot.user_email), + time_created=user_message.time_created, + ) + for user_message, ai_message in message_pairs + ] + + def to_json(self) -> dict[str, str]: + return { + "user_message": self.user_message, + "ai_response": self.ai_response, + "retrieved_documents": "|".join( + [ + doc.link or doc.semantic_identifier + for doc in self.retrieved_documents + ] + ), + "feedback_type": self.feedback_type.value if self.feedback_type else "", + "feedback_text": self.feedback_text or "", + "persona_name": self.persona_name, + "user_email": self.user_email, + "time_created": str(self.time_created), + } + + +def fetch_and_process_chat_session_history_minimal( + db_session: Session, + start: datetime, + end: datetime, + feedback_filter: QAFeedbackType | None = None, + limit: int | None = 500, +) -> list[ChatSessionMinimal]: + chat_sessions = fetch_chat_sessions_eagerly_by_time( + start=start, end=end, db_session=db_session, limit=limit + ) + + minimal_sessions = [] + for chat_session in chat_sessions: + if not chat_session.messages: + continue + + first_user_message = next( + ( + message.message + for message in chat_session.messages + if message.message_type == MessageType.USER + ), + "", + ) + first_ai_message = next( + ( + message.message + for message in chat_session.messages + if message.message_type == MessageType.ASSISTANT + ), + "", + ) + + has_positive_feedback = any( + feedback.is_positive + for message in chat_session.messages + for feedback in message.chat_message_feedbacks + ) + + has_negative_feedback = any( + not feedback.is_positive + for message in chat_session.messages + for feedback in message.chat_message_feedbacks + ) + + feedback_type: QAFeedbackType | Literal["mixed"] | None = ( + "mixed" + if has_positive_feedback and has_negative_feedback + else QAFeedbackType.LIKE + if has_positive_feedback + else QAFeedbackType.DISLIKE + if has_negative_feedback + else None + ) + + if feedback_filter: + if feedback_filter == QAFeedbackType.LIKE and not has_positive_feedback: + continue + if feedback_filter == QAFeedbackType.DISLIKE and not has_negative_feedback: + continue + + minimal_sessions.append( + ChatSessionMinimal( + id=chat_session.id, + user_email=get_display_email( + chat_session.user.email if chat_session.user else None + ), + name=chat_session.description, + first_user_message=first_user_message, + first_ai_message=first_ai_message, + persona_name=chat_session.persona.name, + time_created=chat_session.time_created, + feedback_type=feedback_type, + ) + ) + + return minimal_sessions + + +def fetch_and_process_chat_session_history( + db_session: Session, + start: datetime, + end: datetime, + feedback_type: QAFeedbackType | None, + limit: int | None = 500, +) -> list[ChatSessionSnapshot]: + chat_sessions = fetch_chat_sessions_eagerly_by_time( + start=start, end=end, db_session=db_session, limit=limit + ) + + chat_session_snapshots = [ + snapshot_from_chat_session(chat_session=chat_session, db_session=db_session) + for chat_session in chat_sessions + ] + + valid_snapshots = [ + snapshot for snapshot in chat_session_snapshots if snapshot is not None + ] + + if feedback_type: + valid_snapshots = [ + snapshot + for snapshot in valid_snapshots + if any( + message.feedback_type == feedback_type for message in snapshot.messages + ) + ] + + return valid_snapshots + + +def snapshot_from_chat_session( + chat_session: ChatSession, + db_session: Session, +) -> ChatSessionSnapshot | None: + try: + # Older chats may not have the right structure + last_message, messages = create_chat_chain( + chat_session_id=chat_session.id, db_session=db_session + ) + messages.append(last_message) + except RuntimeError: + return None + + return ChatSessionSnapshot( + id=chat_session.id, + user_email=get_display_email( + chat_session.user.email if chat_session.user else None + ), + name=chat_session.description, + messages=[ + MessageSnapshot.build(message) + for message in messages + if message.message_type != MessageType.SYSTEM + ], + persona_name=chat_session.persona.name, + time_created=chat_session.time_created, + ) + + +@router.get("/admin/chat-session-history") +def get_chat_session_history( + feedback_type: QAFeedbackType | None = None, + start: datetime | None = None, + end: datetime | None = None, + _: db_models.User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[ChatSessionMinimal]: + return fetch_and_process_chat_session_history_minimal( + db_session=db_session, + start=start + or ( + datetime.now(tz=timezone.utc) - timedelta(days=30) + ), # default is 30d lookback + end=end or datetime.now(tz=timezone.utc), + feedback_filter=feedback_type, + ) + + +@router.get("/admin/chat-session-history/{chat_session_id}") +def get_chat_session_admin( + chat_session_id: int, + _: db_models.User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> ChatSessionSnapshot: + try: + chat_session = get_chat_session_by_id( + chat_session_id=chat_session_id, + user_id=None, # view chat regardless of user + db_session=db_session, + include_deleted=True, + ) + except ValueError: + raise HTTPException( + 400, f"Chat session with id '{chat_session_id}' does not exist." + ) + snapshot = snapshot_from_chat_session( + chat_session=chat_session, db_session=db_session + ) + + if snapshot is None: + raise HTTPException( + 400, + f"Could not create snapshot for chat session with id '{chat_session_id}'", + ) + + return snapshot + + +@router.get("/admin/query-history-csv") +def get_query_history_as_csv( + _: db_models.User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> StreamingResponse: + complete_chat_session_history = fetch_and_process_chat_session_history( + db_session=db_session, + start=datetime.fromtimestamp(0, tz=timezone.utc), + end=datetime.now(tz=timezone.utc), + feedback_type=None, + limit=None, + ) + + question_answer_pairs: list[QuestionAnswerPairSnapshot] = [] + for chat_session_snapshot in complete_chat_session_history: + question_answer_pairs.extend( + QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot) + ) + + # Create an in-memory text stream + stream = io.StringIO() + writer = csv.DictWriter( + stream, fieldnames=list(QuestionAnswerPairSnapshot.__fields__.keys()) + ) + writer.writeheader() + for row in question_answer_pairs: + writer.writerow(row.to_json()) + + # Reset the stream's position to the start + stream.seek(0) + + return StreamingResponse( + iter([stream.getvalue()]), + media_type="text/csv", + headers={ + "Content-Disposition": "attachment;filename=danswer_query_history.csv" + }, + ) diff --git a/backend/ee/danswer/server/reporting/usage_export_api.py b/backend/ee/danswer/server/reporting/usage_export_api.py new file mode 100644 index 00000000000..409702c3fbf --- /dev/null +++ b/backend/ee/danswer/server/reporting/usage_export_api.py @@ -0,0 +1,82 @@ +from collections.abc import Generator +from datetime import datetime + +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from fastapi import Response +from fastapi.responses import StreamingResponse +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from danswer.auth.users import current_admin_user +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.file_store.constants import STANDARD_CHUNK_SIZE +from ee.danswer.db.usage_export import get_all_usage_reports +from ee.danswer.db.usage_export import get_usage_report_data +from ee.danswer.db.usage_export import UsageReportMetadata +from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report + +router = APIRouter() + + +class GenerateUsageReportParams(BaseModel): + period_from: str | None = None + period_to: str | None = None + + +@router.post("/admin/generate-usage-report") +def generate_report( + params: GenerateUsageReportParams, + user: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> UsageReportMetadata: + period = None + if params.period_from and params.period_to: + try: + period = ( + datetime.fromisoformat(params.period_from), + datetime.fromisoformat(params.period_to), + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + new_report = create_new_usage_report(db_session, user.id if user else None, period) + return new_report + + +@router.get("/admin/usage-report/{report_name}") +def read_usage_report( + report_name: str, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> Response: + try: + file = get_usage_report_data(db_session, report_name) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + + def iterfile() -> Generator[bytes, None, None]: + while True: + chunk = file.read(STANDARD_CHUNK_SIZE) + if not chunk: + break + yield chunk + + return StreamingResponse( + content=iterfile(), + media_type="application/zip", + headers={"Content-Disposition": f"attachment; filename={report_name}"}, + ) + + +@router.get("/admin/usage-report") +def fetch_usage_reports( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[UsageReportMetadata]: + try: + return get_all_usage_reports(db_session) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) diff --git a/backend/ee/danswer/server/reporting/usage_export_generation.py b/backend/ee/danswer/server/reporting/usage_export_generation.py new file mode 100644 index 00000000000..2274b00dd70 --- /dev/null +++ b/backend/ee/danswer/server/reporting/usage_export_generation.py @@ -0,0 +1,165 @@ +import csv +import tempfile +import uuid +import zipfile +from datetime import datetime +from datetime import timedelta +from datetime import timezone + +from fastapi_users_db_sqlalchemy import UUID_ID +from sqlalchemy.orm import Session + +from danswer.auth.schemas import UserStatus +from danswer.configs.constants import FileOrigin +from danswer.db.users import list_users +from danswer.file_store.constants import MAX_IN_MEMORY_SIZE +from danswer.file_store.file_store import FileStore +from danswer.file_store.file_store import get_default_file_store +from ee.danswer.db.usage_export import get_all_empty_chat_message_entries +from ee.danswer.db.usage_export import write_usage_report +from ee.danswer.server.reporting.usage_export_models import UsageReportMetadata +from ee.danswer.server.reporting.usage_export_models import UserSkeleton + + +def generate_chat_messages_report( + db_session: Session, + file_store: FileStore, + report_id: str, + period: tuple[datetime, datetime] | None, +) -> str: + file_name = f"{report_id}_chat_sessions" + + if period is None: + period = ( + datetime.fromtimestamp(0, tz=timezone.utc), + datetime.now(tz=timezone.utc), + ) + else: + # time-picker sends a time which is at the beginning of the day + # so we need to add one day to the end time to make it inclusive + period = ( + period[0], + period[1] + timedelta(days=1), + ) + + with tempfile.SpooledTemporaryFile( + max_size=MAX_IN_MEMORY_SIZE, mode="w+" + ) as temp_file: + csvwriter = csv.writer(temp_file, delimiter=",") + csvwriter.writerow(["session_id", "user_id", "flow_type", "time_sent"]) + for chat_message_skeleton_batch in get_all_empty_chat_message_entries( + db_session, period + ): + for chat_message_skeleton in chat_message_skeleton_batch: + csvwriter.writerow( + [ + chat_message_skeleton.chat_session_id, + chat_message_skeleton.user_id, + chat_message_skeleton.flow_type, + chat_message_skeleton.time_sent.isoformat(), + ] + ) + + # after writing seek to begining of buffer + temp_file.seek(0) + file_store.save_file( + file_name=file_name, + content=temp_file, + display_name=file_name, + file_origin=FileOrigin.OTHER, + file_type="text/csv", + ) + + return file_name + + +def generate_user_report( + db_session: Session, + file_store: FileStore, + report_id: str, +) -> str: + file_name = f"{report_id}_users" + + with tempfile.SpooledTemporaryFile( + max_size=MAX_IN_MEMORY_SIZE, mode="w+" + ) as temp_file: + csvwriter = csv.writer(temp_file, delimiter=",") + csvwriter.writerow(["user_id", "status"]) + + users = list_users(db_session) + for user in users: + user_skeleton = UserSkeleton( + user_id=str(user.id), + status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED, + ) + csvwriter.writerow([user_skeleton.user_id, user_skeleton.status]) + + temp_file.seek(0) + file_store.save_file( + file_name=file_name, + content=temp_file, + display_name=file_name, + file_origin=FileOrigin.OTHER, + file_type="text/csv", + ) + + return file_name + + +def create_new_usage_report( + db_session: Session, + user_id: UUID_ID | None, # None = auto-generated + period: tuple[datetime, datetime] | None, +) -> UsageReportMetadata: + report_id = str(uuid.uuid4()) + file_store = get_default_file_store(db_session) + + messages_filename = generate_chat_messages_report( + db_session, file_store, report_id, period + ) + users_filename = generate_user_report(db_session, file_store, report_id) + + with tempfile.SpooledTemporaryFile(max_size=MAX_IN_MEMORY_SIZE) as zip_buffer: + with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED) as zip_file: + # write messages + chat_messages_tmpfile = file_store.read_file( + messages_filename, mode="b", use_tempfile=True + ) + zip_file.writestr( + "chat_messages.csv", + chat_messages_tmpfile.read(), + ) + + # write users + users_tmpfile = file_store.read_file( + users_filename, mode="b", use_tempfile=True + ) + zip_file.writestr("users.csv", users_tmpfile.read()) + + zip_buffer.seek(0) + + # store zip blob to file_store + report_name = ( + f"{datetime.now(tz=timezone.utc).strftime('%Y-%m-%d')}" + f"_{report_id}_usage_report.zip" + ) + file_store.save_file( + file_name=report_name, + content=zip_buffer, + display_name=report_name, + file_origin=FileOrigin.GENERATED_REPORT, + file_type="application/zip", + ) + + # add report after zip file is written + new_report = write_usage_report(db_session, report_name, user_id, period) + + return UsageReportMetadata( + report_name=new_report.report_name, + requestor=( + str(new_report.requestor_user_id) if new_report.requestor_user_id else None + ), + time_created=new_report.time_created, + period_from=new_report.period_from, + period_to=new_report.period_to, + ) diff --git a/backend/ee/danswer/server/reporting/usage_export_models.py b/backend/ee/danswer/server/reporting/usage_export_models.py new file mode 100644 index 00000000000..98d9021f816 --- /dev/null +++ b/backend/ee/danswer/server/reporting/usage_export_models.py @@ -0,0 +1,33 @@ +from datetime import datetime +from enum import Enum + +from pydantic import BaseModel + +from danswer.auth.schemas import UserStatus + + +class FlowType(str, Enum): + CHAT = "chat" + SEARCH = "search" + SLACK = "slack" + + +class ChatMessageSkeleton(BaseModel): + message_id: int + chat_session_id: int + user_id: str | None + flow_type: FlowType + time_sent: datetime + + +class UserSkeleton(BaseModel): + user_id: str + status: UserStatus + + +class UsageReportMetadata(BaseModel): + report_name: str + requestor: str | None + time_created: datetime + period_from: datetime | None # None = All time + period_to: datetime | None diff --git a/backend/ee/danswer/server/saml.py b/backend/ee/danswer/server/saml.py new file mode 100644 index 00000000000..5fe57efe0ea --- /dev/null +++ b/backend/ee/danswer/server/saml.py @@ -0,0 +1,177 @@ +import contextlib +import secrets +from typing import Any + +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from fastapi import Request +from fastapi import Response +from fastapi import status +from fastapi_users import exceptions +from fastapi_users.password import PasswordHelper +from onelogin.saml2.auth import OneLogin_Saml2_Auth # type: ignore +from pydantic import BaseModel +from pydantic import EmailStr +from sqlalchemy.orm import Session + +from danswer.auth.schemas import UserCreate +from danswer.auth.schemas import UserRole +from danswer.auth.users import get_user_manager +from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS +from danswer.db.auth import get_user_count +from danswer.db.auth import get_user_db +from danswer.db.engine import get_async_session +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.utils.logger import setup_logger +from ee.danswer.configs.app_configs import SAML_CONF_DIR +from ee.danswer.db.saml import expire_saml_account +from ee.danswer.db.saml import get_saml_account +from ee.danswer.db.saml import upsert_saml_account +from ee.danswer.utils.secrets import encrypt_string +from ee.danswer.utils.secrets import extract_hashed_cookie + + +logger = setup_logger() +router = APIRouter(prefix="/auth/saml") + + +async def upsert_saml_user(email: str) -> User: + get_async_session_context = contextlib.asynccontextmanager( + get_async_session + ) # type:ignore + get_user_db_context = contextlib.asynccontextmanager(get_user_db) + get_user_manager_context = contextlib.asynccontextmanager(get_user_manager) + + async with get_async_session_context() as session: + async with get_user_db_context(session) as user_db: + async with get_user_manager_context(user_db) as user_manager: + try: + return await user_manager.get_by_email(email) + except exceptions.UserNotExists: + logger.info("Creating user from SAML login") + + user_count = await get_user_count() + role = UserRole.ADMIN if user_count == 0 else UserRole.BASIC + + fastapi_users_pw_helper = PasswordHelper() + password = fastapi_users_pw_helper.generate() + hashed_pass = fastapi_users_pw_helper.hash(password) + + user: User = await user_manager.create( + UserCreate( + email=EmailStr(email), + password=hashed_pass, + is_verified=True, + role=role, + ) + ) + + return user + + +async def prepare_from_fastapi_request(request: Request) -> dict[str, Any]: + form_data = await request.form() + if request.client is None: + raise ValueError("Invalid request for SAML") + + rv: dict[str, Any] = { + "http_host": request.client.host, + "server_port": request.url.port, + "script_name": request.url.path, + "post_data": {}, + "get_data": {}, + } + if request.query_params: + rv["get_data"] = (request.query_params,) + if "SAMLResponse" in form_data: + SAMLResponse = form_data["SAMLResponse"] + rv["post_data"]["SAMLResponse"] = SAMLResponse + if "RelayState" in form_data: + RelayState = form_data["RelayState"] + rv["post_data"]["RelayState"] = RelayState + return rv + + +class SAMLAuthorizeResponse(BaseModel): + authorization_url: str + + +@router.get("/authorize") +async def saml_login(request: Request) -> SAMLAuthorizeResponse: + req = await prepare_from_fastapi_request(request) + auth = OneLogin_Saml2_Auth(req, custom_base_path=SAML_CONF_DIR) + callback_url = auth.login() + return SAMLAuthorizeResponse(authorization_url=callback_url) + + +@router.post("/callback") +async def saml_login_callback( + request: Request, + db_session: Session = Depends(get_session), +) -> Response: + req = await prepare_from_fastapi_request(request) + auth = OneLogin_Saml2_Auth(req, custom_base_path=SAML_CONF_DIR) + auth.process_response() + errors = auth.get_errors() + if len(errors) != 0: + logger.error( + "Error when processing SAML Response: %s %s" + % (", ".join(errors), auth.get_last_error_reason()) + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Access denied. Failed to parse SAML Response.", + ) + + if not auth.is_authenticated(): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Access denied. User was not Authenticated.", + ) + + user_email = auth.get_attribute("email") + if not user_email: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="SAML is not set up correctly, email attribute must be provided.", + ) + + user_email = user_email[0] + + user = await upsert_saml_user(email=user_email) + + # Generate a random session cookie and Sha256 encrypt before saving + session_cookie = secrets.token_hex(16) + saved_cookie = encrypt_string(session_cookie) + + upsert_saml_account(user_id=user.id, cookie=saved_cookie, db_session=db_session) + + # Redirect to main Danswer search page + response = Response(status_code=status.HTTP_204_NO_CONTENT) + + response.set_cookie( + key="session", + value=session_cookie, + httponly=True, + secure=True, + max_age=SESSION_EXPIRE_TIME_SECONDS, + ) + + return response + + +@router.post("/logout") +def saml_logout( + request: Request, + db_session: Session = Depends(get_session), +) -> None: + saved_cookie = extract_hashed_cookie(request) + + if saved_cookie: + saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session) + if saml_account: + expire_saml_account(saml_account, db_session) + + return diff --git a/backend/ee/danswer/server/seeding.py b/backend/ee/danswer/server/seeding.py new file mode 100644 index 00000000000..20c57facbe5 --- /dev/null +++ b/backend/ee/danswer/server/seeding.py @@ -0,0 +1,137 @@ +import os + +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from danswer.db.engine import get_session_context_manager +from danswer.db.llm import fetch_existing_llm_providers +from danswer.db.llm import update_default_provider +from danswer.db.llm import upsert_llm_provider +from danswer.db.persona import get_personas +from danswer.db.persona import upsert_persona +from danswer.search.enums import RecencyBiasSetting +from danswer.server.features.persona.models import CreatePersonaRequest +from danswer.server.manage.llm.models import LLMProviderUpsertRequest +from danswer.server.settings.models import Settings +from danswer.server.settings.store import store_settings as store_base_settings +from danswer.utils.logger import setup_logger +from ee.danswer.server.enterprise_settings.models import EnterpriseSettings +from ee.danswer.server.enterprise_settings.store import ( + store_settings as store_ee_settings, +) +from ee.danswer.server.enterprise_settings.store import upload_logo + + +logger = setup_logger() + +_SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION" + + +class SeedConfiguration(BaseModel): + llms: list[LLMProviderUpsertRequest] | None = None + admin_user_emails: list[str] | None = None + seeded_name: str | None = None + seeded_logo_path: str | None = None + personas: list[CreatePersonaRequest] | None = None + settings: Settings | None = None + + +def _parse_env() -> SeedConfiguration | None: + seed_config_str = os.getenv(_SEED_CONFIG_ENV_VAR_NAME) + if not seed_config_str: + return None + seed_config = SeedConfiguration.parse_raw(seed_config_str) + return seed_config + + +def _seed_llms( + db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest] +) -> None: + # don't seed LLMs if we've already done this + existing_llms = fetch_existing_llm_providers(db_session) + if existing_llms: + return + + logger.info("Seeding LLMs") + seeded_providers = [ + upsert_llm_provider(db_session, llm_upsert_request) + for llm_upsert_request in llm_upsert_requests + ] + update_default_provider(db_session, seeded_providers[0].id) + + +def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None: + # don't seed personas if we've already done this + existing_personas = get_personas( + user_id=None, # Admin view + db_session=db_session, + include_default=True, + include_slack_bot_personas=True, + include_deleted=False, + ) + if existing_personas: + return + + logger.info("Seeding Personas") + for persona in personas: + upsert_persona( + user=None, # Seeding is done as admin + name=persona.name, + description=persona.description, + num_chunks=persona.num_chunks if persona.num_chunks is not None else 0.0, + llm_relevance_filter=persona.llm_relevance_filter, + llm_filter_extraction=persona.llm_filter_extraction, + recency_bias=RecencyBiasSetting.AUTO, + prompt_ids=persona.prompt_ids, + document_set_ids=persona.document_set_ids, + llm_model_provider_override=persona.llm_model_provider_override, + llm_model_version_override=persona.llm_model_version_override, + starter_messages=persona.starter_messages, + is_public=persona.is_public, + db_session=db_session, + tool_ids=persona.tool_ids, + ) + + +def _seed_settings(settings: Settings) -> None: + logger.info("Seeding Settings") + try: + settings.check_validity() + store_base_settings(settings) + logger.info("Successfully seeded Settings") + except ValueError as e: + logger.error(f"Failed to seed Settings: {str(e)}") + + +def get_seed_config() -> SeedConfiguration | None: + return _parse_env() + + +def seed_db() -> None: + seed_config = _parse_env() + + if seed_config is None: + logger.info("No seeding configuration file passed") + return + + with get_session_context_manager() as db_session: + if seed_config.llms is not None: + _seed_llms(db_session, seed_config.llms) + if seed_config.personas is not None: + _seed_personas(db_session, seed_config.personas) + if seed_config.settings is not None: + _seed_settings(seed_config.settings) + + is_seeded_logo = ( + upload_logo(db_session=db_session, file=seed_config.seeded_logo_path) + if seed_config.seeded_logo_path + else False + ) + seeded_name = seed_config.seeded_name + + if is_seeded_logo or seeded_name: + logger.info("Seeding enterprise settings") + seeded_settings = EnterpriseSettings( + application_name=seeded_name, use_custom_logo=is_seeded_logo + ) + store_ee_settings(seeded_settings) diff --git a/backend/ee/danswer/server/token_rate_limits/api.py b/backend/ee/danswer/server/token_rate_limits/api.py new file mode 100644 index 00000000000..aac3ebb16c0 --- /dev/null +++ b/backend/ee/danswer/server/token_rate_limits/api.py @@ -0,0 +1,105 @@ +from collections import defaultdict + +from fastapi import APIRouter +from fastapi import Depends +from sqlalchemy.orm import Session + +from danswer.auth.users import current_admin_user +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.server.query_and_chat.token_limit import any_rate_limit_exists +from danswer.server.token_rate_limits.models import TokenRateLimitArgs +from danswer.server.token_rate_limits.models import TokenRateLimitDisplay +from ee.danswer.db.token_limit import fetch_all_user_group_token_rate_limits +from ee.danswer.db.token_limit import fetch_all_user_group_token_rate_limits_by_group +from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits +from ee.danswer.db.token_limit import insert_user_group_token_rate_limit +from ee.danswer.db.token_limit import insert_user_token_rate_limit + +router = APIRouter(prefix="/admin/token-rate-limits") + + +""" +Group Token Limit Settings +""" + + +@router.get("/user-groups") +def get_all_group_token_limit_settings( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> dict[str, list[TokenRateLimitDisplay]]: + user_groups_to_token_rate_limits = fetch_all_user_group_token_rate_limits_by_group( + db_session + ) + + token_rate_limits_by_group = defaultdict(list) + for token_rate_limit, group_name in user_groups_to_token_rate_limits: + token_rate_limits_by_group[group_name].append( + TokenRateLimitDisplay.from_db(token_rate_limit) + ) + + return dict(token_rate_limits_by_group) + + +@router.get("/user-group/{group_id}") +def get_group_token_limit_settings( + group_id: int, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[TokenRateLimitDisplay]: + return [ + TokenRateLimitDisplay.from_db(token_rate_limit) + for token_rate_limit in fetch_all_user_group_token_rate_limits( + db_session, group_id + ) + ] + + +@router.post("/user-group/{group_id}") +def create_group_token_limit_settings( + group_id: int, + token_limit_settings: TokenRateLimitArgs, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> TokenRateLimitDisplay: + rate_limit_display = TokenRateLimitDisplay.from_db( + insert_user_group_token_rate_limit( + db_session=db_session, + token_rate_limit_settings=token_limit_settings, + group_id=group_id, + ) + ) + # clear cache in case this was the first rate limit created + any_rate_limit_exists.cache_clear() + return rate_limit_display + + +""" +User Token Limit Settings +""" + + +@router.get("/users") +def get_user_token_limit_settings( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[TokenRateLimitDisplay]: + return [ + TokenRateLimitDisplay.from_db(token_rate_limit) + for token_rate_limit in fetch_all_user_token_rate_limits(db_session) + ] + + +@router.post("/users") +def create_user_token_limit_settings( + token_limit_settings: TokenRateLimitArgs, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> TokenRateLimitDisplay: + rate_limit_display = TokenRateLimitDisplay.from_db( + insert_user_token_rate_limit(db_session, token_limit_settings) + ) + # clear cache in case this was the first rate limit created + any_rate_limit_exists.cache_clear() + return rate_limit_display diff --git a/backend/ee/danswer/server/user_group/api.py b/backend/ee/danswer/server/user_group/api.py new file mode 100644 index 00000000000..6b5163d5417 --- /dev/null +++ b/backend/ee/danswer/server/user_group/api.py @@ -0,0 +1,71 @@ +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +import danswer.db.models as db_models +from danswer.auth.users import current_admin_user +from danswer.db.engine import get_session +from ee.danswer.db.user_group import fetch_user_groups +from ee.danswer.db.user_group import insert_user_group +from ee.danswer.db.user_group import prepare_user_group_for_deletion +from ee.danswer.db.user_group import update_user_group +from ee.danswer.server.user_group.models import UserGroup +from ee.danswer.server.user_group.models import UserGroupCreate +from ee.danswer.server.user_group.models import UserGroupUpdate + +router = APIRouter(prefix="/manage") + + +@router.get("/admin/user-group") +def list_user_groups( + _: db_models.User = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[UserGroup]: + user_groups = fetch_user_groups(db_session, only_current=False) + return [UserGroup.from_model(user_group) for user_group in user_groups] + + +@router.post("/admin/user-group") +def create_user_group( + user_group: UserGroupCreate, + _: db_models.User = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> UserGroup: + try: + db_user_group = insert_user_group(db_session, user_group) + except IntegrityError: + raise HTTPException( + 400, + f"User group with name '{user_group.name}' already exists. Please " + + "choose a different name.", + ) + return UserGroup.from_model(db_user_group) + + +@router.patch("/admin/user-group/{user_group_id}") +def patch_user_group( + user_group_id: int, + user_group: UserGroupUpdate, + _: db_models.User = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> UserGroup: + try: + return UserGroup.from_model( + update_user_group(db_session, user_group_id, user_group) + ) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@router.delete("/admin/user-group/{user_group_id}") +def delete_user_group( + user_group_id: int, + _: db_models.User = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + try: + prepare_user_group_for_deletion(db_session, user_group_id) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) diff --git a/backend/ee/danswer/server/user_group/models.py b/backend/ee/danswer/server/user_group/models.py new file mode 100644 index 00000000000..fafb73a3e09 --- /dev/null +++ b/backend/ee/danswer/server/user_group/models.py @@ -0,0 +1,78 @@ +from uuid import UUID + +from pydantic import BaseModel + +from danswer.db.models import UserGroup as UserGroupModel +from danswer.server.documents.models import ConnectorCredentialPairDescriptor +from danswer.server.documents.models import ConnectorSnapshot +from danswer.server.documents.models import CredentialSnapshot +from danswer.server.features.document_set.models import DocumentSet +from danswer.server.features.persona.models import PersonaSnapshot +from danswer.server.manage.models import UserInfo +from danswer.server.manage.models import UserPreferences + + +class UserGroup(BaseModel): + id: int + name: str + users: list[UserInfo] + cc_pairs: list[ConnectorCredentialPairDescriptor] + document_sets: list[DocumentSet] + personas: list[PersonaSnapshot] + is_up_to_date: bool + is_up_for_deletion: bool + + @classmethod + def from_model(cls, user_group_model: UserGroupModel) -> "UserGroup": + return cls( + id=user_group_model.id, + name=user_group_model.name, + users=[ + UserInfo( + id=str(user.id), + email=user.email, + is_active=user.is_active, + is_superuser=user.is_superuser, + is_verified=user.is_verified, + role=user.role, + preferences=UserPreferences( + chosen_assistants=user.chosen_assistants + ), + ) + for user in user_group_model.users + ], + cc_pairs=[ + ConnectorCredentialPairDescriptor( + id=cc_pair_relationship.cc_pair.id, + name=cc_pair_relationship.cc_pair.name, + connector=ConnectorSnapshot.from_connector_db_model( + cc_pair_relationship.cc_pair.connector + ), + credential=CredentialSnapshot.from_credential_db_model( + cc_pair_relationship.cc_pair.credential + ), + ) + for cc_pair_relationship in user_group_model.cc_pair_relationships + if cc_pair_relationship.is_current + ], + document_sets=[ + DocumentSet.from_model(ds) for ds in user_group_model.document_sets + ], + personas=[ + PersonaSnapshot.from_model(persona) + for persona in user_group_model.personas + ], + is_up_to_date=user_group_model.is_up_to_date, + is_up_for_deletion=user_group_model.is_up_for_deletion, + ) + + +class UserGroupCreate(BaseModel): + name: str + user_ids: list[UUID] + cc_pair_ids: list[int] + + +class UserGroupUpdate(BaseModel): + user_ids: list[UUID] + cc_pair_ids: list[int] diff --git a/backend/ee/danswer/user_groups/sync.py b/backend/ee/danswer/user_groups/sync.py new file mode 100644 index 00000000000..e33655ba2f9 --- /dev/null +++ b/backend/ee/danswer/user_groups/sync.py @@ -0,0 +1,87 @@ +from sqlalchemy.orm import Session + +from danswer.access.access import get_access_for_documents +from danswer.db.document import prepare_to_modify_documents +from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.db.embedding_model import get_secondary_db_embedding_model +from danswer.document_index.factory import get_default_document_index +from danswer.document_index.interfaces import DocumentIndex +from danswer.document_index.interfaces import UpdateRequest +from danswer.utils.logger import setup_logger +from ee.danswer.db.user_group import delete_user_group +from ee.danswer.db.user_group import fetch_documents_for_user_group_paginated +from ee.danswer.db.user_group import fetch_user_group +from ee.danswer.db.user_group import mark_user_group_as_synced + +logger = setup_logger() + +_SYNC_BATCH_SIZE = 100 + + +def _sync_user_group_batch( + document_ids: list[str], document_index: DocumentIndex, db_session: Session +) -> None: + logger.debug(f"Syncing document sets for: {document_ids}") + + # Acquires a lock on the documents so that no other process can modify them + with prepare_to_modify_documents(db_session=db_session, document_ids=document_ids): + # get current state of document sets for these documents + document_id_to_access = get_access_for_documents( + document_ids=document_ids, db_session=db_session + ) + + # update Vespa + document_index.update( + update_requests=[ + UpdateRequest( + document_ids=[document_id], + access=document_id_to_access[document_id], + ) + for document_id in document_ids + ] + ) + + # Finish the transaction and release the locks + db_session.commit() + + +def sync_user_groups(user_group_id: int, db_session: Session) -> None: + """Sync the status of Postgres for the specified user group""" + db_embedding_model = get_current_db_embedding_model(db_session) + secondary_db_embedding_model = get_secondary_db_embedding_model(db_session) + + document_index = get_default_document_index( + primary_index_name=db_embedding_model.index_name, + secondary_index_name=secondary_db_embedding_model.index_name + if secondary_db_embedding_model + else None, + ) + + user_group = fetch_user_group(db_session=db_session, user_group_id=user_group_id) + if user_group is None: + raise ValueError(f"User group '{user_group_id}' does not exist") + + cursor = None + while True: + # NOTE: this may miss some documents, but that is okay. Any new documents added + # will be added with the correct group membership + document_batch, cursor = fetch_documents_for_user_group_paginated( + db_session=db_session, + user_group_id=user_group_id, + last_document_id=cursor, + limit=_SYNC_BATCH_SIZE, + ) + + _sync_user_group_batch( + document_ids=[document.id for document in document_batch], + document_index=document_index, + db_session=db_session, + ) + + if cursor is None: + break + + if user_group.is_up_for_deletion: + delete_user_group(db_session=db_session, user_group=user_group) + else: + mark_user_group_as_synced(db_session=db_session, user_group=user_group) diff --git a/backend/ee/danswer/utils/__init__.py b/backend/ee/danswer/utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/ee/danswer/utils/encryption.py b/backend/ee/danswer/utils/encryption.py new file mode 100644 index 00000000000..4e2329985ce --- /dev/null +++ b/backend/ee/danswer/utils/encryption.py @@ -0,0 +1,85 @@ +from functools import lru_cache +from os import urandom + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import padding +from cryptography.hazmat.primitives.ciphers import algorithms +from cryptography.hazmat.primitives.ciphers import Cipher +from cryptography.hazmat.primitives.ciphers import modes + +from danswer.configs.app_configs import ENCRYPTION_KEY_SECRET +from danswer.utils.logger import setup_logger +from danswer.utils.variable_functionality import fetch_versioned_implementation + +logger = setup_logger() + + +@lru_cache(maxsize=1) +def _get_trimmed_key(key: str) -> bytes: + encoded_key = key.encode() + key_length = len(encoded_key) + if key_length < 16: + raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short") + elif key_length > 32: + key = key[:32] + elif key_length not in (16, 24, 32): + valid_lengths = [16, 24, 32] + key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))] + + return encoded_key + + +def _encrypt_string(input_str: str) -> bytes: + if not ENCRYPTION_KEY_SECRET: + return input_str.encode() + + key = _get_trimmed_key(ENCRYPTION_KEY_SECRET) + iv = urandom(16) + padder = padding.PKCS7(algorithms.AES.block_size).padder() + padded_data = padder.update(input_str.encode()) + padder.finalize() + + cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) + encryptor = cipher.encryptor() + encrypted_data = encryptor.update(padded_data) + encryptor.finalize() + + return iv + encrypted_data + + +def _decrypt_bytes(input_bytes: bytes) -> str: + if not ENCRYPTION_KEY_SECRET: + return input_bytes.decode() + + key = _get_trimmed_key(ENCRYPTION_KEY_SECRET) + iv = input_bytes[:16] + encrypted_data = input_bytes[16:] + + cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) + decryptor = cipher.decryptor() + decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize() + + unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() + decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize() + + return decrypted_data.decode() + + +def encrypt_string_to_bytes(input_str: str) -> bytes: + versioned_encryption_fn = fetch_versioned_implementation( + "danswer.utils.encryption", "_encrypt_string" + ) + return versioned_encryption_fn(input_str) + + +def decrypt_bytes_to_string(input_bytes: bytes) -> str: + versioned_decryption_fn = fetch_versioned_implementation( + "danswer.utils.encryption", "_decrypt_bytes" + ) + return versioned_decryption_fn(input_bytes) + + +def test_encryption() -> None: + test_string = "Danswer is the BEST!" + encrypted_bytes = encrypt_string_to_bytes(test_string) + decrypted_string = decrypt_bytes_to_string(encrypted_bytes) + if test_string != decrypted_string: + raise RuntimeError("Encryption decryption test failed") diff --git a/backend/ee/danswer/utils/secrets.py b/backend/ee/danswer/utils/secrets.py new file mode 100644 index 00000000000..d59d3b77ac8 --- /dev/null +++ b/backend/ee/danswer/utils/secrets.py @@ -0,0 +1,14 @@ +import hashlib + +from fastapi import Request + +from danswer.configs.constants import SESSION_KEY + + +def encrypt_string(s: str) -> str: + return hashlib.sha256(s.encode()).hexdigest() + + +def extract_hashed_cookie(request: Request) -> str | None: + session_cookie = request.cookies.get(SESSION_KEY) + return encrypt_string(session_cookie) if session_cookie else None diff --git a/backend/pytest.ini b/backend/pytest.ini new file mode 100644 index 00000000000..db3dbf8b00d --- /dev/null +++ b/backend/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +pythonpath = . +markers = + slow: marks tests as slow \ No newline at end of file diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 598b0e6de17..a5fe49da52f 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -74,3 +74,4 @@ zulip==0.8.2 hubspot-api-client==8.1.0 zenpy==2.0.41 dropbox==11.36.2 +boto3-stubs[s3]==1.34.133 \ No newline at end of file diff --git a/backend/requirements/dev.txt b/backend/requirements/dev.txt index 4a9bd21d3e7..881920af7f2 100644 --- a/backend/requirements/dev.txt +++ b/backend/requirements/dev.txt @@ -3,6 +3,7 @@ celery-types==0.19.0 mypy-extensions==1.0.0 mypy==1.8.0 pre-commit==3.2.2 +pytest==7.4.4 reorder-python-imports==3.9.0 ruff==0.0.286 types-PyYAML==6.0.12.11 @@ -19,3 +20,4 @@ types-regex==2023.3.23.1 types-requests==2.28.11.17 types-retry==0.9.9.3 types-urllib3==1.26.25.11 +boto3-stubs[s3]==1.34.133 \ No newline at end of file diff --git a/backend/requirements/ee.txt b/backend/requirements/ee.txt new file mode 100644 index 00000000000..0717e3a67e7 --- /dev/null +++ b/backend/requirements/ee.txt @@ -0,0 +1 @@ +python3-saml==1.15.0 diff --git a/backend/scripts/dev_run_background_jobs.py b/backend/scripts/dev_run_background_jobs.py index c9b91b00caa..adbb5d22090 100644 --- a/backend/scripts/dev_run_background_jobs.py +++ b/backend/scripts/dev_run_background_jobs.py @@ -21,7 +21,7 @@ def run_jobs(exclude_indexing: bool) -> None: cmd_worker = [ "celery", "-A", - "danswer.background.celery", + "ee.danswer.background.celery", "worker", "--pool=threads", "--autoscale=3,10", @@ -29,7 +29,13 @@ def run_jobs(exclude_indexing: bool) -> None: "--concurrency=1", ] - cmd_beat = ["celery", "-A", "danswer.background.celery", "beat", "--loglevel=INFO"] + cmd_beat = [ + "celery", + "-A", + "ee.danswer.background.celery", + "beat", + "--loglevel=INFO", + ] worker_process = subprocess.Popen( cmd_worker, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True @@ -65,6 +71,26 @@ def run_jobs(exclude_indexing: bool) -> None: indexing_thread.start() indexing_thread.join() + try: + update_env = os.environ.copy() + update_env["PYTHONPATH"] = "." + cmd_perm_sync = ["python", "ee.danswer/background/permission_sync.py"] + + indexing_process = subprocess.Popen( + cmd_perm_sync, + env=update_env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + perm_sync_thread = threading.Thread( + target=monitor_process, args=("INDEXING", indexing_process) + ) + perm_sync_thread.start() + perm_sync_thread.join() + except Exception: + pass worker_thread.join() beat_thread.join() diff --git a/backend/scripts/force_delete_connector_by_id.py b/backend/scripts/force_delete_connector_by_id.py new file mode 100644 index 00000000000..a740bb06398 --- /dev/null +++ b/backend/scripts/force_delete_connector_by_id.py @@ -0,0 +1,198 @@ +import argparse +import os +import sys + +from sqlalchemy import delete +from sqlalchemy.orm import Session + +# Modify sys.path +current_dir = os.path.dirname(os.path.abspath(__file__)) +parent_dir = os.path.dirname(current_dir) +sys.path.append(parent_dir) + +# pylint: disable=E402 +# flake8: noqa: E402 + +# Now import Danswer modules +from danswer.db.models import DocumentSet__ConnectorCredentialPair +from danswer.db.connector import fetch_connector_by_id +from danswer.db.document import get_documents_for_connector_credential_pair +from danswer.db.index_attempt import ( + delete_index_attempts, + cancel_indexing_attempts_for_connector, +) +from danswer.db.models import ConnectorCredentialPair +from danswer.document_index.interfaces import DocumentIndex +from danswer.utils.logger import setup_logger +from danswer.configs.constants import DocumentSource +from danswer.db.connector_credential_pair import ( + get_connector_credential_pair_from_id, + get_connector_credential_pair, +) +from danswer.db.engine import get_session_context_manager +from danswer.document_index.factory import get_default_document_index +from danswer.file_store.file_store import get_default_file_store +from danswer.document_index.document_index_utils import get_both_index_names +from danswer.db.document import delete_documents_complete__no_commit + +# pylint: enable=E402 +# flake8: noqa: E402 + + +logger = setup_logger() + +_DELETION_BATCH_SIZE = 1000 + + +def unsafe_deletion( + db_session: Session, + document_index: DocumentIndex, + cc_pair: ConnectorCredentialPair, + pair_id: int, +) -> int: + connector_id = cc_pair.connector_id + credential_id = cc_pair.credential_id + + num_docs_deleted = 0 + + # Gather and delete documents + while True: + documents = get_documents_for_connector_credential_pair( + db_session=db_session, + connector_id=connector_id, + credential_id=credential_id, + limit=_DELETION_BATCH_SIZE, + ) + if not documents: + break + + document_ids = [document.id for document in documents] + document_index.delete(doc_ids=document_ids) + delete_documents_complete__no_commit( + db_session=db_session, + document_ids=document_ids, + ) + + num_docs_deleted += len(documents) + + # Delete index attempts + delete_index_attempts( + db_session=db_session, + connector_id=connector_id, + credential_id=credential_id, + ) + + # Delete document sets + connector / credential Pairs + stmt = delete(DocumentSet__ConnectorCredentialPair).where( + DocumentSet__ConnectorCredentialPair.connector_credential_pair_id == pair_id + ) + db_session.execute(stmt) + stmt = delete(ConnectorCredentialPair).where( + ConnectorCredentialPair.connector_id == connector_id, + ConnectorCredentialPair.credential_id == credential_id, + ) + db_session.execute(stmt) + + # Delete Connector + connector = fetch_connector_by_id( + db_session=db_session, + connector_id=connector_id, + ) + if not connector or not len(connector.credentials): + logger.debug("Found no credentials left for connector, deleting connector") + db_session.delete(connector) + db_session.commit() + + logger.info( + "Successfully deleted connector_credential_pair with connector_id:" + f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs." + ) + return num_docs_deleted + + +def _delete_connector(cc_pair_id: int, db_session: Session) -> None: + user_input = input( + "DO NOT USE THIS UNLESS YOU KNOW WHAT YOU ARE DOING. \ + IT MAY CAUSE ISSUES with your Danswer instance! \ + Are you SURE you want to continue? (enter 'Y' to continue): " + ) + if user_input != "Y": + logger.info(f"You entered {user_input}. Exiting!") + return + + logger.info("Getting connector credential pair") + cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + + if not cc_pair: + logger.error(f"Connector credential pair with ID {cc_pair_id} not found") + return + + if not cc_pair.connector.disabled: + logger.error( + f"Connector {cc_pair.connector.name} is not disabled, cannot continue. \ + Please navigate to the connector and disbale before attempting again" + ) + return + + connector_id = cc_pair.connector_id + credential_id = cc_pair.credential_id + + if cc_pair is None: + logger.error( + f"Connector with ID '{connector_id}' and credential ID " + f"'{credential_id}' does not exist. Has it already been deleted?", + ) + return + + logger.info("Cancelling indexing attempt for the connector") + cancel_indexing_attempts_for_connector( + connector_id=connector_id, db_session=db_session, include_secondary_index=True + ) + + validated_cc_pair = get_connector_credential_pair( + db_session=db_session, + connector_id=connector_id, + credential_id=credential_id, + ) + + if not validated_cc_pair: + logger.error( + f"Cannot run deletion attempt - connector_credential_pair with Connector ID: " + f"{connector_id} and Credential ID: {credential_id} does not exist." + ) + + try: + logger.info("Deleting information from Vespa and Postgres") + curr_ind_name, sec_ind_name = get_both_index_names(db_session) + document_index = get_default_document_index( + primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name + ) + + files_deleted_count = unsafe_deletion( + db_session=db_session, + document_index=document_index, + cc_pair=cc_pair, + pair_id=cc_pair_id, + ) + logger.info(f"Deleted {files_deleted_count} files!") + + except Exception as e: + logger.error(f"Failed to delete connector due to {e}") + + if cc_pair.connector.source == DocumentSource.FILE: + connector = cc_pair.connector + logger.info("Deleting stored files!") + file_store = get_default_file_store(db_session) + for file_name in connector.connector_specific_config["file_locations"]: + logger.info(f"Deleting file {file_name}") + file_store.delete_file(file_name) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Delete a connector by its ID") + parser.add_argument( + "connector_id", type=int, help="The ID of the connector to delete" + ) + args = parser.parse_args() + with get_session_context_manager() as db_session: + _delete_connector(args.connector_id, db_session) diff --git a/backend/scripts/save_load_state.py b/backend/scripts/save_load_state.py index 73586c5fb1b..19d6fa66e25 100644 --- a/backend/scripts/save_load_state.py +++ b/backend/scripts/save_load_state.py @@ -86,7 +86,9 @@ def load_vespa(filename: str) -> None: new_doc = json.loads(line.strip()) doc_id = new_doc["update"].split("::")[-1] response = requests.post( - DOCUMENT_ID_ENDPOINT + "/" + doc_id, headers=headers, json=new_doc + DOCUMENT_ID_ENDPOINT + "/" + doc_id, + headers=headers, + json=new_doc, ) response.raise_for_status() diff --git a/backend/supervisord.conf b/backend/supervisord.conf index f9e3a5e331f..7f6376939a5 100644 --- a/backend/supervisord.conf +++ b/backend/supervisord.conf @@ -25,7 +25,7 @@ autorestart=true # relatively compute-light (e.g. they tend to just make a bunch of requests to # Vespa / Postgres) [program:celery_worker] -command=celery -A danswer.background.celery worker --pool=threads --autoscale=3,10 --loglevel=DEBUG --logfile=/var/log/celery_worker.log +command=celery -A danswer.background.celery.celery_run:celery_app worker --pool=threads --autoscale=3,10 --loglevel=INFO --logfile=/var/log/celery_worker.log stdout_logfile=/var/log/celery_worker_supervisor.log stdout_logfile_maxbytes=52428800 redirect_stderr=true @@ -33,7 +33,7 @@ autorestart=true # Job scheduler for periodic tasks [program:celery_beat] -command=celery -A danswer.background.celery beat --loglevel=DEBUG --logfile=/var/log/celery_beat.log +command=celery -A danswer.background.celery.celery_run:celery_app beat --loglevel=INFO --logfile=/var/log/celery_beat.log stdout_logfile=/var/log/celery_beat_supervisor.log stdout_logfile_maxbytes=52428800 redirect_stderr=true diff --git a/backend/tests/regression/answer_quality/README.md b/backend/tests/regression/answer_quality/README.md new file mode 100644 index 00000000000..4610a9abc2e --- /dev/null +++ b/backend/tests/regression/answer_quality/README.md @@ -0,0 +1,74 @@ +# Search Quality Test Script + +This Python script automates the process of running search quality tests for a backend system. + +## Features + +- Loads configuration from a YAML file +- Sets up Docker environment +- Manages environment variables +- Switches to specified Git branch +- Uploads test documents +- Runs search quality tests using Relari +- Cleans up Docker containers (optional) + +## Usage + +1. Ensure you have the required dependencies installed. +2. Configure the `search_test_config.yaml` file based on the `search_test_config.yaml.template` file. +3. Configure the `.env_eval` file with the correct environment variables. +4. Navigate to the answer_quality folder: +``` +cd danswer/backend/tests/regression/answer_quality +``` +4. Run the script: +``` +python search_quality_test.py +``` + +## Configuration + +Edit `search_test_config.yaml` to set: + +- output_folder + This is the folder where the folders for each test will go + These folders will contain the postgres/vespa data as well as the results for each test +- zipped_documents_file + The path to the zip file containing the files you'd like to test against +- questions_file + The path to the yaml containing the questions you'd like to test with +- branch + Set the branch to null if you want it to just use the code as is +- clean_up_docker_containers + Set this to true to automatically delete all docker containers, networks and volumes after the test +- launch_web_ui + Set this to true if you want to use the UI during/after the testing process +- use_cloud_gpu + Set to true or false depending on if you want to use the remote gpu + Only need to set this if use_cloud_gpu is true +- model_server_ip + This is the ip of the remote model server + Only need to set this if use_cloud_gpu is true +- model_server_port + This is the port of the remote model server + Only need to set this if use_cloud_gpu is true +- existing_test_suffix + Use this if you would like to relaunch a previous test instance + Input the suffix of the test you'd like to re-launch + (E.g. to use the data from folder "test_1234_5678" put "_1234_5678") + No new files will automatically be uploaded + Leave empty to run a new test +- limit + Max number of questions you'd like to ask against the dataset + Set to null for no limit +- llm + Fill this out according to the normal LLM seeding + + +To restart the evaluation using a particular index, set the suffix and turn off clean_up_docker_containers. +This also will skip running the evaluation questions, in this case, the relari.py script can be run manually. + + +Docker daemon must be running for this to work. + +Each script is able to be individually run to upload additional docs or run additional tests \ No newline at end of file diff --git a/backend/tests/regression/answer_quality/api_utils.py b/backend/tests/regression/answer_quality/api_utils.py new file mode 100644 index 00000000000..20292f35b63 --- /dev/null +++ b/backend/tests/regression/answer_quality/api_utils.py @@ -0,0 +1,220 @@ +import requests +from retry import retry + +from danswer.configs.constants import DocumentSource +from danswer.configs.constants import MessageType +from danswer.connectors.models import InputType +from danswer.db.enums import IndexingStatus +from danswer.one_shot_answer.models import DirectQARequest +from danswer.one_shot_answer.models import ThreadMessage +from danswer.search.models import IndexFilters +from danswer.search.models import OptionalSearchSetting +from danswer.search.models import RetrievalDetails +from danswer.server.documents.models import ConnectorBase +from tests.regression.answer_quality.cli_utils import ( + get_api_server_host_port, +) + + +def _api_url_builder(run_suffix: str, api_path: str) -> str: + return f"http://localhost:{get_api_server_host_port(run_suffix)}" + api_path + + +@retry(tries=5, delay=2, backoff=2) +def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]: + filters = IndexFilters( + source_type=None, + document_set=None, + time_cutoff=None, + tags=None, + access_control_list=None, + ) + + messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)] + + new_message_request = DirectQARequest( + messages=messages, + prompt_id=0, + persona_id=0, + retrieval_options=RetrievalDetails( + run_search=OptionalSearchSetting.ALWAYS, + real_time=True, + filters=filters, + enable_auto_detect_filters=False, + ), + chain_of_thought=False, + return_contexts=True, + ) + + url = _api_url_builder(run_suffix, "/query/answer-with-quote/") + headers = { + "Content-Type": "application/json", + } + + body = new_message_request.dict() + body["user"] = None + try: + response_json = requests.post(url, headers=headers, json=body).json() + content_list = [ + context.get("content", "") + for context in response_json.get("contexts", {}).get("contexts", []) + ] + answer = response_json.get("answer") + except Exception as e: + print("Failed to answer the questions, trying again") + print(f"error: {str(e)}") + raise e + + print("\nquery: ", query) + print("answer: ", answer) + print("content_list: ", content_list) + + return content_list, answer + + +def check_if_query_ready(run_suffix: str) -> bool: + url = _api_url_builder(run_suffix, "/manage/admin/connector/indexing-status/") + headers = { + "Content-Type": "application/json", + } + + indexing_status_dict = requests.get(url, headers=headers).json() + + ongoing_index_attempts = False + doc_count = 0 + for index_attempt in indexing_status_dict: + status = index_attempt["last_status"] + if status == IndexingStatus.IN_PROGRESS or status == IndexingStatus.NOT_STARTED: + ongoing_index_attempts = True + doc_count += index_attempt["docs_indexed"] + + if not doc_count: + print("No docs indexed, waiting for indexing to start") + elif ongoing_index_attempts: + print( + f"{doc_count} docs indexed but waiting for ongoing indexing jobs to finish..." + ) + + return doc_count > 0 and not ongoing_index_attempts + + +def run_cc_once(run_suffix: str, connector_id: int, credential_id: int) -> None: + url = _api_url_builder(run_suffix, "/manage/admin/connector/run-once/") + headers = { + "Content-Type": "application/json", + } + + body = { + "connector_id": connector_id, + "credential_ids": [credential_id], + "from_beginning": True, + } + print("body:", body) + response = requests.post(url, headers=headers, json=body) + if response.status_code == 200: + print("Connector created successfully:", response.json()) + else: + print("Failed status_code:", response.status_code) + print("Failed text:", response.text) + + +def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> None: + url = _api_url_builder( + run_suffix, f"/manage/connector/{connector_id}/credential/{credential_id}" + ) + headers = { + "Content-Type": "application/json", + } + + body = {"name": "zip_folder_contents", "is_public": True} + print("body:", body) + response = requests.put(url, headers=headers, json=body) + if response.status_code == 200: + print("Connector created successfully:", response.json()) + else: + print("Failed status_code:", response.status_code) + print("Failed text:", response.text) + + +def _get_existing_connector_names(run_suffix: str) -> list[str]: + url = _api_url_builder(run_suffix, "/manage/connector") + headers = { + "Content-Type": "application/json", + } + body = { + "credential_json": {}, + "admin_public": True, + } + response = requests.get(url, headers=headers, json=body) + if response.status_code == 200: + connectors = response.json() + return [connector["name"] for connector in connectors] + else: + raise RuntimeError(response.__dict__) + + +def create_connector(run_suffix: str, file_paths: list[str]) -> int: + url = _api_url_builder(run_suffix, "/manage/admin/connector") + headers = { + "Content-Type": "application/json", + } + connector_name = base_connector_name = "search_eval_connector" + existing_connector_names = _get_existing_connector_names(run_suffix) + + count = 1 + while connector_name in existing_connector_names: + connector_name = base_connector_name + "_" + str(count) + count += 1 + + connector = ConnectorBase( + name=connector_name, + source=DocumentSource.FILE, + input_type=InputType.LOAD_STATE, + connector_specific_config={"file_locations": file_paths}, + refresh_freq=None, + prune_freq=None, + disabled=False, + ) + + body = connector.dict() + print("body:", body) + response = requests.post(url, headers=headers, json=body) + if response.status_code == 200: + print("Connector created successfully:", response.json()) + return response.json()["id"] + else: + raise RuntimeError(response.__dict__) + + +def create_credential(run_suffix: str) -> int: + url = _api_url_builder(run_suffix, "/manage/credential") + headers = { + "Content-Type": "application/json", + } + body = { + "credential_json": {}, + "admin_public": True, + } + response = requests.post(url, headers=headers, json=body) + if response.status_code == 200: + print("credential created successfully:", response.json()) + return response.json()["id"] + else: + raise RuntimeError(response.__dict__) + + +@retry(tries=10, delay=2, backoff=2) +def upload_file(run_suffix: str, zip_file_path: str) -> list[str]: + files = [ + ("files", open(zip_file_path, "rb")), + ] + + api_path = _api_url_builder(run_suffix, "/manage/admin/connector/file/upload") + try: + response = requests.post(api_path, files=files) + response.raise_for_status() # Raises an HTTPError for bad responses + print("file uploaded successfully:", response.json()) + return response.json()["file_paths"] + except Exception as e: + print("File upload failed, waiting for API server to come up and trying again") + raise e diff --git a/backend/tests/regression/answer_quality/cli_utils.py b/backend/tests/regression/answer_quality/cli_utils.py new file mode 100644 index 00000000000..a39309efd58 --- /dev/null +++ b/backend/tests/regression/answer_quality/cli_utils.py @@ -0,0 +1,236 @@ +import json +import os +import subprocess +import sys +from threading import Thread +from typing import IO + +from retry import retry + + +def _run_command(command: str, stream_output: bool = False) -> tuple[str, str]: + process = subprocess.Popen( + command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + + stdout_lines: list[str] = [] + stderr_lines: list[str] = [] + + def process_stream(stream: IO[str], lines: list[str]) -> None: + for line in stream: + lines.append(line) + if stream_output: + print( + line, + end="", + file=sys.stdout if stream == process.stdout else sys.stderr, + ) + + stdout_thread = Thread(target=process_stream, args=(process.stdout, stdout_lines)) + stderr_thread = Thread(target=process_stream, args=(process.stderr, stderr_lines)) + + stdout_thread.start() + stderr_thread.start() + + stdout_thread.join() + stderr_thread.join() + + process.wait() + + if process.returncode != 0: + raise RuntimeError(f"Command failed with error: {''.join(stderr_lines)}") + + return "".join(stdout_lines), "".join(stderr_lines) + + +def get_current_commit_sha() -> str: + print("Getting current commit SHA...") + stdout, _ = _run_command("git rev-parse HEAD") + sha = stdout.strip() + print(f"Current commit SHA: {sha}") + return sha + + +def switch_to_branch(branch: str) -> None: + print(f"Switching to branch: {branch}...") + _run_command(f"git checkout {branch}") + _run_command("git pull") + print(f"Successfully switched to branch: {branch}") + print("Repository updated successfully.") + + +def manage_data_directories(suffix: str, base_path: str, use_cloud_gpu: bool) -> str: + # Use the user's home directory as the base path + target_path = os.path.join(os.path.expanduser(base_path), f"test{suffix}") + directories = { + "DANSWER_POSTGRES_DATA_DIR": os.path.join(target_path, "postgres/"), + "DANSWER_VESPA_DATA_DIR": os.path.join(target_path, "vespa/"), + } + if not use_cloud_gpu: + directories["DANSWER_INDEX_MODEL_CACHE_DIR"] = os.path.join( + target_path, "index_model_cache/" + ) + directories["DANSWER_INFERENCE_MODEL_CACHE_DIR"] = os.path.join( + target_path, "inference_model_cache/" + ) + + # Create directories if they don't exist + for env_var, directory in directories.items(): + os.makedirs(directory, exist_ok=True) + os.environ[env_var] = directory + print(f"Set {env_var} to: {directory}") + relari_output_path = os.path.join(target_path, "relari_output/") + os.makedirs(relari_output_path, exist_ok=True) + return relari_output_path + + +def set_env_variables( + remote_server_ip: str, + remote_server_port: str, + use_cloud_gpu: bool, + llm_config: dict, +) -> None: + env_vars: dict = {} + env_vars["ENV_SEED_CONFIGURATION"] = json.dumps({"llms": [llm_config]}) + env_vars["ENABLE_PAID_ENTERPRISE_EDITION_FEATURES"] = "true" + if use_cloud_gpu: + env_vars["MODEL_SERVER_HOST"] = remote_server_ip + env_vars["MODEL_SERVER_PORT"] = remote_server_port + env_vars["INDEXING_MODEL_SERVER_HOST"] = remote_server_ip + + for env_var_name, env_var in env_vars.items(): + os.environ[env_var_name] = env_var + print(f"Set {env_var_name} to: {env_var}") + + +def start_docker_compose( + run_suffix: str, launch_web_ui: bool, use_cloud_gpu: bool +) -> None: + print("Starting Docker Compose...") + os.chdir("../deployment/docker_compose") + command = f"docker compose -f docker-compose.search-testing.yml -p danswer-stack{run_suffix} up -d" + command += " --build" + command += " --force-recreate" + if not launch_web_ui: + command += " --scale web_server=0" + command += " --scale nginx=0" + if use_cloud_gpu: + command += " --scale indexing_model_server=0" + command += " --scale inference_model_server=0" + + print("Docker Command:\n", command) + + _run_command(command, stream_output=True) + print("Containers have been launched") + + +def cleanup_docker(run_suffix: str) -> None: + print( + f"Deleting Docker containers, volumes, and networks for project suffix: {run_suffix}" + ) + + stdout, _ = _run_command("docker ps -a --format '{{json .}}'") + + containers = [json.loads(line) for line in stdout.splitlines()] + + project_name = f"danswer-stack{run_suffix}" + containers_to_delete = [ + c for c in containers if c["Names"].startswith(project_name) + ] + + if not containers_to_delete: + print(f"No containers found for project: {project_name}") + else: + container_ids = " ".join([c["ID"] for c in containers_to_delete]) + _run_command(f"docker rm -f {container_ids}") + + print( + f"Successfully deleted {len(containers_to_delete)} containers for project: {project_name}" + ) + + stdout, _ = _run_command("docker volume ls --format '{{.Name}}'") + + volumes = stdout.splitlines() + + volumes_to_delete = [v for v in volumes if v.startswith(project_name)] + + if not volumes_to_delete: + print(f"No volumes found for project: {project_name}") + return + + # Delete filtered volumes + volume_names = " ".join(volumes_to_delete) + _run_command(f"docker volume rm {volume_names}") + + print( + f"Successfully deleted {len(volumes_to_delete)} volumes for project: {project_name}" + ) + stdout, _ = _run_command("docker network ls --format '{{.Name}}'") + + networks = stdout.splitlines() + + networks_to_delete = [n for n in networks if run_suffix in n] + + if not networks_to_delete: + print(f"No networks found containing suffix: {run_suffix}") + else: + network_names = " ".join(networks_to_delete) + _run_command(f"docker network rm {network_names}") + + print( + f"Successfully deleted {len(networks_to_delete)} networks containing suffix: {run_suffix}" + ) + + +@retry(tries=5, delay=5, backoff=2) +def get_api_server_host_port(suffix: str) -> str: + """ + This pulls all containers with the provided suffix + It then grabs the JSON specific container with a name containing "api_server" + It then grabs the port info from the JSON and strips out the relevent data + """ + container_name = "api_server" + + stdout, _ = _run_command("docker ps -a --format '{{json .}}'") + containers = [json.loads(line) for line in stdout.splitlines()] + server_jsons = [] + + for container in containers: + if container_name in container["Names"] and suffix in container["Names"]: + server_jsons.append(container) + + if not server_jsons: + raise RuntimeError( + f"No container found containing: {container_name} and {suffix}" + ) + elif len(server_jsons) > 1: + raise RuntimeError( + f"Too many containers matching {container_name} found, please indicate a suffix" + ) + server_json = server_jsons[0] + + # This is in case the api_server has multiple ports + client_port = "8080" + ports = server_json.get("Ports", "") + port_infos = ports.split(",") if ports else [] + port_dict = {} + for port_info in port_infos: + port_arr = port_info.split(":")[-1].split("->") if port_info else [] + if len(port_arr) == 2: + port_dict[port_arr[1]] = port_arr[0] + + # Find the host port where client_port is in the key + matching_ports = [value for key, value in port_dict.items() if client_port in key] + + if len(matching_ports) > 1: + raise RuntimeError(f"Too many ports matching {client_port} found") + if not matching_ports: + raise RuntimeError( + f"No port found containing: {client_port} for container: {container_name} and suffix: {suffix}" + ) + return matching_ports[0] diff --git a/backend/tests/regression/answer_quality/eval_direct_qa.py b/backend/tests/regression/answer_quality/eval_direct_qa.py index d32f2754725..1a8ca7c4526 100644 --- a/backend/tests/regression/answer_quality/eval_direct_qa.py +++ b/backend/tests/regression/answer_quality/eval_direct_qa.py @@ -82,6 +82,7 @@ def get_answer_for_question( source_type=None, document_set=None, time_cutoff=None, + tags=None, access_control_list=None, ) diff --git a/backend/tests/regression/answer_quality/file_uploader.py b/backend/tests/regression/answer_quality/file_uploader.py new file mode 100644 index 00000000000..a80cadb2cac --- /dev/null +++ b/backend/tests/regression/answer_quality/file_uploader.py @@ -0,0 +1,31 @@ +import os +from types import SimpleNamespace + +import yaml + +from tests.regression.answer_quality.api_utils import create_cc_pair +from tests.regression.answer_quality.api_utils import create_connector +from tests.regression.answer_quality.api_utils import create_credential +from tests.regression.answer_quality.api_utils import run_cc_once +from tests.regression.answer_quality.api_utils import upload_file + + +def upload_test_files(zip_file_path: str, run_suffix: str) -> None: + print("zip:", zip_file_path) + file_paths = upload_file(run_suffix, zip_file_path) + + conn_id = create_connector(run_suffix, file_paths) + cred_id = create_credential(run_suffix) + + create_cc_pair(run_suffix, conn_id, cred_id) + run_cc_once(run_suffix, conn_id, cred_id) + + +if __name__ == "__main__": + current_dir = os.path.dirname(os.path.abspath(__file__)) + config_path = os.path.join(current_dir, "search_test_config.yaml") + with open(config_path, "r") as file: + config = SimpleNamespace(**yaml.safe_load(file)) + file_location = config.zipped_documents_file + run_suffix = config.existing_test_suffix + upload_test_files(file_location, run_suffix) diff --git a/backend/tests/regression/answer_quality/relari.py b/backend/tests/regression/answer_quality/relari.py index a63226a3d97..c669d55c253 100644 --- a/backend/tests/regression/answer_quality/relari.py +++ b/backend/tests/regression/answer_quality/relari.py @@ -1,118 +1,119 @@ -import argparse import json +import os +import time +from types import SimpleNamespace -from sqlalchemy.orm import Session - -from danswer.configs.constants import MessageType -from danswer.db.engine import get_sqlalchemy_engine -from danswer.one_shot_answer.answer_question import get_search_answer -from danswer.one_shot_answer.models import DirectQARequest -from danswer.one_shot_answer.models import OneShotQAResponse -from danswer.one_shot_answer.models import ThreadMessage -from danswer.search.models import IndexFilters -from danswer.search.models import OptionalSearchSetting -from danswer.search.models import RetrievalDetails - - -def get_answer_for_question(query: str, db_session: Session) -> OneShotQAResponse: - filters = IndexFilters( - source_type=None, - document_set=None, - time_cutoff=None, - access_control_list=None, - ) +import yaml - messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)] - - new_message_request = DirectQARequest( - messages=messages, - prompt_id=0, - persona_id=0, - retrieval_options=RetrievalDetails( - run_search=OptionalSearchSetting.ALWAYS, - real_time=True, - filters=filters, - enable_auto_detect_filters=False, - ), - chain_of_thought=False, - return_contexts=True, - ) +from tests.regression.answer_quality.api_utils import check_if_query_ready +from tests.regression.answer_quality.api_utils import get_answer_from_query +from tests.regression.answer_quality.cli_utils import get_current_commit_sha - answer = get_search_answer( - query_req=new_message_request, - user=None, - max_document_tokens=None, - max_history_tokens=None, - db_session=db_session, - answer_generation_timeout=100, - enable_reflexion=False, - bypass_acl=True, - ) - return answer +def _get_and_write_relari_outputs( + samples: list[dict], run_suffix: str, output_file_path: str +) -> None: + while not check_if_query_ready(run_suffix): + time.sleep(5) + with open(output_file_path, "w", encoding="utf-8") as file: + for sample in samples: + retrieved_context, answer = get_answer_from_query( + query=sample["question"], + run_suffix=run_suffix, + ) -def read_questions(questions_file_path: str) -> list[dict]: - samples = [] - with open(questions_file_path, "r", encoding="utf-8") as file: - for line in file: - sample = json.loads(line.strip()) - samples.append(sample) - return samples + if not answer: + print("NO ANSWER GIVEN FOR QUESTION:", sample["question"]) + continue + output = { + "label": sample["uid"], + "question": sample["question"], + "answer": answer, + "retrieved_context": retrieved_context, + } -def main(questions_file: str, output_file: str, limit: int | None = None) -> None: - samples = read_questions(questions_file) + file.write(json.dumps(output) + "\n") + file.flush() - if limit is not None: - samples = samples[:limit] - response_dicts = [] - with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session: - for sample in samples: - answer = get_answer_for_question( - query=sample["question"], db_session=db_session - ) - assert answer.contexts +def _write_metadata_file(run_suffix: str, metadata_file_path: str) -> None: + metadata = {"commit_sha": get_current_commit_sha(), "run_suffix": run_suffix} - response_dict = { - "question": sample["question"], - "retrieved_contexts": [ - context.content for context in answer.contexts.contexts - ], - "ground_truth_contexts": sample["ground_truth_contexts"], - "answer": answer.answer, - "ground_truths": sample["ground_truths"], - } + print("saving metadata to:", metadata_file_path) + with open(metadata_file_path, "w", encoding="utf-8") as yaml_file: + yaml.dump(metadata, yaml_file) - response_dicts.append(response_dict) - with open(output_file, "w", encoding="utf-8") as out_file: - for response_dict in response_dicts: - json_line = json.dumps(response_dict) - out_file.write(json_line + "\n") +def _read_questions_jsonl(questions_file_path: str) -> list[dict]: + questions = [] + with open(questions_file_path, "r") as file: + for line in file: + json_obj = json.loads(line) + questions.append(json_obj) + return questions -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--questions_file", - type=str, - help="Path to the Relari questions file.", - default="./tests/regression/answer_quality/combined_golden_dataset.jsonl", - ) - parser.add_argument( - "--output_file", - type=str, - help="Path to the output results file.", - default="./tests/regression/answer_quality/relari_results.txt", +def answer_relari_questions( + questions_file_path: str, + results_folder_path: str, + run_suffix: str, + limit: int | None = None, +) -> None: + results_file = "run_results.jsonl" + metadata_file = "run_metadata.yaml" + samples = _read_questions_jsonl(questions_file_path) + + if limit is not None: + samples = samples[:limit] + + counter = 1 + output_file_path = os.path.join(results_folder_path, results_file) + metadata_file_path = os.path.join(results_folder_path, metadata_file) + while os.path.exists(output_file_path): + output_file_path = os.path.join( + results_folder_path, + results_file.replace("run_results", f"run_results_{counter}"), + ) + metadata_file_path = os.path.join( + results_folder_path, + metadata_file.replace("run_metadata", f"run_metadata_{counter}"), + ) + counter += 1 + + print("saving question results to:", output_file_path) + _write_metadata_file(run_suffix, metadata_file_path) + _get_and_write_relari_outputs( + samples=samples, run_suffix=run_suffix, output_file_path=output_file_path ) - parser.add_argument( - "--limit", - type=int, - default=None, - help="Limit the number of examples to process.", + + +def main() -> None: + current_dir = os.path.dirname(os.path.abspath(__file__)) + config_path = os.path.join(current_dir, "search_test_config.yaml") + with open(config_path, "r") as file: + config = SimpleNamespace(**yaml.safe_load(file)) + + current_output_folder = os.path.expanduser(config.output_folder) + if config.existing_test_suffix: + current_output_folder = os.path.join( + current_output_folder, "test" + config.existing_test_suffix, "relari_output" + ) + else: + current_output_folder = os.path.join(current_output_folder, "no_defined_suffix") + + answer_relari_questions( + config.questions_file, + current_output_folder, + config.existing_test_suffix, + config.limit, ) - args = parser.parse_args() - main(args.questions_file, args.output_file, args.limit) + +if __name__ == "__main__": + """ + To run a different set of questions, update the questions_file in search_test_config.yaml + If there is more than one instance of Danswer running, specify the suffix in search_test_config.yaml + """ + main() diff --git a/backend/tests/regression/answer_quality/search_quality_test.py b/backend/tests/regression/answer_quality/search_quality_test.py new file mode 100644 index 00000000000..06da3995354 --- /dev/null +++ b/backend/tests/regression/answer_quality/search_quality_test.py @@ -0,0 +1,58 @@ +import os +from datetime import datetime +from types import SimpleNamespace + +import yaml + +from tests.regression.answer_quality.cli_utils import cleanup_docker +from tests.regression.answer_quality.cli_utils import manage_data_directories +from tests.regression.answer_quality.cli_utils import set_env_variables +from tests.regression.answer_quality.cli_utils import start_docker_compose +from tests.regression.answer_quality.cli_utils import switch_to_branch +from tests.regression.answer_quality.file_uploader import upload_test_files +from tests.regression.answer_quality.relari import answer_relari_questions + + +def load_config(config_filename: str) -> SimpleNamespace: + current_dir = os.path.dirname(os.path.abspath(__file__)) + config_path = os.path.join(current_dir, config_filename) + with open(config_path, "r") as file: + return SimpleNamespace(**yaml.safe_load(file)) + + +def main() -> None: + config = load_config("search_test_config.yaml") + if config.existing_test_suffix: + run_suffix = config.existing_test_suffix + print("launching danswer with existing data suffix:", run_suffix) + else: + run_suffix = datetime.now().strftime("_%Y%m%d_%H%M%S") + print("run_suffix:", run_suffix) + + set_env_variables( + config.model_server_ip, + config.model_server_port, + config.use_cloud_gpu, + config.llm, + ) + relari_output_folder_path = manage_data_directories( + run_suffix, config.output_folder, config.use_cloud_gpu + ) + if config.branch: + switch_to_branch(config.branch) + + start_docker_compose(run_suffix, config.launch_web_ui, config.use_cloud_gpu) + + if not config.existing_test_suffix: + upload_test_files(config.zipped_documents_file, run_suffix) + + answer_relari_questions( + config.questions_file, relari_output_folder_path, run_suffix, config.limit + ) + + if config.clean_up_docker_containers: + cleanup_docker(run_suffix) + + +if __name__ == "__main__": + main() diff --git a/backend/tests/regression/answer_quality/search_test_config.yaml.template b/backend/tests/regression/answer_quality/search_test_config.yaml.template new file mode 100644 index 00000000000..47310a3c373 --- /dev/null +++ b/backend/tests/regression/answer_quality/search_test_config.yaml.template @@ -0,0 +1,52 @@ +# Copy this to search_test_config.yaml and fill in the values to run the eval pipeline +# Don't forget to also update the .env_eval file with the correct values + +# Directory where test results will be saved +output_folder: "~/danswer_test_results" + +# Path to the zip file containing sample documents +zipped_documents_file: "~/sampledocs.zip" + +# Path to the YAML file containing sample questions +questions_file: "~/sample_questions.yaml" + +# Git branch to use (null means use current branch as is) +branch: null + +# Whether to remove Docker containers after the test +clean_up_docker_containers: true + +# Whether to launch a web UI for the test +launch_web_ui: false + +# Whether to use a cloud GPU for processing +use_cloud_gpu: false + +# IP address of the model server (placeholder) +model_server_ip: "PUT_PUBLIC_CLOUD_IP_HERE" + +# Port of the model server (placeholder) +model_server_port: "PUT_PUBLIC_CLOUD_PORT_HERE" + +# Suffix for existing test results (empty string means no suffix) +existing_test_suffix: "" + +# Limit on number of tests to run (null means no limit) +limit: null + +# LLM configuration +llm: + # Name of the LLM + name: "default_test_llm" + + # Provider of the LLM (e.g., OpenAI) + provider: "openai" + + # API key + api_key: "PUT_API_KEY_HERE" + + # Default model name to use + default_model_name: "gpt-4o" + + # List of model names to use for testing + model_names: ["gpt-4o"] diff --git a/backend/tests/regression/search_quality/eval_search.py b/backend/tests/regression/search_quality/eval_search.py index e67fe7c5b44..d20c685dc27 100644 --- a/backend/tests/regression/search_quality/eval_search.py +++ b/backend/tests/regression/search_quality/eval_search.py @@ -8,9 +8,9 @@ from sqlalchemy.orm import Session from danswer.db.engine import get_sqlalchemy_engine -from danswer.llm.answering.doc_pruning import reorder_docs -from danswer.llm.factory import get_default_llm -from danswer.search.models import InferenceChunk +from danswer.llm.answering.prune_and_merge import reorder_sections +from danswer.llm.factory import get_default_llms +from danswer.search.models import InferenceSection from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer from danswer.search.models import SearchRequest @@ -75,7 +75,7 @@ def word_wrap(s: str, max_line_size: int = 100, prepend_tab: bool = True) -> str def get_search_results( query: str, ) -> tuple[ - list[InferenceChunk], + list[InferenceSection], RetrievalMetricsContainer | None, RerankMetricsContainer | None, ]: @@ -83,22 +83,24 @@ def get_search_results( rerank_metrics = MetricsHander[RerankMetricsContainer]() with Session(get_sqlalchemy_engine()) as db_session: + llm, fast_llm = get_default_llms() search_pipeline = SearchPipeline( search_request=SearchRequest( query=query, ), user=None, - llm=get_default_llm(), + llm=llm, + fast_llm=fast_llm, db_session=db_session, retrieval_metrics_callback=retrieval_metrics.record_metric, rerank_metrics_callback=rerank_metrics.record_metric, ) - top_chunks = search_pipeline.reranked_chunks - llm_chunk_selection = search_pipeline.chunk_relevance_list + top_sections = search_pipeline.reranked_sections + llm_section_selection = search_pipeline.section_relevance_list return ( - reorder_docs(top_chunks, llm_chunk_selection), + reorder_sections(top_sections, llm_section_selection), retrieval_metrics.metrics, rerank_metrics.metrics, ) @@ -167,7 +169,7 @@ def main( print(f"\n\nQuestion: {question}") ( - top_chunks, + top_sections, retrieval_metrics, rerank_metrics, ) = get_search_results(query=question) @@ -186,7 +188,7 @@ def main( running_rerank_score += rerank_score print(f"Average: {running_rerank_score / (ind + 1)}") - llm_ids = [chunk.document_id for chunk in top_chunks] + llm_ids = [section.center_chunk.document_id for section in top_sections] llm_score = calculate_score("LLM Filter", llm_ids, targets) running_llm_filter_score += llm_score print(f"Average: {running_llm_filter_score / (ind + 1)}") diff --git a/backend/tests/unit/danswer/chat/test_chat_llm.py b/backend/tests/unit/danswer/chat/test_chat_llm.py deleted file mode 100644 index 7f67fa37b4e..00000000000 --- a/backend/tests/unit/danswer/chat/test_chat_llm.py +++ /dev/null @@ -1,38 +0,0 @@ -import unittest - - -class TestChatLlm(unittest.TestCase): - def test_citation_extraction(self) -> None: - pass # May fix these tests some day - """ - links: list[str | None] = [f"link_{i}" for i in range(1, 21)] - - test_1 = "Something [1]" - res = "".join(list(extract_citations_from_stream(iter(test_1), links))) - self.assertEqual(res, "Something [[1]](link_1)") - - test_2 = "Something [14]" - res = "".join(list(extract_citations_from_stream(iter(test_2), links))) - self.assertEqual(res, "Something [[14]](link_14)") - - test_3 = "Something [14][15]" - res = "".join(list(extract_citations_from_stream(iter(test_3), links))) - self.assertEqual(res, "Something [[14]](link_14)[[15]](link_15)") - - test_4 = ["Something ", "[", "3", "][", "4", "]."] - res = "".join(list(extract_citations_from_stream(iter(test_4), links))) - self.assertEqual(res, "Something [[3]](link_3)[[4]](link_4).") - - test_5 = ["Something ", "[", "31", "][", "4", "]."] - res = "".join(list(extract_citations_from_stream(iter(test_5), links))) - self.assertEqual(res, "Something [31][[4]](link_4).") - - links[3] = None - test_1 = "Something [2][4][5]" - res = "".join(list(extract_citations_from_stream(iter(test_1), links))) - self.assertEqual(res, "Something [[2]](link_2)[4][[5]](link_5)") - """ - - -if __name__ == "__main__": - unittest.main() diff --git a/backend/tests/unit/danswer/connectors/cross_connector_utils/test_html_utils.py b/backend/tests/unit/danswer/connectors/cross_connector_utils/test_html_utils.py index 860001e1534..f14a92faa3a 100644 --- a/backend/tests/unit/danswer/connectors/cross_connector_utils/test_html_utils.py +++ b/backend/tests/unit/danswer/connectors/cross_connector_utils/test_html_utils.py @@ -1,19 +1,13 @@ import pathlib -import unittest from danswer.file_processing.html_utils import parse_html_page_basic -class TestQAPostprocessing(unittest.TestCase): - def test_parse_table(self) -> None: - dir_path = pathlib.Path(__file__).parent.resolve() - with open(f"{dir_path}/test_table.html", "r") as file: - content = file.read() +def test_parse_table() -> None: + dir_path = pathlib.Path(__file__).parent.resolve() + with open(f"{dir_path}/test_table.html", "r") as file: + content = file.read() - parsed = parse_html_page_basic(content) - expected = "\n\thello\tthere\tgeneral\n\tkenobi\ta\tb\n\tc\td\te" - self.assertIn(expected, parsed) - - -if __name__ == "__main__": - unittest.main() + parsed = parse_html_page_basic(content) + expected = "\n\thello\tthere\tgeneral\n\tkenobi\ta\tb\n\tc\td\te" + assert expected in parsed diff --git a/backend/tests/unit/danswer/connectors/cross_connector_utils/test_rate_limit.py b/backend/tests/unit/danswer/connectors/cross_connector_utils/test_rate_limit.py index bc38cada51d..471ef424f15 100644 --- a/backend/tests/unit/danswer/connectors/cross_connector_utils/test_rate_limit.py +++ b/backend/tests/unit/danswer/connectors/cross_connector_utils/test_rate_limit.py @@ -1,36 +1,29 @@ import time -import unittest from danswer.connectors.cross_connector_utils.rate_limit_wrapper import ( rate_limit_builder, ) -class TestRateLimit(unittest.TestCase): +def test_rate_limit_basic() -> None: call_cnt = 0 - def test_rate_limit_basic(self) -> None: - self.call_cnt = 0 + @rate_limit_builder(max_calls=2, period=5) + def func() -> None: + nonlocal call_cnt + call_cnt += 1 - @rate_limit_builder(max_calls=2, period=5) - def func() -> None: - self.call_cnt += 1 + start = time.time() - start = time.time() + # Make calls that shouldn't be rate-limited + func() + func() + time_to_finish_non_ratelimited = time.time() - start - # make calls that shouldn't be rate-limited - func() - func() - time_to_finish_non_ratelimited = time.time() - start + # Make a call which SHOULD be rate-limited + func() + time_to_finish_ratelimited = time.time() - start - # make a call which SHOULD be rate-limited - func() - time_to_finish_ratelimited = time.time() - start - - self.assertEqual(self.call_cnt, 3) - self.assertLess(time_to_finish_non_ratelimited, 1) - self.assertGreater(time_to_finish_ratelimited, 5) - - -if __name__ == "__main__": - unittest.main() + assert call_cnt == 3 + assert time_to_finish_non_ratelimited < 1 + assert time_to_finish_ratelimited > 5 diff --git a/backend/tests/unit/danswer/connectors/gmail/test_connector.py b/backend/tests/unit/danswer/connectors/gmail/test_connector.py index 83523f1b26d..2689e2a2751 100644 --- a/backend/tests/unit/danswer/connectors/gmail/test_connector.py +++ b/backend/tests/unit/danswer/connectors/gmail/test_connector.py @@ -1,5 +1,4 @@ import datetime -from unittest.mock import MagicMock import pytest from pytest_mock import MockFixture @@ -100,7 +99,7 @@ def test_fetch_mails_from_gmail_empty(mocker: MockFixture) -> None: "messages": [] } connector = GmailConnector() - connector.creds = MagicMock() + connector.creds = mocker.Mock() with pytest.raises(StopIteration): next(connector.load_from_state()) @@ -178,7 +177,7 @@ def test_fetch_mails_from_gmail(mocker: MockFixture) -> None: } connector = GmailConnector() - connector.creds = MagicMock() + connector.creds = mocker.Mock() docs = next(connector.load_from_state()) assert len(docs) == 1 doc: Document = docs[0] diff --git a/backend/tests/unit/danswer/connectors/mediawiki/test_mediawiki_family.py b/backend/tests/unit/danswer/connectors/mediawiki/test_mediawiki_family.py index 8f053e9f343..35a189f6dd4 100644 --- a/backend/tests/unit/danswer/connectors/mediawiki/test_mediawiki_family.py +++ b/backend/tests/unit/danswer/connectors/mediawiki/test_mediawiki_family.py @@ -1,7 +1,7 @@ from typing import Final -from unittest import mock import pytest +from pytest_mock import MockFixture from pywikibot.families.wikipedia_family import Family as WikipediaFamily # type: ignore[import-untyped] from pywikibot.family import Family # type: ignore[import-untyped] @@ -50,13 +50,11 @@ def test_family_class_dispatch_builtins( @pytest.mark.parametrize("url, name", NON_BUILTIN_WIKIS) def test_family_class_dispatch_on_non_builtins_generates_new_class_fast( - url: str, name: str + url: str, name: str, mocker: MockFixture ) -> None: """Test that using the family class dispatch function on an unknown url generates a new family class.""" - with mock.patch.object( - family, "generate_family_class" - ) as mock_generate_family_class: - family.family_class_dispatch(url, name) + mock_generate_family_class = mocker.patch.object(family, "generate_family_class") + family.family_class_dispatch(url, name) mock_generate_family_class.assert_called_once_with(url, name) diff --git a/backend/tests/unit/danswer/direct_qa/test_qa_utils.py b/backend/tests/unit/danswer/direct_qa/test_qa_utils.py index a9046691b5a..9a21fc4db41 100644 --- a/backend/tests/unit/danswer/direct_qa/test_qa_utils.py +++ b/backend/tests/unit/danswer/direct_qa/test_qa_utils.py @@ -1,5 +1,6 @@ import textwrap -import unittest + +import pytest from danswer.configs.constants import DocumentSource from danswer.llm.answering.stream_processing.quotes_processing import ( @@ -11,194 +12,181 @@ from danswer.search.models import InferenceChunk -class TestQAPostprocessing(unittest.TestCase): - def test_separate_answer_quotes(self) -> None: - test_answer = textwrap.dedent( - """ - It seems many people love dogs - Quote: A dog is a man's best friend - Quote: Air Bud was a movie about dogs and people loved it - """ - ).strip() - answer, quotes = separate_answer_quotes(test_answer) - self.assertEqual(answer, "It seems many people love dogs") - self.assertEqual(quotes[0], "A dog is a man's best friend") # type: ignore - self.assertEqual( - quotes[1], "Air Bud was a movie about dogs and people loved it" # type: ignore - ) - - # Lowercase should be allowed - test_answer = textwrap.dedent( - """ - It seems many people love dogs - quote: A dog is a man's best friend - Quote: Air Bud was a movie about dogs and people loved it - """ - ).strip() - answer, quotes = separate_answer_quotes(test_answer) - self.assertEqual(answer, "It seems many people love dogs") - self.assertEqual(quotes[0], "A dog is a man's best friend") # type: ignore - self.assertEqual( - quotes[1], "Air Bud was a movie about dogs and people loved it" # type: ignore - ) +def test_separate_answer_quotes() -> None: + # Test case 1: Basic quote separation + test_answer = textwrap.dedent( + """ + It seems many people love dogs + Quote: A dog is a man's best friend + Quote: Air Bud was a movie about dogs and people loved it + """ + ).strip() + answer, quotes = separate_answer_quotes(test_answer) + assert answer == "It seems many people love dogs" + assert isinstance(quotes, list) + assert quotes[0] == "A dog is a man's best friend" + assert quotes[1] == "Air Bud was a movie about dogs and people loved it" - # No Answer - test_answer = textwrap.dedent( - """ - Quote: This one has no answer - """ - ).strip() - answer, quotes = separate_answer_quotes(test_answer) - self.assertIsNone(answer) - self.assertIsNone(quotes) + # Test case 2: Lowercase 'quote' allowed + test_answer = textwrap.dedent( + """ + It seems many people love dogs + quote: A dog is a man's best friend + Quote: Air Bud was a movie about dogs and people loved it + """ + ).strip() + answer, quotes = separate_answer_quotes(test_answer) + assert answer == "It seems many people love dogs" + assert isinstance(quotes, list) + assert quotes[0] == "A dog is a man's best friend" + assert quotes[1] == "Air Bud was a movie about dogs and people loved it" - # Multiline Quote - test_answer = textwrap.dedent( - """ - It seems many people love dogs - quote: A well known saying is: - A dog is a man's best friend - Quote: Air Bud was a movie about dogs and people loved it - """ - ).strip() - answer, quotes = separate_answer_quotes(test_answer) - self.assertEqual(answer, "It seems many people love dogs") - self.assertEqual( - quotes[0], "A well known saying is:\nA dog is a man's best friend" # type: ignore - ) - self.assertEqual( - quotes[1], "Air Bud was a movie about dogs and people loved it" # type: ignore - ) + # Test case 3: No Answer + test_answer = textwrap.dedent( + """ + Quote: This one has no answer + """ + ).strip() + answer, quotes = separate_answer_quotes(test_answer) + assert answer is None + assert quotes is None - # Random patterns not picked up - test_answer = textwrap.dedent( - """ - It seems many people love quote: dogs - quote: Quote: A well known saying is: - A dog is a man's best friend - Quote: Answer: Air Bud was a movie about dogs and quote: people loved it - """ - ).strip() - answer, quotes = separate_answer_quotes(test_answer) - self.assertEqual(answer, "It seems many people love quote: dogs") - self.assertEqual( - quotes[0], "Quote: A well known saying is:\nA dog is a man's best friend" # type: ignore - ) - self.assertEqual( - quotes[1], # type: ignore - "Answer: Air Bud was a movie about dogs and quote: people loved it", - ) + # Test case 4: Multiline Quote + test_answer = textwrap.dedent( + """ + It seems many people love dogs + quote: A well known saying is: + A dog is a man's best friend + Quote: Air Bud was a movie about dogs and people loved it + """ + ).strip() + answer, quotes = separate_answer_quotes(test_answer) + assert answer == "It seems many people love dogs" + assert isinstance(quotes, list) + assert quotes[0] == "A well known saying is:\nA dog is a man's best friend" + assert quotes[1] == "Air Bud was a movie about dogs and people loved it" - @unittest.skip( - "Using fuzzy match is too slow anyway, doesn't matter if it's broken" + # Test case 5: Random patterns not picked up + test_answer = textwrap.dedent( + """ + It seems many people love quote: dogs + quote: Quote: A well known saying is: + A dog is a man's best friend + Quote: Answer: Air Bud was a movie about dogs and quote: people loved it + """ + ).strip() + answer, quotes = separate_answer_quotes(test_answer) + assert answer == "It seems many people love quote: dogs" + assert isinstance(quotes, list) + assert quotes[0] == "Quote: A well known saying is:\nA dog is a man's best friend" + assert ( + quotes[1] == "Answer: Air Bud was a movie about dogs and quote: people loved it" ) - def test_fuzzy_match_quotes_to_docs(self) -> None: - chunk_0_text = textwrap.dedent( - """ - Here's a doc with some LINK embedded in the text - THIS SECTION IS A LINK - Some more text - """ - ).strip() - chunk_1_text = textwrap.dedent( - """ - Some completely different text here - ANOTHER LINK embedded in this text - ending in a DIFFERENT-LINK - """ - ).strip() - test_chunk_0 = InferenceChunk( - document_id="test doc 0", - source_type=DocumentSource.FILE, - chunk_id=0, - content=chunk_0_text, - source_links={ - 0: "doc 0 base", - 23: "first line link", - 49: "second line link", - }, - blurb="anything", - semantic_identifier="anything", - section_continuation=False, - recency_bias=1, - boost=0, - hidden=False, - score=1, - metadata={}, - match_highlights=[], - updated_at=None, - ) - test_chunk_1 = InferenceChunk( - document_id="test doc 1", - source_type=DocumentSource.FILE, - chunk_id=0, - content=chunk_1_text, - source_links={0: "doc 1 base", 36: "2nd line link", 82: "last link"}, - blurb="whatever", - semantic_identifier="whatever", - section_continuation=False, - recency_bias=1, - boost=0, - hidden=False, - score=1, - metadata={}, - match_highlights=[], - updated_at=None, - ) - test_quotes = [ - "a doc with some", # Basic case - "a doc with some LINK", # Should take the start of quote, even if a link is in it - "a doc with some \nLINK", # Requires a newline deletion fuzzy match - "a doc with some link", # Capitalization insensitive - "embedded in this text", # Fuzzy match to first doc - "SECTION IS A LINK", # Match exact link - "some more text", # Match the end, after every link offset - "different taxt", # Substitution - "embedded in this texts", # Cannot fuzzy match to first doc, fuzzy match to second doc - "DIFFERENT-LINK", # Exact link match at the end - "Some complitali", # Too many edits, shouldn't match anything - ] - results = match_quotes_to_docs( - test_quotes, [test_chunk_0, test_chunk_1], fuzzy_search=True - ) - self.assertEqual( - results, - { - "a doc with some": {"document": "test doc 0", "link": "doc 0 base"}, - "a doc with some LINK": { - "document": "test doc 0", - "link": "doc 0 base", - }, - "a doc with some \nLINK": { - "document": "test doc 0", - "link": "doc 0 base", - }, - "a doc with some link": { - "document": "test doc 0", - "link": "doc 0 base", - }, - "embedded in this text": { - "document": "test doc 0", - "link": "first line link", - }, - "SECTION IS A LINK": { - "document": "test doc 0", - "link": "second line link", - }, - "some more text": { - "document": "test doc 0", - "link": "second line link", - }, - "different taxt": {"document": "test doc 1", "link": "doc 1 base"}, - "embedded in this texts": { - "document": "test doc 1", - "link": "2nd line link", - }, - "DIFFERENT-LINK": {"document": "test doc 1", "link": "last link"}, - }, - ) +@pytest.mark.skip( + reason="Using fuzzy match is too slow anyway, doesn't matter if it's broken" +) +def test_fuzzy_match_quotes_to_docs() -> None: + chunk_0_text = textwrap.dedent( + """ + Here's a doc with some LINK embedded in the text + THIS SECTION IS A LINK + Some more text + """ + ).strip() + chunk_1_text = textwrap.dedent( + """ + Some completely different text here + ANOTHER LINK embedded in this text + ending in a DIFFERENT-LINK + """ + ).strip() + test_chunk_0 = InferenceChunk( + document_id="test doc 0", + source_type=DocumentSource.FILE, + chunk_id=0, + content=chunk_0_text, + source_links={ + 0: "doc 0 base", + 23: "first line link", + 49: "second line link", + }, + blurb="anything", + semantic_identifier="anything", + section_continuation=False, + recency_bias=1, + boost=0, + hidden=False, + score=1, + metadata={}, + match_highlights=[], + updated_at=None, + ) + test_chunk_1 = InferenceChunk( + document_id="test doc 1", + source_type=DocumentSource.FILE, + chunk_id=0, + content=chunk_1_text, + source_links={0: "doc 1 base", 36: "2nd line link", 82: "last link"}, + blurb="whatever", + semantic_identifier="whatever", + section_continuation=False, + recency_bias=1, + boost=0, + hidden=False, + score=1, + metadata={}, + match_highlights=[], + updated_at=None, + ) -if __name__ == "__main__": - unittest.main() + test_quotes = [ + "a doc with some", # Basic case + "a doc with some LINK", # Should take the start of quote, even if a link is in it + "a doc with some \nLINK", # Requires a newline deletion fuzzy match + "a doc with some link", # Capitalization insensitive + "embedded in this text", # Fuzzy match to first doc + "SECTION IS A LINK", # Match exact link + "some more text", # Match the end, after every link offset + "different taxt", # Substitution + "embedded in this texts", # Cannot fuzzy match to first doc, fuzzy match to second doc + "DIFFERENT-LINK", # Exact link match at the end + "Some complitali", # Too many edits, shouldn't match anything + ] + results = match_quotes_to_docs( + test_quotes, [test_chunk_0, test_chunk_1], fuzzy_search=True + ) + assert results == { + "a doc with some": {"document": "test doc 0", "link": "doc 0 base"}, + "a doc with some LINK": { + "document": "test doc 0", + "link": "doc 0 base", + }, + "a doc with some \nLINK": { + "document": "test doc 0", + "link": "doc 0 base", + }, + "a doc with some link": { + "document": "test doc 0", + "link": "doc 0 base", + }, + "embedded in this text": { + "document": "test doc 0", + "link": "first line link", + }, + "SECTION IS A LINK": { + "document": "test doc 0", + "link": "second line link", + }, + "some more text": { + "document": "test doc 0", + "link": "second line link", + }, + "different taxt": {"document": "test doc 1", "link": "doc 1 base"}, + "embedded in this texts": { + "document": "test doc 1", + "link": "2nd line link", + }, + "DIFFERENT-LINK": {"document": "test doc 1", "link": "last link"}, + } diff --git a/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py b/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py new file mode 100644 index 00000000000..87d8e0d32f4 --- /dev/null +++ b/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py @@ -0,0 +1,277 @@ +from datetime import datetime + +import pytest + +from danswer.chat.models import CitationInfo +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import LlmDoc +from danswer.configs.constants import DocumentSource +from danswer.llm.answering.stream_processing.citation_processing import ( + extract_citations_from_stream, +) +from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping + + +""" +This module contains tests for the citation extraction functionality in Danswer. + +The tests focus on the `extract_citations_from_stream` function, which processes +a stream of tokens and extracts citations, replacing them with properly formatted +versions including links where available. + +Key components: +- mock_docs: A list of mock LlmDoc objects used for testing. +- mock_doc_mapping: A dictionary mapping document IDs to their ranks. +- process_text: A helper function that simulates the citation extraction process. +- test_citation_extraction: A parametrized test function covering various citation scenarios. + +To add new test cases: +1. Add a new tuple to the @pytest.mark.parametrize decorator of test_citation_extraction. +2. Each tuple should contain: + - A descriptive test name (string) + - Input tokens (list of strings) + - Expected output text (string) + - Expected citations (list of document IDs) +""" + + +mock_docs = [ + LlmDoc( + document_id=f"doc_{int(id/2)}", + content="Document is a doc", + blurb=f"Document #{id}", + semantic_identifier=f"Doc {id}", + source_type=DocumentSource.WEB, + metadata={}, + updated_at=datetime.now(), + link=f"https://{int(id/2)}.com" if int(id / 2) % 2 == 0 else None, + source_links={0: "https://mintlify.com/docs/settings/broken-links"}, + ) + for id in range(10) +] + +mock_doc_mapping = { + "doc_0": 1, + "doc_1": 2, + "doc_2": 3, + "doc_3": 4, + "doc_4": 5, + "doc_5": 6, +} + + +@pytest.fixture +def mock_data() -> tuple[list[LlmDoc], dict[str, int]]: + return mock_docs, mock_doc_mapping + + +def process_text( + tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]] +) -> tuple[str, list[CitationInfo]]: + mock_docs, mock_doc_id_to_rank_map = mock_data + mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map) + result = list( + extract_citations_from_stream( + tokens=iter(tokens), + context_docs=mock_docs, + doc_id_to_rank_map=mapping, + stop_stream=None, + ) + ) + final_answer_text = "" + citations = [] + for piece in result: + if isinstance(piece, DanswerAnswerPiece): + final_answer_text += piece.answer_piece or "" + elif isinstance(piece, CitationInfo): + citations.append(piece) + return final_answer_text, citations + + +@pytest.mark.parametrize( + "test_name, input_tokens, expected_text, expected_citations", + [ + ( + "Single citation", + ["Gro", "wth! [", "1", "]", "."], + "Growth! [[1]](https://0.com).", + ["doc_0"], + ), + ( + "Repeated citations", + ["Test! ", "[", "1", "]", ". And so", "me more ", "[", "2", "]", "."], + "Test! [[1]](https://0.com). And some more [[1]](https://0.com).", + ["doc_0"], + ), + ( + "Citations at sentence boundaries", + [ + "Citation at the ", + "end of a sen", + "tence.", + "[", + "2", + "]", + " Another sen", + "tence.", + "[", + "4", + "]", + ], + "Citation at the end of a sentence.[[1]](https://0.com) Another sentence.[[2]]()", + ["doc_0", "doc_1"], + ), + ( + "Citations at beginning, middle, and end", + [ + "[", + "1", + "]", + " Citation at ", + "the beginning. ", + "[", + "3", + "]", + " In the mid", + "dle. At the end ", + "[", + "5", + "]", + ".", + ], + "[[1]](https://0.com) Citation at the beginning. [[2]]() In the middle. At the end [[3]](https://2.com).", + ["doc_0", "doc_1", "doc_2"], + ), + ( + "Mixed valid and invalid citations", + [ + "Mixed valid and in", + "valid citations ", + "[", + "1", + "]", + "[", + "99", + "]", + "[", + "3", + "]", + "[", + "100", + "]", + "[", + "5", + "]", + ".", + ], + "Mixed valid and invalid citations [[1]](https://0.com)[99][[2]]()[100][[3]](https://2.com).", + ["doc_0", "doc_1", "doc_2"], + ), + ( + "Hardest!", + [ + "Multiple cit", + "ations in one ", + "sentence [", + "1", + "]", + "[", + "4", + "]", + "[", + "5", + "]", + ". ", + ], + "Multiple citations in one sentence [[1]](https://0.com)[[2]]()[[3]](https://2.com).", + ["doc_0", "doc_1", "doc_2"], + ), + ( + "Repeated citations with text", + ["[", "1", "]", "Aasf", "asda", "sff ", "[", "1", "]", " ."], + "[[1]](https://0.com)Aasfasdasff [[1]](https://0.com) .", + ["doc_0"], + ), + ( + "Consecutive identical citations!", + [ + "Citations [", + "1", + "]", + "[", + "1]", + "", + "[2", + "", + "]", + ". ", + ], + "Citations [[1]](https://0.com).", + ["doc_0"], + ), + ( + "Consecutive identical citations!", + [ + "test [1]tt[1]t", + "", + ], + "test [[1]](https://0.com)ttt", + ["doc_0"], + ), + ( + "Consecutive identical citations!", + [ + "test [1]t[1]t[1]", + "", + ], + "test [[1]](https://0.com)tt", + ["doc_0"], + ), + ( + "Repeated citations with text", + ["[", "1", "]", "Aasf", "asda", "sff ", "[", "1", "]", " ."], + "[[1]](https://0.com)Aasfasdasff [[1]](https://0.com) .", + ["doc_0"], + ), + ( + "Repeated citations with text", + ["[1][", "1", "]t", "[2]"], + "[[1]](https://0.com)t", + ["doc_0"], + ), + ( + "Repeated citations with text", + ["[1][", "1", "]t]", "[2]"], + "[[1]](https://0.com)t]", + ["doc_0"], + ), + ( + "Repeated citations with text", + ["[1][", "3", "]t]", "[2]"], + "[[1]](https://0.com)[[2]]()t]", + ["doc_0", "doc_1"], + ), + ( + "Repeated citations with text", + ["[1", "][", "3", "]t]", "[2]"], + "[[1]](https://0.com)[[2]]()t]", + ["doc_0", "doc_1"], + ), + ], +) +def test_citation_extraction( + mock_data: tuple[list[LlmDoc], dict[str, int]], + test_name: str, + input_tokens: list[str], + expected_text: str, + expected_citations: list[str], +) -> None: + final_answer_text, citations = process_text(input_tokens, mock_data) + assert ( + final_answer_text.strip() == expected_text.strip() + ), f"Test '{test_name}' failed: Final answer text does not match expected output." + assert [ + citation.document_id for citation in citations + ] == expected_citations, ( + f"Test '{test_name}' failed: Citations do not match expected output." + ) diff --git a/backend/tests/unit/danswer/llm/answering/test_prune_and_merge.py b/backend/tests/unit/danswer/llm/answering/test_prune_and_merge.py new file mode 100644 index 00000000000..1782f3edbbd --- /dev/null +++ b/backend/tests/unit/danswer/llm/answering/test_prune_and_merge.py @@ -0,0 +1,229 @@ +import pytest + +from danswer.configs.constants import DocumentSource +from danswer.llm.answering.prune_and_merge import _merge_sections +from danswer.search.models import InferenceChunk +from danswer.search.models import InferenceSection + + +# This large test accounts for all of the following: +# 1. Merging of adjacent sections +# 2. Merging of non-adjacent sections +# 3. Merging of sections where there are multiple documents +# 4. Verifying the contents of merged sections +# 5. Verifying the order/score of the merged sections + + +def create_inference_chunk( + document_id: str, chunk_id: int, content: str, score: float | None +) -> InferenceChunk: + """ + Create an InferenceChunk with hardcoded values for testing purposes. + """ + return InferenceChunk( + chunk_id=chunk_id, + document_id=document_id, + semantic_identifier=f"{document_id}_{chunk_id}", + blurb=f"{document_id}_{chunk_id}", + content=content, + source_links={0: "fake_link"}, + section_continuation=False, + source_type=DocumentSource.WEB, + boost=0, + recency_bias=1.0, + score=score, + hidden=False, + metadata={}, + match_highlights=[], + updated_at=None, + ) + + +# Document 1, top connected sections +DOC_1_FILLER_1 = create_inference_chunk("doc1", 2, "Content 2", 1.0) +DOC_1_FILLER_2 = create_inference_chunk("doc1", 3, "Content 3", 2.0) +DOC_1_TOP_CHUNK = create_inference_chunk("doc1", 4, "Content 4", None) +DOC_1_MID_CHUNK = create_inference_chunk("doc1", 5, "Content 5", 4.0) +DOC_1_FILLER_3 = create_inference_chunk("doc1", 6, "Content 6", 5.0) +DOC_1_FILLER_4 = create_inference_chunk("doc1", 7, "Content 7", 6.0) +# This chunk below has the top score for testing +DOC_1_BOTTOM_CHUNK = create_inference_chunk("doc1", 8, "Content 8", 70.0) +DOC_1_FILLER_5 = create_inference_chunk("doc1", 9, "Content 9", None) +DOC_1_FILLER_6 = create_inference_chunk("doc1", 10, "Content 10", 9.0) +# Document 1, separate section +DOC_1_FILLER_7 = create_inference_chunk("doc1", 13, "Content 13", 10.0) +DOC_1_FILLER_8 = create_inference_chunk("doc1", 14, "Content 14", 11.0) +DOC_1_DISCONNECTED = create_inference_chunk("doc1", 15, "Content 15", 12.0) +DOC_1_FILLER_9 = create_inference_chunk("doc1", 16, "Content 16", 13.0) +DOC_1_FILLER_10 = create_inference_chunk("doc1", 17, "Content 17", 14.0) +# Document 2 +DOC_2_FILLER_1 = create_inference_chunk("doc2", 1, "Doc 2 Content 1", 15.0) +DOC_2_FILLER_2 = create_inference_chunk("doc2", 2, "Doc 2 Content 2", 16.0) +# This chunk below has top score for testing +DOC_2_TOP_CHUNK = create_inference_chunk("doc2", 3, "Doc 2 Content 3", 170.0) +DOC_2_FILLER_3 = create_inference_chunk("doc2", 4, "Doc 2 Content 4", 18.0) +DOC_2_BOTTOM_CHUNK = create_inference_chunk("doc2", 5, "Doc 2 Content 5", 19.0) +DOC_2_FILLER_4 = create_inference_chunk("doc2", 6, "Doc 2 Content 6", 20.0) +DOC_2_FILLER_5 = create_inference_chunk("doc2", 7, "Doc 2 Content 7", 21.0) + + +# Doc 2 has the highest score so it comes first +EXPECTED_CONTENT_1 = """ +Doc 2 Content 1 +Doc 2 Content 2 +Doc 2 Content 3 +Doc 2 Content 4 +Doc 2 Content 5 +Doc 2 Content 6 +Doc 2 Content 7 +""".strip() + + +EXPECTED_CONTENT_2 = """ +Content 2 +Content 3 +Content 4 +Content 5 +Content 6 +Content 7 +Content 8 +Content 9 +Content 10 + +... + +Content 13 +Content 14 +Content 15 +Content 16 +Content 17 +""".strip() + + +@pytest.mark.parametrize( + "sections,expected_contents,expected_center_chunks", + [ + ( + # Sections + [ + # Document 1, top/middle/bot connected + disconnected section + InferenceSection( + center_chunk=DOC_1_TOP_CHUNK, + chunks=[ + DOC_1_FILLER_1, + DOC_1_FILLER_2, + DOC_1_TOP_CHUNK, + DOC_1_MID_CHUNK, + DOC_1_FILLER_3, + ], + combined_content="N/A", # Not used + ), + InferenceSection( + center_chunk=DOC_1_MID_CHUNK, + chunks=[ + DOC_1_FILLER_2, + DOC_1_TOP_CHUNK, + DOC_1_MID_CHUNK, + DOC_1_FILLER_3, + DOC_1_FILLER_4, + ], + combined_content="N/A", + ), + InferenceSection( + center_chunk=DOC_1_BOTTOM_CHUNK, + chunks=[ + DOC_1_FILLER_3, + DOC_1_FILLER_4, + DOC_1_BOTTOM_CHUNK, + DOC_1_FILLER_5, + DOC_1_FILLER_6, + ], + combined_content="N/A", + ), + InferenceSection( + center_chunk=DOC_1_DISCONNECTED, + chunks=[ + DOC_1_FILLER_7, + DOC_1_FILLER_8, + DOC_1_DISCONNECTED, + DOC_1_FILLER_9, + DOC_1_FILLER_10, + ], + combined_content="N/A", + ), + InferenceSection( + center_chunk=DOC_2_TOP_CHUNK, + chunks=[ + DOC_2_FILLER_1, + DOC_2_FILLER_2, + DOC_2_TOP_CHUNK, + DOC_2_FILLER_3, + DOC_2_BOTTOM_CHUNK, + ], + combined_content="N/A", + ), + InferenceSection( + center_chunk=DOC_2_BOTTOM_CHUNK, + chunks=[ + DOC_2_TOP_CHUNK, + DOC_2_FILLER_3, + DOC_2_BOTTOM_CHUNK, + DOC_2_FILLER_4, + DOC_2_FILLER_5, + ], + combined_content="N/A", + ), + ], + # Expected Content + [EXPECTED_CONTENT_1, EXPECTED_CONTENT_2], + # Expected Center Chunks (highest scores) + [DOC_2_TOP_CHUNK, DOC_1_BOTTOM_CHUNK], + ), + ], +) +def test_merge_sections( + sections: list[InferenceSection], + expected_contents: list[str], + expected_center_chunks: list[InferenceChunk], +) -> None: + sections.sort(key=lambda section: section.center_chunk.score or 0, reverse=True) + merged_sections = _merge_sections(sections) + assert merged_sections[0].combined_content == expected_contents[0] + assert merged_sections[1].combined_content == expected_contents[1] + assert merged_sections[0].center_chunk == expected_center_chunks[0] + assert merged_sections[1].center_chunk == expected_center_chunks[1] + + +@pytest.mark.parametrize( + "sections,expected_content,expected_center_chunk", + [ + ( + # Sections + [ + InferenceSection( + center_chunk=DOC_1_TOP_CHUNK, + chunks=[DOC_1_TOP_CHUNK], + combined_content="N/A", # Not used + ), + InferenceSection( + center_chunk=DOC_1_MID_CHUNK, + chunks=[DOC_1_MID_CHUNK], + combined_content="N/A", + ), + ], + # Expected Content + "Content 4\nContent 5", + # Expected Center Chunks (highest scores) + DOC_1_MID_CHUNK, + ), + ], +) +def test_merge_minimal_sections( + sections: list[InferenceSection], + expected_content: str, + expected_center_chunk: InferenceChunk, +) -> None: + sections.sort(key=lambda section: section.center_chunk.score or 0, reverse=True) + merged_sections = _merge_sections(sections) + assert merged_sections[0].combined_content == expected_content + assert merged_sections[0].center_chunk == expected_center_chunk diff --git a/backend/throttle.ctrl b/backend/throttle.ctrl new file mode 100644 index 00000000000..03aa9179535 --- /dev/null +++ b/backend/throttle.ctrl @@ -0,0 +1 @@ +f1f2 1 1718910083.03085 wikipedia:en \ No newline at end of file diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index cec7cac25a9..0f35829230e 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -30,6 +30,9 @@ services: - SMTP_USER=${SMTP_USER:-} - SMTP_PASS=${SMTP_PASS:-} - EMAIL_FROM=${EMAIL_FROM:-} + - OAUTH_CLIENT_ID=${OAUTH_CLIENT_ID:-} + - OAUTH_CLIENT_SECRET=${OAUTH_CLIENT_SECRET:-} + - OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-} # Gen AI Settings - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-} - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-} @@ -41,13 +44,13 @@ services: - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-} - QA_TIMEOUT=${QA_TIMEOUT:-} - MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-} - - DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-} - DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-} - DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-} - DISABLE_LLM_QUERY_REPHRASE=${DISABLE_LLM_QUERY_REPHRASE:-} - DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-} - DISABLE_LITELLM_STREAMING=${DISABLE_LITELLM_STREAMING:-} - LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-} + - BING_API_KEY=${BING_API_KEY:-} # if set, allows for the use of the token budget system - TOKEN_BUDGET_GLOBALLY_ENABLED=${TOKEN_BUDGET_GLOBALLY_ENABLED:-} # Enables the use of bedrock models @@ -59,6 +62,8 @@ services: - HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector) - EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-} - MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-} + - LANGUAGE_HINT=${LANGUAGE_HINT:-} + - LANGUAGE_CHAT_NAMING_HINT=${LANGUAGE_CHAT_NAMING_HINT:-} - QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-} # Other services - POSTGRES_HOST=relational_db @@ -76,12 +81,20 @@ services: # Leave this on pretty please? Nothing sensitive is collected! # https://docs.danswer.dev/more/telemetry - DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-} - - LOG_LEVEL=${LOG_LEVEL:-info} # Set to debug to get more fine-grained logs - - LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-} # Log all of the prompts to the LLM + - LOG_LEVEL=${LOG_LEVEL:-info} # Set to debug to get more fine-grained logs + - LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-} # LiteLLM Verbose Logging + # Log all of Danswer prompts and interactions with the LLM + - LOG_DANSWER_MODEL_INTERACTIONS=${LOG_DANSWER_MODEL_INTERACTIONS:-} # If set to `true` will enable additional logs about Vespa query performance # (time spent on finding the right docs + time spent fetching summaries from disk) - LOG_VESPA_TIMING_INFORMATION=${LOG_VESPA_TIMING_INFORMATION:-} - LOG_ENDPOINT_LATENCY=${LOG_ENDPOINT_LATENCY:-} + + # Enterprise Edition only + - ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false} + - API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-} + # Seeding configuration + - ENV_SEED_CONFIGURATION=${ENV_SEED_CONFIGURATION:-} extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -95,7 +108,7 @@ services: build: context: ../../backend dockerfile: Dockerfile - command: /usr/bin/supervisord + command: /usr/bin/supervisord -c /etc/supervisor/conf.d/supervisord.conf depends_on: - relational_db - index @@ -115,7 +128,6 @@ services: - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-} - QA_TIMEOUT=${QA_TIMEOUT:-} - MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-} - - DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-} - DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-} - DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-} - DISABLE_LLM_QUERY_REPHRASE=${DISABLE_LLM_QUERY_REPHRASE:-} @@ -123,11 +135,14 @@ services: - GENERATIVE_MODEL_ACCESS_CHECK_FREQ=${GENERATIVE_MODEL_ACCESS_CHECK_FREQ:-} - DISABLE_LITELLM_STREAMING=${DISABLE_LITELLM_STREAMING:-} - LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-} + - BING_API_KEY=${BING_API_KEY:-} # Query Options - DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years) - HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector) - EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-} - MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-} + - LANGUAGE_HINT=${LANGUAGE_HINT:-} + - LANGUAGE_CHAT_NAMING_HINT=${LANGUAGE_CHAT_NAMING_HINT:-} - QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-} # Other Services - POSTGRES_HOST=relational_db @@ -174,9 +189,14 @@ services: # Leave this on pretty please? Nothing sensitive is collected! # https://docs.danswer.dev/more/telemetry - DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-} - - LOG_LEVEL=${LOG_LEVEL:-info} # Set to debug to get more fine-grained logs - - LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-} # Log all of the prompts to the LLM + - LOG_LEVEL=${LOG_LEVEL:-info} # Set to debug to get more fine-grained logs + - LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-} # LiteLLM Verbose Logging + # Log all of Danswer prompts and interactions with the LLM + - LOG_DANSWER_MODEL_INTERACTIONS=${LOG_DANSWER_MODEL_INTERACTIONS:-} - LOG_VESPA_TIMING_INFORMATION=${LOG_VESPA_TIMING_INFORMATION:-} + + # Enterprise Edition stuff + - ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false} extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -196,12 +216,21 @@ services: - NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS:-} - NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-} - NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-} + + # Enterprise Edition only + - NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-} + # DO NOT TURN ON unless you have EXPLICIT PERMISSION from Danswer. + - NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED:-false} depends_on: - api_server restart: always environment: - INTERNAL_URL=http://api_server:8080 - WEB_DOMAIN=${WEB_DOMAIN:-} + - THEME_IS_DARK=${THEME_IS_DARK:-} + + # Enterprise Edition only + - ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false} inference_model_server: image: danswer/danswer-model-server:latest @@ -302,8 +331,8 @@ services: options: max-size: "50m" max-file: "6" - # the specified script waits for the api_server to start up. - # Without this we've seen issues where nginx shows no error logs but + # The specified script waits for the api_server to start up. + # Without this we've seen issues where nginx shows no error logs but # does not recieve any traffic # NOTE: we have to use dos2unix to remove Carriage Return chars from the file # in order to make this work on both Unix-like systems and windows diff --git a/deployment/docker_compose/docker-compose.gpu-dev.yml b/deployment/docker_compose/docker-compose.gpu-dev.yml index cce0f43028b..a46dd00a38f 100644 --- a/deployment/docker_compose/docker-compose.gpu-dev.yml +++ b/deployment/docker_compose/docker-compose.gpu-dev.yml @@ -41,7 +41,6 @@ services: - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-} - QA_TIMEOUT=${QA_TIMEOUT:-} - MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-} - - DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-} - DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-} - DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-} - DISABLE_LLM_QUERY_REPHRASE=${DISABLE_LLM_QUERY_REPHRASE:-} @@ -59,6 +58,8 @@ services: - HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector) - EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-} - MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-} + - LANGUAGE_HINT=${LANGUAGE_HINT:-} + - LANGUAGE_CHAT_NAMING_HINT=${LANGUAGE_CHAT_NAMING_HINT:-} - QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-} # Other services - POSTGRES_HOST=relational_db @@ -77,10 +78,16 @@ services: # https://docs.danswer.dev/more/telemetry - DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-} - LOG_LEVEL=${LOG_LEVEL:-info} # Set to debug to get more fine-grained logs - - LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-} # Log all of the prompts to the LLM + - LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-} # LiteLLM Verbose Logging + # Log all of Danswer prompts and interactions with the LLM + - LOG_DANSWER_MODEL_INTERACTIONS=${LOG_DANSWER_MODEL_INTERACTIONS:-} # If set to `true` will enable additional logs about Vespa query performance # (time spent on finding the right docs + time spent fetching summaries from disk) - LOG_VESPA_TIMING_INFORMATION=${LOG_VESPA_TIMING_INFORMATION:-} + + # Enterprise Edition only + - API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-} + - ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false} extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -115,7 +122,6 @@ services: - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-} - QA_TIMEOUT=${QA_TIMEOUT:-} - MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-} - - DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-} - DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-} - DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-} - DISABLE_LLM_QUERY_REPHRASE=${DISABLE_LLM_QUERY_REPHRASE:-} @@ -128,6 +134,8 @@ services: - HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector) - EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-} - MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-} + - LANGUAGE_HINT=${LANGUAGE_HINT:-} + - LANGUAGE_CHAT_NAMING_HINT=${LANGUAGE_CHAT_NAMING_HINT:-} - QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-} # Other Services - POSTGRES_HOST=relational_db @@ -175,8 +183,14 @@ services: # https://docs.danswer.dev/more/telemetry - DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-} - LOG_LEVEL=${LOG_LEVEL:-info} # Set to debug to get more fine-grained logs - - LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-} # Log all of the prompts to the LLM + - LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-} # LiteLLM Verbose Logging + # Log all of Danswer prompts and interactions with the LLM + - LOG_DANSWER_MODEL_INTERACTIONS=${LOG_DANSWER_MODEL_INTERACTIONS:-} - LOG_VESPA_TIMING_INFORMATION=${LOG_VESPA_TIMING_INFORMATION:-} + + # Enterprise Edition only + - API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-} + - ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false} extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -197,12 +211,17 @@ services: - NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS:-} - NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-} - NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-} + - NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-} depends_on: - api_server restart: always environment: - INTERNAL_URL=http://api_server:8080 - WEB_DOMAIN=${WEB_DOMAIN:-} + - THEME_IS_DARK=${THEME_IS_DARK:-} + + # Enterprise Edition only + - ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false} inference_model_server: @@ -326,7 +345,7 @@ services: options: max-size: "50m" max-file: "6" - # the specified script waits for the api_server to start up. + # The specified script waits for the api_server to start up. # Without this we've seen issues where nginx shows no error logs but # does not recieve any traffic # NOTE: we have to use dos2unix to remove Carriage Return chars from the file diff --git a/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml b/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml index c2ef8a13274..49024b6e366 100644 --- a/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml +++ b/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml @@ -17,7 +17,7 @@ services: env_file: - .env environment: - - AUTH_TYPE=${AUTH_TYPE:-google_oauth} + - AUTH_TYPE=${AUTH_TYPE:-oidc} - POSTGRES_HOST=relational_db - VESPA_HOST=index - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} @@ -34,7 +34,7 @@ services: build: context: ../../backend dockerfile: Dockerfile - command: /usr/bin/supervisord + command: /usr/bin/supervisord -c /etc/supervisor/conf.d/supervisord.conf depends_on: - relational_db - index @@ -44,7 +44,7 @@ services: env_file: - .env environment: - - AUTH_TYPE=${AUTH_TYPE:-google_oauth} + - AUTH_TYPE=${AUTH_TYPE:-oidc} - POSTGRES_HOST=relational_db - VESPA_HOST=index - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} @@ -68,6 +68,7 @@ services: - NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS:-} - NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-} - NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-} + - NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-} depends_on: - api_server restart: always @@ -182,8 +183,8 @@ services: options: max-size: "50m" max-file: "6" - # the specified script waits for the api_server to start up. - # Without this we've seen issues where nginx shows no error logs but + # The specified script waits for the api_server to start up. + # Without this we've seen issues where nginx shows no error logs but # does not recieve any traffic # NOTE: we have to use dos2unix to remove Carriage Return chars from the file # in order to make this work on both Unix-like systems and windows diff --git a/deployment/docker_compose/docker-compose.prod.yml b/deployment/docker_compose/docker-compose.prod.yml index 26ef6101ce7..9c5d7f8c1f6 100644 --- a/deployment/docker_compose/docker-compose.prod.yml +++ b/deployment/docker_compose/docker-compose.prod.yml @@ -17,7 +17,7 @@ services: env_file: - .env environment: - - AUTH_TYPE=${AUTH_TYPE:-google_oauth} + - AUTH_TYPE=${AUTH_TYPE:-oidc} - POSTGRES_HOST=relational_db - VESPA_HOST=index - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} @@ -34,7 +34,7 @@ services: build: context: ../../backend dockerfile: Dockerfile - command: /usr/bin/supervisord + command: /usr/bin/supervisord -c /etc/supervisor/conf.d/supervisord.conf depends_on: - relational_db - index @@ -44,7 +44,7 @@ services: env_file: - .env environment: - - AUTH_TYPE=${AUTH_TYPE:-google_oauth} + - AUTH_TYPE=${AUTH_TYPE:-oidc} - POSTGRES_HOST=relational_db - VESPA_HOST=index - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} @@ -68,6 +68,7 @@ services: - NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS:-} - NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-} - NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-} + - NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-} depends_on: - api_server restart: always @@ -186,9 +187,9 @@ services: options: max-size: "50m" max-file: "6" - # the specified script waits for the api_server to start up. - # Without this we've seen issues where nginx shows no error logs but - # does not recieve any traffic + # The specified script waits for the api_server to start up. + # Without this we've seen issues where nginx shows no error logs but + # does not recieve any traffic # NOTE: we have to use dos2unix to remove Carriage Return chars from the file # in order to make this work on both Unix-like systems and windows command: > diff --git a/deployment/docker_compose/docker-compose.search-testing.yml b/deployment/docker_compose/docker-compose.search-testing.yml new file mode 100644 index 00000000000..41eb50eaf8c --- /dev/null +++ b/deployment/docker_compose/docker-compose.search-testing.yml @@ -0,0 +1,214 @@ +version: '3' +services: + api_server: + image: danswer/danswer-backend:latest + build: + context: ../../backend + dockerfile: Dockerfile + command: > + /bin/sh -c "alembic upgrade head && + echo \"Starting Danswer Api Server\" && + uvicorn danswer.main:app --host 0.0.0.0 --port 8080" + depends_on: + - relational_db + - index + restart: always + ports: + - "8080" + env_file: + - .env_eval + environment: + - AUTH_TYPE=disabled + - POSTGRES_HOST=relational_db + - VESPA_HOST=index + - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} + - MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-} + - ENV_SEED_CONFIGURATION=${ENV_SEED_CONFIGURATION:-} + - ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=True + extra_hosts: + - "host.docker.internal:host-gateway" + logging: + driver: json-file + options: + max-size: "50m" + max-file: "6" + + + background: + image: danswer/danswer-backend:latest + build: + context: ../../backend + dockerfile: Dockerfile + command: /usr/bin/supervisord -c /etc/supervisor/conf.d/supervisord.conf + depends_on: + - relational_db + - index + restart: always + env_file: + - .env_eval + environment: + - AUTH_TYPE=disabled + - POSTGRES_HOST=relational_db + - VESPA_HOST=index + - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} + - MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-} + - INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server} + - ENV_SEED_CONFIGURATION=${ENV_SEED_CONFIGURATION:-} + - ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=True + extra_hosts: + - "host.docker.internal:host-gateway" + logging: + driver: json-file + options: + max-size: "50m" + max-file: "6" + + + web_server: + image: danswer/danswer-web-server:latest + build: + context: ../../web + dockerfile: Dockerfile + args: + - NEXT_PUBLIC_DISABLE_STREAMING=${NEXT_PUBLIC_DISABLE_STREAMING:-false} + - NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA=${NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA:-false} + - NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS:-} + - NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-} + - NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-} + + # Enterprise Edition only + - NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-} + # DO NOT TURN ON unless you have EXPLICIT PERMISSION from Danswer. + - NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED:-false} + depends_on: + - api_server + restart: always + environment: + - INTERNAL_URL=http://api_server:8080 + - WEB_DOMAIN=${WEB_DOMAIN:-} + - THEME_IS_DARK=${THEME_IS_DARK:-} + + # Enterprise Edition only + - ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false} + + + inference_model_server: + image: danswer/danswer-model-server:latest + build: + context: ../../backend + dockerfile: Dockerfile.model_server + command: > + /bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then + echo 'Skipping service...'; + exit 0; + else + exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000; + fi" + restart: on-failure + environment: + - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} + - LOG_LEVEL=${LOG_LEVEL:-debug} + - inference_model_cache_huggingface:/root/.cache/huggingface/ + logging: + driver: json-file + options: + max-size: "50m" + max-file: "6" + + + indexing_model_server: + image: danswer/danswer-model-server:latest + build: + context: ../../backend + dockerfile: Dockerfile.model_server + command: > + /bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then + echo 'Skipping service...'; + exit 0; + else + exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000; + fi" + restart: on-failure + environment: + - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} + - INDEXING_ONLY=True + - LOG_LEVEL=${LOG_LEVEL:-debug} + - index_model_cache_huggingface:/root/.cache/huggingface/ + logging: + driver: json-file + options: + max-size: "50m" + max-file: "6" + + + relational_db: + image: postgres:15.2-alpine + restart: always + environment: + - POSTGRES_USER=${POSTGRES_USER:-postgres} + - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} + ports: + - "5432" + volumes: + - db_volume:/var/lib/postgresql/data + + + # This container name cannot have an underscore in it due to Vespa expectations of the URL + index: + image: vespaengine/vespa:8.277.17 + restart: always + ports: + - "19071" + - "8081" + volumes: + - vespa_volume:/opt/vespa/var + logging: + driver: json-file + options: + max-size: "50m" + max-file: "6" + + + nginx: + image: nginx:1.23.4-alpine + restart: always + # nginx will immediately crash with `nginx: [emerg] host not found in upstream` + # if api_server / web_server are not up + depends_on: + - api_server + - web_server + environment: + - DOMAIN=localhost + ports: + - "80:80" + - "3000:80" # allow for localhost:3000 usage, since that is the norm + volumes: + - ../data/nginx:/etc/nginx/conf.d + logging: + driver: json-file + options: + max-size: "50m" + max-file: "6" + # The specified script waits for the api_server to start up. + # Without this we've seen issues where nginx shows no error logs but + # does not recieve any traffic + # NOTE: we have to use dos2unix to remove Carriage Return chars from the file + # in order to make this work on both Unix-like systems and windows + command: > + /bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh + && /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev" + + +volumes: + db_volume: + driver: local + driver_opts: + type: none + o: bind + device: ${DANSWER_POSTGRES_DATA_DIR:-./postgres_data} + vespa_volume: + driver: local + driver_opts: + type: none + o: bind + device: ${DANSWER_VESPA_DATA_DIR:-./vespa_data} diff --git a/deployment/docker_compose/env.multilingual.template b/deployment/docker_compose/env.multilingual.template index 031015a2326..e6059c6ae7d 100644 --- a/deployment/docker_compose/env.multilingual.template +++ b/deployment/docker_compose/env.multilingual.template @@ -6,6 +6,9 @@ # Rephrase the user query in specified languages using LLM, use comma separated values MULTILINGUAL_QUERY_EXPANSION="English, French" +# Change the below to suit your specific needs, can be more explicit about the language of the response +LANGUAGE_HINT="IMPORTANT: Respond in the same language as my query!" +LANGUAGE_CHAT_NAMING_HINT="The name of the conversation must be in the same language as the user query." # A recent MIT license multilingual model: https://huggingface.co/intfloat/multilingual-e5-small DOCUMENT_ENCODER_MODEL="intfloat/multilingual-e5-small" diff --git a/deployment/helm/values.yaml b/deployment/helm/values.yaml index e590c16a496..bb41a7511b9 100644 --- a/deployment/helm/values.yaml +++ b/deployment/helm/values.yaml @@ -402,7 +402,6 @@ configMap: GEN_AI_MAX_TOKENS: "" QA_TIMEOUT: "60" MAX_CHUNKS_FED_TO_CHAT: "" - DISABLE_LLM_FILTER_EXTRACTION: "" DISABLE_LLM_CHUNK_FILTER: "" DISABLE_LLM_CHOOSE_SEARCH: "" DISABLE_LLM_QUERY_REPHRASE: "" @@ -411,7 +410,11 @@ configMap: HYBRID_ALPHA: "" EDIT_KEYWORD_QUERY: "" MULTILINGUAL_QUERY_EXPANSION: "" + LANGUAGE_HINT: "" + LANGUAGE_CHAT_NAMING_HINT: "" QA_PROMPT_OVERRIDE: "" + # Internet Search Tool + BING_API_KEY: "" # Don't change the NLP models unless you know what you're doing DOCUMENT_ENCODER_MODEL: "" NORMALIZE_EMBEDDINGS: "" @@ -445,6 +448,7 @@ configMap: DISABLE_TELEMETRY: "" LOG_LEVEL: "" LOG_ALL_MODEL_INTERACTIONS: "" + LOG_DANSWER_MODEL_INTERACTIONS: "" LOG_VESPA_TIMING_INFORMATION: "" # Shared or Non-backend Related WEB_DOMAIN: "http://localhost:3000" # for web server and api server diff --git a/deployment/kubernetes/background-deployment.yaml b/deployment/kubernetes/background-deployment.yaml index 82369e6d3e8..18521b0f5ad 100644 --- a/deployment/kubernetes/background-deployment.yaml +++ b/deployment/kubernetes/background-deployment.yaml @@ -16,7 +16,7 @@ spec: - name: background image: danswer/danswer-backend:latest imagePullPolicy: IfNotPresent - command: ["/usr/bin/supervisord"] + command: ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"] # There are some extra values since this is shared between services # There are no conflicts though, extra env variables are simply ignored envFrom: diff --git a/deployment/kubernetes/env-configmap.yaml b/deployment/kubernetes/env-configmap.yaml index 45cc806d86a..81918c147a6 100644 --- a/deployment/kubernetes/env-configmap.yaml +++ b/deployment/kubernetes/env-configmap.yaml @@ -24,7 +24,6 @@ data: GEN_AI_MAX_TOKENS: "" QA_TIMEOUT: "60" MAX_CHUNKS_FED_TO_CHAT: "" - DISABLE_LLM_FILTER_EXTRACTION: "" DISABLE_LLM_CHUNK_FILTER: "" DISABLE_LLM_CHOOSE_SEARCH: "" DISABLE_LLM_QUERY_REPHRASE: "" @@ -33,10 +32,14 @@ data: HYBRID_ALPHA: "" EDIT_KEYWORD_QUERY: "" MULTILINGUAL_QUERY_EXPANSION: "" + LANGUAGE_HINT: "" + LANGUAGE_CHAT_NAMING_HINT: "" QA_PROMPT_OVERRIDE: "" # Other Services POSTGRES_HOST: "relational-db-service" VESPA_HOST: "document-index-service" + # Internet Search Tool + BING_API_KEY: "" # Don't change the NLP models unless you know what you're doing DOCUMENT_ENCODER_MODEL: "" NORMALIZE_EMBEDDINGS: "" @@ -74,6 +77,7 @@ data: DISABLE_TELEMETRY: "" LOG_LEVEL: "" LOG_ALL_MODEL_INTERACTIONS: "" + LOG_DANSWER_MODEL_INTERACTIONS: "" LOG_VESPA_TIMING_INFORMATION: "" # Shared or Non-backend Related INTERNAL_URL: "http://api-server-service:80" # for web server diff --git a/web/Dockerfile b/web/Dockerfile index 3d27813c7ae..585b1fb756a 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -1,37 +1,29 @@ FROM node:20-alpine AS base LABEL com.danswer.maintainer="founders@danswer.ai" -LABEL com.danswer.description="This image is for the frontend/webserver of Danswer. It is MIT \ -Licensed and free for all to use. You can find it at \ -https://hub.docker.com/r/danswer/danswer-web-server. For more details, visit \ -https://github.com/danswer-ai/danswer." +LABEL com.danswer.description="This image is the web/frontend container of Danswer which \ +contains code for both the Community and Enterprise editions of Danswer. If you do not \ +have a contract or agreement with DanswerAI, you are not permitted to use the Enterprise \ +Edition features outside of personal development or testing purposes. Please reach out to \ +founders@danswer.ai for more information. Please visit https://github.com/danswer-ai/danswer" # Default DANSWER_VERSION, typically overriden during builds by GitHub Actions. ARG DANSWER_VERSION=0.3-dev ENV DANSWER_VERSION=${DANSWER_VERSION} RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}" -# Step 1. Install dependencies only when needed -FROM base AS deps +# Step 1. Install dependencies + rebuild the source code only when needed +FROM base AS builder # Check https://github.com/nodejs/docker-node/tree/b4117f9333da4138b03a546ec926ef50a31506c3#nodealpine to understand why libc6-compat might be needed. RUN apk add --no-cache libc6-compat WORKDIR /app -# Install dependencies based on the preferred package manager -COPY package.json yarn.lock* package-lock.json* pnpm-lock.yaml* ./ -RUN \ - if [ -f yarn.lock ]; then yarn --frozen-lockfile; \ - elif [ -f package-lock.json ]; then npm ci; \ - elif [ -f pnpm-lock.yaml ]; then yarn global add pnpm && pnpm i --frozen-lockfile; \ - else echo "Lockfile not found." && exit 1; \ - fi - -# Step 2. Rebuild the source code only when needed -FROM base AS builder -WORKDIR /app -COPY --from=deps /app/node_modules ./node_modules +# pull in source code / package.json / package-lock.json COPY . . +# Install dependencies +RUN npm ci + # needed to get the `standalone` dir we expect later ENV NEXT_PRIVATE_STANDALONE true @@ -54,12 +46,18 @@ ENV NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PRED ARG NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS ENV NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS} +ARG NEXT_PUBLIC_THEME +ENV NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME} + +ARG NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED +ENV NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED} + ARG NEXT_PUBLIC_DISABLE_LOGOUT ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT} -RUN npm run build +RUN npx next build -# Step 3. Production image, copy all the files and run next +# Step 2. Production image, copy all the files and run next FROM base AS runner WORKDIR /app @@ -102,6 +100,12 @@ ENV NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PRED ARG NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS ENV NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS} +ARG NEXT_PUBLIC_THEME +ENV NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME} + +ARG NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED +ENV NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED} + ARG NEXT_PUBLIC_DISABLE_LOGOUT ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT} diff --git a/web/package-lock.json b/web/package-lock.json index 0e19baab4e9..48ac21d6477 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -38,6 +38,7 @@ "react-icons": "^4.8.0", "react-loader-spinner": "^5.4.5", "react-markdown": "^9.0.1", + "react-select": "^5.8.0", "rehype-prism-plus": "^2.0.0", "remark-gfm": "^4.0.0", "semver": "^7.5.4", @@ -559,6 +560,46 @@ "react": ">=16.8.0" } }, + "node_modules/@emotion/babel-plugin": { + "version": "11.11.0", + "resolved": "https://registry.npmjs.org/@emotion/babel-plugin/-/babel-plugin-11.11.0.tgz", + "integrity": "sha512-m4HEDZleaaCH+XgDDsPF15Ht6wTLsgDTeR3WYj9Q/k76JtWhrJjcP4+/XlG8LGT/Rol9qUfOIztXeA84ATpqPQ==", + "dependencies": { + "@babel/helper-module-imports": "^7.16.7", + "@babel/runtime": "^7.18.3", + "@emotion/hash": "^0.9.1", + "@emotion/memoize": "^0.8.1", + "@emotion/serialize": "^1.1.2", + "babel-plugin-macros": "^3.1.0", + "convert-source-map": "^1.5.0", + "escape-string-regexp": "^4.0.0", + "find-root": "^1.1.0", + "source-map": "^0.5.7", + "stylis": "4.2.0" + } + }, + "node_modules/@emotion/babel-plugin/node_modules/convert-source-map": { + "version": "1.9.0", + "resolved": "https://registry.npmjs.org/convert-source-map/-/convert-source-map-1.9.0.tgz", + "integrity": "sha512-ASFBup0Mz1uyiIjANan1jzLQami9z1PoYSZCiiYW2FczPbenXc45FZdBZLzOT+r6+iciuEModtmCti+hjaAk0A==" + }, + "node_modules/@emotion/cache": { + "version": "11.11.0", + "resolved": "https://registry.npmjs.org/@emotion/cache/-/cache-11.11.0.tgz", + "integrity": "sha512-P34z9ssTCBi3e9EI1ZsWpNHcfY1r09ZO0rZbRO2ob3ZQMnFI35jB536qoXbkdesr5EUhYi22anuEJuyxifaqAQ==", + "dependencies": { + "@emotion/memoize": "^0.8.1", + "@emotion/sheet": "^1.2.2", + "@emotion/utils": "^1.2.1", + "@emotion/weak-memoize": "^0.3.1", + "stylis": "4.2.0" + } + }, + "node_modules/@emotion/hash": { + "version": "0.9.1", + "resolved": "https://registry.npmjs.org/@emotion/hash/-/hash-0.9.1.tgz", + "integrity": "sha512-gJB6HLm5rYwSLI6PQa+X1t5CFGrv1J1TWG+sOyMCeKz2ojaj6Fnl/rZEspogG+cvqbt4AE/2eIyD2QfLKTBNlQ==" + }, "node_modules/@emotion/is-prop-valid": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/@emotion/is-prop-valid/-/is-prop-valid-1.2.2.tgz", @@ -572,6 +613,51 @@ "resolved": "https://registry.npmjs.org/@emotion/memoize/-/memoize-0.8.1.tgz", "integrity": "sha512-W2P2c/VRW1/1tLox0mVUalvnWXxavmv/Oum2aPsRcoDJuob75FC3Y8FbpfLwUegRcxINtGUMPq0tFCvYNTBXNA==" }, + "node_modules/@emotion/react": { + "version": "11.11.4", + "resolved": "https://registry.npmjs.org/@emotion/react/-/react-11.11.4.tgz", + "integrity": "sha512-t8AjMlF0gHpvvxk5mAtCqR4vmxiGHCeJBaQO6gncUSdklELOgtwjerNY2yuJNfwnc6vi16U/+uMF+afIawJ9iw==", + "dependencies": { + "@babel/runtime": "^7.18.3", + "@emotion/babel-plugin": "^11.11.0", + "@emotion/cache": "^11.11.0", + "@emotion/serialize": "^1.1.3", + "@emotion/use-insertion-effect-with-fallbacks": "^1.0.1", + "@emotion/utils": "^1.2.1", + "@emotion/weak-memoize": "^0.3.1", + "hoist-non-react-statics": "^3.3.1" + }, + "peerDependencies": { + "react": ">=16.8.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@emotion/serialize": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/@emotion/serialize/-/serialize-1.1.4.tgz", + "integrity": "sha512-RIN04MBT8g+FnDwgvIUi8czvr1LU1alUMI05LekWB5DGyTm8cCBMCRpq3GqaiyEDRptEXOyXnvZ58GZYu4kBxQ==", + "dependencies": { + "@emotion/hash": "^0.9.1", + "@emotion/memoize": "^0.8.1", + "@emotion/unitless": "^0.8.1", + "@emotion/utils": "^1.2.1", + "csstype": "^3.0.2" + } + }, + "node_modules/@emotion/serialize/node_modules/@emotion/unitless": { + "version": "0.8.1", + "resolved": "https://registry.npmjs.org/@emotion/unitless/-/unitless-0.8.1.tgz", + "integrity": "sha512-KOEGMu6dmJZtpadb476IsZBclKvILjopjUii3V+7MnXIQCYh8W3NgNcgwo21n9LXZX6EDIKvqfjYxXebDwxKmQ==" + }, + "node_modules/@emotion/sheet": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/@emotion/sheet/-/sheet-1.2.2.tgz", + "integrity": "sha512-0QBtGvaqtWi+nx6doRwDdBIzhNdZrXUppvTM4dtZZWEGTXL/XE/yJxLMGlDT1Gt+UHH5IX1n+jkXyytE/av7OA==" + }, "node_modules/@emotion/stylis": { "version": "0.8.5", "resolved": "https://registry.npmjs.org/@emotion/stylis/-/stylis-0.8.5.tgz", @@ -582,6 +668,24 @@ "resolved": "https://registry.npmjs.org/@emotion/unitless/-/unitless-0.7.5.tgz", "integrity": "sha512-OWORNpfjMsSSUBVrRBVGECkhWcULOAJz9ZW8uK9qgxD+87M7jHRcvh/A96XXNhXTLmKcoYSQtBEX7lHMO7YRwg==" }, + "node_modules/@emotion/use-insertion-effect-with-fallbacks": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@emotion/use-insertion-effect-with-fallbacks/-/use-insertion-effect-with-fallbacks-1.0.1.tgz", + "integrity": "sha512-jT/qyKZ9rzLErtrjGgdkMBn2OP8wl0G3sQlBb3YPryvKHsjvINUhVaPFfP+fpBcOkmrVOVEEHQFJ7nbj2TH2gw==", + "peerDependencies": { + "react": ">=16.8.0" + } + }, + "node_modules/@emotion/utils": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/@emotion/utils/-/utils-1.2.1.tgz", + "integrity": "sha512-Y2tGf3I+XVnajdItskUCn6LX+VUDmP6lTL4fcqsXAv43dnlbZiuW4MWQW38rW/BVWSE7Q/7+XQocmpnRYILUmg==" + }, + "node_modules/@emotion/weak-memoize": { + "version": "0.3.1", + "resolved": "https://registry.npmjs.org/@emotion/weak-memoize/-/weak-memoize-0.3.1.tgz", + "integrity": "sha512-EsBwpc7hBUJWAsNPBmJy4hxWx12v6bshQsldrVmjxJoc3isbxhOrF2IcCpaXxfvq03NwkI7sbsOLXbYuqF/8Ww==" + }, "node_modules/@eslint-community/eslint-utils": { "version": "4.4.0", "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.4.0.tgz", @@ -1765,6 +1869,11 @@ "resolved": "https://registry.npmjs.org/@types/node/-/node-18.15.11.tgz", "integrity": "sha512-E5Kwq2n4SbMzQOn6wnmBjuK9ouqlURrcZDVfbo9ftDDTFt3nk7ZKK4GMOzoYgnpQJKcxwQw+lGaBvvlMo0qN/Q==" }, + "node_modules/@types/parse-json": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/@types/parse-json/-/parse-json-4.0.2.tgz", + "integrity": "sha512-dISoDXWWQwUquiKsyZ4Ng+HX2KsPL7LyHKHQwgGFEA3IaKac4Obd+h2a/a6waisAoepJlBcx9paWqjA8/HVjCw==" + }, "node_modules/@types/prismjs": { "version": "1.26.4", "resolved": "https://registry.npmjs.org/@types/prismjs/-/prismjs-1.26.4.tgz", @@ -1793,6 +1902,14 @@ "@types/react": "*" } }, + "node_modules/@types/react-transition-group": { + "version": "4.4.10", + "resolved": "https://registry.npmjs.org/@types/react-transition-group/-/react-transition-group-4.4.10.tgz", + "integrity": "sha512-hT/+s0VQs2ojCX823m60m5f0sL5idt9SO6Tj6Dg+rdphGPIeJbJ6CxvBYkgkGKrYeDjvIpKTR38UzmtHJOGW3Q==", + "dependencies": { + "@types/react": "*" + } + }, "node_modules/@types/scheduler": { "version": "0.23.0", "resolved": "https://registry.npmjs.org/@types/scheduler/-/scheduler-0.23.0.tgz", @@ -2303,6 +2420,20 @@ "resolved": "https://registry.npmjs.org/b4a/-/b4a-1.6.6.tgz", "integrity": "sha512-5Tk1HLk6b6ctmjIkAcU/Ujv/1WqiDl0F0JdRCR80VsOcUlHcu7pWeWRlOqQLHfDEsVx9YH/aif5AG4ehoCtTmg==" }, + "node_modules/babel-plugin-macros": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/babel-plugin-macros/-/babel-plugin-macros-3.1.0.tgz", + "integrity": "sha512-Cg7TFGpIr01vOQNODXOOaGz2NpCU5gl8x1qJFbb6hbZxR7XrcE2vtbAsTAbJ7/xwJtUuJEw8K8Zr/AE0LHlesg==", + "dependencies": { + "@babel/runtime": "^7.12.5", + "cosmiconfig": "^7.0.0", + "resolve": "^1.19.0" + }, + "engines": { + "node": ">=10", + "npm": ">=6" + } + }, "node_modules/babel-plugin-styled-components": { "version": "2.1.4", "resolved": "https://registry.npmjs.org/babel-plugin-styled-components/-/babel-plugin-styled-components-2.1.4.tgz", @@ -2522,7 +2653,6 @@ "version": "3.1.0", "resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz", "integrity": "sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==", - "dev": true, "engines": { "node": ">=6" } @@ -2741,6 +2871,29 @@ "integrity": "sha512-Kvp459HrV2FEJ1CAsi1Ku+MY3kasH19TFykTz2xWmMeq6bk2NU3XXvfJ+Q61m0xktWwt+1HSYf3JZsTms3aRJg==", "peer": true }, + "node_modules/cosmiconfig": { + "version": "7.1.0", + "resolved": "https://registry.npmjs.org/cosmiconfig/-/cosmiconfig-7.1.0.tgz", + "integrity": "sha512-AdmX6xUzdNASswsFtmwSt7Vj8po9IuqXm0UXz7QKPuEUmPB4XyjGfaAr2PSuELMwkRMVH1EpIkX5bTZGRB3eCA==", + "dependencies": { + "@types/parse-json": "^4.0.0", + "import-fresh": "^3.2.1", + "parse-json": "^5.0.0", + "path-type": "^4.0.0", + "yaml": "^1.10.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/cosmiconfig/node_modules/yaml": { + "version": "1.10.2", + "resolved": "https://registry.npmjs.org/yaml/-/yaml-1.10.2.tgz", + "integrity": "sha512-r3vXyErRCYJ7wg28yvBY5VSoAF8ZvlcW9/BwUzEtUsjvX/DKs24dIkuwjtuprwJJHsbyUbLApepYTR1BN4uHrg==", + "engines": { + "node": ">= 6" + } + }, "node_modules/cross-spawn": { "version": "7.0.3", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", @@ -3190,6 +3343,19 @@ "url": "https://github.com/fb55/entities?sponsor=1" } }, + "node_modules/error-ex": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/error-ex/-/error-ex-1.3.2.tgz", + "integrity": "sha512-7dFHNmqeFSEt2ZBsCriorKnn3Z2pj+fd9kmI6QoWw4//DL+icEBfc0U7qJCisqrTsKTjw4fNFy2pW9OqStD84g==", + "dependencies": { + "is-arrayish": "^0.2.1" + } + }, + "node_modules/error-ex/node_modules/is-arrayish": { + "version": "0.2.1", + "resolved": "https://registry.npmjs.org/is-arrayish/-/is-arrayish-0.2.1.tgz", + "integrity": "sha512-zz06S8t0ozoDXMG+ube26zeCTNXcKIPJZJi8hBrF4idCLms4CG9QtK7qBl1boi5ODzFpjswb5JPmHCbMpjaYzg==" + }, "node_modules/es-abstract": { "version": "1.23.3", "resolved": "https://registry.npmjs.org/es-abstract/-/es-abstract-1.23.3.tgz", @@ -3360,7 +3526,6 @@ "version": "4.0.0", "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz", "integrity": "sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA==", - "dev": true, "engines": { "node": ">=10" }, @@ -3906,6 +4071,11 @@ "node": ">=8" } }, + "node_modules/find-root": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/find-root/-/find-root-1.1.0.tgz", + "integrity": "sha512-NKfW6bec6GfKc0SGx1e07QZY9PE99u0Bft/0rzSD5k3sO/vwkVUpDUKVm5Gpp5Ue3YfShPFTX2070tDs5kB9Ng==" + }, "node_modules/find-up": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/find-up/-/find-up-5.0.0.tgz", @@ -4556,7 +4726,6 @@ "version": "3.3.0", "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.0.tgz", "integrity": "sha512-veYYhQa+D1QBKznvhUHxb8faxlrwUnxseDAbAp457E0wLNio2bOSKnjYDhMj+YiAq61xrMGhQk9iXVk5FzgQMw==", - "dev": true, "dependencies": { "parent-module": "^1.0.0", "resolve-from": "^4.0.0" @@ -5140,6 +5309,11 @@ "integrity": "sha512-4bV5BfR2mqfQTJm+V5tPPdf+ZpuhiIvTuAB5g8kcrXOZpTT/QwwVRWBywX1ozr6lEuPdbHxwaJlm9G6mI2sfSQ==", "dev": true }, + "node_modules/json-parse-even-better-errors": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/json-parse-even-better-errors/-/json-parse-even-better-errors-2.3.1.tgz", + "integrity": "sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==" + }, "node_modules/json-schema-traverse": { "version": "0.4.1", "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", @@ -5578,6 +5752,11 @@ "url": "https://opencollective.com/unified" } }, + "node_modules/memoize-one": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/memoize-one/-/memoize-one-6.0.0.tgz", + "integrity": "sha512-rkpe71W0N0c0Xz6QD0eJETuWAJGnJ9afsl1srmwPrI+yBCkge5EycXXbYRyvL29zZVUWQCY7InPRCv3GDXuZNw==" + }, "node_modules/merge2": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", @@ -8918,7 +9097,6 @@ "version": "1.0.1", "resolved": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz", "integrity": "sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==", - "dev": true, "dependencies": { "callsites": "^3.0.0" }, @@ -8950,6 +9128,23 @@ "resolved": "https://registry.npmjs.org/@types/unist/-/unist-2.0.10.tgz", "integrity": "sha512-IfYcSBWE3hLpBg8+X2SEa8LVkJdJEkT2Ese2aaLs3ptGdVtABxndrMaxuFlQ1qdFf9Q5rDvDpxI3WwgvKFAsQA==" }, + "node_modules/parse-json": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/parse-json/-/parse-json-5.2.0.tgz", + "integrity": "sha512-ayCKvm/phCGxOkYRSCM82iDwct8/EonSEgCSxWxD7ve6jHggsFl4fZVQBPRNgQoKiuV/odhFrGzQXZwbifC8Rg==", + "dependencies": { + "@babel/code-frame": "^7.0.0", + "error-ex": "^1.3.1", + "json-parse-even-better-errors": "^2.3.0", + "lines-and-columns": "^1.1.6" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/parse-numeric-range": { "version": "1.3.0", "resolved": "https://registry.npmjs.org/parse-numeric-range/-/parse-numeric-range-1.3.0.tgz", @@ -9016,7 +9211,6 @@ "version": "4.0.0", "resolved": "https://registry.npmjs.org/path-type/-/path-type-4.0.0.tgz", "integrity": "sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==", - "dev": true, "engines": { "node": ">=8" } @@ -9546,6 +9740,26 @@ } } }, + "node_modules/react-select": { + "version": "5.8.0", + "resolved": "https://registry.npmjs.org/react-select/-/react-select-5.8.0.tgz", + "integrity": "sha512-TfjLDo58XrhP6VG5M/Mi56Us0Yt8X7xD6cDybC7yoRMUNm7BGO7qk8J0TLQOua/prb8vUOtsfnXZwfm30HGsAA==", + "dependencies": { + "@babel/runtime": "^7.12.0", + "@emotion/cache": "^11.4.0", + "@emotion/react": "^11.8.1", + "@floating-ui/dom": "^1.0.1", + "@types/react-transition-group": "^4.4.0", + "memoize-one": "^6.0.0", + "prop-types": "^15.6.0", + "react-transition-group": "^4.3.0", + "use-isomorphic-layout-effect": "^1.1.2" + }, + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0 || ^18.0.0", + "react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0" + } + }, "node_modules/react-smooth": { "version": "4.0.1", "resolved": "https://registry.npmjs.org/react-smooth/-/react-smooth-4.0.1.tgz", @@ -9849,7 +10063,6 @@ "version": "4.0.0", "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz", "integrity": "sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==", - "dev": true, "engines": { "node": ">=4" } @@ -10169,6 +10382,14 @@ "node": ">=8" } }, + "node_modules/source-map": { + "version": "0.5.7", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.5.7.tgz", + "integrity": "sha512-LbrmJOMUSdEVxIKvdcJzQC+nQhe8FUZQTXQy6+I75skNgn3OoQ0DZA8YnFa7gp8tqtL3KPf1kmo0R5DoApeSGQ==", + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/source-map-js": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.0.tgz", @@ -10489,6 +10710,11 @@ "resolved": "https://registry.npmjs.org/styled-tools/-/styled-tools-1.7.2.tgz", "integrity": "sha512-IjLxzM20RMwAsx8M1QoRlCG/Kmq8lKzCGyospjtSXt/BTIIcvgTonaxQAsKnBrsZNwhpHzO9ADx5te0h76ILVg==" }, + "node_modules/stylis": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/stylis/-/stylis-4.2.0.tgz", + "integrity": "sha512-Orov6g6BB1sDfYgzWfTHDOxamtX1bE/zo104Dh9e6fqJ3PooipYyfJ0pUmrZO2wAvO8YbEyeFrkV91XTsGMSrw==" + }, "node_modules/sucrase": { "version": "3.35.0", "resolved": "https://registry.npmjs.org/sucrase/-/sucrase-3.35.0.tgz", @@ -11064,6 +11290,19 @@ } } }, + "node_modules/use-isomorphic-layout-effect": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/use-isomorphic-layout-effect/-/use-isomorphic-layout-effect-1.1.2.tgz", + "integrity": "sha512-49L8yCO3iGT/ZF9QttjwLF/ZD9Iwto5LnH5LmEdk/6cFmXddqi2ulF0edxTwjj+7mqvpVVGQWvbXZdn32wRSHA==", + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/use-sidecar": { "version": "1.1.2", "resolved": "https://registry.npmjs.org/use-sidecar/-/use-sidecar-1.1.2.tgz", diff --git a/web/package.json b/web/package.json index 1ba7286d401..190ec9b1aa6 100644 --- a/web/package.json +++ b/web/package.json @@ -39,6 +39,7 @@ "react-icons": "^4.8.0", "react-loader-spinner": "^5.4.5", "react-markdown": "^9.0.1", + "react-select": "^5.8.0", "rehype-prism-plus": "^2.0.0", "remark-gfm": "^4.0.0", "semver": "^7.5.4", diff --git a/web/public/Amazon.webp b/web/public/Amazon.webp new file mode 100644 index 00000000000..cd8574fa59e Binary files /dev/null and b/web/public/Amazon.webp differ diff --git a/web/public/Anthropic.svg b/web/public/Anthropic.svg new file mode 100644 index 00000000000..da352fb54b7 --- /dev/null +++ b/web/public/Anthropic.svg @@ -0,0 +1,8 @@ + + + Anthropic + + + + + diff --git a/web/public/Azure.png b/web/public/Azure.png new file mode 100644 index 00000000000..7c750318da4 Binary files /dev/null and b/web/public/Azure.png differ diff --git a/web/public/GoogleCloudStorage.png b/web/public/GoogleCloudStorage.png new file mode 100644 index 00000000000..4790f2c5244 Binary files /dev/null and b/web/public/GoogleCloudStorage.png differ diff --git a/web/public/OCI.svg b/web/public/OCI.svg new file mode 100644 index 00000000000..75d359b00bd --- /dev/null +++ b/web/public/OCI.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/public/OpenSource.png b/web/public/OpenSource.png new file mode 100644 index 00000000000..9490ffc6d2e Binary files /dev/null and b/web/public/OpenSource.png differ diff --git a/web/public/Openai.svg b/web/public/Openai.svg new file mode 100644 index 00000000000..e04db75a5bb --- /dev/null +++ b/web/public/Openai.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/public/S3.png b/web/public/S3.png new file mode 100644 index 00000000000..75eb1f19b34 Binary files /dev/null and b/web/public/S3.png differ diff --git a/web/src/app/favicon.ico b/web/public/danswer.ico similarity index 100% rename from web/src/app/favicon.ico rename to web/public/danswer.ico diff --git a/web/public/r2.webp b/web/public/r2.webp new file mode 100644 index 00000000000..d79134ce7d8 Binary files /dev/null and b/web/public/r2.webp differ diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index 10b0ad586aa..665af18743e 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -20,12 +20,12 @@ import Link from "next/link"; import { useEffect, useState } from "react"; import { BooleanFormField, + Label, SelectorFormField, TextFormField, } from "@/components/admin/connectors/Field"; -import { HidableSection } from "./HidableSection"; -import { FiPlus, FiX } from "react-icons/fi"; -import { EE_ENABLED } from "@/lib/constants"; +import CollapsibleSection from "./CollapsibleSection"; +import { FiInfo, FiPlus, FiX } from "react-icons/fi"; import { useUserGroups } from "@/lib/hooks"; import { Bubble } from "@/components/Bubble"; import { GroupsIcon } from "@/components/icons/icons"; @@ -37,6 +37,13 @@ import { ToolSnapshot } from "@/lib/tools/interfaces"; import { checkUserIsNoAuthUser } from "@/lib/user"; import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences"; import { checkLLMSupportsImageInput } from "@/lib/llm/utils"; +import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; +import { + TooltipProvider, + Tooltip, + TooltipContent, + TooltipTrigger, +} from "@radix-ui/react-tooltip"; function findSearchTool(tools: ToolSnapshot[]) { return tools.find((tool) => tool.in_code_tool_id === "SearchTool"); @@ -46,10 +53,8 @@ function findImageGenerationTool(tools: ToolSnapshot[]) { return tools.find((tool) => tool.in_code_tool_id === "ImageGenerationTool"); } -function Label({ children }: { children: string | JSX.Element }) { - return ( -
{children}
- ); +function findInternetSearchTool(tools: ToolSnapshot[]) { + return tools.find((tool) => tool.in_code_tool_id === "InternetSearchTool"); } function SubLabel({ children }: { children: string | JSX.Element }) { @@ -80,6 +85,8 @@ export function AssistantEditor({ const router = useRouter(); const { popup, setPopup } = usePopup(); + const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled(); + // EE only const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups(); @@ -147,16 +154,20 @@ export function AssistantEditor({ const imageGenerationTool = providerSupportingImageGenerationExists ? findImageGenerationTool(tools) : undefined; + const internetSearchTool = findInternetSearchTool(tools); + const customTools = tools.filter( (tool) => tool.in_code_tool_id !== searchTool?.in_code_tool_id && - tool.in_code_tool_id !== imageGenerationTool?.in_code_tool_id + tool.in_code_tool_id !== imageGenerationTool?.in_code_tool_id && + tool.in_code_tool_id !== internetSearchTool?.in_code_tool_id ); const availableTools = [ ...customTools, ...(searchTool ? [searchTool] : []), ...(imageGenerationTool ? [imageGenerationTool] : []), + ...(internetSearchTool ? [internetSearchTool] : []), ]; const enabledToolsMap: { [key: number]: boolean } = {}; availableTools.forEach((tool) => { @@ -199,9 +210,11 @@ export function AssistantEditor({ initialValues={initialValues} validationSchema={Yup.object() .shape({ - name: Yup.string().required("Must give the Assistant a name!"), + name: Yup.string().required( + "Must provide a name for the Assistant" + ), description: Yup.string().required( - "Must give the Assistant a description!" + "Must provide a description for the Assistant" ), system_prompt: Yup.string(), task_prompt: Yup.string(), @@ -224,29 +237,29 @@ export function AssistantEditor({ }) .test( "system-prompt-or-task-prompt", - "Must provide at least one of System Prompt or Task Prompt", - (values) => { - const systemPromptSpecified = values.system_prompt - ? values.system_prompt.length > 0 - : false; - const taskPromptSpecified = values.task_prompt - ? values.task_prompt.length > 0 - : false; + "Must provide either System Prompt or Additional Instructions", + function (values) { + const systemPromptSpecified = + values.system_prompt && values.system_prompt.trim().length > 0; + const taskPromptSpecified = + values.task_prompt && values.task_prompt.trim().length > 0; + if (systemPromptSpecified || taskPromptSpecified) { - setFinalPromptError(""); return true; - } // Return true if at least one field has a value + } - setFinalPromptError( - "Must provide at least one of System Prompt or Task Prompt" - ); + return this.createError({ + path: "system_prompt", + message: + "Must provide either System Prompt or Additional Instructions", + }); } )} onSubmit={async (values, formikHelpers) => { if (finalPromptError) { setPopup({ type: "error", - message: "Cannot submit while there are errors in the form!", + message: "Cannot submit while there are errors in the form", }); return; } @@ -389,79 +402,146 @@ export function AssistantEditor({ return (
- - <> - - - - - { - setFieldValue("system_prompt", e.target.value); - triggerFinalPromptUpdate( - e.target.value, - values.task_prompt, - searchToolEnabled() - ); - }} - error={finalPromptError} - /> - - { - setFieldValue("task_prompt", e.target.value); - triggerFinalPromptUpdate( - values.system_prompt, - e.target.value, - searchToolEnabled() - ); - }} - error={finalPromptError} - /> - - + + + { + setFieldValue("system_prompt", e.target.value); + triggerFinalPromptUpdate( + e.target.value, + values.task_prompt, + searchToolEnabled() + ); + }} + error={finalPromptError} + /> + +
+
+
+ LLM Provider{" "} +
+ + + + + + +

+ Select a Large Language Model (Generative AI model) + to power this Assistant +

+
+
+
+
+
+
+ ({ + name: llmProvider.name, + value: llmProvider.name, + icon: llmProvider.icon, + }))} + includeDefault={true} + onSelect={(selected) => { + if (selected !== values.llm_model_provider_override) { + setFieldValue("llm_model_version_override", null); + } + setFieldValue( + "llm_model_provider_override", + selected + ); + }} + /> +
- {finalPrompt ? ( -
-                        {finalPrompt}
-                      
- ) : ( - "-" + {values.llm_model_provider_override && ( +
+ +
)} - - +
+
- +
+
+
+ Capabilities{" "} +
+ + + + + + +

+ You can give your assistant advanced capabilities + like image generation +

+
+
+
+
+ Advanced +
+
+ +
+ {imageGenerationTool && + checkLLMSupportsImageInput( + providerDisplayNameToProviderName.get( + values.llm_model_provider_override || "" + ) || + defaultProviderName || + "", + values.llm_model_version_override || + defaultModelName || + "" + ) && ( + { + toggleToolInValues(imageGenerationTool.id); + }} + /> + )} - - <> {ccPairs.length > 0 && searchTool && ( <> { setFieldValue("num_chunks", null); toggleToolInValues(searchTool.id); @@ -469,156 +549,148 @@ export function AssistantEditor({ /> {searchToolEnabled() && ( -
- {ccPairs.length > 0 && ( - <> - - -
- - <> - Select which{" "} - {!user || user.role === "admin" ? ( - - Document Sets - - ) : ( - "Document Sets" - )}{" "} - that this Assistant should search through. - If none are specified, the Assistant will - search through all available documents in - order to try and respond to queries. - - -
- - {documentSets.length > 0 ? ( - ( -
-
- {documentSets.map((documentSet) => { - const ind = - values.document_set_ids.indexOf( - documentSet.id - ); - let isSelected = ind !== -1; - return ( - { - if (isSelected) { - arrayHelpers.remove(ind); - } else { - arrayHelpers.push( - documentSet.id - ); - } - }} - /> - ); - })} -
-
- )} - /> - ) : ( - - No Document Sets available.{" "} - {user?.role !== "admin" && ( + +
+ {ccPairs.length > 0 && ( + <> + +
+ <> - If this functionality would be useful, - reach out to the administrators of - Danswer for assistance. + Select which{" "} + {!user || user.role === "admin" ? ( + + Document Sets + + ) : ( + "Document Sets" + )}{" "} + that this Assistant should search + through. If none are specified, the + Assistant will search through all + available documents in order to try and + respond to queries. - )} - - )} + +
- <> - - How many chunks should we feed into the - LLM when generating the final response? - Each chunk is ~400 words long. -
- } - onChange={(e) => { - const value = e.target.value; - // Allow only integer values - if ( - value === "" || - /^[0-9]+$/.test(value) - ) { - setFieldValue("num_chunks", value); + {documentSets.length > 0 ? ( + ( +
+
+ {documentSets.map((documentSet) => { + const ind = + values.document_set_ids.indexOf( + documentSet.id + ); + let isSelected = ind !== -1; + return ( + { + if (isSelected) { + arrayHelpers.remove(ind); + } else { + arrayHelpers.push( + documentSet.id + ); + } + }} + /> + ); + })} +
+
+ )} + /> + ) : ( + + No Document Sets available.{" "} + {user?.role !== "admin" && ( + <> + If this functionality would be useful, + reach out to the administrators of + Danswer for assistance. + + )} + + )} + +
+ { + const value = e.target.value; + if ( + value === "" || + /^[0-9]+$/.test(value) + ) { + setFieldValue("num_chunks", value); + } + }} + /> + + - - - - - - + /> + + +
- - )} -
+ )} +
+ )} )} - {imageGenerationTool && - checkLLMSupportsImageInput( - providerDisplayNameToProviderName.get( - values.llm_model_provider_override || "" - ) || - defaultProviderName || - "", - values.llm_model_version_override || - defaultModelName || - "" - ) && ( - { - toggleToolInValues(imageGenerationTool.id); - }} - /> - )} + {internetSearchTool && ( + { + toggleToolInValues(internetSearchTool.id); + }} + /> + )} {customTools.length > 0 && ( <> {customTools.map((tool) => ( )} - - +
- + {llmProviders.length > 0 && ( + <> + - {llmProviders.length > 0 && ( - <> - - <> - - Pick which LLM to use for this Assistant. If left as - Default, will use{" "} - {defaultModelName} - . -
-
- For more information on the different LLMs, checkout - the{" "} - - OpenAI docs - - . -
- -
-
- LLM Provider - ({ - name: llmProvider.name, - value: llmProvider.name, - }))} - includeDefault={true} - onSelect={(selected) => { - if ( - selected !== - values.llm_model_provider_override - ) { - setFieldValue( - "llm_model_version_override", - null - ); - } - setFieldValue( - "llm_model_provider_override", - selected - ); - }} - /> -
- - {values.llm_model_provider_override && ( -
- Model - -
- )} -
- -
- - - - )} - - - <> -
- - Starter Messages help guide users to use this Assistant. - They are shown to the user as clickable options when - they select this Assistant. When selected, the specified - message is sent to the LLM as the initial user message. - + { + setFieldValue("task_prompt", e.target.value); + triggerFinalPromptUpdate( + values.system_prompt, + e.target.value, + searchToolEnabled() + ); + }} + explanationText="Learn about prompting in our docs!" + explanationLink="https://docs.danswer.dev/guides/assistants" + /> + + )} +
+
+
+ Add Starter Messages (Optional){" "} +
-
- + Shows up as the "title" for this Starter Message. For example, @@ -770,7 +777,7 @@ export function AssistantEditor({
- + A description which tells the user what they might want to use this @@ -800,7 +807,7 @@ export function AssistantEditor({
- + The actual message to be sent as the initial user message if a user selects @@ -863,84 +870,82 @@ export function AssistantEditor({
)} /> - - +
- + {isPaidEnterpriseFeaturesEnabled && + userGroups && + (!user || user.role === "admin") && ( + <> + - {EE_ENABLED && - userGroups && - (!user || user.role === "admin") && ( - <> - - <> - + - {userGroups && - userGroups.length > 0 && - !values.is_public && ( -
- - Select which User Groups should have access to - this Assistant. - -
- {userGroups.map((userGroup) => { - const isSelected = values.groups.includes( - userGroup.id - ); - return ( - { - if (isSelected) { - setFieldValue( - "groups", - values.groups.filter( - (id) => id !== userGroup.id - ) - ); - } else { - setFieldValue("groups", [ - ...values.groups, - userGroup.id, - ]); - } - }} - > -
- -
- {userGroup.name} -
+ {userGroups && + userGroups.length > 0 && + !values.is_public && ( +
+ + Select which User Groups should have access to + this Assistant. + +
+ {userGroups.map((userGroup) => { + const isSelected = values.groups.includes( + userGroup.id + ); + return ( + { + if (isSelected) { + setFieldValue( + "groups", + values.groups.filter( + (id) => id !== userGroup.id + ) + ); + } else { + setFieldValue("groups", [ + ...values.groups, + userGroup.id, + ]); + } + }} + > +
+ +
+ {userGroup.name}
- - ); - })} -
+
+ + ); + })}
- )} - - - - - )} +
+ )} + + )} -
- +
+ +
diff --git a/web/src/app/admin/assistants/CollapsibleSection.tsx b/web/src/app/admin/assistants/CollapsibleSection.tsx new file mode 100644 index 00000000000..72b598846e6 --- /dev/null +++ b/web/src/app/admin/assistants/CollapsibleSection.tsx @@ -0,0 +1,55 @@ +"use client"; +import { Button } from "@tremor/react"; +import React, { ReactNode, useState } from "react"; +import { FiSettings } from "react-icons/fi"; + +interface CollapsibleSectionProps { + children: ReactNode; + prompt?: string; + className?: string; +} + +const CollapsibleSection: React.FC = ({ + children, + prompt, + className = "", +}) => { + const [isCollapsed, setIsCollapsed] = useState(false); + + const toggleCollapse = (e?: React.MouseEvent) => { + // Only toggle if the click is on the border or plus sign + if ( + !e || + e.currentTarget === e.target || + (e.target as HTMLElement).classList.contains("collapse-toggle") + ) { + setIsCollapsed(!isCollapsed); + } + }; + + return ( +
+
+ {isCollapsed ? ( + + + {prompt}{" "} + + ) : ( + <>{children} + )} +
+
+ ); +}; + +export default CollapsibleSection; diff --git a/web/src/app/admin/assistants/HidableSection.tsx b/web/src/app/admin/assistants/HidableSection.tsx index 714f2344cf1..3741e514b88 100644 --- a/web/src/app/admin/assistants/HidableSection.tsx +++ b/web/src/app/admin/assistants/HidableSection.tsx @@ -44,7 +44,7 @@ export function HidableSection({
- {!isHidden &&
{children}
} + {!isHidden &&
{children}
}
); } diff --git a/web/src/app/admin/assistants/PersonaTable.tsx b/web/src/app/admin/assistants/PersonaTable.tsx index 87873611559..73fcbeb2b6a 100644 --- a/web/src/app/admin/assistants/PersonaTable.tsx +++ b/web/src/app/admin/assistants/PersonaTable.tsx @@ -9,7 +9,7 @@ import { useState } from "react"; import { UniqueIdentifier } from "@dnd-kit/core"; import { DraggableTable } from "@/components/table/DraggableTable"; import { deletePersona, personaComparator } from "./lib"; -import { FiEdit } from "react-icons/fi"; +import { FiEdit2 } from "react-icons/fi"; import { TrashIcon } from "@/components/icons/icons"; function PersonaTypeDisplay({ persona }: { persona: Persona }) { @@ -89,7 +89,7 @@ export function PersonasTable({ personas }: { personas: Persona[] }) { cells: [
{!persona.default_persona && ( - router.push( diff --git a/web/src/app/admin/assistants/new/page.tsx b/web/src/app/admin/assistants/new/page.tsx index 343d78c7e59..5123dc4f4f0 100644 --- a/web/src/app/admin/assistants/new/page.tsx +++ b/web/src/app/admin/assistants/new/page.tsx @@ -30,12 +30,10 @@ export default async function Page() { return (
- } /> - {body}
); diff --git a/web/src/app/admin/bot/SlackBotConfigCreationForm.tsx b/web/src/app/admin/bot/SlackBotConfigCreationForm.tsx index 6a598412541..23968d585fe 100644 --- a/web/src/app/admin/bot/SlackBotConfigCreationForm.tsx +++ b/web/src/app/admin/bot/SlackBotConfigCreationForm.tsx @@ -3,7 +3,11 @@ import { ArrayHelpers, FieldArray, Form, Formik } from "formik"; import * as Yup from "yup"; import { usePopup } from "@/components/admin/connectors/Popup"; -import { DocumentSet, SlackBotConfig } from "@/lib/types"; +import { + DocumentSet, + SlackBotConfig, + StandardAnswerCategory, +} from "@/lib/types"; import { BooleanFormField, SectionHeader, @@ -31,20 +35,22 @@ import { useRouter } from "next/navigation"; import { Persona } from "../assistants/interfaces"; import { useState } from "react"; import { BookmarkIcon, RobotIcon } from "@/components/icons/icons"; +import MultiSelectDropdown from "@/components/MultiSelectDropdown"; export const SlackBotCreationForm = ({ documentSets, personas, + standardAnswerCategories, existingSlackBotConfig, }: { documentSets: DocumentSet[]; personas: Persona[]; + standardAnswerCategories: StandardAnswerCategory[]; existingSlackBotConfig?: SlackBotConfig; }) => { const isUpdate = existingSlackBotConfig !== undefined; const { popup, setPopup } = usePopup(); const router = useRouter(); - const existingSlackBotUsesPersona = existingSlackBotConfig?.persona ? !isPersonaASlackBotPersona(existingSlackBotConfig.persona) : false; @@ -71,9 +77,15 @@ export const SlackBotCreationForm = ({ existingSlackBotConfig?.channel_config?.respond_tag_only || false, respond_to_bots: existingSlackBotConfig?.channel_config?.respond_to_bots || false, - respond_team_member_list: + enable_auto_filters: + existingSlackBotConfig?.enable_auto_filters || false, + respond_member_group_list: ( existingSlackBotConfig?.channel_config - ?.respond_team_member_list || ([] as string[]), + ?.respond_team_member_list ?? [] + ).concat( + existingSlackBotConfig?.channel_config + ?.respond_slack_group_list ?? [] + ), still_need_help_enabled: existingSlackBotConfig?.channel_config?.follow_up_tags !== undefined, @@ -91,6 +103,9 @@ export const SlackBotCreationForm = ({ ? existingSlackBotConfig.persona.id : null, response_type: existingSlackBotConfig?.response_type || "citations", + standard_answer_categories: existingSlackBotConfig + ? existingSlackBotConfig.standard_answer_categories + : [], }} validationSchema={Yup.object().shape({ channel_names: Yup.array().of(Yup.string()), @@ -101,11 +116,13 @@ export const SlackBotCreationForm = ({ questionmark_prefilter_enabled: Yup.boolean().required(), respond_tag_only: Yup.boolean().required(), respond_to_bots: Yup.boolean().required(), - respond_team_member_list: Yup.array().of(Yup.string()).required(), + enable_auto_filters: Yup.boolean().required(), + respond_member_group_list: Yup.array().of(Yup.string()).required(), still_need_help_enabled: Yup.boolean().required(), follow_up_tags: Yup.array().of(Yup.string()), document_sets: Yup.array().of(Yup.number()), persona_id: Yup.number().nullable(), + standard_answer_categories: Yup.array(), })} onSubmit={async (values, formikHelpers) => { formikHelpers.setSubmitting(true); @@ -116,10 +133,18 @@ export const SlackBotCreationForm = ({ channel_names: values.channel_names.filter( (channelName) => channelName !== "" ), - respond_team_member_list: values.respond_team_member_list.filter( - (teamMemberEmail) => teamMemberEmail !== "" + respond_team_member_list: values.respond_member_group_list.filter( + (teamMemberEmail) => + /^[^\s@]+@[^\s@]+\.[^\s@]+$/.test(teamMemberEmail) + ), + respond_slack_group_list: values.respond_member_group_list.filter( + (slackGroupName) => + !/^[^\s@]+@[^\s@]+\.[^\s@]+$/.test(slackGroupName) ), usePersona: usingPersonas, + standard_answer_categories: values.standard_answer_categories.map( + (category) => category.id + ), }; if (!cleanedValues.still_need_help_enabled) { cleanedValues.follow_up_tags = undefined; @@ -128,7 +153,6 @@ export const SlackBotCreationForm = ({ cleanedValues.follow_up_tags = []; } } - let response; if (isUpdate) { response = await updateSlackBotConfig( @@ -153,7 +177,7 @@ export const SlackBotCreationForm = ({ } }} > - {({ isSubmitting, values }) => ( + {({ isSubmitting, values, setFieldValue }) => (
The Basics @@ -226,15 +250,20 @@ export const SlackBotCreationForm = ({ label="Responds to Bot messages" subtext="If not set, DanswerBot will always ignore messages from Bots" /> + @@ -383,6 +412,45 @@ export const SlackBotCreationForm = ({ +
+ + [Optional] Standard Answer Categories + +
+ { + const selected_categories = selected_options.map( + (option) => { + return { + id: Number(option.value), + name: option.label, + }; + } + ); + setFieldValue( + "standard_answer_categories", + selected_categories + ); + }} + creatable={false} + options={standardAnswerCategories.map((category) => ({ + label: category.name, + value: category.id.toString(), + }))} + initialSelectedOptions={values.standard_answer_categories.map( + (category) => ({ + label: category.name, + value: category.id.toString(), + }) + )} + /> +
+
+ + +
diff --git a/web/src/app/admin/connectors/confluence/page.tsx b/web/src/app/admin/connectors/confluence/page.tsx index 25c69cdbd0a..981d63e9513 100644 --- a/web/src/app/admin/connectors/confluence/page.tsx +++ b/web/src/app/admin/connectors/confluence/page.tsx @@ -2,7 +2,10 @@ import * as Yup from "yup"; import { ConfluenceIcon, TrashIcon } from "@/components/icons/icons"; -import { TextFormField } from "@/components/admin/connectors/Field"; +import { + BooleanFormField, + TextFormField, +} from "@/components/admin/connectors/Field"; import { HealthCheckBanner } from "@/components/health/healthcheck"; import { CredentialForm } from "@/components/admin/connectors/CredentialForm"; import { @@ -207,13 +210,17 @@ const Main = () => {

Specify any link to a Confluence page below and click "Index" to Index. Based on the provided link, we will - index the ENTIRE SPACE, not just the specified page. For example, - entering{" "} + index either the entire page and its subpages OR the entire space. + For example, entering{" "} https://danswer.atlassian.net/wiki/spaces/Engineering/overview {" "} and clicking the Index button will index the whole{" "} - Engineering Confluence space. + Engineering Confluence space, but entering + https://danswer.atlassian.net/wiki/spaces/Engineering/pages/164331/example+page + will index that page's children (and optionally, itself). Use + the checkbox below to determine whether or not to index the parent + page in addition to its children.

{confluenceConnectorIndexingStatuses.length > 0 && ( @@ -274,9 +281,8 @@ const Main = () => { )} - -

Add a New Space

+

Add a New Space or Page

nameBuilder={(values) => `ConfluenceConnector-${values.wiki_page_url}` @@ -289,15 +295,21 @@ const Main = () => { formBody={ <> + } validationSchema={Yup.object().shape({ wiki_page_url: Yup.string().required( - "Please enter any link to your confluence e.g. https://danswer.atlassian.net/wiki/spaces/Engineering/overview" + "Please enter any link to a Confluence space or Page e.g. https://danswer.atlassian.net/wiki/spaces/Engineering/overview" ), + index_origin: Yup.boolean(), })} initialValues={{ wiki_page_url: "", + index_origin: true, }} refreshFreq={10 * 60} // 10 minutes credentialId={confluenceCredential.id} diff --git a/web/src/app/admin/connectors/file/page.tsx b/web/src/app/admin/connectors/file/page.tsx index b7e93e01c85..b2857f66c61 100644 --- a/web/src/app/admin/connectors/file/page.tsx +++ b/web/src/app/admin/connectors/file/page.tsx @@ -16,17 +16,24 @@ import { Spinner } from "@/components/Spinner"; import { SingleUseConnectorsTable } from "@/components/admin/connectors/table/SingleUseConnectorsTable"; import { LoadingAnimation } from "@/components/Loading"; import { Form, Formik } from "formik"; -import { TextFormField } from "@/components/admin/connectors/Field"; +import { + BooleanFormField, + TextFormField, +} from "@/components/admin/connectors/Field"; import { FileUpload } from "@/components/admin/connectors/FileUpload"; import { getNameFromPath } from "@/lib/fileUtils"; import { Button, Card, Divider, Text } from "@tremor/react"; import { AdminPageTitle } from "@/components/admin/Title"; +import IsPublicField from "@/components/admin/connectors/IsPublicField"; +import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; const Main = () => { const [selectedFiles, setSelectedFiles] = useState([]); const [filesAreUploading, setFilesAreUploading] = useState(false); const { popup, setPopup } = usePopup(); + const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled(); + const { mutate } = useSWRConfig(); const { @@ -85,11 +92,15 @@ const Main = () => { initialValues={{ name: "", selectedFiles: [], + is_public: isPaidEnterpriseFeaturesEnabled ? false : undefined, }} validationSchema={Yup.object().shape({ name: Yup.string().required( "Please enter a descriptive name for the files" ), + ...(isPaidEnterpriseFeaturesEnabled && { + is_public: Yup.boolean().required(), + }), })} onSubmit={async (values, formikHelpers) => { const uploadCreateAndTriggerConnector = async () => { @@ -156,7 +167,8 @@ const Main = () => { const credentialResponse = await linkCredential( connector.id, credentialId, - values.name + values.name, + values.is_public ); if (!credentialResponse.ok) { const credentialResponseJson = @@ -215,6 +227,15 @@ const Main = () => { selectedFiles={selectedFiles} setSelectedFiles={setSelectedFiles} /> + + {isPaidEnterpriseFeaturesEnabled && ( + <> + + + + + )} +
+
+ + ) : ( + <> + +
    +
  • + Provide your GCS Project ID, Client Email, and Private Key for + authentication. +
  • +
  • + These credentials will be used to access your GCS buckets. +
  • +
+
+ + + formBody={ + <> + + + + + } + validationSchema={Yup.object().shape({ + secret_access_key: Yup.string().required( + "Client Email is required" + ), + access_key_id: Yup.string().required("Private Key is required"), + })} + initialValues={{ + secret_access_key: "", + access_key_id: "", + }} + onSubmit={(isSuccess) => { + if (isSuccess) { + refreshCredentials(); + } + }} + /> + + + )} + + + Step 2: Which GCS bucket do you want to make searchable? + + + {gcsConnectorIndexingStatuses.length > 0 && ( + <> + + GCS indexing status + + + The latest changes are fetched every 10 minutes. + +
+ + includeName={true} + connectorIndexingStatuses={gcsConnectorIndexingStatuses} + liveCredential={gcsCredential} + getCredential={(credential) => { + return
; + }} + onCredentialLink={async (connectorId) => { + if (gcsCredential) { + await linkCredential(connectorId, gcsCredential.id); + mutate("/api/manage/admin/connector/indexing-status"); + } + }} + onUpdate={() => + mutate("/api/manage/admin/connector/indexing-status") + } + /> +
+ + )} + + {gcsCredential && ( + <> + +

Create Connection

+ + Press connect below to start the connection to your GCS bucket. + + + nameBuilder={(values) => `GCSConnector-${values.bucket_name}`} + ccPairNameBuilder={(values) => + `GCSConnector-${values.bucket_name}` + } + source="google_cloud_storage" + inputType="poll" + formBodyBuilder={(values) => ( +
+ + +
+ )} + validationSchema={Yup.object().shape({ + bucket_type: Yup.string() + .oneOf(["google_cloud_storage"]) + .required("Bucket type must be google_cloud_storage"), + bucket_name: Yup.string().required( + "Please enter the name of the GCS bucket to index, e.g. my-gcs-bucket" + ), + prefix: Yup.string().default(""), + })} + initialValues={{ + bucket_type: "google_cloud_storage", + bucket_name: "", + prefix: "", + }} + refreshFreq={60 * 60 * 24} // 1 day + credentialId={gcsCredential.id} + /> +
+ + )} + + ); +}; + +export default function Page() { + return ( +
+
+ +
+ } + title="Google Cloud Storage" + /> + +
+ ); +} diff --git a/web/src/app/admin/connectors/oracle-storage/page.tsx b/web/src/app/admin/connectors/oracle-storage/page.tsx new file mode 100644 index 00000000000..34847a4b9a1 --- /dev/null +++ b/web/src/app/admin/connectors/oracle-storage/page.tsx @@ -0,0 +1,272 @@ +"use client"; + +import { AdminPageTitle } from "@/components/admin/Title"; +import { HealthCheckBanner } from "@/components/health/healthcheck"; +import { OCIStorageIcon, TrashIcon } from "@/components/icons/icons"; +import { LoadingAnimation } from "@/components/Loading"; +import { ConnectorForm } from "@/components/admin/connectors/ConnectorForm"; +import { CredentialForm } from "@/components/admin/connectors/CredentialForm"; +import { TextFormField } from "@/components/admin/connectors/Field"; +import { usePopup } from "@/components/admin/connectors/Popup"; +import { ConnectorsTable } from "@/components/admin/connectors/table/ConnectorsTable"; +import { adminDeleteCredential, linkCredential } from "@/lib/credential"; +import { errorHandlingFetcher } from "@/lib/fetcher"; +import { ErrorCallout } from "@/components/ErrorCallout"; +import { usePublicCredentials } from "@/lib/hooks"; + +import { + ConnectorIndexingStatus, + Credential, + OCIConfig, + OCICredentialJson, + R2Config, + R2CredentialJson, +} from "@/lib/types"; +import { Card, Select, SelectItem, Text, Title } from "@tremor/react"; +import useSWR, { useSWRConfig } from "swr"; +import * as Yup from "yup"; +import { useState } from "react"; + +const OCIMain = () => { + const { popup, setPopup } = usePopup(); + + const { mutate } = useSWRConfig(); + const { + data: connectorIndexingStatuses, + isLoading: isConnectorIndexingStatusesLoading, + error: connectorIndexingStatusesError, + } = useSWR[]>( + "/api/manage/admin/connector/indexing-status", + errorHandlingFetcher + ); + const { + data: credentialsData, + isLoading: isCredentialsLoading, + error: credentialsError, + refreshCredentials, + } = usePublicCredentials(); + + if ( + (!connectorIndexingStatuses && isConnectorIndexingStatusesLoading) || + (!credentialsData && isCredentialsLoading) + ) { + return ; + } + + if (connectorIndexingStatusesError || !connectorIndexingStatuses) { + return ( + + ); + } + + if (credentialsError || !credentialsData) { + return ( + + ); + } + + const ociConnectorIndexingStatuses: ConnectorIndexingStatus< + OCIConfig, + OCICredentialJson + >[] = connectorIndexingStatuses.filter( + (connectorIndexingStatus) => + connectorIndexingStatus.connector.source === "oci_storage" + ); + + const ociCredential: Credential | undefined = + credentialsData.find((credential) => credential.credential_json?.namespace); + + return ( + <> + {popup} + + Step 1: Provide your access info + + {ociCredential ? ( + <> + {" "} +
+

Existing OCI Access Key ID:

+

+ {ociCredential.credential_json.access_key_id} +

+ {", "} +

Namespace:

+

+ {ociCredential.credential_json.namespace} +

{" "} + +
+ + ) : ( + <> + +
    +
  • + Provide your OCI Access Key ID, Secret Access Key, Namespace, + and Region for authentication. +
  • +
  • + These credentials will be used to access your OCI buckets. +
  • +
+
+ + + formBody={ + <> + + + + + + } + validationSchema={Yup.object().shape({ + access_key_id: Yup.string().required( + "OCI Access Key ID is required" + ), + secret_access_key: Yup.string().required( + "OCI Secret Access Key is required" + ), + namespace: Yup.string().required("Namespace is required"), + region: Yup.string().required("Region is required"), + })} + initialValues={{ + access_key_id: "", + secret_access_key: "", + namespace: "", + region: "", + }} + onSubmit={(isSuccess) => { + if (isSuccess) { + refreshCredentials(); + } + }} + /> + + + )} + + + Step 2: Which OCI bucket do you want to make searchable? + + + {ociConnectorIndexingStatuses.length > 0 && ( + <> + + OCI indexing status + + + The latest changes are fetched every 10 minutes. + +
+ + includeName={true} + connectorIndexingStatuses={ociConnectorIndexingStatuses} + liveCredential={ociCredential} + getCredential={(credential) => { + return
; + }} + onCredentialLink={async (connectorId) => { + if (ociCredential) { + await linkCredential(connectorId, ociCredential.id); + mutate("/api/manage/admin/connector/indexing-status"); + } + }} + onUpdate={() => + mutate("/api/manage/admin/connector/indexing-status") + } + /> +
+ + )} + + {ociCredential && ( + <> + +

Create Connection

+ + Press connect below to start the connection to your OCI bucket. + + + nameBuilder={(values) => `OCIConnector-${values.bucket_name}`} + ccPairNameBuilder={(values) => + `OCIConnector-${values.bucket_name}` + } + source="oci_storage" + inputType="poll" + formBodyBuilder={(values) => ( +
+ + +
+ )} + validationSchema={Yup.object().shape({ + bucket_type: Yup.string() + .oneOf(["oci_storage"]) + .required("Bucket type must be oci_storage"), + bucket_name: Yup.string().required( + "Please enter the name of the OCI bucket to index, e.g. my-test-bucket" + ), + prefix: Yup.string().default(""), + })} + initialValues={{ + bucket_type: "oci_storage", + bucket_name: "", + prefix: "", + }} + refreshFreq={60 * 60 * 24} // 1 day + credentialId={ociCredential.id} + /> +
+ + )} + + ); +}; + +export default function Page() { + return ( +
+
+ +
+ } + title="Oracle Cloud Infrastructure" + /> + +
+ ); +} diff --git a/web/src/app/admin/connectors/r2/page.tsx b/web/src/app/admin/connectors/r2/page.tsx new file mode 100644 index 00000000000..372660acc4f --- /dev/null +++ b/web/src/app/admin/connectors/r2/page.tsx @@ -0,0 +1,265 @@ +"use client"; + +import { AdminPageTitle } from "@/components/admin/Title"; +import { HealthCheckBanner } from "@/components/health/healthcheck"; +import { R2Icon, S3Icon, TrashIcon } from "@/components/icons/icons"; +import { LoadingAnimation } from "@/components/Loading"; +import { ConnectorForm } from "@/components/admin/connectors/ConnectorForm"; +import { CredentialForm } from "@/components/admin/connectors/CredentialForm"; +import { TextFormField } from "@/components/admin/connectors/Field"; +import { usePopup } from "@/components/admin/connectors/Popup"; +import { ConnectorsTable } from "@/components/admin/connectors/table/ConnectorsTable"; +import { adminDeleteCredential, linkCredential } from "@/lib/credential"; +import { errorHandlingFetcher } from "@/lib/fetcher"; +import { ErrorCallout } from "@/components/ErrorCallout"; +import { usePublicCredentials } from "@/lib/hooks"; +import { + ConnectorIndexingStatus, + Credential, + R2Config, + R2CredentialJson, +} from "@/lib/types"; +import { Card, Select, SelectItem, Text, Title } from "@tremor/react"; +import useSWR, { useSWRConfig } from "swr"; +import * as Yup from "yup"; +import { useState } from "react"; + +const R2Main = () => { + const { popup, setPopup } = usePopup(); + + const { mutate } = useSWRConfig(); + const { + data: connectorIndexingStatuses, + isLoading: isConnectorIndexingStatusesLoading, + error: connectorIndexingStatusesError, + } = useSWR[]>( + "/api/manage/admin/connector/indexing-status", + errorHandlingFetcher + ); + const { + data: credentialsData, + isLoading: isCredentialsLoading, + error: credentialsError, + refreshCredentials, + } = usePublicCredentials(); + + if ( + (!connectorIndexingStatuses && isConnectorIndexingStatusesLoading) || + (!credentialsData && isCredentialsLoading) + ) { + return ; + } + + if (connectorIndexingStatusesError || !connectorIndexingStatuses) { + return ( + + ); + } + + if (credentialsError || !credentialsData) { + return ( + + ); + } + + const r2ConnectorIndexingStatuses: ConnectorIndexingStatus< + R2Config, + R2CredentialJson + >[] = connectorIndexingStatuses.filter( + (connectorIndexingStatus) => + connectorIndexingStatus.connector.source === "r2" + ); + + const r2Credential: Credential | undefined = + credentialsData.find( + (credential) => credential.credential_json?.account_id + ); + + return ( + <> + {popup} + + Step 1: Provide your access info + + {r2Credential ? ( + <> + {" "} +
+

Existing R2 Access Key ID:

+

+ {r2Credential.credential_json.r2_access_key_id} +

+ {", "} +

Account ID:

+

+ {r2Credential.credential_json.account_id} +

{" "} + +
+ + ) : ( + <> + +
    +
  • + Provide your R2 Access Key ID, Secret Access Key, and Account ID + for authentication. +
  • +
  • These credentials will be used to access your R2 buckets.
  • +
+
+ + + formBody={ + <> + + + + + } + validationSchema={Yup.object().shape({ + r2_access_key_id: Yup.string().required( + "R2 Access Key ID is required" + ), + r2_secret_access_key: Yup.string().required( + "R2 Secret Access Key is required" + ), + account_id: Yup.string().required("Account ID is required"), + })} + initialValues={{ + r2_access_key_id: "", + r2_secret_access_key: "", + account_id: "", + }} + onSubmit={(isSuccess) => { + if (isSuccess) { + refreshCredentials(); + } + }} + /> + + + )} + + + Step 2: Which R2 bucket do you want to make searchable? + + + {r2ConnectorIndexingStatuses.length > 0 && ( + <> + + R2 indexing status + + + The latest changes are fetched every 10 minutes. + +
+ + includeName={true} + connectorIndexingStatuses={r2ConnectorIndexingStatuses} + liveCredential={r2Credential} + getCredential={(credential) => { + return
; + }} + onCredentialLink={async (connectorId) => { + if (r2Credential) { + await linkCredential(connectorId, r2Credential.id); + mutate("/api/manage/admin/connector/indexing-status"); + } + }} + onUpdate={() => + mutate("/api/manage/admin/connector/indexing-status") + } + /> +
+ + )} + + {r2Credential && ( + <> + +

Create Connection

+ + Press connect below to start the connection to your R2 bucket. + + + nameBuilder={(values) => `R2Connector-${values.bucket_name}`} + ccPairNameBuilder={(values) => + `R2Connector-${values.bucket_name}` + } + source="r2" + inputType="poll" + formBodyBuilder={(values) => ( +
+ + +
+ )} + validationSchema={Yup.object().shape({ + bucket_type: Yup.string() + .oneOf(["r2"]) + .required("Bucket type must be r2"), + bucket_name: Yup.string().required( + "Please enter the name of the r2 bucket to index, e.g. my-test-bucket" + ), + prefix: Yup.string().default(""), + })} + initialValues={{ + bucket_type: "r2", + bucket_name: "", + prefix: "", + }} + refreshFreq={60 * 60 * 24} // 1 day + credentialId={r2Credential.id} + /> +
+ + )} + + ); +}; + +export default function Page() { + const [selectedStorage, setSelectedStorage] = useState("s3"); + + return ( +
+
+ +
+ } title="R2 Storage" /> + +
+ ); +} diff --git a/web/src/app/admin/connectors/s3/page.tsx b/web/src/app/admin/connectors/s3/page.tsx new file mode 100644 index 00000000000..81064a70bf1 --- /dev/null +++ b/web/src/app/admin/connectors/s3/page.tsx @@ -0,0 +1,258 @@ +"use client"; + +import { AdminPageTitle } from "@/components/admin/Title"; +import { HealthCheckBanner } from "@/components/health/healthcheck"; +import { S3Icon, TrashIcon } from "@/components/icons/icons"; +import { LoadingAnimation } from "@/components/Loading"; +import { ConnectorForm } from "@/components/admin/connectors/ConnectorForm"; +import { CredentialForm } from "@/components/admin/connectors/CredentialForm"; +import { TextFormField } from "@/components/admin/connectors/Field"; +import { usePopup } from "@/components/admin/connectors/Popup"; +import { ConnectorsTable } from "@/components/admin/connectors/table/ConnectorsTable"; +import { adminDeleteCredential, linkCredential } from "@/lib/credential"; +import { errorHandlingFetcher } from "@/lib/fetcher"; +import { ErrorCallout } from "@/components/ErrorCallout"; +import { usePublicCredentials } from "@/lib/hooks"; +import { + ConnectorIndexingStatus, + Credential, + S3Config, + S3CredentialJson, +} from "@/lib/types"; +import { Card, Text, Title } from "@tremor/react"; +import useSWR, { useSWRConfig } from "swr"; +import * as Yup from "yup"; +import { useState } from "react"; + +const S3Main = () => { + const { popup, setPopup } = usePopup(); + + const { mutate } = useSWRConfig(); + const { + data: connectorIndexingStatuses, + isLoading: isConnectorIndexingStatusesLoading, + error: connectorIndexingStatusesError, + } = useSWR[]>( + "/api/manage/admin/connector/indexing-status", + errorHandlingFetcher + ); + const { + data: credentialsData, + isLoading: isCredentialsLoading, + error: credentialsError, + refreshCredentials, + } = usePublicCredentials(); + + if ( + (!connectorIndexingStatuses && isConnectorIndexingStatusesLoading) || + (!credentialsData && isCredentialsLoading) + ) { + return ; + } + + if (connectorIndexingStatusesError || !connectorIndexingStatuses) { + return ( + + ); + } + + if (credentialsError || !credentialsData) { + return ( + + ); + } + + const s3ConnectorIndexingStatuses: ConnectorIndexingStatus< + S3Config, + S3CredentialJson + >[] = connectorIndexingStatuses.filter( + (connectorIndexingStatus) => + connectorIndexingStatus.connector.source === "s3" + ); + + const s3Credential: Credential | undefined = + credentialsData.find( + (credential) => credential.credential_json?.aws_access_key_id + ); + + return ( + <> + {popup} + + Step 1: Provide your access info + + {s3Credential ? ( + <> + {" "} +
+

Existing AWS Access Key ID:

+

+ {s3Credential.credential_json.aws_access_key_id} +

+ +
+ + ) : ( + <> + +
    +
  • + If AWS Access Key ID and AWS Secret Access Key are provided, + they will be used for authenticating the connector. +
  • +
  • Otherwise, the Profile Name will be used (if provided).
  • +
  • + If no credentials are provided, then the connector will try to + authenticate with any default AWS credentials available. +
  • +
+
+ + + formBody={ + <> + + + + } + validationSchema={Yup.object().shape({ + aws_access_key_id: Yup.string().default(""), + aws_secret_access_key: Yup.string().default(""), + })} + initialValues={{ + aws_access_key_id: "", + aws_secret_access_key: "", + }} + onSubmit={(isSuccess) => { + if (isSuccess) { + refreshCredentials(); + } + }} + /> + + + )} + + + Step 2: Which S3 bucket do you want to make searchable? + + + {s3ConnectorIndexingStatuses.length > 0 && ( + <> + + S3 indexing status + + + The latest changes are fetched every 10 minutes. + +
+ + includeName={true} + connectorIndexingStatuses={s3ConnectorIndexingStatuses} + liveCredential={s3Credential} + getCredential={(credential) => { + return
; + }} + onCredentialLink={async (connectorId) => { + if (s3Credential) { + await linkCredential(connectorId, s3Credential.id); + mutate("/api/manage/admin/connector/indexing-status"); + } + }} + onUpdate={() => + mutate("/api/manage/admin/connector/indexing-status") + } + /> +
+ + )} + + {s3Credential && ( + <> + +

Create Connection

+ + Press connect below to start the connection to your S3 bucket. + + + nameBuilder={(values) => `S3Connector-${values.bucket_name}`} + ccPairNameBuilder={(values) => + `S3Connector-${values.bucket_name}` + } + source="s3" + inputType="poll" + formBodyBuilder={(values) => ( +
+ + +
+ )} + validationSchema={Yup.object().shape({ + bucket_type: Yup.string() + .oneOf(["s3"]) + .required("Bucket type must be s3"), + bucket_name: Yup.string().required( + "Please enter the name of the s3 bucket to index, e.g. my-test-bucket" + ), + prefix: Yup.string().default(""), + })} + initialValues={{ + bucket_type: "s3", + bucket_name: "", + prefix: "", + }} + refreshFreq={60 * 60 * 24} // 1 day + credentialId={s3Credential.id} + /> +
+ + )} + + ); +}; + +export default function Page() { + const [selectedStorage, setSelectedStorage] = useState("s3"); + + return ( +
+
+ +
+ } title="S3 Storage" /> + + +
+ ); +} diff --git a/web/src/app/admin/documents/explorer/Explorer.tsx b/web/src/app/admin/documents/explorer/Explorer.tsx index 77e924c15c0..315df323e8b 100644 --- a/web/src/app/admin/documents/explorer/Explorer.tsx +++ b/web/src/app/admin/documents/explorer/Explorer.tsx @@ -160,7 +160,7 @@ export function Explorer({
{popup}
-
+