From 82b9cb4cc10aea32833e8fc2cac475b9c3001b7c Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 18 Apr 2024 21:57:01 -0700 Subject: [PATCH 01/25] Add check to ensure auth is enabled for every endpoint unless explicitly whitelisted --- backend/danswer/main.py | 4 + .../danswer/search/retrieval/search_runner.py | 2 +- backend/danswer/server/auth_check.py | 81 +++++++++++++++++++ .../danswer/server/danswer_api/ingestion.py | 22 ----- 4 files changed, 86 insertions(+), 23 deletions(-) create mode 100644 backend/danswer/server/auth_check.py diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 0e43e9754cc..e8afa0838b1 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -52,6 +52,7 @@ from danswer.llm.utils import get_default_llm_version from danswer.search.retrieval.search_runner import download_nltk_data from danswer.search.search_nlp_models import warm_up_encoders +from danswer.server.auth_check import check_router_auth from danswer.server.danswer_api.ingestion import get_danswer_api_key from danswer.server.danswer_api.ingestion import router as danswer_api_router from danswer.server.documents.cc_pair import router as cc_pair_router @@ -353,6 +354,9 @@ def get_application() -> FastAPI: allow_headers=["*"], ) + # Ensure all routes have auth enabled or are explicitly marked as public + check_router_auth(application) + return application diff --git a/backend/danswer/search/retrieval/search_runner.py b/backend/danswer/search/retrieval/search_runner.py index 092b755d9d8..411db5b0f56 100644 --- a/backend/danswer/search/retrieval/search_runner.py +++ b/backend/danswer/search/retrieval/search_runner.py @@ -32,7 +32,7 @@ logger = setup_logger() -def download_nltk_data(): +def download_nltk_data() -> None: resources = { "stopwords": "corpora/stopwords", "wordnet": "corpora/wordnet", diff --git a/backend/danswer/server/auth_check.py b/backend/danswer/server/auth_check.py new file mode 100644 index 00000000000..54efc49564a --- /dev/null +++ b/backend/danswer/server/auth_check.py @@ -0,0 +1,81 @@ +from typing import cast + +from fastapi import FastAPI +from fastapi.dependencies.models import Dependant + +from danswer.auth.users import current_admin_user +from danswer.auth.users import current_user +from danswer.server.danswer_api.ingestion import api_key_dep + + +PUBLIC_ENDPOINT_SPECS = [ + # built-in documentation functions + ("/openapi.json", {"GET", "HEAD"}), + ("/docs", {"GET", "HEAD"}), + ("/docs/oauth2-redirect", {"GET", "HEAD"}), + ("/redoc", {"GET", "HEAD"}), + # should always be callable, will just return 401 if not authenticated + ("/manage/me", {"GET"}), + # just returns 200 to validate that the server is up + ("/health", {"GET"}), + # just returns auth type, needs to be accessible before the user is logged + # in to determine what flow to give the user + ("/auth/type", {"GET"}), + # just gets the version of Danswer (e.g. 0.3.11) + ("/version", {"GET"}), + # stuff related to basic auth + ("/auth/register", {"POST"}), + ("/auth/login", {"POST"}), + ("/auth/logout", {"POST"}), + ("/auth/forgot-password", {"POST"}), + ("/auth/reset-password", {"POST"}), + ("/auth/request-verify-token", {"POST"}), + ("/auth/verify", {"POST"}), + ("/users/me", {"GET"}), + ("/users/me", {"PATCH"}), + ("/users/{id}", {"GET"}), + ("/users/{id}", {"PATCH"}), + ("/users/{id}", {"DELETE"}), + # oauth + ("/auth/oauth/authorize", {"GET"}), + ("/auth/oauth/callback", {"GET"}), +] + + +def check_router_auth(application: FastAPI) -> None: + """Ensures that all endpoints on the passed in application either + (1) have auth enabled OR + (2) are explicitly marked as a public endpoint + """ + for route in application.routes: + # explicitly marked as public + if ( + hasattr(route, "path") + and hasattr(route, "methods") + and (route.path, route.methods) in PUBLIC_ENDPOINT_SPECS + ): + continue + + # check for auth + found_auth = False + route_dependant_obj = cast( + Dependant | None, route.dependant if hasattr(route, "dependant") else None + ) + if route_dependant_obj: + for dependency in route_dependant_obj.dependencies: + depends_fn = dependency.cache_key[0] + if ( + depends_fn == current_user + or depends_fn == current_admin_user + or depends_fn == api_key_dep + ): + found_auth = True + break + + if not found_auth: + # uncomment to print out all route(s) that are missing auth + # print(f"(\"{route.path}\", {set(route.methods)}),") + + raise RuntimeError( + f"Did not find current_user or current_admin_user dependency in route - {route}" + ) diff --git a/backend/danswer/server/danswer_api/ingestion.py b/backend/danswer/server/danswer_api/ingestion.py index 7fce8d1d38c..2b0bed6c31a 100644 --- a/backend/danswer/server/danswer_api/ingestion.py +++ b/backend/danswer/server/danswer_api/ingestion.py @@ -1,5 +1,4 @@ import secrets -from typing import cast from fastapi import APIRouter from fastapi import Depends @@ -25,7 +24,6 @@ from danswer.indexing.indexing_pipeline import build_indexing_pipeline from danswer.server.danswer_api.models import IngestionDocument from danswer.server.danswer_api.models import IngestionResult -from danswer.server.models import ApiKey from danswer.utils.logger import setup_logger logger = setup_logger() @@ -69,26 +67,6 @@ def api_key_dep(authorization: str = Header(...)) -> str: return token -# Provides a way to recover if the api key is deleted for some reason -# Can also just restart the server to regenerate a new one -def api_key_dep_if_exist(authorization: str | None = Header(None)) -> str | None: - token = authorization.removeprefix("Bearer ").strip() if authorization else None - saved_key = get_danswer_api_key(dont_regenerate=True) - if not saved_key: - return None - - if token != saved_key: - raise HTTPException(status_code=401, detail="Invalid API key") - - return token - - -@router.post("/regenerate-key") -def regenerate_key(_: str | None = Depends(api_key_dep_if_exist)) -> ApiKey: - delete_danswer_api_key() - return ApiKey(api_key=cast(str, get_danswer_api_key())) - - @router.post("/doc-ingestion") def document_ingestion( doc_info: IngestionDocument, From 87f304dfd0fa9157043869a12bd51b4e60acea13 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Fri, 19 Apr 2024 10:38:15 -0700 Subject: [PATCH 02/25] Swap Index Early (#1353) --- backend/danswer/background/update.py | 51 +----------------------- backend/danswer/db/swap_index.py | 58 ++++++++++++++++++++++++++++ backend/danswer/main.py | 2 + 3 files changed, 62 insertions(+), 49 deletions(-) create mode 100644 backend/danswer/db/swap_index.py diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 6042e02b1cd..ae629696579 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -17,19 +17,12 @@ from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP from danswer.configs.app_configs import NUM_INDEXING_WORKERS from danswer.db.connector import fetch_connectors -from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed -from danswer.db.connector_credential_pair import resync_cc_pair from danswer.db.connector_credential_pair import update_connector_credential_pair from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.embedding_model import get_secondary_db_embedding_model -from danswer.db.embedding_model import update_embedding_model_status from danswer.db.engine import get_db_current_time from danswer.db.engine import get_sqlalchemy_engine -from danswer.db.index_attempt import cancel_indexing_attempts_past_model -from danswer.db.index_attempt import ( - count_unique_cc_pairs_with_successful_index_attempts, -) from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import get_inprogress_index_attempts @@ -41,6 +34,7 @@ from danswer.db.models import IndexAttempt from danswer.db.models import IndexingStatus from danswer.db.models import IndexModelStatus +from danswer.db.swap_index import check_index_swap from danswer.search.search_nlp_models import warm_up_encoders from danswer.utils.logger import setup_logger from shared_configs.configs import INDEXING_MODEL_SERVER_HOST @@ -354,51 +348,10 @@ def kickoff_indexing_jobs( return existing_jobs_copy -def check_index_swap(db_session: Session) -> None: - """Get count of cc-pairs and count of successful index_attempts for the - new model grouped by connector + credential, if it's the same, then assume - new index is done building. If so, swap the indices and expire the old one.""" - # Default CC-pair created for Ingestion API unused here - all_cc_pairs = get_connector_credential_pairs(db_session) - cc_pair_count = len(all_cc_pairs) - 1 - embedding_model = get_secondary_db_embedding_model(db_session) - - if not embedding_model: - return - - unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts( - embedding_model_id=embedding_model.id, db_session=db_session - ) - - if unique_cc_indexings > cc_pair_count: - raise RuntimeError("More unique indexings than cc pairs, should not occur") - - if cc_pair_count == unique_cc_indexings: - # Swap indices - now_old_embedding_model = get_current_db_embedding_model(db_session) - update_embedding_model_status( - embedding_model=now_old_embedding_model, - new_status=IndexModelStatus.PAST, - db_session=db_session, - ) - - update_embedding_model_status( - embedding_model=embedding_model, - new_status=IndexModelStatus.PRESENT, - db_session=db_session, - ) - - # Expire jobs for the now past index/embedding model - cancel_indexing_attempts_past_model(db_session) - - # Recount aggregates - for cc_pair in all_cc_pairs: - resync_cc_pair(cc_pair, db_session=db_session) - - def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None: engine = get_sqlalchemy_engine() with Session(engine) as db_session: + check_index_swap(db_session=db_session) db_embedding_model = get_current_db_embedding_model(db_session) # So that the first time users aren't surprised by really slow speed of first diff --git a/backend/danswer/db/swap_index.py b/backend/danswer/db/swap_index.py new file mode 100644 index 00000000000..93eb4714ac3 --- /dev/null +++ b/backend/danswer/db/swap_index.py @@ -0,0 +1,58 @@ +from sqlalchemy.orm import Session + +from danswer.db.connector_credential_pair import get_connector_credential_pairs +from danswer.db.connector_credential_pair import resync_cc_pair +from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.db.embedding_model import get_secondary_db_embedding_model +from danswer.db.embedding_model import update_embedding_model_status +from danswer.db.enums import IndexModelStatus +from danswer.db.index_attempt import cancel_indexing_attempts_past_model +from danswer.db.index_attempt import ( + count_unique_cc_pairs_with_successful_index_attempts, +) +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def check_index_swap(db_session: Session) -> None: + """Get count of cc-pairs and count of successful index_attempts for the + new model grouped by connector + credential, if it's the same, then assume + new index is done building. If so, swap the indices and expire the old one.""" + # Default CC-pair created for Ingestion API unused here + all_cc_pairs = get_connector_credential_pairs(db_session) + cc_pair_count = max(len(all_cc_pairs) - 1, 0) + embedding_model = get_secondary_db_embedding_model(db_session) + + if not embedding_model: + return + + unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts( + embedding_model_id=embedding_model.id, db_session=db_session + ) + + if unique_cc_indexings > cc_pair_count: + raise RuntimeError("More unique indexings than cc pairs, should not occur") + + if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings: + # Swap indices + now_old_embedding_model = get_current_db_embedding_model(db_session) + update_embedding_model_status( + embedding_model=now_old_embedding_model, + new_status=IndexModelStatus.PAST, + db_session=db_session, + ) + + update_embedding_model_status( + embedding_model=embedding_model, + new_status=IndexModelStatus.PRESENT, + db_session=db_session, + ) + + if cc_pair_count > 0: + # Expire jobs for the now past index/embedding model + cancel_indexing_attempts_past_model(db_session) + + # Recount aggregates + for cc_pair in all_cc_pairs: + resync_cc_pair(cc_pair, db_session=db_session) diff --git a/backend/danswer/main.py b/backend/danswer/main.py index e8afa0838b1..8e27482ae07 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -46,6 +46,7 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.db.index_attempt import cancel_indexing_attempts_past_model from danswer.db.index_attempt import expire_index_attempts +from danswer.db.swap_index import check_index_swap from danswer.document_index.factory import get_default_document_index from danswer.dynamic_configs.port_configs import port_filesystem_to_postgres from danswer.llm.factory import get_default_llm @@ -180,6 +181,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: ) with Session(engine) as db_session: + check_index_swap(db_session=db_session) db_embedding_model = get_current_db_embedding_model(db_session) secondary_db_embedding_model = get_secondary_db_embedding_model(db_session) From 58545ccf3a530a195a2c9d2e98674b111f3b8dee Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Fri, 19 Apr 2024 21:52:53 -0700 Subject: [PATCH 03/25] Pre download models (#1354) --- backend/Dockerfile | 9 +++++++++ backend/Dockerfile.model_server | 10 ++++++++++ 2 files changed, 19 insertions(+) diff --git a/backend/Dockerfile b/backend/Dockerfile index a61864fa2c0..533a94b24a2 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -42,6 +42,15 @@ RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cma rm -rf /var/lib/apt/lists/* && \ rm /usr/local/lib/python3.11/site-packages/tornado/test/test.key +# Pre-downloading models for setups with limited egress +RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('intfloat/e5-base-v2')" + +# Pre-downloading NLTK for setups with limited egress +RUN python -c "import nltk; \ +nltk.download('stopwords', quiet=True); \ +nltk.download('wordnet', quiet=True); \ +nltk.download('punkt', quiet=True);" + # Set up application files WORKDIR /app COPY ./danswer /app/danswer diff --git a/backend/Dockerfile.model_server b/backend/Dockerfile.model_server index 365a553c9f1..89f24e2ac26 100644 --- a/backend/Dockerfile.model_server +++ b/backend/Dockerfile.model_server @@ -17,6 +17,16 @@ RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt RUN apt-get remove -y --allow-remove-essential perl-base && \ apt-get autoremove -y +# Pre-downloading models for setups with limited egress +RUN python -c "from transformers import AutoModel, AutoTokenizer, TFDistilBertForSequenceClassification; \ +from huggingface_hub import snapshot_download; \ +AutoTokenizer.from_pretrained('danswer/intent-model'); \ +AutoTokenizer.from_pretrained('intfloat/e5-base-v2'); \ +AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \ +snapshot_download('danswer/intent-model'); \ +snapshot_download('intfloat/e5-base-v2'); \ +snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1')" + WORKDIR /app # Utils used by model server From 4e9605e6522ebe25937a589ffee412fea854788c Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sat, 20 Apr 2024 09:25:08 -0700 Subject: [PATCH 04/25] Only Log Index Attempt CC Pair Miscount (#1355) --- backend/danswer/db/index_attempt.py | 3 +++ backend/danswer/db/swap_index.py | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/backend/danswer/db/index_attempt.py b/backend/danswer/db/index_attempt.py index 4580140a5f1..df42d869d65 100644 --- a/backend/danswer/db/index_attempt.py +++ b/backend/danswer/db/index_attempt.py @@ -295,6 +295,9 @@ def count_unique_cc_pairs_with_successful_index_attempts( embedding_model_id: int | None, db_session: Session, ) -> int: + """Collect all of the Index Attempts that are successful and for the specified embedding model + Then do distinct by connector_id and credential_id which is equivalent to the cc-pair. Finally, + do a count to get the total number of unique cc-pairs with successful attempts""" unique_pairs_count = ( db_session.query(IndexAttempt.connector_id, IndexAttempt.credential_id) .filter( diff --git a/backend/danswer/db/swap_index.py b/backend/danswer/db/swap_index.py index 93eb4714ac3..f14a45f296e 100644 --- a/backend/danswer/db/swap_index.py +++ b/backend/danswer/db/swap_index.py @@ -31,8 +31,10 @@ def check_index_swap(db_session: Session) -> None: embedding_model_id=embedding_model.id, db_session=db_session ) + # Index Attempts are cleaned up as well when the cc-pair is deleted so the logic in this + # function is correct. The unique_cc_indexings are specifically for the existing cc-pairs if unique_cc_indexings > cc_pair_count: - raise RuntimeError("More unique indexings than cc pairs, should not occur") + logger.error("More unique indexings than cc pairs, should not occur") if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings: # Swap indices From 7d51549b1b45643667d8ecf0a415ab6d93d56fa0 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sat, 20 Apr 2024 10:27:41 -0700 Subject: [PATCH 05/25] Remove Unused Volumes (#1356) --- deployment/docker_compose/docker-compose.dev.yml | 12 +++++------- .../docker-compose.prod-no-letsencrypt.yml | 12 +++++------- deployment/docker_compose/docker-compose.prod.yml | 12 +++++------- 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 9b5115f801f..f3a56934016 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -80,7 +80,6 @@ services: volumes: - local_dynamic_storage:/home/storage - file_connector_tmp_storage:/home/file_connector_storage - - model_cache_nltk:/root/nltk_data/ extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -174,7 +173,6 @@ services: volumes: - local_dynamic_storage:/home/storage - file_connector_tmp_storage:/home/file_connector_storage - - model_cache_nltk:/root/nltk_data/ extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -218,7 +216,6 @@ services: # Set to debug to get more fine-grained logs - LOG_LEVEL=${LOG_LEVEL:-info} volumes: - - model_cache_torch:/root/.cache/torch/ - model_cache_huggingface:/root/.cache/huggingface/ logging: driver: json-file @@ -246,7 +243,6 @@ services: # Set to debug to get more fine-grained logs - LOG_LEVEL=${LOG_LEVEL:-info} volumes: - - model_cache_torch:/root/.cache/torch/ - model_cache_huggingface:/root/.cache/huggingface/ logging: driver: json-file @@ -314,10 +310,12 @@ services: volumes: + # local_dynamic_storage is legacy only now local_dynamic_storage: - file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them + # used to store files uploaded by the user temporarily while we are indexing them + # file_connector_tmp_storage is legacy only now + file_connector_tmp_storage: db_volume: vespa_volume: - model_cache_torch: - model_cache_nltk: + # Created by the container itself model_cache_huggingface: diff --git a/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml b/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml index 5c5cd5a4663..32a7dae2659 100644 --- a/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml +++ b/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml @@ -23,7 +23,6 @@ services: volumes: - local_dynamic_storage:/home/storage - file_connector_tmp_storage:/home/file_connector_storage - - model_cache_nltk:/root/nltk_data/ extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -54,7 +53,6 @@ services: volumes: - local_dynamic_storage:/home/storage - file_connector_tmp_storage:/home/file_connector_storage - - model_cache_nltk:/root/nltk_data/ extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -104,7 +102,6 @@ services: # Set to debug to get more fine-grained logs - LOG_LEVEL=${LOG_LEVEL:-info} volumes: - - model_cache_torch:/root/.cache/torch/ - model_cache_huggingface:/root/.cache/huggingface/ logging: driver: json-file @@ -132,7 +129,6 @@ services: # Set to debug to get more fine-grained logs - LOG_LEVEL=${LOG_LEVEL:-info} volumes: - - model_cache_torch:/root/.cache/torch/ - model_cache_huggingface:/root/.cache/huggingface/ logging: driver: json-file @@ -204,10 +200,12 @@ services: volumes: + # local_dynamic_storage is legacy only now local_dynamic_storage: - file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them + # used to store files uploaded by the user temporarily while we are indexing them + # file_connector_tmp_storage is legacy only now + file_connector_tmp_storage: db_volume: vespa_volume: - model_cache_torch: - model_cache_nltk: + # Created by the container itself model_cache_huggingface: diff --git a/deployment/docker_compose/docker-compose.prod.yml b/deployment/docker_compose/docker-compose.prod.yml index 9c7202abd3c..c5b8177c4c4 100644 --- a/deployment/docker_compose/docker-compose.prod.yml +++ b/deployment/docker_compose/docker-compose.prod.yml @@ -23,7 +23,6 @@ services: volumes: - local_dynamic_storage:/home/storage - file_connector_tmp_storage:/home/file_connector_storage - - model_cache_nltk:/root/nltk_data/ extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -54,7 +53,6 @@ services: volumes: - local_dynamic_storage:/home/storage - file_connector_tmp_storage:/home/file_connector_storage - - model_cache_nltk:/root/nltk_data/ extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -117,7 +115,6 @@ services: # Set to debug to get more fine-grained logs - LOG_LEVEL=${LOG_LEVEL:-info} volumes: - - model_cache_torch:/root/.cache/torch/ - model_cache_huggingface:/root/.cache/huggingface/ logging: driver: json-file @@ -145,7 +142,6 @@ services: # Set to debug to get more fine-grained logs - LOG_LEVEL=${LOG_LEVEL:-info} volumes: - - model_cache_torch:/root/.cache/torch/ - model_cache_huggingface:/root/.cache/huggingface/ logging: driver: json-file @@ -221,10 +217,12 @@ services: volumes: + # local_dynamic_storage is legacy only now local_dynamic_storage: - file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them + # used to store files uploaded by the user temporarily while we are indexing them + # file_connector_tmp_storage is legacy only now + file_connector_tmp_storage: db_volume: vespa_volume: - model_cache_torch: - model_cache_nltk: + # Created by the container itself model_cache_huggingface: From f616b7e6e562e2b416b265619136caf1e7db0a0e Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sat, 20 Apr 2024 15:24:00 -0700 Subject: [PATCH 06/25] Web Connector to only allow Global IPs (#1357) --- backend/danswer/connectors/web/connector.py | 64 ++++++++++++++++----- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/backend/danswer/connectors/web/connector.py b/backend/danswer/connectors/web/connector.py index 6114f0a867a..c5fb6db26d4 100644 --- a/backend/danswer/connectors/web/connector.py +++ b/backend/danswer/connectors/web/connector.py @@ -1,4 +1,6 @@ import io +import ipaddress +import socket from enum import Enum from typing import Any from typing import cast @@ -27,7 +29,6 @@ from danswer.connectors.models import Section from danswer.utils.logger import setup_logger - logger = setup_logger() @@ -43,16 +44,33 @@ class WEB_CONNECTOR_VALID_SETTINGS(str, Enum): def protected_url_check(url: str) -> None: + """Couple considerations: + - DNS mapping changes over time so we don't want to cache the results + - Fetching this is assumed to be relatively fast compared to other bottlenecks like reading + the page or embedding the contents + - To be extra safe, all IPs associated with the URL must be global + """ parse = urlparse(url) - if parse.scheme == "file": - raise ValueError("Not permitted to read local files via Web Connector.") - if ( - parse.scheme == "localhost" - or parse.scheme == "127.0.0.1" - or parse.hostname == "localhost" - or parse.hostname == "127.0.0.1" - ): - raise ValueError("Not permitted to read localhost urls.") + if parse.scheme != "http" and parse.scheme != "https": + raise ValueError("URL must be of scheme https?://") + + if not parse.hostname: + raise ValueError("URL must include a hostname") + + try: + # This may give a large list of IP addresses for domains with extensive DNS configurations + # such as large distributed systems of CDNs + info = socket.getaddrinfo(parse.hostname, None) + except socket.gaierror as e: + raise ConnectionError(f"DNS resolution failed for {parse.hostname}: {e}") + + for address in info: + ip = address[4][0] + if not ipaddress.ip_address(ip).is_global: + raise ValueError( + f"Non-global IP address detected: {ip}, skipping page {url}. " + f"The Web Connector is not allowed to read loopback, link-local, or private ranges" + ) def check_internet_connection(url: str) -> None: @@ -194,6 +212,10 @@ def load_from_state(self) -> GenerateDocumentsOutput: base_url = to_visit[0] # For the recursive case doc_batch: list[Document] = [] + # Needed to report error + at_least_one_doc = False + last_error = None + playwright, context = start_playwright() restart_playwright = False while to_visit: @@ -202,7 +224,12 @@ def load_from_state(self) -> GenerateDocumentsOutput: continue visited_links.add(current_url) - protected_url_check(current_url) + try: + protected_url_check(current_url) + except Exception as e: + last_error = f"Invalid URL {current_url} due to {e}" + logger.warning(last_error) + continue logger.info(f"Visiting {current_url}") @@ -251,9 +278,8 @@ def load_from_state(self) -> GenerateDocumentsOutput: to_visit.append(link) if page_response and str(page_response.status)[0] in ("4", "5"): - logger.info( - f"Skipped indexing {current_url} due to HTTP {page_response.status} response" - ) + last_error = f"Skipped indexing {current_url} due to HTTP {page_response.status} response" + logger.info(last_error) continue parsed_html = web_html_cleanup(soup, self.mintlify_cleanup) @@ -272,7 +298,8 @@ def load_from_state(self) -> GenerateDocumentsOutput: page.close() except Exception as e: - logger.error(f"Failed to fetch '{current_url}': {e}") + last_error = f"Failed to fetch '{current_url}': {e}" + logger.error(last_error) playwright.stop() restart_playwright = True continue @@ -280,13 +307,20 @@ def load_from_state(self) -> GenerateDocumentsOutput: if len(doc_batch) >= self.batch_size: playwright.stop() restart_playwright = True + at_least_one_doc = True yield doc_batch doc_batch = [] if doc_batch: playwright.stop() + at_least_one_doc = True yield doc_batch + if not at_least_one_doc: + if last_error: + raise RuntimeError(last_error) + raise RuntimeError("No valid pages found.") + if __name__ == "__main__": connector = WebConnector("https://docs.danswer.dev/") From b407edbe4919b624f5fa8d6e88b15b945b64271e Mon Sep 17 00:00:00 2001 From: Weves Date: Sat, 20 Apr 2024 16:20:26 -0700 Subject: [PATCH 07/25] Personal assistants --- backend/danswer/chat/load_yamls.py | 10 +- backend/danswer/db/chat.py | 79 +- backend/danswer/db/models.py | 4 +- backend/danswer/db/persona.py | 20 +- backend/danswer/db/slack_bot_config.py | 3 +- .../one_shot_answer/answer_question.py | 2 +- .../danswer/server/features/persona/api.py | 94 +- .../danswer/server/features/persona/models.py | 12 +- backend/danswer/server/features/prompt/api.py | 12 +- .../danswer/server/features/prompt/models.py | 3 - backend/danswer/server/manage/slack_bot.py | 2 +- backend/danswer/server/models.py | 6 + .../server/query_and_chat/chat_backend.py | 2 +- .../AssistantEditor.tsx} | 385 ++++--- .../HidableSection.tsx | 0 .../{personas => assistants}/PersonaTable.tsx | 24 +- .../[personaId]/DeletePersonaButton.tsx | 0 .../app/admin/assistants/[personaId]/page.tsx | 53 + web/src/app/admin/assistants/enums.ts | 4 + .../{personas => assistants}/interfaces.ts | 6 +- .../app/admin/{personas => assistants}/lib.ts | 13 +- web/src/app/admin/assistants/new/page.tsx | 42 + .../admin/{personas => assistants}/page.tsx | 14 +- .../admin/bot/SlackBotConfigCreationForm.tsx | 2 +- web/src/app/admin/bot/[id]/page.tsx | 2 +- web/src/app/admin/bot/lib.ts | 2 +- web/src/app/admin/bot/new/page.tsx | 2 +- .../app/admin/personas/[personaId]/page.tsx | 92 -- web/src/app/admin/personas/new/page.tsx | 66 -- web/src/app/admin/settings/interfaces.ts | 16 + web/src/app/assistants/edit/[id]/page.tsx | 59 ++ web/src/app/assistants/new/page.tsx | 59 ++ web/src/app/chat/Chat.tsx | 966 ----------------- web/src/app/chat/ChatIntro.tsx | 2 +- web/src/app/chat/ChatPage.tsx | 999 +++++++++++++++++- web/src/app/chat/ChatPersonaSelector.tsx | 2 +- web/src/app/chat/StarterMessage.tsx | 2 +- web/src/app/chat/lib.tsx | 4 +- web/src/app/chat/page.tsx | 22 +- web/src/app/chat/searchParams.ts | 2 +- .../app/chat/sessionSidebar/AssistantsTab.tsx | 105 ++ .../app/chat/sessionSidebar/ChatSidebar.tsx | 132 ++- web/src/app/chat/sessionSidebar/ChatTab.tsx | 40 + web/src/app/chat/sessionSidebar/constants.ts | 6 + web/src/app/chat/shared/[chatId]/page.tsx | 10 +- web/src/app/layout.tsx | 13 +- web/src/app/page.tsx | 6 +- web/src/app/search/page.tsx | 19 +- web/src/components/admin/Layout.tsx | 15 +- .../documentSet/DocumentSetSelectable.tsx | 54 + web/src/components/{ => header}/Header.tsx | 23 +- web/src/components/header/HeaderWrapper.tsx | 11 + web/src/components/search/PersonaSelector.tsx | 2 +- web/src/components/search/SearchSection.tsx | 16 +- .../components/settings/SettingsProvider.tsx | 20 + web/src/components/settings/lib.ts | 30 + .../assistants/fetchPersonaEditorInfoSS.ts | 103 ++ web/src/lib/constants.ts | 5 + web/src/lib/filters.ts | 2 +- web/src/lib/search/interfaces.ts | 2 +- web/src/lib/settings.ts | 10 - web/src/lib/sources.ts | 2 +- web/src/lib/types.ts | 7 +- 63 files changed, 2128 insertions(+), 1594 deletions(-) rename web/src/app/admin/{personas/PersonaEditor.tsx => assistants/AssistantEditor.tsx} (74%) rename web/src/app/admin/{personas => assistants}/HidableSection.tsx (100%) rename web/src/app/admin/{personas => assistants}/PersonaTable.tsx (88%) rename web/src/app/admin/{personas => assistants}/[personaId]/DeletePersonaButton.tsx (100%) create mode 100644 web/src/app/admin/assistants/[personaId]/page.tsx create mode 100644 web/src/app/admin/assistants/enums.ts rename web/src/app/admin/{personas => assistants}/interfaces.ts (86%) rename web/src/app/admin/{personas => assistants}/lib.ts (96%) create mode 100644 web/src/app/admin/assistants/new/page.tsx rename web/src/app/admin/{personas => assistants}/page.tsx (82%) delete mode 100644 web/src/app/admin/personas/[personaId]/page.tsx delete mode 100644 web/src/app/admin/personas/new/page.tsx create mode 100644 web/src/app/assistants/edit/[id]/page.tsx create mode 100644 web/src/app/assistants/new/page.tsx delete mode 100644 web/src/app/chat/Chat.tsx create mode 100644 web/src/app/chat/sessionSidebar/AssistantsTab.tsx create mode 100644 web/src/app/chat/sessionSidebar/ChatTab.tsx create mode 100644 web/src/app/chat/sessionSidebar/constants.ts create mode 100644 web/src/components/documentSet/DocumentSetSelectable.tsx rename web/src/components/{ => header}/Header.tsx (88%) create mode 100644 web/src/components/header/HeaderWrapper.tsx create mode 100644 web/src/components/settings/SettingsProvider.tsx create mode 100644 web/src/components/settings/lib.ts create mode 100644 web/src/lib/assistants/fetchPersonaEditorInfoSS.ts delete mode 100644 web/src/lib/settings.ts diff --git a/backend/danswer/chat/load_yamls.py b/backend/danswer/chat/load_yamls.py index ccc75443749..1b1e615bb72 100644 --- a/backend/danswer/chat/load_yamls.py +++ b/backend/danswer/chat/load_yamls.py @@ -24,7 +24,7 @@ def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None: with Session(get_sqlalchemy_engine()) as db_session: for prompt in all_prompts: upsert_prompt( - user_id=None, + user=None, prompt_id=prompt.get("id"), name=prompt["name"], description=prompt["description"].strip(), @@ -34,7 +34,6 @@ def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None: datetime_aware=prompt.get("datetime_aware", True), default_prompt=True, personas=None, - shared=True, db_session=db_session, commit=True, ) @@ -67,9 +66,7 @@ def load_personas_from_yaml( prompts: list[PromptDBModel | None] | None = None else: prompts = [ - get_prompt_by_name( - prompt_name, user_id=None, shared=True, db_session=db_session - ) + get_prompt_by_name(prompt_name, user=None, db_session=db_session) for prompt_name in prompt_set_names ] if any([prompt is None for prompt in prompts]): @@ -80,7 +77,7 @@ def load_personas_from_yaml( p_id = persona.get("id") upsert_persona( - user_id=None, + user=None, # Negative to not conflict with existing personas persona_id=(-1 * p_id) if p_id is not None else None, name=persona["name"], @@ -96,7 +93,6 @@ def load_personas_from_yaml( prompts=cast(list[PromptDBModel] | None, prompts), document_sets=doc_sets, default_persona=True, - shared=True, is_public=True, db_session=db_session, ) diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 738d02a1657..d45fe95a78d 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -12,6 +12,7 @@ from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.orm import Session +from danswer.auth.schemas import UserRole from danswer.configs.chat_configs import HARD_DELETE_CHATS from danswer.configs.constants import MessageType from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX @@ -27,6 +28,7 @@ from danswer.db.models import SearchDoc from danswer.db.models import SearchDoc as DBSearchDoc from danswer.db.models import StarterMessage +from danswer.db.models import User from danswer.db.models import User__UserGroup from danswer.llm.override_models import LLMOverride from danswer.llm.override_models import PromptOverride @@ -313,13 +315,16 @@ def set_as_latest_chat_message( def get_prompt_by_id( prompt_id: int, - user_id: UUID | None, + user: User | None, db_session: Session, include_deleted: bool = False, ) -> Prompt: - stmt = select(Prompt).where( - Prompt.id == prompt_id, or_(Prompt.user_id == user_id, Prompt.user_id.is_(None)) - ) + stmt = select(Prompt).where(Prompt.id == prompt_id) + + # if user is not specified OR they are an admin, they should + # have access to all prompts, so this where clause is not needed + if user and user.role != UserRole.ADMIN: + stmt = stmt.where(or_(Prompt.user_id == user.id, Prompt.user_id.is_(None))) if not include_deleted: stmt = stmt.where(Prompt.deleted.is_(False)) @@ -351,14 +356,16 @@ def get_default_prompt() -> Prompt: def get_persona_by_id( persona_id: int, - # if user_id is `None` assume the user is an admin or auth is disabled - user_id: UUID | None, + # if user is `None` assume the user is an admin or auth is disabled + user: User | None, db_session: Session, include_deleted: bool = False, ) -> Persona: stmt = select(Persona).where(Persona.id == persona_id) - if user_id is not None: - stmt = stmt.where(or_(Persona.user_id == user_id, Persona.user_id.is_(None))) + + # if user is an admin, they should have access to all Personas + if user is not None and user.role != UserRole.ADMIN: + stmt = stmt.where(or_(Persona.user_id == user.id, Persona.user_id.is_(None))) if not include_deleted: stmt = stmt.where(Persona.deleted.is_(False)) @@ -397,33 +404,33 @@ def get_personas_by_ids( def get_prompt_by_name( - prompt_name: str, user_id: UUID | None, shared: bool, db_session: Session + prompt_name: str, user: User | None, db_session: Session ) -> Prompt | None: - """Cannot do shared and user owned simultaneously as there may be two of those""" stmt = select(Prompt).where(Prompt.name == prompt_name) - if shared: - stmt = stmt.where(Prompt.user_id.is_(None)) - else: - stmt = stmt.where(Prompt.user_id == user_id) + + # if user is not specified OR they are an admin, they should + # have access to all prompts, so this where clause is not needed + if user and user.role != UserRole.ADMIN: + stmt = stmt.where(Prompt.user_id == user.id) + result = db_session.execute(stmt).scalar_one_or_none() return result def get_persona_by_name( - persona_name: str, user_id: UUID | None, shared: bool, db_session: Session + persona_name: str, user: User | None, db_session: Session ) -> Persona | None: - """Cannot do shared and user owned simultaneously as there may be two of those""" + """Admins can see all, regular users can only fetch their own. + If user is None, assume the user is an admin or auth is disabled.""" stmt = select(Persona).where(Persona.name == persona_name) - if shared: - stmt = stmt.where(Persona.user_id.is_(None)) - else: - stmt = stmt.where(Persona.user_id == user_id) + if user and user.role != UserRole.ADMIN: + stmt = stmt.where(Persona.user_id == user.id) result = db_session.execute(stmt).scalar_one_or_none() return result def upsert_prompt( - user_id: UUID | None, + user: User | None, name: str, description: str, system_prompt: str, @@ -431,7 +438,6 @@ def upsert_prompt( include_citations: bool, datetime_aware: bool, personas: list[Persona] | None, - shared: bool, db_session: Session, prompt_id: int | None = None, default_prompt: bool = True, @@ -440,9 +446,7 @@ def upsert_prompt( if prompt_id is not None: prompt = db_session.query(Prompt).filter_by(id=prompt_id).first() else: - prompt = get_prompt_by_name( - prompt_name=name, user_id=user_id, shared=shared, db_session=db_session - ) + prompt = get_prompt_by_name(prompt_name=name, user=user, db_session=db_session) if prompt: if not default_prompt and prompt.default_prompt: @@ -463,7 +467,7 @@ def upsert_prompt( else: prompt = Prompt( id=prompt_id, - user_id=None if shared else user_id, + user_id=user.id if user else None, name=name, description=description, system_prompt=system_prompt, @@ -485,7 +489,7 @@ def upsert_prompt( def upsert_persona( - user_id: UUID | None, + user: User | None, name: str, description: str, num_chunks: float, @@ -496,7 +500,6 @@ def upsert_persona( document_sets: list[DBDocumentSet] | None, llm_model_version_override: str | None, starter_messages: list[StarterMessage] | None, - shared: bool, is_public: bool, db_session: Session, persona_id: int | None = None, @@ -507,7 +510,7 @@ def upsert_persona( persona = db_session.query(Persona).filter_by(id=persona_id).first() else: persona = get_persona_by_name( - persona_name=name, user_id=user_id, shared=shared, db_session=db_session + persona_name=name, user=user, db_session=db_session ) if persona: @@ -539,7 +542,7 @@ def upsert_persona( else: persona = Persona( id=persona_id, - user_id=None if shared else user_id, + user_id=user.id if user else None, is_public=is_public, name=name, description=description, @@ -566,24 +569,20 @@ def upsert_persona( def mark_prompt_as_deleted( prompt_id: int, - user_id: UUID | None, + user: User | None, db_session: Session, ) -> None: - prompt = get_prompt_by_id( - prompt_id=prompt_id, user_id=user_id, db_session=db_session - ) + prompt = get_prompt_by_id(prompt_id=prompt_id, user=user, db_session=db_session) prompt.deleted = True db_session.commit() def mark_persona_as_deleted( persona_id: int, - user_id: UUID | None, + user: User | None, db_session: Session, ) -> None: - persona = get_persona_by_id( - persona_id=persona_id, user_id=user_id, db_session=db_session - ) + persona = get_persona_by_id(persona_id=persona_id, user=user, db_session=db_session) persona.deleted = True db_session.commit() @@ -621,9 +620,7 @@ def update_persona_visibility( is_visible: bool, db_session: Session, ) -> None: - persona = get_persona_by_id( - persona_id=persona_id, user_id=None, db_session=db_session - ) + persona = get_persona_by_id(persona_id=persona_id, user=None, db_session=db_session) persona.is_visible = is_visible db_session.commit() diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 004025d7ee2..8e1540f20f1 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -736,7 +736,6 @@ class Prompt(Base): __tablename__ = "prompt" id: Mapped[int] = mapped_column(primary_key=True) - # If not belong to a user, then it's shared user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) name: Mapped[str] = mapped_column(String) description: Mapped[str] = mapped_column(String) @@ -770,7 +769,6 @@ class Persona(Base): __tablename__ = "persona" id: Mapped[int] = mapped_column(primary_key=True) - # If not belong to a user, then it's shared user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) name: Mapped[str] = mapped_column(String) description: Mapped[str] = mapped_column(String) @@ -824,7 +822,7 @@ class Persona(Base): back_populates="personas", ) # Owner - user: Mapped[User] = relationship("User", back_populates="personas") + user: Mapped[User | None] = relationship("User", back_populates="personas") # Other users with access users: Mapped[list[User]] = relationship( "User", diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index 38351b18b02..7b1116b5ffa 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -6,6 +6,7 @@ from danswer.db.chat import get_prompts_by_ids from danswer.db.chat import upsert_persona from danswer.db.document_set import get_document_sets_by_ids +from danswer.db.models import Persona__User from danswer.db.models import User from danswer.server.features.persona.models import CreatePersonaRequest from danswer.server.features.persona.models import PersonaSnapshot @@ -21,9 +22,19 @@ def make_persona_private( group_ids: list[int] | None, db_session: Session, ) -> None: + if user_ids is not None: + db_session.query(Persona__User).filter( + Persona__User.persona_id == persona_id + ).delete(synchronize_session="fetch") + + for user_uuid in user_ids: + db_session.add(Persona__User(persona_id=persona_id, user_id=user_uuid)) + + db_session.commit() + # May cause error if someone switches down to MIT from EE - if user_ids or group_ids: - raise NotImplementedError("Danswer MIT does not support private Document Sets") + if group_ids: + raise NotImplementedError("Danswer MIT does not support private Personas") def create_update_persona( @@ -32,8 +43,6 @@ def create_update_persona( user: User | None, db_session: Session, ) -> PersonaSnapshot: - user_id = user.id if user is not None else None - # Permission to actually use these is checked later document_sets = list( get_document_sets_by_ids( @@ -51,7 +60,7 @@ def create_update_persona( try: persona = upsert_persona( persona_id=persona_id, - user_id=user_id, + user=user, name=create_persona_request.name, description=create_persona_request.description, num_chunks=create_persona_request.num_chunks, @@ -62,7 +71,6 @@ def create_update_persona( document_sets=document_sets, llm_model_version_override=create_persona_request.llm_model_version_override, starter_messages=create_persona_request.starter_messages, - shared=create_persona_request.shared, is_public=create_persona_request.is_public, db_session=db_session, ) diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py index c3b463e35d2..9b792ff08bb 100644 --- a/backend/danswer/db/slack_bot_config.py +++ b/backend/danswer/db/slack_bot_config.py @@ -49,7 +49,7 @@ def create_slack_bot_persona( # create/update persona associated with the slack bot persona_name = _build_persona_name(channel_names) persona = upsert_persona( - user_id=None, # Slack Bot Personas are not attached to users + user=None, # Slack Bot Personas are not attached to users persona_id=existing_persona_id, name=persona_name, description="", @@ -61,7 +61,6 @@ def create_slack_bot_persona( document_sets=document_sets, llm_model_version_override=None, starter_messages=None, - shared=True, is_public=True, default_persona=False, db_session=db_session, diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index c0c036339fc..ff6e04a217f 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -173,7 +173,7 @@ def stream_answer_objects( prompt = None if query_req.prompt_id is not None: prompt = get_prompt_by_id( - prompt_id=query_req.prompt_id, user_id=user_id, db_session=db_session + prompt_id=query_req.prompt_id, user=user, db_session=db_session ) if prompt is None: if not chat_session.persona.prompts: diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index b4359f6a1fb..bfaea792f66 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -28,35 +28,6 @@ basic_router = APIRouter(prefix="/persona") -@admin_router.post("") -def create_persona( - create_persona_request: CreatePersonaRequest, - user: User | None = Depends(current_admin_user), - db_session: Session = Depends(get_session), -) -> PersonaSnapshot: - return create_update_persona( - persona_id=None, - create_persona_request=create_persona_request, - user=user, - db_session=db_session, - ) - - -@admin_router.patch("/{persona_id}") -def update_persona( - persona_id: int, - update_persona_request: CreatePersonaRequest, - user: User | None = Depends(current_admin_user), - db_session: Session = Depends(get_session), -) -> PersonaSnapshot: - return create_update_persona( - persona_id=persona_id, - create_persona_request=update_persona_request, - user=user, - db_session=db_session, - ) - - class IsVisibleRequest(BaseModel): is_visible: bool @@ -92,19 +63,6 @@ def patch_persona_display_priority( ) -@admin_router.delete("/{persona_id}") -def delete_persona( - persona_id: int, - user: User | None = Depends(current_admin_user), - db_session: Session = Depends(get_session), -) -> None: - mark_persona_as_deleted( - persona_id=persona_id, - user_id=user.id if user is not None else None, - db_session=db_session, - ) - - @admin_router.get("") def list_personas_admin( _: User | None = Depends(current_admin_user), @@ -124,6 +82,48 @@ def list_personas_admin( """Endpoints for all""" +@basic_router.post("") +def create_persona( + create_persona_request: CreatePersonaRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> PersonaSnapshot: + return create_update_persona( + persona_id=None, + create_persona_request=create_persona_request, + user=user, + db_session=db_session, + ) + + +@basic_router.patch("/{persona_id}") +def update_persona( + persona_id: int, + update_persona_request: CreatePersonaRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> PersonaSnapshot: + return create_update_persona( + persona_id=persona_id, + create_persona_request=update_persona_request, + user=user, + db_session=db_session, + ) + + +@basic_router.delete("/{persona_id}") +def delete_persona( + persona_id: int, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + mark_persona_as_deleted( + persona_id=persona_id, + user=user, + db_session=db_session, + ) + + @basic_router.get("") def list_personas( user: User | None = Depends(current_user), @@ -148,7 +148,7 @@ def get_persona( return PersonaSnapshot.from_model( get_persona_by_id( persona_id=persona_id, - user_id=user.id if user is not None else None, + user=user, db_session=db_session, ) ) @@ -194,9 +194,9 @@ def build_final_template_prompt( ] -@admin_router.get("/utils/list-available-models") +@basic_router.get("/utils/list-available-models") def list_available_model_versions( - _: User | None = Depends(current_admin_user), + _: User | None = Depends(current_user), ) -> list[str]: # currently only support selecting different models for OpenAI if GEN_AI_MODEL_PROVIDER != "openai": @@ -205,9 +205,9 @@ def list_available_model_versions( return GPT_4_MODEL_VERSIONS + GPT_3_5_TURBO_MODEL_VERSIONS -@admin_router.get("/utils/default-model") +@basic_router.get("/utils/default-model") def get_default_model( - _: User | None = Depends(current_admin_user), + _: User | None = Depends(current_user), ) -> str: # currently only support selecting different models for OpenAI if GEN_AI_MODEL_PROVIDER != "openai": diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py index 4cc80eec0ee..8826be2c307 100644 --- a/backend/danswer/server/features/persona/models.py +++ b/backend/danswer/server/features/persona/models.py @@ -7,12 +7,12 @@ from danswer.search.enums import RecencyBiasSetting from danswer.server.features.document_set.models import DocumentSet from danswer.server.features.prompt.models import PromptSnapshot +from danswer.server.models import MinimalUserSnapshot class CreatePersonaRequest(BaseModel): name: str description: str - shared: bool num_chunks: float llm_relevance_filter: bool is_public: bool @@ -29,8 +29,8 @@ class CreatePersonaRequest(BaseModel): class PersonaSnapshot(BaseModel): id: int + owner: MinimalUserSnapshot | None name: str - shared: bool is_visible: bool is_public: bool display_priority: int | None @@ -43,6 +43,7 @@ class PersonaSnapshot(BaseModel): default_persona: bool prompts: list[PromptSnapshot] document_sets: list[DocumentSet] + users: list[UUID] groups: list[int] @classmethod @@ -53,7 +54,11 @@ def from_model(cls, persona: Persona) -> "PersonaSnapshot": return PersonaSnapshot( id=persona.id, name=persona.name, - shared=persona.user_id is None, + owner=( + MinimalUserSnapshot(id=persona.user.id, email=persona.user.email) + if persona.user + else None + ), is_visible=persona.is_visible, is_public=persona.is_public, display_priority=persona.display_priority, @@ -69,6 +74,7 @@ def from_model(cls, persona: Persona) -> "PersonaSnapshot": DocumentSet.from_model(document_set_model) for document_set_model in persona.document_sets ], + users=[user.id for user in persona.users], groups=[user_group.id for user_group in persona.groups], ) diff --git a/backend/danswer/server/features/prompt/api.py b/backend/danswer/server/features/prompt/api.py index b9f27675dc3..24c886ab915 100644 --- a/backend/danswer/server/features/prompt/api.py +++ b/backend/danswer/server/features/prompt/api.py @@ -4,7 +4,6 @@ from sqlalchemy.orm import Session from starlette import status -from danswer.auth.users import current_admin_user from danswer.auth.users import current_user from danswer.db.chat import get_personas_by_ids from danswer.db.chat import get_prompt_by_id @@ -32,8 +31,6 @@ def create_update_prompt( user: User | None, db_session: Session, ) -> PromptSnapshot: - user_id = user.id if user is not None else None - personas = ( list( get_personas_by_ids( @@ -47,7 +44,7 @@ def create_update_prompt( prompt = upsert_prompt( prompt_id=prompt_id, - user_id=user_id, + user=user, name=create_prompt_request.name, description=create_prompt_request.description, system_prompt=create_prompt_request.system_prompt, @@ -55,7 +52,6 @@ def create_update_prompt( include_citations=create_prompt_request.include_citations, datetime_aware=create_prompt_request.datetime_aware, personas=personas, - shared=create_prompt_request.shared, db_session=db_session, ) return PromptSnapshot.from_model(prompt) @@ -64,7 +60,7 @@ def create_update_prompt( @basic_router.post("") def create_prompt( create_prompt_request: CreatePromptRequest, - user: User | None = Depends(current_admin_user), + user: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> PromptSnapshot: try: @@ -124,7 +120,7 @@ def delete_prompt( ) -> None: mark_prompt_as_deleted( prompt_id=prompt_id, - user_id=user.id if user is not None else None, + user=user, db_session=db_session, ) @@ -150,7 +146,7 @@ def get_prompt( return PromptSnapshot.from_model( get_prompt_by_id( prompt_id=prompt_id, - user_id=user.id if user is not None else None, + user=user, db_session=db_session, ) ) diff --git a/backend/danswer/server/features/prompt/models.py b/backend/danswer/server/features/prompt/models.py index 0ae70c58d0c..1cc9452f435 100644 --- a/backend/danswer/server/features/prompt/models.py +++ b/backend/danswer/server/features/prompt/models.py @@ -6,7 +6,6 @@ class CreatePromptRequest(BaseModel): name: str description: str - shared: bool system_prompt: str task_prompt: str include_citations: bool = False @@ -17,7 +16,6 @@ class CreatePromptRequest(BaseModel): class PromptSnapshot(BaseModel): id: int name: str - shared: bool description: str system_prompt: str task_prompt: str @@ -34,7 +32,6 @@ def from_model(cls, prompt: Prompt) -> "PromptSnapshot": return PromptSnapshot( id=prompt.id, name=prompt.name, - shared=prompt.user_id is None, description=prompt.description, system_prompt=prompt.system_prompt, task_prompt=prompt.task_prompt, diff --git a/backend/danswer/server/manage/slack_bot.py b/backend/danswer/server/manage/slack_bot.py index 19003f09d68..40e8663b054 100644 --- a/backend/danswer/server/manage/slack_bot.py +++ b/backend/danswer/server/manage/slack_bot.py @@ -140,7 +140,7 @@ def patch_slack_bot_config( existing_persona_id = existing_slack_bot_config.persona_id if existing_persona_id is not None: persona = get_persona_by_id( - persona_id=existing_persona_id, user_id=None, db_session=db_session + persona_id=existing_persona_id, user=None, db_session=db_session ) if not persona.name.startswith(SLACK_BOT_PERSONA_PREFIX): diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index d616edd4f86..ca23f0a15a7 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -1,6 +1,7 @@ from typing import Generic from typing import Optional from typing import TypeVar +from uuid import UUID from pydantic import BaseModel from pydantic.generics import GenericModel @@ -21,3 +22,8 @@ class ApiKey(BaseModel): class IdReturn(BaseModel): id: int + + +class MinimalUserSnapshot(BaseModel): + id: UUID + email: str diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 52d879dfe69..bbc8eb425b2 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -327,7 +327,7 @@ def get_max_document_tokens( try: persona = get_persona_by_id( persona_id=persona_id, - user_id=user.id if user else None, + user=user, db_session=db_session, ) except ValueError: diff --git a/web/src/app/admin/personas/PersonaEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx similarity index 74% rename from web/src/app/admin/personas/PersonaEditor.tsx rename to web/src/app/admin/assistants/AssistantEditor.tsx index 6ce77edb561..05c5a71a277 100644 --- a/web/src/app/admin/personas/PersonaEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -1,7 +1,7 @@ "use client"; -import { DocumentSet, UserGroup } from "@/lib/types"; -import { Button, Divider, Text } from "@tremor/react"; +import { CCPairBasicInfo, DocumentSet, User, UserGroup } from "@/lib/types"; +import { Button, Divider, Italic, Text } from "@tremor/react"; import { ArrayHelpers, ErrorMessage, @@ -29,6 +29,8 @@ import { EE_ENABLED } from "@/lib/constants"; import { useUserGroups } from "@/lib/hooks"; import { Bubble } from "@/components/Bubble"; import { GroupsIcon } from "@/components/icons/icons"; +import { SuccessfulPersonaUpdateRedirectType } from "./enums"; +import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable"; function Label({ children }: { children: string | JSX.Element }) { return ( @@ -40,16 +42,24 @@ function SubLabel({ children }: { children: string | JSX.Element }) { return
{children}
; } -export function PersonaEditor({ +export function AssistantEditor({ existingPersona, + ccPairs, documentSets, llmOverrideOptions, defaultLLM, + user, + defaultPublic, + redirectType, }: { existingPersona?: Persona | null; + ccPairs: CCPairBasicInfo[]; documentSets: DocumentSet[]; llmOverrideOptions: string[]; defaultLLM: string; + user: User | null; + defaultPublic: boolean; + redirectType: SuccessfulPersonaUpdateRedirectType; }) { const router = useRouter(); const { popup, setPopup } = usePopup(); @@ -99,7 +109,7 @@ export function PersonaEditor({ system_prompt: existingPrompt?.system_prompt ?? "", task_prompt: existingPrompt?.task_prompt ?? "", disable_retrieval: (existingPersona?.num_chunks ?? 10) === 0, - is_public: existingPersona?.is_public ?? true, + is_public: existingPersona?.is_public ?? defaultPublic, document_set_ids: existingPersona?.document_sets?.map( (documentSet) => documentSet.id @@ -116,9 +126,9 @@ export function PersonaEditor({ }} validationSchema={Yup.object() .shape({ - name: Yup.string().required("Must give the Persona a name!"), + name: Yup.string().required("Must give the Assistant a name!"), description: Yup.string().required( - "Must give the Persona a description!" + "Must give the Assistant a description!" ), system_prompt: Yup.string(), task_prompt: Yup.string(), @@ -187,12 +197,14 @@ export function PersonaEditor({ existingPromptId: existingPrompt?.id, ...values, num_chunks: numChunks, + users: user ? [user.id] : undefined, groups, }); } else { [promptResponse, personaResponse] = await createPersona({ ...values, num_chunks: numChunks, + users: user ? [user.id] : undefined, groups, }); } @@ -201,51 +213,53 @@ export function PersonaEditor({ if (!promptResponse.ok) { error = await promptResponse.text(); } - if (personaResponse && !personaResponse.ok) { + if (!personaResponse) { + error = "Failed to create Assistant - no response received"; + } else if (!personaResponse.ok) { error = await personaResponse.text(); } - if (error) { + if (error || !personaResponse) { setPopup({ type: "error", - message: `Failed to create Persona - ${error}`, + message: `Failed to create Assistant - ${error}`, }); formikHelpers.setSubmitting(false); } else { - router.push(`/admin/personas?u=${Date.now()}`); + router.push( + redirectType === SuccessfulPersonaUpdateRedirectType.ADMIN + ? `/admin/assistants?u=${Date.now()}` + : `/chat?assistantId=${ + ((await personaResponse.json()) as Persona).id + }` + ); } }} > {({ isSubmitting, values, setFieldValue }) => (
- + <> - - - - - - <> { setFieldValue("system_prompt", e.target.value); @@ -260,11 +274,11 @@ export function PersonaEditor({ { setFieldValue("task_prompt", e.target.value); triggerFinalPromptUpdate( @@ -276,35 +290,6 @@ export function PersonaEditor({ error={finalPromptError} /> - {!values.disable_retrieval && ( - - )} - - { - setFieldValue("disable_retrieval", e.target.checked); - triggerFinalPromptUpdate( - values.system_prompt, - values.task_prompt, - e.target.checked - ); - }} - /> - {finalPrompt ? ( @@ -319,73 +304,100 @@ export function PersonaEditor({ - {!values.disable_retrieval && ( + {ccPairs.length > 0 && ( <> - + <> - ( + { + setFieldValue("disable_retrieval", e.target.checked); + triggerFinalPromptUpdate( + values.system_prompt, + values.task_prompt, + e.target.checked + ); + }} + /> + + {!values.disable_retrieval && ( + <>
-
- - <> - Select which{" "} + + <> + Select which{" "} + {!user || user.role === "admin" ? ( Document Sets - {" "} - that this Persona should search through. If - none are specified, the Persona will search - through all available documents in order to - try and response to queries. - - -
-
- {documentSets.map((documentSet) => { - const ind = values.document_set_ids.indexOf( - documentSet.id - ); - let isSelected = ind !== -1; - return ( -
{ - if (isSelected) { - arrayHelpers.remove(ind); - } else { - arrayHelpers.push(documentSet.id); - } - }} - > -
- {documentSet.name} -
-
- ); - })} -
+ + ) : ( + "Document Sets" + )}{" "} + that this Assistant should search through. If + none are specified, the Assistant will search + through all available documents in order to try + and respond to queries. + +
- )} - /> + + {documentSets.length > 0 ? ( + ( +
+
+ {documentSets.map((documentSet) => { + const ind = + values.document_set_ids.indexOf( + documentSet.id + ); + let isSelected = ind !== -1; + return ( + { + if (isSelected) { + arrayHelpers.remove(ind); + } else { + arrayHelpers.push(documentSet.id); + } + }} + /> + ); + })} +
+
+ )} + /> + ) : ( + + No Document Sets available.{" "} + {user?.role !== "admin" && ( + <> + If this functionality would be useful, reach + out to the administrators of Danswer for + assistance. + + )} + + )} + + )}
@@ -393,73 +405,38 @@ export function PersonaEditor({ )} - {EE_ENABLED && userGroups && ( + {!values.disable_retrieval && ( <> - + <> - - {userGroups && - userGroups.length > 0 && - !values.is_public && ( -
- - Select which User Groups should have access to - this Persona. - -
- {userGroups.map((userGroup) => { - const isSelected = values.groups.includes( - userGroup.id - ); - return ( - { - if (isSelected) { - setFieldValue( - "groups", - values.groups.filter( - (id) => id !== userGroup.id - ) - ); - } else { - setFieldValue("groups", [ - ...values.groups, - userGroup.id, - ]); - } - }} - > -
- -
- {userGroup.name} -
-
-
- ); - })} -
-
- )}
+ )} {llmOverrideOptions.length > 0 && defaultLLM && ( <> - + <> - Pick which LLM to use for this Persona. If left as + Pick which LLM to use for this Assistant. If left as Default, will use {defaultLLM} .
@@ -496,7 +473,10 @@ export function PersonaEditor({ {!values.disable_retrieval && ( <> - + <> How many chunks should we feed into the LLM when generating the final response? Each chunk is ~400 - words long. If you are using gpt-3.5-turbo or other - similar models, setting this to a value greater than - 5 will result in errors at query time due to the - model's input length limit. + words long.

If unspecified, will use 10 chunks. @@ -537,14 +514,17 @@ export function PersonaEditor({ )} - + <>
- Starter Messages help guide users to use this Persona. + Starter Messages help guide users to use this Assistant. They are shown to the user as clickable options when they - select this Persona. When selected, the specified message - is sent to the LLM as the initial user message. + select this Assistant. When selected, the specified + message is sent to the LLM as the initial user message.
@@ -686,6 +666,67 @@ export function PersonaEditor({ + {EE_ENABLED && userGroups && (!user || user.role === "admin") && ( + <> + + <> + + + {userGroups && + userGroups.length > 0 && + !values.is_public && ( +
+ + Select which User Groups should have access to + this Assistant. + +
+ {userGroups.map((userGroup) => { + const isSelected = values.groups.includes( + userGroup.id + ); + return ( + { + if (isSelected) { + setFieldValue( + "groups", + values.groups.filter( + (id) => id !== userGroup.id + ) + ); + } else { + setFieldValue("groups", [ + ...values.groups, + userGroup.id, + ]); + } + }} + > +
+ +
+ {userGroup.name} +
+
+
+ ); + })} +
+
+ )} + +
+ + + )} +