diff --git a/.github/workflows/docker-build-push-backend-container-on-tag.yml b/.github/workflows/docker-build-push-backend-container-on-tag.yml index e95c143fb49..a7d46a09736 100644 --- a/.github/workflows/docker-build-push-backend-container-on-tag.yml +++ b/.github/workflows/docker-build-push-backend-container-on-tag.yml @@ -38,5 +38,7 @@ jobs: - name: Run Trivy vulnerability scanner uses: aquasecurity/trivy-action@master with: + # To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend image-ref: docker.io/danswer/danswer-backend:${{ github.ref_name }} severity: 'CRITICAL,HIGH' + trivyignores: ./backend/.trivyignore diff --git a/.github/workflows/pr-python-checks.yml b/.github/workflows/pr-python-checks.yml index 792fe4d46b3..6c604e93d43 100644 --- a/.github/workflows/pr-python-checks.yml +++ b/.github/workflows/pr-python-checks.yml @@ -20,10 +20,12 @@ jobs: cache-dependency-path: | backend/requirements/default.txt backend/requirements/dev.txt + backend/requirements/model_server.txt - run: | python -m pip install --upgrade pip pip install -r backend/requirements/default.txt pip install -r backend/requirements/dev.txt + pip install -r backend/requirements/model_server.txt - name: Run MyPy run: | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4d88752da2b..7e80baeb2d7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -85,6 +85,7 @@ Install the required python dependencies: ```bash pip install -r danswer/backend/requirements/default.txt pip install -r danswer/backend/requirements/dev.txt +pip install -r danswer/backend/requirements/model_server.txt ``` Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend. @@ -112,26 +113,24 @@ docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational (index refers to Vespa and relational_db refers to Postgres) #### Running Danswer - -Setup a folder to store config. Navigate to `danswer/backend` and run: -```bash -mkdir dynamic_config_storage -``` - To start the frontend, navigate to `danswer/web` and run: ```bash npm run dev ``` -Package the Vespa schema. This will only need to be done when the Vespa schema is updated locally. - -Navigate to `danswer/backend/danswer/document_index/vespa/app_config` and run: +Next, start the model server which runs the local NLP models. +Navigate to `danswer/backend` and run: ```bash -zip -r ../vespa-app.zip . +uvicorn model_server.main:app --reload --port 9000 +``` +_For Windows (for compatibility with both PowerShell and Command Prompt):_ +```bash +powershell -Command " + uvicorn model_server.main:app --reload --port 9000 +" ``` -- Note: If you don't have the `zip` utility, you will need to install it prior to running the above -The first time running Danswer, you will also need to run the DB migrations for Postgres. +The first time running Danswer, you will need to run the DB migrations for Postgres. After the first time, this is no longer required unless the DB models change. Navigate to `danswer/backend` and with the venv active, run: @@ -149,17 +148,12 @@ python ./scripts/dev_run_background_jobs.py To run the backend API server, navigate back to `danswer/backend` and run: ```bash -AUTH_TYPE=disabled \ -DYNAMIC_CONFIG_DIR_PATH=./dynamic_config_storage \ -VESPA_DEPLOYMENT_ZIP=./danswer/document_index/vespa/vespa-app.zip \ -uvicorn danswer.main:app --reload --port 8080 +AUTH_TYPE=disabled uvicorn danswer.main:app --reload --port 8080 ``` _For Windows (for compatibility with both PowerShell and Command Prompt):_ ```bash powershell -Command " $env:AUTH_TYPE='disabled' - $env:DYNAMIC_CONFIG_DIR_PATH='./dynamic_config_storage' - $env:VESPA_DEPLOYMENT_ZIP='./danswer/document_index/vespa/vespa-app.zip' uvicorn danswer.main:app --reload --port 8080 " ``` @@ -178,20 +172,16 @@ pre-commit install Additionally, we use `mypy` for static type checking. Danswer is fully type-annotated, and we would like to keep it that way! -Right now, there is no automated type checking at the moment (coming soon), but we ask you to manually run it before -creating a pull requests with `python -m mypy .` from the `danswer/backend` directory. +To run the mypy checks manually, run `python -m mypy .` from the `danswer/backend` directory. #### Web We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `danswer/web` directory. To run the formatter, use `npx prettier --write .` from the `danswer/web` directory. -Like `mypy`, we have no automated formatting yet (coming soon), but we request that, for now, -you run this manually before creating a pull request. +Please double check that prettier passes before creating a pull request. ### Release Process Danswer follows the semver versioning standard. A set of Docker containers will be pushed automatically to DockerHub with every tag. You can see the containers [here](https://hub.docker.com/search?q=danswer%2F). - -As pre-1.0 software, even patch releases may contain breaking or non-backwards-compatible changes. diff --git a/README.md b/README.md index 3e70e7259c7..edd8328c31e 100644 --- a/README.md +++ b/README.md @@ -22,11 +22,12 @@
-[Danswer](https://www.danswer.ai/) is the ChatGPT for teams. Danswer provides a Chat interface and plugs into any LLM of -your choice. Danswer can be deployed anywhere and for any scale - on a laptop, on-premise, or to cloud. Since you own -the deployment, your user data and chats are fully in your own control. Danswer is MIT licensed and designed to be -modular and easily extensible. The system also comes fully ready for production usage with user authentication, role -management (admin/basic users), chat persistence, and a UI for configuring Personas (AI Assistants) and their Prompts. +[Danswer](https://www.danswer.ai/) is the AI Assistant connected to your company's docs, apps, and people. +Danswer provides a Chat interface and plugs into any LLM of your choice. Danswer can be deployed anywhere and for any +scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your +own control. Danswer is MIT licensed and designed to be modular and easily extensible. The system also comes fully ready +for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for +configuring Personas (AI Assistants) and their Prompts. Danswer also serves as a Unified Search across all common workplace tools such as Slack, Google Drive, Confluence, etc. By combining LLMs and team specific knowledge, Danswer becomes a subject matter expert for the team. Imagine ChatGPT if diff --git a/backend/.trivyignore b/backend/.trivyignore new file mode 100644 index 00000000000..e8351b40741 --- /dev/null +++ b/backend/.trivyignore @@ -0,0 +1,46 @@ +# https://github.com/madler/zlib/issues/868 +# Pulled in with base Debian image, it's part of the contrib folder but unused +# zlib1g is fine +# Will be gone with Debian image upgrade +# No impact in our settings +CVE-2023-45853 + +# krb5 related, worst case is denial of service by resource exhaustion +# Accept the risk +CVE-2024-26458 +CVE-2024-26461 +CVE-2024-26462 +CVE-2024-26458 +CVE-2024-26461 +CVE-2024-26462 +CVE-2024-26458 +CVE-2024-26461 +CVE-2024-26462 +CVE-2024-26458 +CVE-2024-26461 +CVE-2024-26462 + +# Specific to Firefox which we do not use +# No impact in our settings +CVE-2024-0743 + +# bind9 related, worst case is denial of service by CPU resource exhaustion +# Accept the risk +CVE-2023-50387 +CVE-2023-50868 +CVE-2023-50387 +CVE-2023-50868 + +# libexpat1, XML parsing resource exhaustion +# We don't parse any user provided XMLs +# No impact in our settings +CVE-2023-52425 +CVE-2024-28757 + +# sqlite, only used by NLTK library to grab word lemmatizer and stopwords +# No impact in our settings +CVE-2023-7104 + +# libharfbuzz0b, O(n^2) growth, worst case is denial of service +# Accept the risk +CVE-2023-25193 diff --git a/backend/Dockerfile b/backend/Dockerfile index a9bc852a5a2..a61864fa2c0 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -1,5 +1,10 @@ FROM python:3.11.7-slim-bookworm +LABEL com.danswer.maintainer="founders@danswer.ai" +LABEL com.danswer.description="This image is for the backend of Danswer. It is MIT Licensed and \ +free for all to use. You can find it at https://hub.docker.com/r/danswer/danswer-backend. For \ +more details, visit https://github.com/danswer-ai/danswer." + # Default DANSWER_VERSION, typically overriden during builds by GitHub Actions. ARG DANSWER_VERSION=0.3-dev ENV DANSWER_VERSION=${DANSWER_VERSION} @@ -12,7 +17,9 @@ RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}" # zip for Vespa step futher down # ca-certificates for HTTPS RUN apt-get update && \ - apt-get install -y cmake curl zip ca-certificates libgnutls30=3.7.9-2+deb12u2 && \ + apt-get install -y cmake curl zip ca-certificates libgnutls30=3.7.9-2+deb12u2 \ + libblkid1=2.38.1-5+deb12u1 libmount1=2.38.1-5+deb12u1 libsmartcols1=2.38.1-5+deb12u1 \ + libuuid1=2.38.1-5+deb12u1 && \ rm -rf /var/lib/apt/lists/* && \ apt-get clean @@ -29,7 +36,8 @@ RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \ # xserver-common and xvfb included by playwright installation but not needed after # perl-base is part of the base Python Debian image but not needed for Danswer functionality # perl-base could only be removed with --allow-remove-essential -RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake libldap-2.5-0 libldap-2.5-0 && \ +RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake \ + libldap-2.5-0 libldap-2.5-0 && \ apt-get autoremove -y && \ rm -rf /var/lib/apt/lists/* && \ rm /usr/local/lib/python3.11/site-packages/tornado/test/test.key @@ -37,7 +45,7 @@ RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cma # Set up application files WORKDIR /app COPY ./danswer /app/danswer -COPY ./shared_models /app/shared_models +COPY ./shared_configs /app/shared_configs COPY ./alembic /app/alembic COPY ./alembic.ini /app/alembic.ini COPY supervisord.conf /usr/etc/supervisord.conf diff --git a/backend/Dockerfile.model_server b/backend/Dockerfile.model_server index 624bdd37fcd..365a553c9f1 100644 --- a/backend/Dockerfile.model_server +++ b/backend/Dockerfile.model_server @@ -1,5 +1,11 @@ FROM python:3.11.7-slim-bookworm +LABEL com.danswer.maintainer="founders@danswer.ai" +LABEL com.danswer.description="This image is for the Danswer model server which runs all of the \ +AI models for Danswer. This container and all the code is MIT Licensed and free for all to use. \ +You can find it at https://hub.docker.com/r/danswer/danswer-model-server. For more details, \ +visit https://github.com/danswer-ai/danswer." + # Default DANSWER_VERSION, typically overriden during builds by GitHub Actions. ARG DANSWER_VERSION=0.3-dev ENV DANSWER_VERSION=${DANSWER_VERSION} @@ -13,23 +19,14 @@ RUN apt-get remove -y --allow-remove-essential perl-base && \ WORKDIR /app -# Needed for model configs and defaults -COPY ./danswer/configs /app/danswer/configs -COPY ./danswer/dynamic_configs /app/danswer/dynamic_configs - # Utils used by model server COPY ./danswer/utils/logger.py /app/danswer/utils/logger.py -COPY ./danswer/utils/timing.py /app/danswer/utils/timing.py -COPY ./danswer/utils/telemetry.py /app/danswer/utils/telemetry.py # Place to fetch version information COPY ./danswer/__init__.py /app/danswer/__init__.py -# Shared implementations for running NLP models locally -COPY ./danswer/search/search_nlp_models.py /app/danswer/search/search_nlp_models.py - -# Request/Response models -COPY ./shared_models /app/shared_models +# Shared between Danswer Backend and Model Server +COPY ./shared_configs /app/shared_configs # Model Server main code COPY ./model_server /app/model_server diff --git a/backend/alembic/versions/38eda64af7fe_add_chat_session_sharing.py b/backend/alembic/versions/38eda64af7fe_add_chat_session_sharing.py new file mode 100644 index 00000000000..e77ee186f42 --- /dev/null +++ b/backend/alembic/versions/38eda64af7fe_add_chat_session_sharing.py @@ -0,0 +1,41 @@ +"""Add chat session sharing + +Revision ID: 38eda64af7fe +Revises: 776b3bbe9092 +Create Date: 2024-03-27 19:41:29.073594 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "38eda64af7fe" +down_revision = "776b3bbe9092" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "chat_session", + sa.Column( + "shared_status", + sa.Enum( + "PUBLIC", + "PRIVATE", + name="chatsessionsharedstatus", + native_enum=False, + ), + nullable=True, + ), + ) + op.execute("UPDATE chat_session SET shared_status='PRIVATE'") + op.alter_column( + "chat_session", + "shared_status", + nullable=False, + ) + + +def downgrade() -> None: + op.drop_column("chat_session", "shared_status") diff --git a/backend/alembic/versions/475fcefe8826_add_name_to_api_key.py b/backend/alembic/versions/475fcefe8826_add_name_to_api_key.py new file mode 100644 index 00000000000..356766e6acc --- /dev/null +++ b/backend/alembic/versions/475fcefe8826_add_name_to_api_key.py @@ -0,0 +1,23 @@ +"""Add name to api_key + +Revision ID: 475fcefe8826 +Revises: ecab2b3f1a3b +Create Date: 2024-04-11 11:05:18.414438 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "475fcefe8826" +down_revision = "ecab2b3f1a3b" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column("api_key", sa.Column("name", sa.String(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("api_key", "name") diff --git a/backend/alembic/versions/72bdc9929a46_permission_auto_sync_framework.py b/backend/alembic/versions/72bdc9929a46_permission_auto_sync_framework.py new file mode 100644 index 00000000000..0e04478f221 --- /dev/null +++ b/backend/alembic/versions/72bdc9929a46_permission_auto_sync_framework.py @@ -0,0 +1,81 @@ +"""Permission Auto Sync Framework + +Revision ID: 72bdc9929a46 +Revises: 475fcefe8826 +Create Date: 2024-04-14 21:15:28.659634 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "72bdc9929a46" +down_revision = "475fcefe8826" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "email_to_external_user_cache", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("external_user_id", sa.String(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("user_email", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "external_permission", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("user_email", sa.String(), nullable=False), + sa.Column( + "source_type", + sa.String(), + nullable=False, + ), + sa.Column("external_permission_group", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "permission_sync_run", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "source_type", + sa.String(), + nullable=False, + ), + sa.Column("update_type", sa.String(), nullable=False), + sa.Column("cc_pair_id", sa.Integer(), nullable=True), + sa.Column( + "status", + sa.String(), + nullable=False, + ), + sa.Column("error_msg", sa.Text(), nullable=True), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["cc_pair_id"], + ["connector_credential_pair.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + + +def downgrade() -> None: + op.drop_table("permission_sync_run") + op.drop_table("external_permission") + op.drop_table("email_to_external_user_cache") diff --git a/backend/alembic/versions/ecab2b3f1a3b_add_overrides_to_the_chat_session.py b/backend/alembic/versions/ecab2b3f1a3b_add_overrides_to_the_chat_session.py new file mode 100644 index 00000000000..791d7e42e07 --- /dev/null +++ b/backend/alembic/versions/ecab2b3f1a3b_add_overrides_to_the_chat_session.py @@ -0,0 +1,40 @@ +"""Add overrides to the chat session + +Revision ID: ecab2b3f1a3b +Revises: 38eda64af7fe +Create Date: 2024-04-01 19:08:21.359102 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "ecab2b3f1a3b" +down_revision = "38eda64af7fe" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "chat_session", + sa.Column( + "llm_override", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + op.add_column( + "chat_session", + sa.Column( + "prompt_override", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + + +def downgrade() -> None: + op.drop_column("chat_session", "prompt_override") + op.drop_column("chat_session", "llm_override") diff --git a/backend/danswer/background/indexing/job_client.py b/backend/danswer/background/indexing/job_client.py index d37690627f5..6b1344b59f8 100644 --- a/backend/danswer/background/indexing/job_client.py +++ b/backend/danswer/background/indexing/job_client.py @@ -6,18 +6,15 @@ https://github.com/celery/celery/issues/7007#issuecomment-1740139367""" from collections.abc import Callable from dataclasses import dataclass +from multiprocessing import Process from typing import Any from typing import Literal from typing import Optional -from typing import TYPE_CHECKING from danswer.utils.logger import setup_logger logger = setup_logger() -if TYPE_CHECKING: - from torch.multiprocessing import Process - JobStatusType = ( Literal["error"] | Literal["finished"] @@ -89,8 +86,6 @@ def _cleanup_completed_jobs(self) -> None: def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None: """NOTE: `pure` arg is needed so this can be a drop in replacement for Dask""" - from torch.multiprocessing import Process - self._cleanup_completed_jobs() if len(self.jobs) >= self.n_workers: logger.debug("No available workers to run job") diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index 6241af6f56b..9e8ee6b7fe5 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -330,20 +330,15 @@ def _run_indexing( ) -def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None: +def run_indexing_entrypoint(index_attempt_id: int) -> None: """Entrypoint for indexing run when using dask distributed. Wraps the actual logic in a `try` block so that we can catch any exceptions and mark the attempt as failed.""" - import torch - try: # set the indexing attempt ID so that all log messages from this process # will have it added as a prefix IndexAttemptSingleton.set_index_attempt_id(index_attempt_id) - logger.info(f"Setting task to use {num_threads} threads") - torch.set_num_threads(num_threads) - with Session(get_sqlalchemy_engine()) as db_session: attempt = get_index_attempt( db_session=db_session, index_attempt_id=index_attempt_id diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index b77ddee859a..6042e02b1cd 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -15,9 +15,7 @@ from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP -from danswer.configs.app_configs import LOG_LEVEL from danswer.configs.app_configs import NUM_INDEXING_WORKERS -from danswer.configs.model_configs import MIN_THREADS_ML_MODELS from danswer.db.connector import fetch_connectors from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed @@ -29,7 +27,9 @@ from danswer.db.engine import get_db_current_time from danswer.db.engine import get_sqlalchemy_engine from danswer.db.index_attempt import cancel_indexing_attempts_past_model -from danswer.db.index_attempt import count_unique_cc_pairs_with_index_attempts +from danswer.db.index_attempt import ( + count_unique_cc_pairs_with_successful_index_attempts, +) from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import get_inprogress_index_attempts @@ -41,7 +41,11 @@ from danswer.db.models import IndexAttempt from danswer.db.models import IndexingStatus from danswer.db.models import IndexModelStatus +from danswer.search.search_nlp_models import warm_up_encoders from danswer.utils.logger import setup_logger +from shared_configs.configs import INDEXING_MODEL_SERVER_HOST +from shared_configs.configs import LOG_LEVEL +from shared_configs.configs import MODEL_SERVER_PORT logger = setup_logger() @@ -54,18 +58,6 @@ ) -"""Util funcs""" - - -def _get_num_threads() -> int: - """Get # of "threads" to use for ML models in an indexing job. By default uses - the torch implementation, which returns the # of physical cores on the machine. - """ - import torch - - return max(MIN_THREADS_ML_MODELS, torch.get_num_threads()) - - def _should_create_new_indexing( connector: Connector, last_index: IndexAttempt | None, @@ -344,12 +336,10 @@ def kickoff_indexing_jobs( if use_secondary_index: run = secondary_client.submit( - run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False + run_indexing_entrypoint, attempt.id, pure=False ) else: - run = client.submit( - run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False - ) + run = client.submit(run_indexing_entrypoint, attempt.id, pure=False) if run: secondary_str = "(secondary index) " if use_secondary_index else "" @@ -365,9 +355,9 @@ def kickoff_indexing_jobs( def check_index_swap(db_session: Session) -> None: - """Get count of cc-pairs and count of index_attempts for the new model grouped by - connector + credential, if it's the same, then assume new index is done building. - This does not take into consideration if the attempt failed or not""" + """Get count of cc-pairs and count of successful index_attempts for the + new model grouped by connector + credential, if it's the same, then assume + new index is done building. If so, swap the indices and expire the old one.""" # Default CC-pair created for Ingestion API unused here all_cc_pairs = get_connector_credential_pairs(db_session) cc_pair_count = len(all_cc_pairs) - 1 @@ -376,7 +366,7 @@ def check_index_swap(db_session: Session) -> None: if not embedding_model: return - unique_cc_indexings = count_unique_cc_pairs_with_index_attempts( + unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts( embedding_model_id=embedding_model.id, db_session=db_session ) @@ -407,6 +397,20 @@ def check_index_swap(db_session: Session) -> None: def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None: + engine = get_sqlalchemy_engine() + with Session(engine) as db_session: + db_embedding_model = get_current_db_embedding_model(db_session) + + # So that the first time users aren't surprised by really slow speed of first + # batch of documents indexed + logger.info("Running a first inference to warm up embedding model") + warm_up_encoders( + model_name=db_embedding_model.model_name, + normalize=db_embedding_model.normalize, + model_server_host=INDEXING_MODEL_SERVER_HOST, + model_server_port=MODEL_SERVER_PORT, + ) + client_primary: Client | SimpleJobClient client_secondary: Client | SimpleJobClient if DASK_JOB_CLIENT_ENABLED: @@ -433,7 +437,6 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non client_secondary = SimpleJobClient(n_workers=num_workers) existing_jobs: dict[int, Future | SimpleJob] = {} - engine = get_sqlalchemy_engine() with Session(engine) as db_session: # Previous version did not always clean up cc-pairs well leaving some connectors undeleteable @@ -470,14 +473,6 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non def update__main() -> None: - # needed for CUDA to work with multiprocessing - # NOTE: needs to be done on application startup - # before any other torch code has been run - import torch - - if not DASK_JOB_CLIENT_ENABLED: - torch.multiprocessing.set_start_method("spawn") - logger.info("Starting Indexing Loop") update_loop() diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index ee2f582c954..66b0f37de5c 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -7,16 +7,19 @@ from danswer.chat.models import LlmDoc from danswer.db.chat import get_chat_messages_by_session from danswer.db.models import ChatMessage -from danswer.indexing.models import InferenceChunk +from danswer.search.models import InferenceChunk +from danswer.search.models import InferenceSection from danswer.utils.logger import setup_logger logger = setup_logger() -def llm_doc_from_inference_chunk(inf_chunk: InferenceChunk) -> LlmDoc: +def llm_doc_from_inference_section(inf_chunk: InferenceSection) -> LlmDoc: return LlmDoc( document_id=inf_chunk.document_id, - content=inf_chunk.content, + # This one is using the combined content of all the chunks of the section + # In default settings, this is the same as just the content of base chunk + content=inf_chunk.combined_content, blurb=inf_chunk.blurb, semantic_identifier=inf_chunk.semantic_identifier, source_type=inf_chunk.source_type, @@ -55,7 +58,7 @@ def create_chat_chain( id_to_msg = {msg.id: msg for msg in all_chat_messages} if not all_chat_messages: - raise ValueError("No messages in Chat Session") + raise RuntimeError("No messages in Chat Session") root_message = all_chat_messages[0] if root_message.parent_message is not None: diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 270afc67e29..21b87296f2b 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from danswer.chat.chat_utils import create_chat_chain -from danswer.chat.chat_utils import llm_doc_from_inference_chunk +from danswer.chat.chat_utils import llm_doc_from_inference_section from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc @@ -34,7 +34,9 @@ from danswer.llm.answering.models import AnswerStyleConfig from danswer.llm.answering.models import CitationConfig from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.models import LLMConfig from danswer.llm.answering.models import PreviousMessage +from danswer.llm.answering.models import PromptConfig from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.utils import get_default_llm_tokenizer @@ -42,7 +44,7 @@ from danswer.search.models import SearchRequest from danswer.search.pipeline import SearchPipeline from danswer.search.retrieval.search_runner import inference_documents_from_ids -from danswer.search.utils import chunks_to_search_docs +from danswer.search.utils import chunks_or_sections_to_search_docs from danswer.secondary_llm_flows.choose_search import check_if_need_search from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase from danswer.server.query_and_chat.models import ChatMessageDetail @@ -93,6 +95,10 @@ def stream_chat_message_objects( # For flow with search, don't include as many chunks as possible since we need to leave space # for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE, + # if specified, uses the last user message and does not create a new user message based + # on the `new_msg_req.message`. Currently, requires a state where the last message is a + # user message (e.g. this can only be used for the chat-seeding flow). + use_existing_user_message: bool = False, ) -> ChatPacketStream: """Streams in order: 1. [conditional] Retrieved documents if a search needs to be run @@ -159,33 +165,43 @@ def stream_chat_message_objects( else: parent_message = root_message - # Create new message at the right place in the tree and update the parent's child pointer - # Don't commit yet until we verify the chat message chain - new_user_message = create_new_chat_message( - chat_session_id=chat_session_id, - parent_message=parent_message, - prompt_id=prompt_id, - message=message_text, - token_count=len(llm_tokenizer_encode_func(message_text)), - message_type=MessageType.USER, - db_session=db_session, - commit=False, - ) - - # Create linear history of messages - final_msg, history_msgs = create_chat_chain( - chat_session_id=chat_session_id, db_session=db_session - ) - - if final_msg.id != new_user_message.id: - db_session.rollback() - raise RuntimeError( - "The new message was not on the mainline. " - "Be sure to update the chat pointers before calling this." + if not use_existing_user_message: + # Create new message at the right place in the tree and update the parent's child pointer + # Don't commit yet until we verify the chat message chain + user_message = create_new_chat_message( + chat_session_id=chat_session_id, + parent_message=parent_message, + prompt_id=prompt_id, + message=message_text, + token_count=len(llm_tokenizer_encode_func(message_text)), + message_type=MessageType.USER, + db_session=db_session, + commit=False, ) + # re-create linear history of messages + final_msg, history_msgs = create_chat_chain( + chat_session_id=chat_session_id, db_session=db_session + ) + if final_msg.id != user_message.id: + db_session.rollback() + raise RuntimeError( + "The new message was not on the mainline. " + "Be sure to update the chat pointers before calling this." + ) - # Save now to save the latest chat message - db_session.commit() + # Save now to save the latest chat message + db_session.commit() + else: + # re-create linear history of messages + final_msg, history_msgs = create_chat_chain( + chat_session_id=chat_session_id, db_session=db_session + ) + if final_msg.message_type != MessageType.USER: + raise RuntimeError( + "The last message was not a user message. Cannot call " + "`stream_chat_message_objects` with `is_regenerate=True` " + "when the last message is not a user message." + ) run_search = False # Retrieval options are only None if reference_doc_ids are provided @@ -200,6 +216,7 @@ def stream_chat_message_objects( ) rephrased_query = None + llm_relevance_list = None if reference_doc_ids: identifier_tuples = get_doc_query_identifiers_from_model( search_doc_ids=reference_doc_ids, @@ -247,13 +264,16 @@ def stream_chat_message_objects( persona=persona, offset=retrieval_options.offset if retrieval_options else None, limit=retrieval_options.limit if retrieval_options else None, + chunks_above=new_msg_req.chunks_above, + chunks_below=new_msg_req.chunks_below, + full_doc=new_msg_req.full_doc, ), user=user, db_session=db_session, ) - top_chunks = search_pipeline.reranked_docs - top_docs = chunks_to_search_docs(top_chunks) + top_sections = search_pipeline.reranked_sections + top_docs = chunks_or_sections_to_search_docs(top_sections) reference_db_search_docs = [ create_db_search_doc(server_search_doc=top_doc, db_session=db_session) @@ -278,7 +298,7 @@ def stream_chat_message_objects( # Yield the list of LLM selected chunks for showing the LLM selected icons in the UI llm_relevance_filtering_response = LLMRelevanceFilterResponse( - relevant_chunk_indices=search_pipeline.relevant_chunk_indicies + relevant_chunk_indices=search_pipeline.relevant_chunk_indices ) yield llm_relevance_filtering_response @@ -289,9 +309,13 @@ def stream_chat_message_objects( else default_num_chunks ), max_window_percentage=max_document_percentage, + use_sections=search_pipeline.ran_merge_chunk, ) - llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in top_chunks] + llm_docs = [ + llm_doc_from_inference_section(section) for section in top_sections + ] + llm_relevance_list = search_pipeline.section_relevance_list else: llm_docs = [] @@ -302,7 +326,7 @@ def stream_chat_message_objects( partial_response = partial( create_new_chat_message, chat_session_id=chat_session_id, - parent_message=new_user_message, + parent_message=final_msg, prompt_id=prompt_id, # message=, rephrased_query=rephrased_query, @@ -343,8 +367,17 @@ def stream_chat_message_objects( ), document_pruning_config=document_pruning_config, ), - prompt=final_msg.prompt, - persona=persona, + prompt_config=PromptConfig.from_model( + final_msg.prompt, + prompt_override=( + new_msg_req.prompt_override or chat_session.prompt_override + ), + ), + llm_config=LLMConfig.from_persona( + persona, + llm_override=(new_msg_req.llm_override or chat_session.llm_override), + ), + doc_relevance_list=llm_relevance_list, message_history=[ PreviousMessage.from_chat_message(msg) for msg in history_msgs ], @@ -393,12 +426,14 @@ def stream_chat_message_objects( def stream_chat_message( new_msg_req: CreateChatMessageRequest, user: User | None, + use_existing_user_message: bool = False, ) -> Iterator[str]: with get_session_context_manager() as db_session: objects = stream_chat_message_objects( new_msg_req=new_msg_req, user=user, db_session=db_session, + use_existing_user_message=use_existing_user_message, ) for obj in objects: yield get_json_line(obj.dict()) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 08ac2fc23df..1e4809d0716 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -157,6 +157,11 @@ ) if ignored_tag ] +JIRA_CONNECTOR_LABELS_TO_SKIP = [ + ignored_tag + for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",") + if ignored_tag +] GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME") @@ -204,23 +209,6 @@ ) -##### -# Model Server Configs -##### -# If MODEL_SERVER_HOST is set, the NLP models required for Danswer are offloaded to the server via -# requests. Be sure to include the scheme in the MODEL_SERVER_HOST value. -MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or None -MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0" -MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000") - -# specify this env variable directly to have a different model server for the background -# indexing job vs the api server so that background indexing does not effect query-time -# performance -INDEXING_MODEL_SERVER_HOST = ( - os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST -) - - ##### # Miscellaneous ##### @@ -245,5 +233,7 @@ ) # Anonymous usage telemetry DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true" -# notset, debug, info, warning, error, or critical -LOG_LEVEL = os.environ.get("LOG_LEVEL", "info") + +TOKEN_BUDGET_GLOBALLY_ENABLED = ( + os.environ.get("TOKEN_BUDGET_GLOBALLY_ENABLED", "").lower() == "true" +) diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 356fc2831f6..b961cdfb39e 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -40,6 +40,10 @@ SESSION_KEY = "session" QUERY_EVENT_ID = "query_event_id" LLM_CHUNKS = "llm_chunks" +TOKEN_BUDGET = "token_budget" +TOKEN_BUDGET_TIME_PERIOD = "token_budget_time_period" +ENABLE_TOKEN_BUDGET = "enable_token_budget" +TOKEN_BUDGET_SETTINGS = "token_budget_settings" # For chunking/processing chunks TITLE_SEPARATOR = "\n\r\n" @@ -87,6 +91,7 @@ class DocumentSource(str, Enum): ZENDESK = "zendesk" LOOPIO = "loopio" SHAREPOINT = "sharepoint" + AXERO = "axero" class DocumentIndexType(str, Enum): diff --git a/backend/danswer/configs/danswerbot_configs.py b/backend/danswer/configs/danswerbot_configs.py index 5935c9b999e..192a0594d13 100644 --- a/backend/danswer/configs/danswerbot_configs.py +++ b/backend/danswer/configs/danswerbot_configs.py @@ -21,6 +21,14 @@ DANSWER_REACT_EMOJI = os.environ.get("DANSWER_REACT_EMOJI") or "eyes" # When User needs more help, what should the emoji be DANSWER_FOLLOWUP_EMOJI = os.environ.get("DANSWER_FOLLOWUP_EMOJI") or "sos" +# What kind of message should be shown when someone gives an AI answer feedback to DanswerBot +# Defaults to Private if not provided or invalid +# Private: Only visible to user clicking the feedback +# Anonymous: Public but anonymous +# Public: Visible with the user name who submitted the feedback +DANSWER_BOT_FEEDBACK_VISIBILITY = ( + os.environ.get("DANSWER_BOT_FEEDBACK_VISIBILITY") or "private" +) # Should DanswerBot send an apology message if it's not able to find an answer # That way the user isn't confused as to why DanswerBot reacted but then said nothing # Off by default to be less intrusive (don't want to give a notif that just says we couldnt help) diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index f6cd71f31db..e0d774c82b3 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -37,36 +37,13 @@ ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ") # Purely an optimization, memory limitation consideration BATCH_SIZE_ENCODE_CHUNKS = 8 -# This controls the minimum number of pytorch "threads" to allocate to the embedding -# model. If torch finds more threads on its own, this value is not used. -MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1) - -# Cross Encoder Settings -ENABLE_RERANKING_ASYNC_FLOW = ( - os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true" -) -ENABLE_RERANKING_REAL_TIME_FLOW = ( - os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true" -) -# https://www.sbert.net/docs/pretrained-models/ce-msmarco.html -CROSS_ENCODER_MODEL_ENSEMBLE = [ - "cross-encoder/ms-marco-MiniLM-L-4-v2", - "cross-encoder/ms-marco-TinyBERT-L-2-v2", -] -# For score normalizing purposes, only way is to know the expected ranges +# For score display purposes, only way is to know the expected ranges CROSS_ENCODER_RANGE_MAX = 12 CROSS_ENCODER_RANGE_MIN = -12 -CROSS_EMBED_CONTEXT_SIZE = 512 # Unused currently, can't be used with the current default encoder model due to its output range SEARCH_DISTANCE_CUTOFF = 0 -# Intent model max context size -QUERY_MAX_CONTEXT_SIZE = 256 - -# Danswer custom Deep Learning Models -INTENT_MODEL_VERSION = "danswer/intent-model" - ##### # Generative AI Model Configs diff --git a/backend/shared_models/__init__.py b/backend/danswer/connectors/axero/__init__.py similarity index 100% rename from backend/shared_models/__init__.py rename to backend/danswer/connectors/axero/__init__.py diff --git a/backend/danswer/connectors/axero/connector.py b/backend/danswer/connectors/axero/connector.py new file mode 100644 index 00000000000..f82c6b4494a --- /dev/null +++ b/backend/danswer/connectors/axero/connector.py @@ -0,0 +1,363 @@ +import time +from datetime import datetime +from datetime import timezone +from typing import Any + +import requests +from pydantic import BaseModel + +from danswer.configs.app_configs import INDEX_BATCH_SIZE +from danswer.configs.constants import DocumentSource +from danswer.connectors.cross_connector_utils.html_utils import parse_html_page_basic +from danswer.connectors.cross_connector_utils.miscellaneous_utils import ( + process_in_batches, +) +from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc +from danswer.connectors.cross_connector_utils.rate_limit_wrapper import ( + rate_limit_builder, +) +from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder +from danswer.connectors.interfaces import GenerateDocumentsOutput +from danswer.connectors.interfaces import PollConnector +from danswer.connectors.interfaces import SecondsSinceUnixEpoch +from danswer.connectors.models import ConnectorMissingCredentialError +from danswer.connectors.models import Document +from danswer.connectors.models import Section +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + + +ENTITY_NAME_MAP = {1: "Forum", 3: "Article", 4: "Blog", 9: "Wiki"} + + +def _get_auth_header(api_key: str) -> dict[str, str]: + return {"Rest-Api-Key": api_key} + + +@retry_builder() +@rate_limit_builder(max_calls=5, period=1) +def _rate_limited_request( + endpoint: str, headers: dict, params: dict | None = None +) -> Any: + # https://my.axerosolutions.com/spaces/5/communifire-documentation/wiki/view/370/rest-api + return requests.get(endpoint, headers=headers, params=params) + + +# https://my.axerosolutions.com/spaces/5/communifire-documentation/wiki/view/595/rest-api-get-content-list +def _get_entities( + entity_type: int, + api_key: str, + axero_base_url: str, + start: datetime, + end: datetime, + space_id: str | None = None, +) -> list[dict]: + endpoint = axero_base_url + "api/content/list" + page_num = 1 + pages_fetched = 0 + pages_to_return = [] + break_out = False + while True: + params = { + "EntityType": str(entity_type), + "SortColumn": "DateUpdated", + "SortOrder": "1", # descending + "StartPage": str(page_num), + } + + if space_id is not None: + params["SpaceID"] = space_id + + res = _rate_limited_request( + endpoint, headers=_get_auth_header(api_key), params=params + ) + res.raise_for_status() + + # Axero limitations: + # No next page token, can paginate but things may have changed + # for example, a doc that hasn't been read in by Danswer is updated and is now front of the list + # due to this limitation and the fact that Axero has no rate limiting but API calls can cause + # increased latency for the team, we have to just fetch all the pages quickly to reduce the + # chance of missing a document due to an update (it will still get updated next pass) + # Assumes the volume of data isn't too big to store in memory (probably fine) + data = res.json() + total_records = data["TotalRecords"] + contents = data["ResponseData"] + pages_fetched += len(contents) + logger.debug(f"Fetched {pages_fetched} {ENTITY_NAME_MAP[entity_type]}") + + for page in contents: + update_time = time_str_to_utc(page["DateUpdated"]) + + if update_time > end: + continue + + if update_time < start: + break_out = True + break + + pages_to_return.append(page) + + if pages_fetched >= total_records: + break + + page_num += 1 + + if break_out: + break + + return pages_to_return + + +def _get_obj_by_id(obj_id: int, api_key: str, axero_base_url: str) -> dict: + endpoint = axero_base_url + f"api/content/{obj_id}" + res = _rate_limited_request(endpoint, headers=_get_auth_header(api_key)) + res.raise_for_status() + + return res.json() + + +class AxeroForum(BaseModel): + doc_id: str + title: str + link: str + initial_content: str + responses: list[str] + last_update: datetime + + +def _map_post_to_parent( + posts: dict, + api_key: str, + axero_base_url: str, +) -> list[AxeroForum]: + """Cannot handle in batches since the posts aren't ordered or structured in any way + may need to map any number of them to the initial post""" + epoch_str = "1970-01-01T00:00:00.000" + post_map: dict[int, AxeroForum] = {} + + for ind, post in enumerate(posts): + if (ind + 1) % 25 == 0: + logger.debug(f"Processed {ind + 1} posts or responses") + + post_time = time_str_to_utc( + post.get("DateUpdated") or post.get("DateCreated") or epoch_str + ) + p_id = post.get("ParentContentID") + if p_id in post_map: + axero_forum = post_map[p_id] + axero_forum.responses.insert(0, post.get("ContentSummary")) + axero_forum.last_update = max(axero_forum.last_update, post_time) + else: + initial_post_d = _get_obj_by_id(p_id, api_key, axero_base_url)[ + "ResponseData" + ] + initial_post_time = time_str_to_utc( + initial_post_d.get("DateUpdated") + or initial_post_d.get("DateCreated") + or epoch_str + ) + post_map[p_id] = AxeroForum( + doc_id="AXERO_" + str(initial_post_d.get("ContentID")), + title=initial_post_d.get("ContentTitle"), + link=initial_post_d.get("ContentURL"), + initial_content=initial_post_d.get("ContentSummary"), + responses=[post.get("ContentSummary")], + last_update=max(post_time, initial_post_time), + ) + + return list(post_map.values()) + + +def _get_forums( + api_key: str, + axero_base_url: str, + space_id: str | None = None, +) -> list[dict]: + endpoint = axero_base_url + "api/content/list" + page_num = 1 + pages_fetched = 0 + pages_to_return = [] + break_out = False + + while True: + params = { + "EntityType": "54", + "SortColumn": "DateUpdated", + "SortOrder": "1", # descending + "StartPage": str(page_num), + } + + if space_id is not None: + params["SpaceID"] = space_id + + res = _rate_limited_request( + endpoint, headers=_get_auth_header(api_key), params=params + ) + res.raise_for_status() + + data = res.json() + total_records = data["TotalRecords"] + contents = data["ResponseData"] + pages_fetched += len(contents) + logger.debug(f"Fetched {pages_fetched} forums") + + for page in contents: + pages_to_return.append(page) + + if pages_fetched >= total_records: + break + + page_num += 1 + + if break_out: + break + + return pages_to_return + + +def _translate_forum_to_doc(af: AxeroForum) -> Document: + doc = Document( + id=af.doc_id, + sections=[Section(link=af.link, text=reply) for reply in af.responses], + source=DocumentSource.AXERO, + semantic_identifier=af.title, + doc_updated_at=af.last_update, + metadata={}, + ) + + return doc + + +def _translate_content_to_doc(content: dict) -> Document: + page_text = "" + summary = content.get("ContentSummary") + body = content.get("ContentBody") + if summary: + page_text += f"{summary}\n" + + if body: + content_parsed = parse_html_page_basic(body) + page_text += content_parsed + + doc = Document( + id="AXERO_" + str(content["ContentID"]), + sections=[Section(link=content["ContentURL"], text=page_text)], + source=DocumentSource.AXERO, + semantic_identifier=content["ContentTitle"], + doc_updated_at=time_str_to_utc(content["DateUpdated"]), + metadata={"space": content["SpaceName"]}, + ) + + return doc + + +class AxeroConnector(PollConnector): + def __init__( + self, + # Strings of the integer ids of the spaces + spaces: list[str] | None = None, + include_article: bool = True, + include_blog: bool = True, + include_wiki: bool = True, + include_forum: bool = True, + batch_size: int = INDEX_BATCH_SIZE, + ) -> None: + self.include_article = include_article + self.include_blog = include_blog + self.include_wiki = include_wiki + self.include_forum = include_forum + self.batch_size = batch_size + self.space_ids = spaces + self.axero_key = None + self.base_url = None + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + self.axero_key = credentials["axero_api_token"] + # As the API key specifically applies to a particular deployment, this is + # included as part of the credential + base_url = credentials["base_url"] + if not base_url.endswith("/"): + base_url += "/" + self.base_url = base_url + return None + + def poll_source( + self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch + ) -> GenerateDocumentsOutput: + if not self.axero_key or not self.base_url: + raise ConnectorMissingCredentialError("Axero") + + start_datetime = datetime.utcfromtimestamp(start).replace(tzinfo=timezone.utc) + end_datetime = datetime.utcfromtimestamp(end).replace(tzinfo=timezone.utc) + + entity_types = [] + if self.include_article: + entity_types.append(3) + if self.include_blog: + entity_types.append(4) + if self.include_wiki: + entity_types.append(9) + + iterable_space_ids = self.space_ids if self.space_ids else [None] + + for space_id in iterable_space_ids: + for entity in entity_types: + axero_obj = _get_entities( + entity_type=entity, + api_key=self.axero_key, + axero_base_url=self.base_url, + start=start_datetime, + end=end_datetime, + space_id=space_id, + ) + yield from process_in_batches( + objects=axero_obj, + process_function=_translate_content_to_doc, + batch_size=self.batch_size, + ) + + if self.include_forum: + forums_posts = _get_forums( + api_key=self.axero_key, + axero_base_url=self.base_url, + space_id=space_id, + ) + + all_axero_forums = _map_post_to_parent( + posts=forums_posts, + api_key=self.axero_key, + axero_base_url=self.base_url, + ) + + filtered_forums = [ + f + for f in all_axero_forums + if f.last_update >= start_datetime and f.last_update <= end_datetime + ] + + yield from process_in_batches( + objects=filtered_forums, + process_function=_translate_forum_to_doc, + batch_size=self.batch_size, + ) + + +if __name__ == "__main__": + import os + + connector = AxeroConnector() + connector.load_credentials( + { + "axero_api_token": os.environ["AXERO_API_TOKEN"], + "base_url": os.environ["AXERO_BASE_URL"], + } + ) + current = time.time() + + one_year_ago = current - 24 * 60 * 60 * 360 + latest_docs = connector.poll_source(one_year_ago, current) + + print(next(latest_docs)) diff --git a/backend/danswer/connectors/cross_connector_utils/miscellaneous_utils.py b/backend/danswer/connectors/cross_connector_utils/miscellaneous_utils.py index 10c8315601b..8faf6bfadaf 100644 --- a/backend/danswer/connectors/cross_connector_utils/miscellaneous_utils.py +++ b/backend/danswer/connectors/cross_connector_utils/miscellaneous_utils.py @@ -1,5 +1,8 @@ +from collections.abc import Callable +from collections.abc import Iterator from datetime import datetime from datetime import timezone +from typing import TypeVar from dateutil.parser import parse @@ -43,3 +46,14 @@ def get_experts_stores_representations( reps = [basic_expert_info_representation(owner) for owner in experts] return [owner for owner in reps if owner is not None] + + +T = TypeVar("T") +U = TypeVar("U") + + +def process_in_batches( + objects: list[T], process_function: Callable[[T], U], batch_size: int +) -> Iterator[list[U]]: + for i in range(0, len(objects), batch_size): + yield [process_function(obj) for obj in objects[i : i + batch_size]] diff --git a/backend/danswer/connectors/danswer_jira/connector.py b/backend/danswer/connectors/danswer_jira/connector.py index 5ef833e581d..dfed7ebd16c 100644 --- a/backend/danswer/connectors/danswer_jira/connector.py +++ b/backend/danswer/connectors/danswer_jira/connector.py @@ -8,6 +8,7 @@ from jira.resources import Issue from danswer.configs.app_configs import INDEX_BATCH_SIZE +from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP from danswer.configs.constants import DocumentSource from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from danswer.connectors.interfaces import GenerateDocumentsOutput @@ -68,6 +69,7 @@ def fetch_jira_issues_batch( jira_client: JIRA, batch_size: int = INDEX_BATCH_SIZE, comment_email_blacklist: tuple[str, ...] = (), + labels_to_skip: set[str] | None = None, ) -> tuple[list[Document], int]: doc_batch = [] @@ -82,6 +84,15 @@ def fetch_jira_issues_batch( logger.warning(f"Found Jira object not of type Issue {jira}") continue + if labels_to_skip and any( + label in jira.fields.labels for label in labels_to_skip + ): + logger.info( + f"Skipping {jira.key} because it has a label to skip. Found " + f"labels: {jira.fields.labels}. Labels to skip: {labels_to_skip}." + ) + continue + comments = _get_comment_strs(jira, comment_email_blacklist) semantic_rep = f"{jira.fields.description}\n" + "\n".join( [f"Comment: {comment}" for comment in comments] @@ -143,12 +154,18 @@ def __init__( jira_project_url: str, comment_email_blacklist: list[str] | None = None, batch_size: int = INDEX_BATCH_SIZE, + # if a ticket has one of the labels specified in this list, we will just + # skip it. This is generally used to avoid indexing extra sensitive + # tickets. + labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP, ) -> None: self.batch_size = batch_size self.jira_base, self.jira_project = extract_jira_project(jira_project_url) self.jira_client: JIRA | None = None self._comment_email_blacklist = comment_email_blacklist or [] + self.labels_to_skip = set(labels_to_skip) + @property def comment_email_blacklist(self) -> tuple: return tuple(email.strip() for email in self._comment_email_blacklist) @@ -182,6 +199,8 @@ def load_from_state(self) -> GenerateDocumentsOutput: start_index=start_ind, jira_client=self.jira_client, batch_size=self.batch_size, + comment_email_blacklist=self.comment_email_blacklist, + labels_to_skip=self.labels_to_skip, ) if doc_batch: @@ -218,6 +237,7 @@ def poll_source( jira_client=self.jira_client, batch_size=self.batch_size, comment_email_blacklist=self.comment_email_blacklist, + labels_to_skip=self.labels_to_skip, ) if doc_batch: diff --git a/backend/danswer/connectors/factory.py b/backend/danswer/connectors/factory.py index f4a9ee29083..5e6438088b3 100644 --- a/backend/danswer/connectors/factory.py +++ b/backend/danswer/connectors/factory.py @@ -2,6 +2,7 @@ from typing import Type from danswer.configs.constants import DocumentSource +from danswer.connectors.axero.connector import AxeroConnector from danswer.connectors.bookstack.connector import BookstackConnector from danswer.connectors.confluence.connector import ConfluenceConnector from danswer.connectors.danswer_jira.connector import JiraConnector @@ -70,6 +71,7 @@ def identify_connector_class( DocumentSource.ZENDESK: ZendeskConnector, DocumentSource.LOOPIO: LoopioConnector, DocumentSource.SHAREPOINT: SharepointConnector, + DocumentSource.AXERO: AxeroConnector, } connector_by_source = connector_map.get(source, {}) diff --git a/backend/danswer/connectors/hubspot/connector.py b/backend/danswer/connectors/hubspot/connector.py index 861f53ee6f8..bd13c1e75b7 100644 --- a/backend/danswer/connectors/hubspot/connector.py +++ b/backend/danswer/connectors/hubspot/connector.py @@ -94,6 +94,8 @@ def _process_tickets( note = api_client.crm.objects.notes.basic_api.get_by_id( note_id=note.id, properties=["content", "hs_body_preview"] ) + if note.properties["hs_body_preview"] is None: + continue associated_notes.append(note.properties["hs_body_preview"]) associated_emails_str = " ,".join(associated_emails) diff --git a/backend/danswer/connectors/notion/connector.py b/backend/danswer/connectors/notion/connector.py index 28fb47a44d5..2fab138781a 100644 --- a/backend/danswer/connectors/notion/connector.py +++ b/backend/danswer/connectors/notion/connector.py @@ -93,7 +93,9 @@ def __init__( self.recursive_index_enabled = recursive_index_enabled or self.root_page_id @retry(tries=3, delay=1, backoff=2) - def _fetch_blocks(self, block_id: str, cursor: str | None = None) -> dict[str, Any]: + def _fetch_child_blocks( + self, block_id: str, cursor: str | None = None + ) -> dict[str, Any] | None: """Fetch all child blocks via the Notion API.""" logger.debug(f"Fetching children of block with ID '{block_id}'") block_url = f"https://api.notion.com/v1/blocks/{block_id}/children" @@ -107,6 +109,15 @@ def _fetch_blocks(self, block_id: str, cursor: str | None = None) -> dict[str, A try: res.raise_for_status() except Exception as e: + if res.status_code == 404: + # this happens when a page is not shared with the integration + # in this case, we should just ignore the page + logger.error( + f"Unable to access block with ID '{block_id}'. " + f"This is likely due to the block not being shared " + f"with the Danswer integration. Exact exception:\n\n{e}" + ) + return None logger.exception(f"Error fetching blocks - {res.json()}") raise e return res.json() @@ -187,24 +198,30 @@ def _read_pages_from_database(self, database_id: str) -> list[str]: return result_pages def _read_blocks( - self, page_block_id: str + self, base_block_id: str ) -> tuple[list[tuple[str, str]], list[str]]: - """Reads blocks for a page""" + """Reads all child blocks for the specified block""" result_lines: list[tuple[str, str]] = [] child_pages: list[str] = [] cursor = None while True: - data = self._fetch_blocks(page_block_id, cursor) + data = self._fetch_child_blocks(base_block_id, cursor) + + # this happens when a block is not shared with the integration + if data is None: + return result_lines, child_pages for result in data["results"]: - logger.debug(f"Found block for page '{page_block_id}': {result}") + logger.debug( + f"Found child block for block with ID '{base_block_id}': {result}" + ) result_block_id = result["id"] result_type = result["type"] result_obj = result[result_type] if result_type == "ai_block": logger.warning( - f"Skipping 'ai_block' ('{result_block_id}') for page '{page_block_id}': " + f"Skipping 'ai_block' ('{result_block_id}') for base block '{base_block_id}': " f"Notion API does not currently support reading AI blocks (as of 24/02/09) " f"(discussion: https://github.com/danswer-ai/danswer/issues/1053)" ) @@ -416,8 +433,8 @@ def poll_source( ) if len(pages) > 0: yield from batch_generator(self._read_pages(pages), self.batch_size) - if db_res.has_more: - query_dict["start_cursor"] = db_res.next_cursor + if db_res.has_more: + query_dict["start_cursor"] = db_res.next_cursor else: break diff --git a/backend/danswer/connectors/web/connector.py b/backend/danswer/connectors/web/connector.py index 38f30a28edf..6114f0a867a 100644 --- a/backend/danswer/connectors/web/connector.py +++ b/backend/danswer/connectors/web/connector.py @@ -1,5 +1,4 @@ import io -import socket from enum import Enum from typing import Any from typing import cast @@ -43,15 +42,25 @@ class WEB_CONNECTOR_VALID_SETTINGS(str, Enum): UPLOAD = "upload" -def check_internet_connection() -> None: - dns_servers = [("1.1.1.1", 53), ("8.8.8.8", 53)] - for server in dns_servers: - try: - socket.create_connection(server, timeout=3) - return - except OSError: - continue - raise Exception("Unable to contact DNS server - check your internet connection") +def protected_url_check(url: str) -> None: + parse = urlparse(url) + if parse.scheme == "file": + raise ValueError("Not permitted to read local files via Web Connector.") + if ( + parse.scheme == "localhost" + or parse.scheme == "127.0.0.1" + or parse.hostname == "localhost" + or parse.hostname == "127.0.0.1" + ): + raise ValueError("Not permitted to read localhost urls.") + + +def check_internet_connection(url: str) -> None: + try: + response = requests.get(url, timeout=3) + response.raise_for_status() + except (requests.RequestException, ValueError): + raise Exception(f"Unable to reach {url} - check your internet connection") def is_valid_url(url: str) -> bool: @@ -185,7 +194,6 @@ def load_from_state(self) -> GenerateDocumentsOutput: base_url = to_visit[0] # For the recursive case doc_batch: list[Document] = [] - check_internet_connection() playwright, context = start_playwright() restart_playwright = False while to_visit: @@ -194,9 +202,12 @@ def load_from_state(self) -> GenerateDocumentsOutput: continue visited_links.add(current_url) + protected_url_check(current_url) + logger.info(f"Visiting {current_url}") try: + check_internet_connection(current_url) if restart_playwright: playwright, context = start_playwright() restart_playwright = False diff --git a/backend/danswer/danswerbot/slack/constants.py b/backend/danswer/danswerbot/slack/constants.py index a4930b593c3..1e524025fc7 100644 --- a/backend/danswer/danswerbot/slack/constants.py +++ b/backend/danswer/danswerbot/slack/constants.py @@ -1,3 +1,5 @@ +from enum import Enum + LIKE_BLOCK_ACTION_ID = "feedback-like" DISLIKE_BLOCK_ACTION_ID = "feedback-dislike" FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button" @@ -6,3 +8,9 @@ FOLLOWUP_BUTTON_RESOLVED_ACTION_ID = "followup-resolved-button" SLACK_CHANNEL_ID = "channel_id" VIEW_DOC_FEEDBACK_ID = "view-doc-feedback" + + +class FeedbackVisibility(str, Enum): + PRIVATE = "private" + ANONYMOUS = "anonymous" + PUBLIC = "public" diff --git a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py index 0ca030612f3..bec1959e3cc 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py @@ -15,6 +15,7 @@ from danswer.danswerbot.slack.blocks import get_document_feedback_blocks from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID +from danswer.danswerbot.slack.constants import FeedbackVisibility from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import VIEW_DOC_FEEDBACK_ID from danswer.danswerbot.slack.utils import build_feedback_id @@ -22,6 +23,7 @@ from danswer.danswerbot.slack.utils import fetch_groupids_from_names from danswer.danswerbot.slack.utils import fetch_userids_from_emails from danswer.danswerbot.slack.utils import get_channel_name_from_id +from danswer.danswerbot.slack.utils import get_feedback_visibility from danswer.danswerbot.slack.utils import respond_in_thread from danswer.danswerbot.slack.utils import update_emote_react from danswer.db.engine import get_sqlalchemy_engine @@ -120,13 +122,33 @@ def handle_slack_feedback( else: logger_base.error(f"Feedback type '{feedback_type}' not supported") - # post message to slack confirming that feedback was received - client.chat_postEphemeral( - channel=channel_id_to_post_confirmation, - user=user_id_to_post_confirmation, - thread_ts=thread_ts_to_post_confirmation, - text="Thanks for your feedback!", - ) + if get_feedback_visibility() == FeedbackVisibility.PRIVATE or feedback_type not in [ + LIKE_BLOCK_ACTION_ID, + DISLIKE_BLOCK_ACTION_ID, + ]: + client.chat_postEphemeral( + channel=channel_id_to_post_confirmation, + user=user_id_to_post_confirmation, + thread_ts=thread_ts_to_post_confirmation, + text="Thanks for your feedback!", + ) + else: + feedback_response_txt = ( + "liked" if feedback_type == LIKE_BLOCK_ACTION_ID else "disliked" + ) + + if get_feedback_visibility() == FeedbackVisibility.ANONYMOUS: + msg = f"A user has {feedback_response_txt} the AI Answer" + else: + msg = f"<@{user_id_to_post_confirmation}> has {feedback_response_txt} the AI Answer" + + respond_in_thread( + client=client, + channel=channel_id_to_post_confirmation, + text=msg, + thread_ts=thread_ts_to_post_confirmation, + unfurl=False, + ) def handle_followup_button( diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index b3fdb79c88b..fc1c038aebf 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -38,7 +38,9 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import SlackBotConfig from danswer.db.models import SlackBotResponseType -from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens +from danswer.llm.answering.prompts.citations_prompt import ( + compute_max_document_tokens_for_persona, +) from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import get_default_llm_version from danswer.llm.utils import get_max_input_tokens @@ -49,6 +51,7 @@ from danswer.search.models import OptionalSearchSetting from danswer.search.models import RetrievalDetails from danswer.utils.logger import setup_logger +from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW logger_base = setup_logger() @@ -247,7 +250,7 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse: query_text = new_message_request.messages[0].message if persona: - max_document_tokens = compute_max_document_tokens( + max_document_tokens = compute_max_document_tokens_for_persona( persona=persona, actual_user_input=query_text, max_llm_token_override=remaining_tokens, @@ -308,6 +311,7 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse: persona_id=persona.id if persona is not None else 0, retrieval_options=retrieval_details, chain_of_thought=not disable_cot, + skip_rerank=not ENABLE_RERANKING_ASYNC_FLOW, ) ) except Exception as e: diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index fc7055577cb..829c5bbf6c2 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -3,7 +3,6 @@ from typing import Any from typing import cast -import nltk # type: ignore from slack_sdk import WebClient from slack_sdk.socket_mode import SocketModeClient from slack_sdk.socket_mode.request import SocketModeRequest @@ -13,7 +12,6 @@ from danswer.configs.constants import MessageType from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER -from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID @@ -43,9 +41,12 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.one_shot_answer.models import ThreadMessage -from danswer.search.search_nlp_models import warm_up_models +from danswer.search.retrieval.search_runner import download_nltk_data +from danswer.search.search_nlp_models import warm_up_encoders from danswer.server.manage.models import SlackBotTokens from danswer.utils.logger import setup_logger +from shared_configs.configs import MODEL_SERVER_HOST +from shared_configs.configs import MODEL_SERVER_PORT logger = setup_logger() @@ -374,8 +375,7 @@ def _initialize_socket_client(socket_client: SocketModeClient) -> None: socket_client: SocketModeClient | None = None logger.info("Verifying query preprocessing (NLTK) data is downloaded") - nltk.download("stopwords", quiet=True) - nltk.download("punkt", quiet=True) + download_nltk_data() while True: try: @@ -390,10 +390,11 @@ def _initialize_socket_client(socket_client: SocketModeClient) -> None: with Session(get_sqlalchemy_engine()) as db_session: embedding_model = get_current_db_embedding_model(db_session) - warm_up_models( + warm_up_encoders( model_name=embedding_model.model_name, normalize=embedding_model.normalize, - skip_cross_encoders=not ENABLE_RERANKING_ASYNC_FLOW, + model_server_host=MODEL_SERVER_HOST, + model_server_port=MODEL_SERVER_PORT, ) slack_bot_tokens = latest_slack_bot_tokens diff --git a/backend/danswer/danswerbot/slack/utils.py b/backend/danswer/danswerbot/slack/utils.py index 5d761dec0ee..5895dc52f91 100644 --- a/backend/danswer/danswerbot/slack/utils.py +++ b/backend/danswer/danswerbot/slack/utils.py @@ -18,11 +18,13 @@ from danswer.configs.app_configs import DISABLE_TELEMETRY from danswer.configs.constants import ID_SEPARATOR from danswer.configs.constants import MessageType +from danswer.configs.danswerbot_configs import DANSWER_BOT_FEEDBACK_VISIBILITY from danswer.configs.danswerbot_configs import DANSWER_BOT_MAX_QPM from danswer.configs.danswerbot_configs import DANSWER_BOT_MAX_WAIT_TIME from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES from danswer.connectors.slack.utils import make_slack_api_rate_limited from danswer.connectors.slack.utils import SlackTextCleaner +from danswer.danswerbot.slack.constants import FeedbackVisibility from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID from danswer.danswerbot.slack.tokens import fetch_tokens from danswer.db.engine import get_sqlalchemy_engine @@ -449,3 +451,10 @@ def waiter(self, func_randid: int) -> None: self.refill() del self.waiting_questions[0] + + +def get_feedback_visibility() -> FeedbackVisibility: + try: + return FeedbackVisibility(DANSWER_BOT_FEEDBACK_VISIBILITY.lower()) + except ValueError: + return FeedbackVisibility.PRIVATE diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 6dfa02c2f9c..738d02a1657 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -18,6 +18,7 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import ChatMessage from danswer.db.models import ChatSession +from danswer.db.models import ChatSessionSharedStatus from danswer.db.models import DocumentSet as DBDocumentSet from danswer.db.models import Persona from danswer.db.models import Persona__User @@ -27,6 +28,8 @@ from danswer.db.models import SearchDoc as DBSearchDoc from danswer.db.models import StarterMessage from danswer.db.models import User__UserGroup +from danswer.llm.override_models import LLMOverride +from danswer.llm.override_models import PromptOverride from danswer.search.enums import RecencyBiasSetting from danswer.search.models import RetrievalDocs from danswer.search.models import SavedSearchDoc @@ -42,13 +45,19 @@ def get_chat_session_by_id( user_id: UUID | None, db_session: Session, include_deleted: bool = False, + is_shared: bool = False, ) -> ChatSession: stmt = select(ChatSession).where(ChatSession.id == chat_session_id) - # if user_id is None, assume this is an admin who should be able - # to view all chat sessions - if user_id is not None: - stmt = stmt.where(ChatSession.user_id == user_id) + if is_shared: + stmt = stmt.where(ChatSession.shared_status == ChatSessionSharedStatus.PUBLIC) + else: + # if user_id is None, assume this is an admin who should be able + # to view all chat sessions + if user_id is not None: + stmt = stmt.where( + or_(ChatSession.user_id == user_id, ChatSession.user_id.is_(None)) + ) result = db_session.execute(stmt) chat_session = result.scalar_one_or_none() @@ -87,12 +96,16 @@ def create_chat_session( description: str, user_id: UUID | None, persona_id: int | None = None, + llm_override: LLMOverride | None = None, + prompt_override: PromptOverride | None = None, one_shot: bool = False, ) -> ChatSession: chat_session = ChatSession( user_id=user_id, persona_id=persona_id, description=description, + llm_override=llm_override, + prompt_override=prompt_override, one_shot=one_shot, ) @@ -103,7 +116,11 @@ def create_chat_session( def update_chat_session( - user_id: UUID | None, chat_session_id: int, description: str, db_session: Session + db_session: Session, + user_id: UUID | None, + chat_session_id: int, + description: str | None = None, + sharing_status: ChatSessionSharedStatus | None = None, ) -> ChatSession: chat_session = get_chat_session_by_id( chat_session_id=chat_session_id, user_id=user_id, db_session=db_session @@ -112,7 +129,10 @@ def update_chat_session( if chat_session.deleted: raise ValueError("Trying to rename a deleted chat session") - chat_session.description = description + if description is not None: + chat_session.description = description + if sharing_status is not None: + chat_session.shared_status = sharing_status db_session.commit() @@ -724,7 +744,8 @@ def create_db_search_doc( boost=server_search_doc.boost, hidden=server_search_doc.hidden, doc_metadata=server_search_doc.metadata, - score=server_search_doc.score, + # For docs further down that aren't reranked, we can't use the retrieval score + score=server_search_doc.score or 0.0, match_highlights=server_search_doc.match_highlights, updated_at=server_search_doc.updated_at, primary_owners=server_search_doc.primary_owners, @@ -745,6 +766,7 @@ def get_db_search_doc_by_id(doc_id: int, db_session: Session) -> DBSearchDoc | N def translate_db_search_doc_to_server_search_doc( db_search_doc: SearchDoc, + remove_doc_content: bool = False, ) -> SavedSearchDoc: return SavedSearchDoc( db_doc_id=db_search_doc.id, @@ -752,22 +774,30 @@ def translate_db_search_doc_to_server_search_doc( chunk_ind=db_search_doc.chunk_ind, semantic_identifier=db_search_doc.semantic_id, link=db_search_doc.link, - blurb=db_search_doc.blurb, + blurb=db_search_doc.blurb if not remove_doc_content else "", source_type=db_search_doc.source_type, boost=db_search_doc.boost, hidden=db_search_doc.hidden, - metadata=db_search_doc.doc_metadata, + metadata=db_search_doc.doc_metadata if not remove_doc_content else {}, score=db_search_doc.score, - match_highlights=db_search_doc.match_highlights, - updated_at=db_search_doc.updated_at, - primary_owners=db_search_doc.primary_owners, - secondary_owners=db_search_doc.secondary_owners, + match_highlights=db_search_doc.match_highlights + if not remove_doc_content + else [], + updated_at=db_search_doc.updated_at if not remove_doc_content else None, + primary_owners=db_search_doc.primary_owners if not remove_doc_content else [], + secondary_owners=db_search_doc.secondary_owners + if not remove_doc_content + else [], ) -def get_retrieval_docs_from_chat_message(chat_message: ChatMessage) -> RetrievalDocs: +def get_retrieval_docs_from_chat_message( + chat_message: ChatMessage, remove_doc_content: bool = False +) -> RetrievalDocs: top_documents = [ - translate_db_search_doc_to_server_search_doc(db_doc) + translate_db_search_doc_to_server_search_doc( + db_doc, remove_doc_content=remove_doc_content + ) for db_doc in chat_message.search_docs ] top_documents = sorted(top_documents, key=lambda doc: doc.score, reverse=True) # type: ignore @@ -775,7 +805,7 @@ def get_retrieval_docs_from_chat_message(chat_message: ChatMessage) -> Retrieval def translate_db_message_to_chat_message_detail( - chat_message: ChatMessage, + chat_message: ChatMessage, remove_doc_content: bool = False ) -> ChatMessageDetail: chat_msg_detail = ChatMessageDetail( message_id=chat_message.id, @@ -783,7 +813,9 @@ def translate_db_message_to_chat_message_detail( latest_child_message=chat_message.latest_child_message, message=chat_message.message, rephrased_query=chat_message.rephrased_query, - context_docs=get_retrieval_docs_from_chat_message(chat_message), + context_docs=get_retrieval_docs_from_chat_message( + chat_message, remove_doc_content=remove_doc_content + ), message_type=chat_message.message_type, time_sent=chat_message.time_sent, citations=chat_message.citations, diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 22f1193fe82..1be57179c70 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -4,7 +4,6 @@ from datetime import datetime from typing import ContextManager -from ddtrace import tracer from sqlalchemy import text from sqlalchemy.engine import create_engine from sqlalchemy.engine import Engine @@ -77,9 +76,11 @@ def get_session_context_manager() -> ContextManager: def get_session() -> Generator[Session, None, None]: - with tracer.trace("db.get_session"): - with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session: - yield session + # The line below was added to monitor the latency caused by Postgres connections + # during API calls. + # with tracer.trace("db.get_session"): + with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session: + yield session async def get_async_session() -> AsyncGenerator[AsyncSession, None]: diff --git a/backend/danswer/db/enums.py b/backend/danswer/db/enums.py new file mode 100644 index 00000000000..2a02e078c60 --- /dev/null +++ b/backend/danswer/db/enums.py @@ -0,0 +1,35 @@ +from enum import Enum as PyEnum + + +class IndexingStatus(str, PyEnum): + NOT_STARTED = "not_started" + IN_PROGRESS = "in_progress" + SUCCESS = "success" + FAILED = "failed" + + +# these may differ in the future, which is why we're okay with this duplication +class DeletionStatus(str, PyEnum): + NOT_STARTED = "not_started" + IN_PROGRESS = "in_progress" + SUCCESS = "success" + FAILED = "failed" + + +# Consistent with Celery task statuses +class TaskStatus(str, PyEnum): + PENDING = "PENDING" + STARTED = "STARTED" + SUCCESS = "SUCCESS" + FAILURE = "FAILURE" + + +class IndexModelStatus(str, PyEnum): + PAST = "PAST" + PRESENT = "PRESENT" + FUTURE = "FUTURE" + + +class ChatSessionSharedStatus(str, PyEnum): + PUBLIC = "public" + PRIVATE = "private" diff --git a/backend/danswer/db/index_attempt.py b/backend/danswer/db/index_attempt.py index ce913098eb3..4580140a5f1 100644 --- a/backend/danswer/db/index_attempt.py +++ b/backend/danswer/db/index_attempt.py @@ -291,7 +291,7 @@ def cancel_indexing_attempts_past_model( db_session.commit() -def count_unique_cc_pairs_with_index_attempts( +def count_unique_cc_pairs_with_successful_index_attempts( embedding_model_id: int | None, db_session: Session, ) -> int: @@ -299,12 +299,7 @@ def count_unique_cc_pairs_with_index_attempts( db_session.query(IndexAttempt.connector_id, IndexAttempt.credential_id) .filter( IndexAttempt.embedding_model_id == embedding_model_id, - # Should not be able to hang since indexing jobs expire after a limit - # It will then be marked failed, and the next cycle it will be in a completed state - or_( - IndexAttempt.status == IndexingStatus.SUCCESS, - IndexAttempt.status == IndexingStatus.FAILED, - ), + IndexAttempt.status == IndexingStatus.SUCCESS, ) .distinct() .count() diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index faafd2aedf8..004025d7ee2 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -35,40 +35,18 @@ from danswer.configs.constants import MessageType from danswer.configs.constants import SearchFeedbackType from danswer.connectors.models import InputType +from danswer.db.enums import ChatSessionSharedStatus +from danswer.db.enums import IndexingStatus +from danswer.db.enums import IndexModelStatus +from danswer.db.enums import TaskStatus +from danswer.db.pydantic_type import PydanticType from danswer.dynamic_configs.interface import JSON_ro +from danswer.llm.override_models import LLMOverride +from danswer.llm.override_models import PromptOverride from danswer.search.enums import RecencyBiasSetting from danswer.search.enums import SearchType -class IndexingStatus(str, PyEnum): - NOT_STARTED = "not_started" - IN_PROGRESS = "in_progress" - SUCCESS = "success" - FAILED = "failed" - - -# these may differ in the future, which is why we're okay with this duplication -class DeletionStatus(str, PyEnum): - NOT_STARTED = "not_started" - IN_PROGRESS = "in_progress" - SUCCESS = "success" - FAILED = "failed" - - -# Consistent with Celery task statuses -class TaskStatus(str, PyEnum): - PENDING = "PENDING" - STARTED = "STARTED" - SUCCESS = "SUCCESS" - FAILURE = "FAILURE" - - -class IndexModelStatus(str, PyEnum): - PAST = "PAST" - PRESENT = "PRESENT" - FUTURE = "FUTURE" - - class Base(DeclarativeBase): pass @@ -109,6 +87,7 @@ class ApiKey(Base): __tablename__ = "api_key" id: Mapped[int] = mapped_column(Integer, primary_key=True) + name: Mapped[str | None] = mapped_column(String, nullable=True) hashed_api_key: Mapped[str] = mapped_column(String, unique=True) api_key_display: Mapped[str] = mapped_column(String, unique=True) # the ID of the "user" who represents the access credentials for the API key @@ -586,6 +565,25 @@ class ChatSession(Base): one_shot: Mapped[bool] = mapped_column(Boolean, default=False) # Only ever set to True if system is set to not hard-delete chats deleted: Mapped[bool] = mapped_column(Boolean, default=False) + # controls whether or not this conversation is viewable by others + shared_status: Mapped[ChatSessionSharedStatus] = mapped_column( + Enum(ChatSessionSharedStatus, native_enum=False), + default=ChatSessionSharedStatus.PRIVATE, + ) + + # the latest "overrides" specified by the user. These take precedence over + # the attached persona. However, overrides specified directly in the + # `send-message` call will take precedence over these. + # NOTE: currently only used by the chat seeding flow, will be used in the + # future once we allow users to override default values via the Chat UI + # itself + llm_override: Mapped[LLMOverride | None] = mapped_column( + PydanticType(LLMOverride), nullable=True + ) + prompt_override: Mapped[PromptOverride | None] = mapped_column( + PydanticType(PromptOverride), nullable=True + ) + time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), @@ -1043,3 +1041,87 @@ class UserGroup(Base): secondary=DocumentSet__UserGroup.__table__, viewonly=True, ) + + +"""Tables related to Permission Sync""" + + +class PermissionSyncStatus(str, PyEnum): + IN_PROGRESS = "in_progress" + SUCCESS = "success" + FAILED = "failed" + + +class PermissionSyncJobType(str, PyEnum): + USER_LEVEL = "user_level" + GROUP_LEVEL = "group_level" + + +class PermissionSyncRun(Base): + """Represents one run of a permission sync job. For some given cc_pair, it is either sync-ing + the users or it is sync-ing the groups""" + + __tablename__ = "permission_sync_run" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + # Not strictly needed but makes it easy to use without fetching from cc_pair + source_type: Mapped[DocumentSource] = mapped_column(Enum(DocumentSource)) + # Currently all sync jobs are handled as a group permission sync or a user permission sync + update_type: Mapped[PermissionSyncJobType] = mapped_column( + Enum(PermissionSyncJobType) + ) + cc_pair_id: Mapped[int | None] = mapped_column( + ForeignKey("connector_credential_pair.id"), nullable=True + ) + status: Mapped[PermissionSyncStatus] = mapped_column(Enum(PermissionSyncStatus)) + error_msg: Mapped[str | None] = mapped_column(Text, default=None) + updated_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + cc_pair: Mapped[ConnectorCredentialPair] = relationship("ConnectorCredentialPair") + + +class ExternalPermission(Base): + """Maps user info both internal and external to the name of the external group + This maps the user to all of their external groups so that the external group name can be + attached to the ACL list matching during query time. User level permissions can be handled by + directly adding the Danswer user to the doc ACL list""" + + __tablename__ = "external_permission" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + # Email is needed because we want to keep track of users not in Danswer to simplify process + # when the user joins + user_email: Mapped[str] = mapped_column(String) + source_type: Mapped[DocumentSource] = mapped_column(Enum(DocumentSource)) + external_permission_group: Mapped[str] = mapped_column(String) + user = relationship("User") + + +class EmailToExternalUserCache(Base): + """A way to map users IDs in the external tool to a user in Danswer or at least an email for + when the user joins. Used as a cache for when fetching external groups which have their own + user ids, this can easily be mapped back to users already known in Danswer without needing + to call external APIs to get the user emails. + + This way when groups are updated in the external tool and we need to update the mapping of + internal users to the groups, we can sync the internal users to the external groups they are + part of using this. + + Ie. User Chris is part of groups alpha, beta, and we can update this if Chris is no longer + part of alpha in some external tool. + """ + + __tablename__ = "email_to_external_user_cache" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + external_user_id: Mapped[str] = mapped_column(String) + user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + # Email is needed because we want to keep track of users not in Danswer to simplify process + # when the user joins + user_email: Mapped[str] = mapped_column(String) + source_type: Mapped[DocumentSource] = mapped_column(Enum(DocumentSource)) + + user = relationship("User") diff --git a/backend/danswer/db/pydantic_type.py b/backend/danswer/db/pydantic_type.py new file mode 100644 index 00000000000..1f37152a851 --- /dev/null +++ b/backend/danswer/db/pydantic_type.py @@ -0,0 +1,32 @@ +import json +from typing import Any +from typing import Optional +from typing import Type + +from pydantic import BaseModel +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.types import TypeDecorator + + +class PydanticType(TypeDecorator): + impl = JSONB + + def __init__( + self, pydantic_model: Type[BaseModel], *args: Any, **kwargs: Any + ) -> None: + super().__init__(*args, **kwargs) + self.pydantic_model = pydantic_model + + def process_bind_param( + self, value: Optional[BaseModel], dialect: Any + ) -> Optional[dict]: + if value is not None: + return json.loads(value.json()) + return None + + def process_result_value( + self, value: Optional[dict], dialect: Any + ) -> Optional[BaseModel]: + if value is not None: + return self.pydantic_model.parse_obj(value) + return None diff --git a/backend/danswer/document_index/document_index_utils.py b/backend/danswer/document_index/document_index_utils.py index 51e6433cbc4..271fd0cc2e7 100644 --- a/backend/danswer/document_index/document_index_utils.py +++ b/backend/danswer/document_index/document_index_utils.py @@ -6,7 +6,7 @@ from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.embedding_model import get_secondary_db_embedding_model from danswer.indexing.models import IndexChunk -from danswer.indexing.models import InferenceChunk +from danswer.search.models import InferenceChunk DEFAULT_BATCH_SIZE = 30 diff --git a/backend/danswer/document_index/interfaces.py b/backend/danswer/document_index/interfaces.py index 787ee3889ab..e59874fa03a 100644 --- a/backend/danswer/document_index/interfaces.py +++ b/backend/danswer/document_index/interfaces.py @@ -5,8 +5,8 @@ from danswer.access.models import DocumentAccess from danswer.indexing.models import DocMetadataAwareIndexChunk -from danswer.indexing.models import InferenceChunk from danswer.search.models import IndexFilters +from danswer.search.models import InferenceChunk @dataclass(frozen=True) @@ -183,7 +183,8 @@ class IdRetrievalCapable(abc.ABC): def id_based_retrieval( self, document_id: str, - chunk_ind: int | None, + min_chunk_ind: int | None, + max_chunk_ind: int | None, filters: IndexFilters, ) -> list[InferenceChunk]: """ @@ -196,7 +197,8 @@ def id_based_retrieval( Parameters: - document_id: document id for which to retrieve the chunk(s) - - chunk_ind: chunk index to return, if None, return all of the chunks in order + - min_chunk_ind: if None then fetch from the start of doc + - max_chunk_ind: - filters: standard filters object, in this case only the access filter is applied as a permission check diff --git a/backend/danswer/document_index/vespa/index.py b/backend/danswer/document_index/vespa/index.py index 9f78f05c20e..6b8d7bf6a6b 100644 --- a/backend/danswer/document_index/vespa/index.py +++ b/backend/danswer/document_index/vespa/index.py @@ -62,8 +62,8 @@ from danswer.document_index.interfaces import UpdateRequest from danswer.document_index.vespa.utils import remove_invalid_unicode_chars from danswer.indexing.models import DocMetadataAwareIndexChunk -from danswer.indexing.models import InferenceChunk from danswer.search.models import IndexFilters +from danswer.search.models import InferenceChunk from danswer.search.retrieval.search_runner import query_processing from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation from danswer.utils.batching import batch_generator @@ -112,13 +112,13 @@ def _does_document_exist( """Returns whether the document already exists and the users/group whitelists Specifically in this case, document refers to a vespa document which is equivalent to a Danswer chunk. This checks for whether the chunk exists already in the index""" - doc_fetch_response = http_client.get( - f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}" - ) + doc_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}" + doc_fetch_response = http_client.get(doc_url) if doc_fetch_response.status_code == 404: return False if doc_fetch_response.status_code != 200: + logger.debug(f"Failed to check for document with URL {doc_url}") raise RuntimeError( f"Unexpected fetch document by ID value from Vespa " f"with error {doc_fetch_response.status_code}" @@ -157,7 +157,24 @@ def _get_vespa_chunk_ids_by_document_id( "hits": hits_per_page, } while True: - results = requests.post(SEARCH_ENDPOINT, json=params).json() + res = requests.post(SEARCH_ENDPOINT, json=params) + try: + res.raise_for_status() + except requests.HTTPError as e: + request_info = f"Headers: {res.request.headers}\nPayload: {params}" + response_info = ( + f"Status Code: {res.status_code}\nResponse Content: {res.text}" + ) + error_base = f"Error occurred getting chunk by Document ID {document_id}" + logger.error( + f"{error_base}:\n" + f"{request_info}\n" + f"{response_info}\n" + f"Exception: {e}" + ) + raise requests.HTTPError(error_base) from e + + results = res.json() hits = results["root"].get("children", []) doc_chunk_ids.extend( @@ -179,10 +196,14 @@ def _delete_vespa_doc_chunks( ) for chunk_id in doc_chunk_ids: - res = http_client.delete( - f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{chunk_id}" - ) - res.raise_for_status() + try: + res = http_client.delete( + f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{chunk_id}" + ) + res.raise_for_status() + except httpx.HTTPStatusError as e: + logger.error(f"Failed to delete chunk, details: {e.response.text}") + raise def _delete_vespa_docs( @@ -559,18 +580,35 @@ def _query_vespa(query_params: Mapping[str, str | int | float]) -> list[Inferenc if "query" in query_params and not cast(str, query_params["query"]).strip(): raise ValueError("No/empty query received") + params = dict( + **query_params, + **{ + "presentation.timing": True, + } + if LOG_VESPA_TIMING_INFORMATION + else {}, + ) + response = requests.post( SEARCH_ENDPOINT, - json=dict( - **query_params, - **{ - "presentation.timing": True, - } - if LOG_VESPA_TIMING_INFORMATION - else {}, - ), + json=params, ) - response.raise_for_status() + try: + response.raise_for_status() + except requests.HTTPError as e: + request_info = f"Headers: {response.request.headers}\nPayload: {params}" + response_info = ( + f"Status Code: {response.status_code}\n" + f"Response Content: {response.text}" + ) + error_base = "Failed to query Vespa" + logger.error( + f"{error_base}:\n" + f"{request_info}\n" + f"{response_info}\n" + f"Exception: {e}" + ) + raise requests.HTTPError(error_base) from e response_json: dict[str, Any] = response.json() if LOG_VESPA_TIMING_INFORMATION: @@ -826,10 +864,11 @@ def delete(self, doc_ids: list[str]) -> None: def id_based_retrieval( self, document_id: str, - chunk_ind: int | None, + min_chunk_ind: int | None, + max_chunk_ind: int | None, filters: IndexFilters, ) -> list[InferenceChunk]: - if chunk_ind is None: + if min_chunk_ind is None and max_chunk_ind is None: vespa_chunk_ids = _get_vespa_chunk_ids_by_document_id( document_id=document_id, index_name=self.index_name, @@ -850,14 +889,22 @@ def id_based_retrieval( inference_chunks.sort(key=lambda chunk: chunk.chunk_id) return inference_chunks - else: - filters_str = _build_vespa_filters(filters=filters, include_hidden=True) - yql = ( - VespaIndex.yql_base.format(index_name=self.index_name) - + filters_str - + f"({DOCUMENT_ID} contains '{document_id}' and {CHUNK_ID} contains '{chunk_ind}')" - ) - return _query_vespa({"yql": yql}) + filters_str = _build_vespa_filters(filters=filters, include_hidden=True) + yql = ( + VespaIndex.yql_base.format(index_name=self.index_name) + + filters_str + + f"({DOCUMENT_ID} contains '{document_id}'" + ) + + if min_chunk_ind is not None: + yql += f" and {min_chunk_ind} <= {CHUNK_ID}" + if max_chunk_ind is not None: + yql += f" and {max_chunk_ind} >= {CHUNK_ID}" + yql = yql + ")" + + inference_chunks = _query_vespa({"yql": yql}) + inference_chunks.sort(key=lambda chunk: chunk.chunk_id) + return inference_chunks def keyword_retrieval( self, diff --git a/backend/danswer/indexing/chunker.py b/backend/danswer/indexing/chunker.py index 9be9348b9f9..b6f59d18901 100644 --- a/backend/danswer/indexing/chunker.py +++ b/backend/danswer/indexing/chunker.py @@ -5,18 +5,22 @@ from danswer.configs.app_configs import BLURB_SIZE from danswer.configs.app_configs import CHUNK_OVERLAP from danswer.configs.app_configs import MINI_CHUNK_SIZE +from danswer.configs.constants import DocumentSource from danswer.configs.constants import SECTION_SEPARATOR from danswer.configs.constants import TITLE_SEPARATOR from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.connectors.models import Document from danswer.indexing.models import DocAwareChunk from danswer.search.search_nlp_models import get_default_tokenizer +from danswer.utils.logger import setup_logger from danswer.utils.text_processing import shared_precompare_cleanup - if TYPE_CHECKING: from transformers import AutoTokenizer # type:ignore + +logger = setup_logger() + ChunkFunc = Callable[[Document], list[DocAwareChunk]] @@ -178,4 +182,7 @@ def chunk(self, document: Document) -> list[DocAwareChunk]: class DefaultChunker(Chunker): def chunk(self, document: Document) -> list[DocAwareChunk]: + # Specifically for reproducing an issue with gmail + if document.source == DocumentSource.GMAIL: + logger.debug(f"Chunking {document.semantic_identifier}") return chunk_document(document) diff --git a/backend/danswer/indexing/embedder.py b/backend/danswer/indexing/embedder.py index 3be10f5b41c..0aaeb35526a 100644 --- a/backend/danswer/indexing/embedder.py +++ b/backend/danswer/indexing/embedder.py @@ -4,8 +4,6 @@ from sqlalchemy.orm import Session from danswer.configs.app_configs import ENABLE_MINI_CHUNK -from danswer.configs.app_configs import INDEXING_MODEL_SERVER_HOST -from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.embedding_model import get_current_db_embedding_model @@ -16,9 +14,12 @@ from danswer.indexing.models import ChunkEmbedding from danswer.indexing.models import DocAwareChunk from danswer.indexing.models import IndexChunk +from danswer.search.enums import EmbedTextType from danswer.search.search_nlp_models import EmbeddingModel -from danswer.search.search_nlp_models import EmbedTextType +from danswer.utils.batching import batch_list from danswer.utils.logger import setup_logger +from shared_configs.configs import INDEXING_MODEL_SERVER_HOST +from shared_configs.configs import MODEL_SERVER_PORT logger = setup_logger() @@ -73,6 +74,8 @@ def embed_chunks( title_embed_dict: dict[str, list[float]] = {} embedded_chunks: list[IndexChunk] = [] + # Create Mini Chunks for more precise matching of details + # Off by default with unedited settings chunk_texts = [] chunk_mini_chunks_count = {} for chunk_ind, chunk in enumerate(chunks): @@ -85,23 +88,43 @@ def embed_chunks( chunk_texts.extend(mini_chunk_texts) chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts) - text_batches = [ - chunk_texts[i : i + batch_size] - for i in range(0, len(chunk_texts), batch_size) - ] + # Batching for embedding + text_batches = batch_list(chunk_texts, batch_size) embeddings: list[list[float]] = [] len_text_batches = len(text_batches) for idx, text_batch in enumerate(text_batches, start=1): - logger.debug(f"Embedding text batch {idx} of {len_text_batches}") - # Normalize embeddings is only configured via model_configs.py, be sure to use right value for the set loss + logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}") + # Normalize embeddings is only configured via model_configs.py, be sure to use right + # value for the set loss embeddings.extend( self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE) ) - # Replace line above with the line below for easy debugging of indexing flow, skipping the actual model + # Replace line above with the line below for easy debugging of indexing flow + # skipping the actual model # embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))]) + chunk_titles = { + chunk.source_document.get_title_for_document_index() for chunk in chunks + } + + # Drop any None or empty strings + chunk_titles_list = [title for title in chunk_titles if title] + + # Embed Titles in batches + title_batches = batch_list(chunk_titles_list, batch_size) + len_title_batches = len(title_batches) + for ind_batch, title_batch in enumerate(title_batches, start=1): + logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}") + title_embeddings = self.embedding_model.encode( + title_batch, text_type=EmbedTextType.PASSAGE + ) + title_embed_dict.update( + {title: vector for title, vector in zip(title_batch, title_embeddings)} + ) + + # Mapping embeddings to chunks embedding_ind_start = 0 for chunk_ind, chunk in enumerate(chunks): num_embeddings = chunk_mini_chunks_count[chunk_ind] @@ -114,16 +137,19 @@ def embed_chunks( title_embedding = None if title: if title in title_embed_dict: - # Using cached value for speedup + # Using cached value to avoid recalculating for every chunk title_embedding = title_embed_dict[title] else: + logger.error( + "Title had to be embedded separately, this should not happen!" + ) title_embedding = self.embedding_model.encode( [title], text_type=EmbedTextType.PASSAGE )[0] title_embed_dict[title] = title_embedding new_embedded_chunk = IndexChunk( - **{k: getattr(chunk, k) for k in chunk.__dataclass_fields__}, + **chunk.dict(), embeddings=ChunkEmbedding( full_embedding=chunk_embeddings[0], mini_chunk_embeddings=chunk_embeddings[1:], diff --git a/backend/danswer/indexing/models.py b/backend/danswer/indexing/models.py index c875c88bdd2..5fc32cd9afa 100644 --- a/backend/danswer/indexing/models.py +++ b/backend/danswer/indexing/models.py @@ -1,14 +1,14 @@ -from dataclasses import dataclass -from dataclasses import fields -from datetime import datetime +from typing import TYPE_CHECKING from pydantic import BaseModel from danswer.access.models import DocumentAccess -from danswer.configs.constants import DocumentSource from danswer.connectors.models import Document from danswer.utils.logger import setup_logger +if TYPE_CHECKING: + from danswer.db.models import EmbeddingModel + logger = setup_logger() @@ -16,14 +16,12 @@ Embedding = list[float] -@dataclass -class ChunkEmbedding: +class ChunkEmbedding(BaseModel): full_embedding: Embedding mini_chunk_embeddings: list[Embedding] -@dataclass -class BaseChunk: +class BaseChunk(BaseModel): chunk_id: int blurb: str # The first sentence(s) of the first Section of the chunk content: str @@ -33,7 +31,6 @@ class BaseChunk: section_continuation: bool # True if this Chunk's start is not at the start of a Section -@dataclass class DocAwareChunk(BaseChunk): # During indexing flow, we have access to a complete "Document" # During inference we only have access to the document id and do not reconstruct the Document @@ -46,13 +43,11 @@ def to_short_descriptor(self) -> str: ) -@dataclass class IndexChunk(DocAwareChunk): embeddings: ChunkEmbedding title_embedding: Embedding | None -@dataclass class DocMetadataAwareIndexChunk(IndexChunk): """An `IndexChunk` that contains all necessary metadata to be indexed. This includes the following: @@ -77,56 +72,28 @@ def from_index_chunk( document_sets: set[str], boost: int, ) -> "DocMetadataAwareIndexChunk": + index_chunk_data = index_chunk.dict() return cls( - **{ - field.name: getattr(index_chunk, field.name) - for field in fields(index_chunk) - }, + **index_chunk_data, access=access, document_sets=document_sets, boost=boost, ) -@dataclass -class InferenceChunk(BaseChunk): - document_id: str - source_type: DocumentSource - semantic_identifier: str - boost: int - recency_bias: float - score: float | None - hidden: bool - metadata: dict[str, str | list[str]] - # Matched sections in the chunk. Uses Vespa syntax e.g.+ To use the Axero connector, first follow the guide{" "} + + here + {" "} + to generate an API Key. +
+- {chatName || `Chat ${chatSession.id}`} -
- )} - {isSelected && - (isRenamingChat ? ( -+ {chatName || `Chat ${chatSession.id}`} +
+ )} + {isSelected && + (isRenamingChat ? ( ++ {humanReadableFormat(chatSession.time_created)} +
+ +