diff --git a/.gitignore b/.gitignore index daa62c9e307..e1125f5f6c8 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ .idea /deployment/data/nginx/app.conf .vscode/launch.json +*.sw? diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7e80baeb2d7..116e78b6f19 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -72,6 +72,10 @@ For convenience here's a command for it: python -m venv .venv source .venv/bin/activate ``` + +--> Note that this virtual environment MUST NOT be set up WITHIN the danswer +directory + _For Windows, activate the virtual environment using Command Prompt:_ ```bash .venv\Scripts\activate diff --git a/backend/alembic/versions/48d14957fe80_add_support_for_custom_tools.py b/backend/alembic/versions/48d14957fe80_add_support_for_custom_tools.py new file mode 100644 index 00000000000..514389ac216 --- /dev/null +++ b/backend/alembic/versions/48d14957fe80_add_support_for_custom_tools.py @@ -0,0 +1,61 @@ +"""Add support for custom tools + +Revision ID: 48d14957fe80 +Revises: b85f02ec1308 +Create Date: 2024-06-09 14:58:19.946509 + +""" +from alembic import op +import fastapi_users_db_sqlalchemy +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "48d14957fe80" +down_revision = "b85f02ec1308" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "tool", + sa.Column( + "openapi_schema", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + op.add_column( + "tool", + sa.Column( + "user_id", + fastapi_users_db_sqlalchemy.generics.GUID(), + nullable=True, + ), + ) + op.create_foreign_key("tool_user_fk", "tool", "user", ["user_id"], ["id"]) + + op.create_table( + "tool_call", + sa.Column("id", sa.Integer(), primary_key=True), + sa.Column("tool_id", sa.Integer(), nullable=False), + sa.Column("tool_name", sa.String(), nullable=False), + sa.Column( + "tool_arguments", postgresql.JSONB(astext_type=sa.Text()), nullable=False + ), + sa.Column( + "tool_result", postgresql.JSONB(astext_type=sa.Text()), nullable=False + ), + sa.Column( + "message_id", sa.Integer(), sa.ForeignKey("chat_message.id"), nullable=False + ), + ) + + +def downgrade() -> None: + op.drop_table("tool_call") + + op.drop_constraint("tool_user_fk", "tool", type_="foreignkey") + op.drop_column("tool", "user_id") + op.drop_column("tool", "openapi_schema") diff --git a/backend/alembic/versions/e209dc5a8156_added_prune_frequency.py b/backend/alembic/versions/e209dc5a8156_added_prune_frequency.py new file mode 100644 index 00000000000..0d5c250ebb6 --- /dev/null +++ b/backend/alembic/versions/e209dc5a8156_added_prune_frequency.py @@ -0,0 +1,22 @@ +"""added-prune-frequency + +Revision ID: e209dc5a8156 +Revises: 48d14957fe80 +Create Date: 2024-06-16 16:02:35.273231 + +""" +from alembic import op +import sqlalchemy as sa + +revision = "e209dc5a8156" +down_revision = "48d14957fe80" +branch_labels = None # type: ignore +depends_on = None # type: ignore + + +def upgrade() -> None: + op.add_column("connector", sa.Column("prune_freq", sa.Integer(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("connector", "prune_freq") diff --git a/backend/danswer/auth/invited_users.py b/backend/danswer/auth/invited_users.py new file mode 100644 index 00000000000..56a02fc60c4 --- /dev/null +++ b/backend/danswer/auth/invited_users.py @@ -0,0 +1,21 @@ +from typing import cast + +from danswer.dynamic_configs.factory import get_dynamic_config_store +from danswer.dynamic_configs.interface import ConfigNotFoundError +from danswer.dynamic_configs.interface import JSON_ro + +USER_STORE_KEY = "INVITED_USERS" + + +def get_invited_users() -> list[str]: + try: + store = get_dynamic_config_store() + return cast(list, store.load(USER_STORE_KEY)) + except ConfigNotFoundError: + return list() + + +def write_invited_users(emails: list[str]) -> int: + store = get_dynamic_config_store() + store.store(USER_STORE_KEY, cast(JSON_ro, emails)) + return len(emails) diff --git a/backend/danswer/auth/schemas.py b/backend/danswer/auth/schemas.py index 34456d9a777..79d9a7f8098 100644 --- a/backend/danswer/auth/schemas.py +++ b/backend/danswer/auth/schemas.py @@ -9,6 +9,12 @@ class UserRole(str, Enum): ADMIN = "admin" +class UserStatus(str, Enum): + LIVE = "live" + INVITED = "invited" + DEACTIVATED = "deactivated" + + class UserRead(schemas.BaseUser[uuid.UUID]): role: UserRole diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 03e770bd5f5..dd7bc9f7787 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -1,4 +1,3 @@ -import os import smtplib import uuid from collections.abc import AsyncGenerator @@ -27,6 +26,7 @@ from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase from sqlalchemy.orm import Session +from danswer.auth.invited_users import get_invited_users from danswer.auth.schemas import UserCreate from danswer.auth.schemas import UserRole from danswer.configs.app_configs import AUTH_TYPE @@ -59,9 +59,6 @@ logger = setup_logger() -USER_WHITELIST_FILE = "/home/danswer_whitelist.txt" -_user_whitelist: list[str] | None = None - def verify_auth_setting() -> None: if AUTH_TYPE not in [AuthType.DISABLED, AuthType.BASIC, AuthType.GOOGLE_OAUTH]: @@ -92,20 +89,8 @@ def user_needs_to_be_verified() -> bool: return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION -def get_user_whitelist() -> list[str]: - global _user_whitelist - if _user_whitelist is None: - if os.path.exists(USER_WHITELIST_FILE): - with open(USER_WHITELIST_FILE, "r") as file: - _user_whitelist = [line.strip() for line in file] - else: - _user_whitelist = [] - - return _user_whitelist - - def verify_email_in_whitelist(email: str) -> None: - whitelist = get_user_whitelist() + whitelist = get_invited_users() if (whitelist and email not in whitelist) or not email: raise PermissionError("User not on allowed user whitelist") diff --git a/backend/danswer/background/celery/celery.py b/backend/danswer/background/celery/celery.py index 216fbd50c37..1a678ea11fa 100644 --- a/backend/danswer/background/celery/celery.py +++ b/backend/danswer/background/celery/celery.py @@ -4,13 +4,22 @@ from celery import Celery # type: ignore from sqlalchemy.orm import Session +from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector +from danswer.background.celery.celery_utils import should_prune_cc_pair +from danswer.background.celery.celery_utils import should_sync_doc_set from danswer.background.connector_deletion import delete_connector_credential_pair +from danswer.background.connector_deletion import delete_connector_credential_pair_batch from danswer.background.task_utils import build_celery_task_wrapper from danswer.background.task_utils import name_cc_cleanup_task +from danswer.background.task_utils import name_cc_prune_task from danswer.background.task_utils import name_document_set_sync_task from danswer.configs.app_configs import JOB_TIMEOUT +from danswer.connectors.factory import instantiate_connector +from danswer.connectors.models import InputType from danswer.db.connector_credential_pair import get_connector_credential_pair +from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed +from danswer.db.document import get_documents_for_connector_credential_pair from danswer.db.document import prepare_to_modify_documents from danswer.db.document_set import delete_document_set from danswer.db.document_set import fetch_document_sets @@ -22,8 +31,6 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import SYNC_DB_API from danswer.db.models import DocumentSet -from danswer.db.tasks import check_live_task_not_timed_out -from danswer.db.tasks import get_latest_task from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import UpdateRequest @@ -90,6 +97,74 @@ def cleanup_connector_credential_pair_task( raise e +@build_celery_task_wrapper(name_cc_prune_task) +@celery_app.task(soft_time_limit=JOB_TIMEOUT) +def prune_documents_task(connector_id: int, credential_id: int) -> None: + """connector pruning task. For a cc pair, this task pulls all docuement IDs from the source + and compares those IDs to locally stored documents and deletes all locally stored IDs missing + from the most recently pulled document ID list""" + with Session(get_sqlalchemy_engine()) as db_session: + try: + cc_pair = get_connector_credential_pair( + db_session=db_session, + connector_id=connector_id, + credential_id=credential_id, + ) + + if not cc_pair: + logger.warning(f"ccpair not found for {connector_id} {credential_id}") + return + + runnable_connector = instantiate_connector( + cc_pair.connector.source, + InputType.PRUNE, + cc_pair.connector.connector_specific_config, + cc_pair.credential, + db_session, + ) + + all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector( + runnable_connector + ) + + all_indexed_document_ids = { + doc.id + for doc in get_documents_for_connector_credential_pair( + db_session=db_session, + connector_id=connector_id, + credential_id=credential_id, + ) + } + + doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids) + + curr_ind_name, sec_ind_name = get_both_index_names(db_session) + document_index = get_default_document_index( + primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name + ) + + if len(doc_ids_to_remove) == 0: + logger.info( + f"No docs to prune from {cc_pair.connector.source} connector" + ) + return + + logger.info( + f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector" + ) + delete_connector_credential_pair_batch( + document_ids=doc_ids_to_remove, + connector_id=connector_id, + credential_id=credential_id, + document_index=document_index, + ) + except Exception as e: + logger.exception( + f"Failed to run pruning for connector id {connector_id} due to {e}" + ) + raise e + + @build_celery_task_wrapper(name_document_set_sync_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) def sync_document_set_task(document_set_id: int) -> None: @@ -188,32 +263,48 @@ def _sync_document_batch(document_ids: list[str], db_session: Session) -> None: soft_time_limit=JOB_TIMEOUT, ) def check_for_document_sets_sync_task() -> None: - """Runs periodically to check if any document sets are out of sync - Creates a task to sync the set if needed""" + """Runs periodically to check if any sync tasks should be run and adds them + to the queue""" with Session(get_sqlalchemy_engine()) as db_session: # check if any document sets are not synced document_set_info = fetch_document_sets( user_id=None, db_session=db_session, include_outdated=True ) for document_set, _ in document_set_info: - if not document_set.is_up_to_date: - task_name = name_document_set_sync_task(document_set.id) - latest_sync = get_latest_task(task_name, db_session) - - if latest_sync and check_live_task_not_timed_out( - latest_sync, db_session - ): - logger.info( - f"Document set '{document_set.id}' is already syncing. Skipping." - ) - continue - - logger.info(f"Document set {document_set.id} syncing now!") + if should_sync_doc_set(document_set, db_session): + logger.info(f"Syncing the {document_set.name} document set") sync_document_set_task.apply_async( kwargs=dict(document_set_id=document_set.id), ) +@celery_app.task( + name="check_for_prune_task", + soft_time_limit=JOB_TIMEOUT, +) +def check_for_prune_task() -> None: + """Runs periodically to check if any prune tasks should be run and adds them + to the queue""" + + with Session(get_sqlalchemy_engine()) as db_session: + all_cc_pairs = get_connector_credential_pairs(db_session) + + for cc_pair in all_cc_pairs: + if should_prune_cc_pair( + connector=cc_pair.connector, + credential=cc_pair.credential, + db_session=db_session, + ): + logger.info(f"Pruning the {cc_pair.connector.name} connector") + + prune_documents_task.apply_async( + kwargs=dict( + connector_id=cc_pair.connector.id, + credential_id=cc_pair.credential.id, + ) + ) + + ##### # Celery Beat (Periodic Tasks) Settings ##### @@ -223,3 +314,11 @@ def check_for_document_sets_sync_task() -> None: "schedule": timedelta(seconds=5), }, } +celery_app.conf.beat_schedule.update( + { + "check-for-prune": { + "task": "check_for_prune_task", + "schedule": timedelta(seconds=5), + }, + } +) diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py index 5c155fac315..48f0295cd09 100644 --- a/backend/danswer/background/celery/celery_utils.py +++ b/backend/danswer/background/celery/celery_utils.py @@ -1,8 +1,25 @@ +from datetime import datetime +from datetime import timezone + from sqlalchemy.orm import Session from danswer.background.task_utils import name_cc_cleanup_task +from danswer.background.task_utils import name_cc_prune_task +from danswer.background.task_utils import name_document_set_sync_task +from danswer.connectors.interfaces import BaseConnector +from danswer.connectors.interfaces import IdConnector +from danswer.connectors.interfaces import LoadConnector +from danswer.connectors.interfaces import PollConnector +from danswer.db.engine import get_db_current_time +from danswer.db.models import Connector +from danswer.db.models import Credential +from danswer.db.models import DocumentSet +from danswer.db.tasks import check_live_task_not_timed_out from danswer.db.tasks import get_latest_task from danswer.server.documents.models import DeletionAttemptSnapshot +from danswer.utils.logger import setup_logger + +logger = setup_logger() def get_deletion_status( @@ -21,3 +38,71 @@ def get_deletion_status( credential_id=credential_id, status=task_state.status, ) + + +def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool: + if document_set.is_up_to_date: + return False + + task_name = name_document_set_sync_task(document_set.id) + latest_sync = get_latest_task(task_name, db_session) + + if latest_sync and check_live_task_not_timed_out(latest_sync, db_session): + logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.") + return False + + logger.info(f"Document set {document_set.id} syncing now!") + return True + + +def should_prune_cc_pair( + connector: Connector, credential: Credential, db_session: Session +) -> bool: + if not connector.prune_freq: + return False + + pruning_task_name = name_cc_prune_task( + connector_id=connector.id, credential_id=credential.id + ) + last_pruning_task = get_latest_task(pruning_task_name, db_session) + current_db_time = get_db_current_time(db_session) + + if not last_pruning_task: + time_since_initialization = current_db_time - connector.time_created + if time_since_initialization.total_seconds() >= connector.prune_freq: + return True + return False + + if check_live_task_not_timed_out(last_pruning_task, db_session): + logger.info(f"Connector '{connector.name}' is already pruning. Skipping.") + return False + + if not last_pruning_task.start_time: + return False + + time_since_last_pruning = current_db_time - last_pruning_task.start_time + return time_since_last_pruning.total_seconds() >= connector.prune_freq + + +def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]: + """ + If the PruneConnector hasnt been implemented for the given connector, just pull + all docs using the load_from_state and grab out the IDs + """ + all_connector_doc_ids: set[str] = set() + if isinstance(runnable_connector, IdConnector): + all_connector_doc_ids = runnable_connector.retrieve_all_source_ids() + elif isinstance(runnable_connector, LoadConnector): + doc_batch_generator = runnable_connector.load_from_state() + for doc_batch in doc_batch_generator: + all_connector_doc_ids.update(doc.id for doc in doc_batch) + elif isinstance(runnable_connector, PollConnector): + start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp() + end = datetime.now(timezone.utc).timestamp() + doc_batch_generator = runnable_connector.poll_source(start=start, end=end) + for doc_batch in doc_batch_generator: + all_connector_doc_ids.update(doc.id for doc in doc_batch) + else: + raise RuntimeError("Pruning job could not find a valid runnable_connector.") + + return all_connector_doc_ids diff --git a/backend/danswer/background/connector_deletion.py b/backend/danswer/background/connector_deletion.py index d9701e57727..28c58f02dfc 100644 --- a/backend/danswer/background/connector_deletion.py +++ b/backend/danswer/background/connector_deletion.py @@ -41,7 +41,7 @@ _DELETION_BATCH_SIZE = 1000 -def _delete_connector_credential_pair_batch( +def delete_connector_credential_pair_batch( document_ids: list[str], connector_id: int, credential_id: int, @@ -169,7 +169,7 @@ def delete_connector_credential_pair( if not documents: break - _delete_connector_credential_pair_batch( + delete_connector_credential_pair_batch( document_ids=[document.id for document in documents], connector_id=connector_id, credential_id=credential_id, diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index 18b30113cdb..ffbfd8c931b 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -6,11 +6,7 @@ from sqlalchemy.orm import Session -from danswer.background.connector_deletion import ( - _delete_connector_credential_pair_batch, -) from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt -from danswer.configs.app_configs import DISABLE_DOCUMENT_CLEANUP from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET from danswer.connectors.factory import instantiate_connector from danswer.connectors.interfaces import GenerateDocumentsOutput @@ -21,8 +17,6 @@ from danswer.db.connector import disable_connector from danswer.db.connector_credential_pair import get_last_successful_attempt_time from danswer.db.connector_credential_pair import update_connector_credential_pair -from danswer.db.credentials import backend_update_credential_json -from danswer.db.document import get_documents_for_connector_credential_pair from danswer.db.engine import get_sqlalchemy_engine from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import mark_attempt_failed @@ -46,7 +40,7 @@ def _get_document_generator( attempt: IndexAttempt, start_time: datetime, end_time: datetime, -) -> tuple[GenerateDocumentsOutput, bool]: +) -> GenerateDocumentsOutput: """ NOTE: `start_time` and `end_time` are only used for poll connectors @@ -57,16 +51,13 @@ def _get_document_generator( task = attempt.connector.input_type try: - runnable_connector, new_credential_json = instantiate_connector( + runnable_connector = instantiate_connector( attempt.connector.source, task, attempt.connector.connector_specific_config, - attempt.credential.credential_json, + attempt.credential, + db_session, ) - if new_credential_json is not None: - backend_update_credential_json( - attempt.credential, new_credential_json, db_session - ) except Exception as e: logger.exception(f"Unable to instantiate connector due to {e}") disable_connector(attempt.connector.id, db_session) @@ -75,7 +66,7 @@ def _get_document_generator( if task == InputType.LOAD_STATE: assert isinstance(runnable_connector, LoadConnector) doc_batch_generator = runnable_connector.load_from_state() - is_listing_complete = True + elif task == InputType.POLL: assert isinstance(runnable_connector, PollConnector) if attempt.connector_id is None or attempt.credential_id is None: @@ -88,13 +79,12 @@ def _get_document_generator( doc_batch_generator = runnable_connector.poll_source( start=start_time.timestamp(), end=end_time.timestamp() ) - is_listing_complete = False else: # Event types cannot be handled by a background type raise RuntimeError(f"Invalid task type: {task}") - return doc_batch_generator, is_listing_complete + return doc_batch_generator def _run_indexing( @@ -166,7 +156,7 @@ def _run_indexing( datetime(1970, 1, 1, tzinfo=timezone.utc), ) - doc_batch_generator, is_listing_complete = _get_document_generator( + doc_batch_generator = _get_document_generator( db_session=db_session, attempt=index_attempt, start_time=window_start, @@ -224,39 +214,6 @@ def _run_indexing( docs_removed_from_index=0, ) - if is_listing_complete and not DISABLE_DOCUMENT_CLEANUP: - # clean up all documents from the index that have not been returned from the connector - all_indexed_document_ids = { - d.id - for d in get_documents_for_connector_credential_pair( - db_session=db_session, - connector_id=db_connector.id, - credential_id=db_credential.id, - ) - } - doc_ids_to_remove = list( - all_indexed_document_ids - all_connector_doc_ids - ) - logger.debug( - f"Cleaning up {len(doc_ids_to_remove)} documents that are not contained in the newest connector state" - ) - - # delete docs from cc-pair and receive the number of completely deleted docs in return - _delete_connector_credential_pair_batch( - document_ids=doc_ids_to_remove, - connector_id=db_connector.id, - credential_id=db_credential.id, - document_index=document_index, - ) - - update_docs_indexed( - db_session=db_session, - index_attempt=index_attempt, - total_docs_indexed=document_count, - new_docs_indexed=net_doc_change, - docs_removed_from_index=len(doc_ids_to_remove), - ) - run_end_dt = window_end if is_primary: update_connector_credential_pair( diff --git a/backend/danswer/background/task_utils.py b/backend/danswer/background/task_utils.py index 78a2938c3f7..902abdfec86 100644 --- a/backend/danswer/background/task_utils.py +++ b/backend/danswer/background/task_utils.py @@ -22,6 +22,10 @@ def name_document_set_sync_task(document_set_id: int) -> str: return f"sync_doc_set_{document_set_id}" +def name_cc_prune_task(connector_id: int, credential_id: int) -> str: + return f"prune_connector_credential_pair_{connector_id}_{credential_id}" + + T = TypeVar("T", bound=Callable) diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index 8fa5eecaeeb..7fc526a5cbe 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -106,12 +106,18 @@ class ImageGenerationDisplay(BaseModel): file_ids: list[str] +class CustomToolResponse(BaseModel): + response: dict + tool_name: str + + AnswerQuestionPossibleReturn = ( DanswerAnswerPiece | DanswerQuotes | CitationInfo | DanswerContexts | ImageGenerationDisplay + | CustomToolResponse | StreamingError ) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 72381542176..7733bf523d6 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -7,6 +7,7 @@ from danswer.chat.chat_utils import create_chat_chain from danswer.chat.models import CitationInfo +from danswer.chat.models import CustomToolResponse from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import ImageGenerationDisplay from danswer.chat.models import LlmDoc @@ -31,6 +32,7 @@ from danswer.db.engine import get_session_context_manager from danswer.db.llm import fetch_existing_llm_providers from danswer.db.models import SearchDoc as DbSearchDoc +from danswer.db.models import ToolCall from danswer.db.models import User from danswer.document_index.factory import get_default_document_index from danswer.file_store.models import ChatFileType @@ -54,7 +56,10 @@ from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.server.utils import get_json_line -from danswer.tools.factory import get_tool_cls +from danswer.tools.built_in_tools import get_built_in_tool_by_id +from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema +from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID +from danswer.tools.custom.custom_tool import CustomToolCallSummary from danswer.tools.force import ForceUseTool from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID from danswer.tools.images.image_generation_tool import ImageGenerationResponse @@ -65,6 +70,7 @@ from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID from danswer.tools.tool import Tool from danswer.tools.tool import ToolResponse +from danswer.tools.tool_runner import ToolCallFinalResult from danswer.tools.utils import compute_all_tool_tokens from danswer.tools.utils import explicit_tool_calling_supported from danswer.utils.logger import setup_logger @@ -162,7 +168,7 @@ def _check_should_force_search( args = {"query": new_msg_req.message} return ForceUseTool( - tool_name=SearchTool.name(), + tool_name=SearchTool.NAME, args=args, ) return None @@ -176,6 +182,7 @@ def _check_should_force_search( | DanswerAnswerPiece | CitationInfo | ImageGenerationDisplay + | CustomToolResponse ) ChatPacketStream = Iterator[ChatPacket] @@ -389,61 +396,78 @@ def stream_chat_message_objects( ), ) - persona_tool_classes = [ - get_tool_cls(tool, db_session) for tool in persona.tools - ] + # find out what tools to use + search_tool: SearchTool | None = None + tool_dict: dict[int, list[Tool]] = {} # tool_id to tool + for db_tool_model in persona.tools: + # handle in-code tools specially + if db_tool_model.in_code_tool_id: + tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session) + if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files: + search_tool = SearchTool( + db_session=db_session, + user=user, + persona=persona, + retrieval_options=retrieval_options, + prompt_config=prompt_config, + llm=llm, + pruning_config=document_pruning_config, + selected_docs=selected_llm_docs, + chunks_above=new_msg_req.chunks_above, + chunks_below=new_msg_req.chunks_below, + full_doc=new_msg_req.full_doc, + ) + tool_dict[db_tool_model.id] = [search_tool] + elif tool_cls.__name__ == ImageGenerationTool.__name__: + dalle_key = None + if ( + llm + and llm.config.api_key + and llm.config.model_provider == "openai" + ): + dalle_key = llm.config.api_key + else: + llm_providers = fetch_existing_llm_providers(db_session) + openai_provider = next( + iter( + [ + llm_provider + for llm_provider in llm_providers + if llm_provider.provider == "openai" + ] + ), + None, + ) + if not openai_provider or not openai_provider.api_key: + raise ValueError( + "Image generation tool requires an OpenAI API key" + ) + dalle_key = openai_provider.api_key + tool_dict[db_tool_model.id] = [ + ImageGenerationTool(api_key=dalle_key) + ] + + continue + + # handle all custom tools + if db_tool_model.openapi_schema: + tool_dict[db_tool_model.id] = cast( + list[Tool], + build_custom_tools_from_openapi_schema( + db_tool_model.openapi_schema + ), + ) + + tools: list[Tool] = [] + for tool_list in tool_dict.values(): + tools.extend(tool_list) # factor in tool definition size when pruning - document_pruning_config.tool_num_tokens = compute_all_tool_tokens( - persona_tool_classes - ) + document_pruning_config.tool_num_tokens = compute_all_tool_tokens(tools) document_pruning_config.using_tool_message = explicit_tool_calling_supported( llm.config.model_provider, llm.config.model_name ) - # NOTE: for now, only support SearchTool and ImageGenerationTool - # in the future, will support arbitrary user-defined tools - search_tool: SearchTool | None = None - tools: list[Tool] = [] - for tool_cls in persona_tool_classes: - if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files: - search_tool = SearchTool( - db_session=db_session, - user=user, - persona=persona, - retrieval_options=retrieval_options, - prompt_config=prompt_config, - llm=llm, - pruning_config=document_pruning_config, - selected_docs=selected_llm_docs, - chunks_above=new_msg_req.chunks_above, - chunks_below=new_msg_req.chunks_below, - full_doc=new_msg_req.full_doc, - ) - tools.append(search_tool) - elif tool_cls.__name__ == ImageGenerationTool.__name__: - dalle_key = None - if llm and llm.config.api_key and llm.config.model_provider == "openai": - dalle_key = llm.config.api_key - else: - llm_providers = fetch_existing_llm_providers(db_session) - openai_provider = next( - iter( - [ - llm_provider - for llm_provider in llm_providers - if llm_provider.provider == "openai" - ] - ), - None, - ) - if not openai_provider or not openai_provider.api_key: - raise ValueError( - "Image generation tool requires an OpenAI API key" - ) - dalle_key = openai_provider.api_key - tools.append(ImageGenerationTool(api_key=dalle_key)) - # LLM prompt building, response capturing, etc. answer = Answer( question=final_msg.message, @@ -468,7 +492,9 @@ def stream_chat_message_objects( ], tools=tools, force_use_tool=( - _check_should_force_search(new_msg_req) if search_tool else None + _check_should_force_search(new_msg_req) + if search_tool and len(tools) == 1 + else None ), ) @@ -476,6 +502,7 @@ def stream_chat_message_objects( qa_docs_response = None ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images dropped_indices = None + tool_result = None for packet in answer.processed_streamed_output: if isinstance(packet, ToolResponse): if packet.id == SEARCH_RESPONSE_SUMMARY_ID: @@ -521,8 +548,16 @@ def stream_chat_message_objects( yield ImageGenerationDisplay( file_ids=[str(file_id) for file_id in file_ids] ) + elif packet.id == CUSTOM_TOOL_RESPONSE_ID: + custom_tool_response = cast(CustomToolCallSummary, packet.response) + yield CustomToolResponse( + response=custom_tool_response.tool_result, + tool_name=custom_tool_response.tool_name, + ) else: + if isinstance(packet, ToolCallFinalResult): + tool_result = packet yield cast(ChatPacket, packet) except Exception as e: @@ -551,6 +586,11 @@ def stream_chat_message_objects( ) # Saving Gen AI answer and responding with message info + tool_name_to_tool_id: dict[str, int] = {} + for tool_id, tool_list in tool_dict.items(): + for tool in tool_list: + tool_name_to_tool_id[tool.name()] = tool_id + gen_ai_response_message = partial_response( message=answer.llm_answer, rephrased_query=( @@ -561,6 +601,16 @@ def stream_chat_message_objects( token_count=len(llm_tokenizer_encode_func(answer.llm_answer)), citations=db_citations, error=None, + tool_calls=[ + ToolCall( + tool_id=tool_name_to_tool_id[tool_result.tool_name], + tool_name=tool_result.tool_name, + tool_arguments=tool_result.tool_args, + tool_result=tool_result.tool_result, + ) + ] + if tool_result + else [], ) db_session.commit() # actually save user / assistant message diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 77aa4c17663..996700aef48 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -199,6 +199,8 @@ os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true" ) +DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day + ##### # Indexing Configs @@ -228,10 +230,6 @@ MINI_CHUNK_SIZE = 150 # Timeout to wait for job's last update before killing it, in hours CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3)) -# If set to true, then will not clean up documents that "no longer exist" when running Load connectors -DISABLE_DOCUMENT_CLEANUP = ( - os.environ.get("DISABLE_DOCUMENT_CLEANUP", "").lower() == "true" -) ##### diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 2c28299ccfb..9e0d318c5ec 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -52,6 +52,8 @@ # For combining attributes, doesn't have to be unique/perfect to work INDEX_SEPARATOR = "===" +# For File Connector Metadata override file +DANSWER_METADATA_FILENAME = ".danswer_metadata.json" # Messages DISABLED_GEN_AI_MSG = ( diff --git a/backend/danswer/connectors/factory.py b/backend/danswer/connectors/factory.py index 322ed44b9b8..f70980a8df5 100644 --- a/backend/danswer/connectors/factory.py +++ b/backend/danswer/connectors/factory.py @@ -1,6 +1,8 @@ from typing import Any from typing import Type +from sqlalchemy.orm import Session + from danswer.configs.constants import DocumentSource from danswer.connectors.axero.connector import AxeroConnector from danswer.connectors.bookstack.connector import BookstackConnector @@ -40,6 +42,8 @@ from danswer.connectors.wikipedia.connector import WikipediaConnector from danswer.connectors.zendesk.connector import ZendeskConnector from danswer.connectors.zulip.connector import ZulipConnector +from danswer.db.credentials import backend_update_credential_json +from danswer.db.models import Credential class ConnectorMissingException(Exception): @@ -119,10 +123,14 @@ def instantiate_connector( source: DocumentSource, input_type: InputType, connector_specific_config: dict[str, Any], - credentials: dict[str, Any], -) -> tuple[BaseConnector, dict[str, Any] | None]: + credential: Credential, + db_session: Session, +) -> BaseConnector: connector_class = identify_connector_class(source, input_type) connector = connector_class(**connector_specific_config) - new_credentials = connector.load_credentials(credentials) + new_credentials = connector.load_credentials(credential.credential_json) + + if new_credentials is not None: + backend_update_credential_json(credential, new_credentials, db_session) - return connector, new_credentials + return connector diff --git a/backend/danswer/connectors/file/connector.py b/backend/danswer/connectors/file/connector.py index 7bed4720892..77d01394d4f 100644 --- a/backend/danswer/connectors/file/connector.py +++ b/backend/danswer/connectors/file/connector.py @@ -86,7 +86,12 @@ def _process_file( all_metadata = {**metadata, **file_metadata} if metadata else file_metadata # If this is set, we will show this in the UI as the "name" of the file - file_display_name_override = all_metadata.get("file_display_name") + file_display_name = all_metadata.get("file_display_name") or os.path.basename( + file_name + ) + title = ( + all_metadata["title"] or "" if "title" in all_metadata else file_display_name + ) time_updated = all_metadata.get("time_updated", datetime.now(timezone.utc)) if isinstance(time_updated, str): @@ -108,6 +113,7 @@ def _process_file( "secondary_owners", "filename", "file_display_name", + "title", ] } @@ -131,8 +137,8 @@ def _process_file( Section(link=all_metadata.get("link"), text=file_content_raw.strip()) ], source=DocumentSource.FILE, - semantic_identifier=file_display_name_override - or os.path.basename(file_name), + semantic_identifier=file_display_name, + title=title, doc_updated_at=final_time_updated, primary_owners=p_owners, secondary_owners=s_owners, diff --git a/backend/danswer/connectors/interfaces.py b/backend/danswer/connectors/interfaces.py index 74e0865c06f..3bd99792cce 100644 --- a/backend/danswer/connectors/interfaces.py +++ b/backend/danswer/connectors/interfaces.py @@ -50,6 +50,12 @@ def poll_source( raise NotImplementedError +class IdConnector(BaseConnector): + @abc.abstractmethod + def retrieve_all_source_ids(self) -> set[str]: + raise NotImplementedError + + # Event driven class EventConnector(BaseConnector): @abc.abstractmethod diff --git a/backend/danswer/connectors/models.py b/backend/danswer/connectors/models.py index 85df7bfc9b9..37ed2e22bd5 100644 --- a/backend/danswer/connectors/models.py +++ b/backend/danswer/connectors/models.py @@ -13,6 +13,7 @@ class InputType(str, Enum): LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file POLL = "poll" # e.g. calling an API to get all documents in the last hour EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events + PRUNE = "prune" class ConnectorMissingCredentialError(PermissionError): diff --git a/backend/danswer/connectors/salesforce/connector.py b/backend/danswer/connectors/salesforce/connector.py index 741ec0b2eb9..9f10da78ba4 100644 --- a/backend/danswer/connectors/salesforce/connector.py +++ b/backend/danswer/connectors/salesforce/connector.py @@ -11,6 +11,7 @@ from danswer.configs.constants import DocumentSource from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from danswer.connectors.interfaces import GenerateDocumentsOutput +from danswer.connectors.interfaces import IdConnector from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import PollConnector from danswer.connectors.interfaces import SecondsSinceUnixEpoch @@ -23,11 +24,12 @@ DEFAULT_PARENT_OBJECT_TYPES = ["Account"] MAX_QUERY_LENGTH = 10000 # max query length is 20,000 characters +ID_PREFIX = "SALESFORCE_" logger = setup_logger() -class SalesforceConnector(LoadConnector, PollConnector): +class SalesforceConnector(LoadConnector, PollConnector, IdConnector): def __init__( self, batch_size: int = INDEX_BATCH_SIZE, @@ -77,7 +79,7 @@ def _convert_object_instance_to_document( if self.sf_client is None: raise ConnectorMissingCredentialError("Salesforce") - extracted_id = f"SALESFORCE_{object_dict['Id']}" + extracted_id = f"{ID_PREFIX}{object_dict['Id']}" extracted_link = f"https://{self.sf_client.sf_instance}/{extracted_id}" extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"]) extracted_object_text = extract_dict_text(object_dict) @@ -229,8 +231,6 @@ def _fetch_from_salesforce( yield doc_batch def load_from_state(self) -> GenerateDocumentsOutput: - if self.sf_client is None: - raise ConnectorMissingCredentialError("Salesforce") return self._fetch_from_salesforce() def poll_source( @@ -242,6 +242,20 @@ def poll_source( end_datetime = datetime.utcfromtimestamp(end) return self._fetch_from_salesforce(start=start_datetime, end=end_datetime) + def retrieve_all_source_ids(self) -> set[str]: + if self.sf_client is None: + raise ConnectorMissingCredentialError("Salesforce") + all_retrieved_ids: set[str] = set() + for parent_object_type in self.parent_object_list: + query = f"SELECT Id FROM {parent_object_type}" + query_result = self.sf_client.query_all(query) + all_retrieved_ids.update( + f"{ID_PREFIX}{instance_dict.get('Id', '')}" + for instance_dict in query_result["records"] + ) + + return all_retrieved_ids + if __name__ == "__main__": connector = SalesforceConnector( diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 50458efd8d8..7f31ccb264e 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -5,6 +5,7 @@ from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy.exc import MultipleResultsFound +from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session from danswer.auth.schemas import UserRole @@ -16,6 +17,7 @@ from danswer.db.models import Prompt from danswer.db.models import SearchDoc from danswer.db.models import SearchDoc as DBSearchDoc +from danswer.db.models import ToolCall from danswer.db.models import User from danswer.file_store.models import FileDescriptor from danswer.llm.override_models import LLMOverride @@ -24,6 +26,7 @@ from danswer.search.models import SavedSearchDoc from danswer.search.models import SearchDoc as ServerSearchDoc from danswer.server.query_and_chat.models import ChatMessageDetail +from danswer.tools.tool_runner import ToolCallFinalResult from danswer.utils.logger import setup_logger logger = setup_logger() @@ -185,6 +188,7 @@ def get_chat_messages_by_session( user_id: UUID | None, db_session: Session, skip_permission_check: bool = False, + prefetch_tool_calls: bool = False, ) -> list[ChatMessage]: if not skip_permission_check: get_chat_session_by_id( @@ -192,12 +196,18 @@ def get_chat_messages_by_session( ) stmt = ( - select(ChatMessage).where(ChatMessage.chat_session_id == chat_session_id) - # Start with the root message which has no parent + select(ChatMessage) + .where(ChatMessage.chat_session_id == chat_session_id) .order_by(nullsfirst(ChatMessage.parent_message)) ) - result = db_session.execute(stmt).scalars().all() + if prefetch_tool_calls: + stmt = stmt.options(joinedload(ChatMessage.tool_calls)) + + if prefetch_tool_calls: + result = db_session.scalars(stmt).unique().all() + else: + result = db_session.scalars(stmt).all() return list(result) @@ -251,6 +261,7 @@ def create_new_chat_message( reference_docs: list[DBSearchDoc] | None = None, # Maps the citation number [n] to the DB SearchDoc citations: dict[int, int] | None = None, + tool_calls: list[ToolCall] | None = None, commit: bool = True, ) -> ChatMessage: new_chat_message = ChatMessage( @@ -264,6 +275,7 @@ def create_new_chat_message( message_type=message_type, citations=citations, files=files, + tool_calls=tool_calls if tool_calls else [], error=error, ) @@ -459,6 +471,14 @@ def translate_db_message_to_chat_message_detail( time_sent=chat_message.time_sent, citations=chat_message.citations, files=chat_message.files or [], + tool_calls=[ + ToolCallFinalResult( + tool_name=tool_call.tool_name, + tool_args=tool_call.tool_arguments, + tool_result=tool_call.tool_result, + ) + for tool_call in chat_message.tool_calls + ], ) return chat_msg_detail diff --git a/backend/danswer/db/connector.py b/backend/danswer/db/connector.py index b13f6e7eca0..2e4b1ed4c3e 100644 --- a/backend/danswer/db/connector.py +++ b/backend/danswer/db/connector.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import aliased from sqlalchemy.orm import Session +from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ from danswer.configs.constants import DocumentSource from danswer.connectors.models import InputType from danswer.db.models import Connector @@ -84,6 +85,9 @@ def create_connector( input_type=connector_data.input_type, connector_specific_config=connector_data.connector_specific_config, refresh_freq=connector_data.refresh_freq, + prune_freq=connector_data.prune_freq + if connector_data.prune_freq is not None + else DEFAULT_PRUNING_FREQ, disabled=connector_data.disabled, ) db_session.add(connector) @@ -113,6 +117,11 @@ def update_connector( connector.input_type = connector_data.input_type connector.connector_specific_config = connector_data.connector_specific_config connector.refresh_freq = connector_data.refresh_freq + connector.prune_freq = ( + connector_data.prune_freq + if connector_data.prune_freq is not None + else DEFAULT_PRUNING_FREQ + ) connector.disabled = connector_data.disabled db_session.commit() @@ -259,6 +268,7 @@ def create_initial_default_connector(db_session: Session) -> None: input_type=InputType.LOAD_STATE, connector_specific_config={}, refresh_freq=None, + prune_freq=None, ) db_session.add(connector) db_session.commit() diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 3b4b67f0947..29ca74bfc89 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -133,6 +133,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base): prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user") # Personas owned by this user personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user") + # Custom tools created by this user + custom_tools: Mapped[list["Tool"]] = relationship("Tool", back_populates="user") class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base): @@ -330,7 +332,6 @@ class Document(Base): primary_owners: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True ) - # Something like assignee or space owner secondary_owners: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True ) @@ -382,6 +383,7 @@ class Connector(Base): postgresql.JSONB() ) refresh_freq: Mapped[int | None] = mapped_column(Integer, nullable=True) + prune_freq: Mapped[int | None] = mapped_column(Integer, nullable=True) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) @@ -618,6 +620,26 @@ class SearchDoc(Base): ) +class ToolCall(Base): + """Represents a single tool call""" + + __tablename__ = "tool_call" + + id: Mapped[int] = mapped_column(primary_key=True) + # not a FK because we want to be able to delete the tool without deleting + # this entry + tool_id: Mapped[int] = mapped_column(Integer()) + tool_name: Mapped[str] = mapped_column(String()) + tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB()) + tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB()) + + message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id")) + + message: Mapped["ChatMessage"] = relationship( + "ChatMessage", back_populates="tool_calls" + ) + + class ChatSession(Base): __tablename__ = "chat_session" @@ -723,6 +745,10 @@ class ChatMessage(Base): secondary="chat_message__search_doc", back_populates="chat_messages", ) + tool_calls: Mapped[list["ToolCall"]] = relationship( + "ToolCall", + back_populates="message", + ) class ChatFolder(Base): @@ -901,9 +927,18 @@ class Tool(Base): name: Mapped[str] = mapped_column(String, nullable=False) description: Mapped[str] = mapped_column(Text, nullable=True) # ID of the tool in the codebase, only applies for in-code tools. - # tools defiend via the UI will have this as None + # tools defined via the UI will have this as None in_code_tool_id: Mapped[str | None] = mapped_column(String, nullable=True) + # OpenAPI scheme for the tool. Only applies to tools defined via the UI. + openapi_schema: Mapped[dict[str, Any] | None] = mapped_column( + postgresql.JSONB(), nullable=True + ) + + # user who created / owns the tool. Will be None for built-in tools. + user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + + user: Mapped[User | None] = relationship("User", back_populates="custom_tools") # Relationship to Persona through the association table personas: Mapped[list["Persona"]] = relationship( "Persona", diff --git a/backend/danswer/db/tools.py b/backend/danswer/db/tools.py new file mode 100644 index 00000000000..1e75b1c4901 --- /dev/null +++ b/backend/danswer/db/tools.py @@ -0,0 +1,74 @@ +from typing import Any +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.db.models import Tool +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def get_tools(db_session: Session) -> list[Tool]: + return list(db_session.scalars(select(Tool)).all()) + + +def get_tool_by_id(tool_id: int, db_session: Session) -> Tool: + tool = db_session.scalar(select(Tool).where(Tool.id == tool_id)) + if not tool: + raise ValueError("Tool by specified id does not exist") + return tool + + +def create_tool( + name: str, + description: str | None, + openapi_schema: dict[str, Any] | None, + user_id: UUID | None, + db_session: Session, +) -> Tool: + new_tool = Tool( + name=name, + description=description, + in_code_tool_id=None, + openapi_schema=openapi_schema, + user_id=user_id, + ) + db_session.add(new_tool) + db_session.commit() + return new_tool + + +def update_tool( + tool_id: int, + name: str | None, + description: str | None, + openapi_schema: dict[str, Any] | None, + user_id: UUID | None, + db_session: Session, +) -> Tool: + tool = get_tool_by_id(tool_id, db_session) + if tool is None: + raise ValueError(f"Tool with ID {tool_id} does not exist") + + if name is not None: + tool.name = name + if description is not None: + tool.description = description + if openapi_schema is not None: + tool.openapi_schema = openapi_schema + if user_id is not None: + tool.user_id = user_id + db_session.commit() + + return tool + + +def delete_tool(tool_id: int, db_session: Session) -> None: + tool = get_tool_by_id(tool_id, db_session) + if tool is None: + raise ValueError(f"Tool with ID {tool_id} does not exist") + + db_session.delete(tool) + db_session.commit() diff --git a/backend/danswer/db/users.py b/backend/danswer/db/users.py index fa15aa4b789..f8a3938027f 100644 --- a/backend/danswer/db/users.py +++ b/backend/danswer/db/users.py @@ -1,15 +1,18 @@ from collections.abc import Sequence -from sqlalchemy import select from sqlalchemy.orm import Session +from sqlalchemy.schema import Column from danswer.db.models import User -def list_users(db_session: Session) -> Sequence[User]: +def list_users(db_session: Session, q: str = "") -> Sequence[User]: """List all users. No pagination as of now, as the # of users is assumed to be relatively small (<< 1 million)""" - return db_session.scalars(select(User)).unique().all() + query = db_session.query(User) + if q: + query = query.filter(Column("email").ilike("%{}%".format(q))) + return query.all() def get_user_by_email(email: str, db_session: Session) -> User | None: diff --git a/backend/danswer/file_processing/extract_file_text.py b/backend/danswer/file_processing/extract_file_text.py index bb964141c4b..14cca7f6f9b 100644 --- a/backend/danswer/file_processing/extract_file_text.py +++ b/backend/danswer/file_processing/extract_file_text.py @@ -16,6 +16,7 @@ from pypdf import PdfReader from pypdf.errors import PdfStreamError +from danswer.configs.constants import DANSWER_METADATA_FILENAME from danswer.file_processing.html_utils import parse_html_page_basic from danswer.utils.logger import setup_logger @@ -88,7 +89,7 @@ def load_files_from_zip( with zipfile.ZipFile(zip_file_io, "r") as zip_file: zip_metadata = {} try: - metadata_file_info = zip_file.getinfo(".danswer_metadata.json") + metadata_file_info = zip_file.getinfo(DANSWER_METADATA_FILENAME) with zip_file.open(metadata_file_info, "r") as metadata_file: try: zip_metadata = json.load(metadata_file) @@ -96,18 +97,19 @@ def load_files_from_zip( # convert list of dicts to dict of dicts zip_metadata = {d["filename"]: d for d in zip_metadata} except json.JSONDecodeError: - logger.warn("Unable to load .danswer_metadata.json") + logger.warn(f"Unable to load {DANSWER_METADATA_FILENAME}") except KeyError: - logger.info("No .danswer_metadata.json file") + logger.info(f"No {DANSWER_METADATA_FILENAME} file") for file_info in zip_file.infolist(): with zip_file.open(file_info.filename, "r") as file: if ignore_dirs and file_info.is_dir(): continue - if ignore_macos_resource_fork_files and is_macos_resource_fork_file( - file_info.filename - ): + if ( + ignore_macos_resource_fork_files + and is_macos_resource_fork_file(file_info.filename) + ) or file_info.filename == DANSWER_METADATA_FILENAME: continue yield file_info, file, zip_metadata.get(file_info.filename, {}) diff --git a/backend/danswer/file_store/utils.py b/backend/danswer/file_store/utils.py index 82c027304ae..4b849f70d96 100644 --- a/backend/danswer/file_store/utils.py +++ b/backend/danswer/file_store/utils.py @@ -24,7 +24,7 @@ def load_chat_file( file_id=file_descriptor["id"], content=file_io.read(), file_type=file_descriptor["type"], - filename=file_descriptor["name"], + filename=file_descriptor.get("name"), ) diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index f26e5f4ea93..55b2be36002 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -4,6 +4,7 @@ from langchain.schema.messages import BaseMessage from langchain_core.messages import AIMessageChunk +from langchain_core.messages import HumanMessage from danswer.chat.chat_utils import llm_doc_from_inference_section from danswer.chat.models import AnswerQuestionPossibleReturn @@ -33,6 +34,9 @@ from danswer.llm.interfaces import LLM from danswer.llm.utils import get_default_llm_tokenizer from danswer.llm.utils import message_generator_to_string_generator +from danswer.tools.custom.custom_tool_prompt_builder import ( + build_user_message_for_custom_tool_for_non_tool_calling_llm, +) from danswer.tools.force import filter_tools_for_force_tool_use from danswer.tools.force import ForceUseTool from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID @@ -50,7 +54,8 @@ from danswer.tools.tool_runner import ( check_which_tools_should_run_for_non_tool_calling_llm, ) -from danswer.tools.tool_runner import ToolRunKickoff +from danswer.tools.tool_runner import ToolCallFinalResult +from danswer.tools.tool_runner import ToolCallKickoff from danswer.tools.tool_runner import ToolRunner from danswer.tools.utils import explicit_tool_calling_supported @@ -72,7 +77,7 @@ def _get_answer_stream_processor( raise RuntimeError("Not implemented yet") -AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolRunKickoff | ToolResponse] +AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse] class Answer: @@ -124,9 +129,10 @@ def __init__( self._final_prompt: list[BaseMessage] | None = None self._streamed_output: list[str] | None = None - self._processed_stream: ( - list[AnswerQuestionPossibleReturn | ToolResponse | ToolRunKickoff] | None - ) = None + + self._processed_stream: list[ + AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff + ] | None = None def _update_prompt_builder_for_search_tool( self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc] @@ -160,7 +166,7 @@ def _update_prompt_builder_for_search_tool( def _raw_output_for_explicit_tool_calling_llms( self, - ) -> Iterator[str | ToolRunKickoff | ToolResponse]: + ) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]: prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) tool_call_chunk: AIMessageChunk | None = None @@ -237,16 +243,18 @@ def _raw_output_for_explicit_tool_calling_llms( ), ) - if tool.name() == SearchTool.name(): + if tool.name() == SearchTool.NAME: self._update_prompt_builder_for_search_tool(prompt_builder, []) - elif tool.name() == ImageGenerationTool.name(): + elif tool.name() == ImageGenerationTool.NAME: prompt_builder.update_user_prompt( build_image_generation_user_prompt( query=self.question, ) ) - prompt = prompt_builder.build(tool_call_summary=tool_call_summary) + yield tool_runner.tool_final_result() + + prompt = prompt_builder.build(tool_call_summary=tool_call_summary) yield from message_generator_to_string_generator( self.llm.stream( prompt=prompt, @@ -258,7 +266,7 @@ def _raw_output_for_explicit_tool_calling_llms( def _raw_output_for_non_explicit_tool_calling_llms( self, - ) -> Iterator[str | ToolRunKickoff | ToolResponse]: + ) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]: prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) chosen_tool_and_args: tuple[Tool, dict] | None = None @@ -324,7 +332,7 @@ def _raw_output_for_non_explicit_tool_calling_llms( tool_runner = ToolRunner(tool, tool_args) yield tool_runner.kickoff() - if tool.name() == SearchTool.name(): + if tool.name() == SearchTool.NAME: final_context_documents = None for response in tool_runner.tool_responses(): if response.id == FINAL_CONTEXT_DOCUMENTS: @@ -337,7 +345,7 @@ def _raw_output_for_non_explicit_tool_calling_llms( self._update_prompt_builder_for_search_tool( prompt_builder, final_context_documents ) - elif tool.name() == ImageGenerationTool.name(): + elif tool.name() == ImageGenerationTool.NAME: img_urls = [] for response in tool_runner.tool_responses(): if response.id == IMAGE_GENERATION_RESPONSE_ID: @@ -354,6 +362,18 @@ def _raw_output_for_non_explicit_tool_calling_llms( img_urls=img_urls, ) ) + else: + prompt_builder.update_user_prompt( + HumanMessage( + content=build_user_message_for_custom_tool_for_non_tool_calling_llm( + self.question, + tool.name(), + *tool_runner.tool_responses(), + ) + ) + ) + + yield tool_runner.tool_final_result() prompt = prompt_builder.build() yield from message_generator_to_string_generator(self.llm.stream(prompt=prompt)) @@ -374,7 +394,7 @@ def processed_streamed_output(self) -> AnswerStream: ) def _process_stream( - stream: Iterator[ToolRunKickoff | ToolResponse | str], + stream: Iterator[ToolCallKickoff | ToolResponse | str], ) -> AnswerStream: message = None @@ -387,7 +407,9 @@ def _process_stream( ) for message in stream: - if isinstance(message, ToolRunKickoff): + if isinstance(message, ToolCallKickoff) or isinstance( + message, ToolCallFinalResult + ): yield message elif isinstance(message, ToolResponse): if message.id == SEARCH_RESPONSE_SUMMARY_ID: diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 92caa608f87..c363961d642 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -63,6 +63,7 @@ from danswer.server.features.persona.api import admin_router as admin_persona_router from danswer.server.features.persona.api import basic_router as persona_router from danswer.server.features.prompt.api import basic_router as prompt_router +from danswer.server.features.tool.api import admin_router as admin_tool_router from danswer.server.features.tool.api import router as tool_router from danswer.server.gpts.api import router as gpts_router from danswer.server.manage.administrative import router as admin_router @@ -281,6 +282,7 @@ def get_application() -> FastAPI: include_router_with_global_prefix_prepended(application, admin_persona_router) include_router_with_global_prefix_prepended(application, prompt_router) include_router_with_global_prefix_prepended(application, tool_router) + include_router_with_global_prefix_prepended(application, admin_tool_router) include_router_with_global_prefix_prepended(application, state_router) include_router_with_global_prefix_prepended(application, danswer_api_router) include_router_with_global_prefix_prepended(application, gpts_router) diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index ebdfc992dcb..c131c0c05ca 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -51,7 +51,7 @@ from danswer.tools.search.search_tool import SearchTool from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID from danswer.tools.tool import ToolResponse -from danswer.tools.tool_runner import ToolRunKickoff +from danswer.tools.tool_runner import ToolCallKickoff from danswer.utils.logger import setup_logger from danswer.utils.timing import log_generator_function_time @@ -67,7 +67,7 @@ | StreamingError | ChatMessageDetail | CitationInfo - | ToolRunKickoff + | ToolCallKickoff ] diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index e1833468054..ad25523817d 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -511,6 +511,7 @@ def update_connector_from_model( input_type=updated_connector.input_type, connector_specific_config=updated_connector.connector_specific_config, refresh_freq=updated_connector.refresh_freq, + prune_freq=updated_connector.prune_freq, credential_ids=[ association.credential.id for association in updated_connector.credentials ], @@ -726,6 +727,7 @@ def get_connector_by_id( input_type=connector.input_type, connector_specific_config=connector.connector_specific_config, refresh_freq=connector.refresh_freq, + prune_freq=connector.prune_freq, credential_ids=[ association.credential.id for association in connector.credentials ], diff --git a/backend/danswer/server/documents/models.py b/backend/danswer/server/documents/models.py index d574cc361fa..f1eb71dcea2 100644 --- a/backend/danswer/server/documents/models.py +++ b/backend/danswer/server/documents/models.py @@ -68,6 +68,7 @@ class ConnectorBase(BaseModel): input_type: InputType connector_specific_config: dict[str, Any] refresh_freq: int | None # In seconds, None for one time index with no refresh + prune_freq: int | None disabled: bool @@ -86,6 +87,7 @@ def from_connector_db_model(cls, connector: Connector) -> "ConnectorSnapshot": input_type=connector.input_type, connector_specific_config=connector.connector_specific_config, refresh_freq=connector.refresh_freq, + prune_freq=connector.prune_freq, credential_ids=[ association.credential.id for association in connector.credentials ], diff --git a/backend/danswer/server/features/tool/api.py b/backend/danswer/server/features/tool/api.py index 0a9666646a4..b1f57a1a924 100644 --- a/backend/danswer/server/features/tool/api.py +++ b/backend/danswer/server/features/tool/api.py @@ -1,32 +1,132 @@ +from typing import Any + from fastapi import APIRouter from fastapi import Depends +from fastapi import HTTPException from pydantic import BaseModel -from sqlalchemy import select from sqlalchemy.orm import Session +from danswer.auth.users import current_admin_user from danswer.auth.users import current_user from danswer.db.engine import get_session -from danswer.db.models import Tool from danswer.db.models import User - +from danswer.db.tools import create_tool +from danswer.db.tools import delete_tool +from danswer.db.tools import get_tool_by_id +from danswer.db.tools import get_tools +from danswer.db.tools import update_tool +from danswer.server.features.tool.models import ToolSnapshot +from danswer.tools.custom.openapi_parsing import MethodSpec +from danswer.tools.custom.openapi_parsing import openapi_to_method_specs +from danswer.tools.custom.openapi_parsing import validate_openapi_schema router = APIRouter(prefix="/tool") +admin_router = APIRouter(prefix="/admin/tool") -class ToolSnapshot(BaseModel): - id: int +class CustomToolCreate(BaseModel): name: str - description: str - in_code_tool_id: str | None + description: str | None + definition: dict[str, Any] + + +class CustomToolUpdate(BaseModel): + name: str | None + description: str | None + definition: dict[str, Any] | None + + +def _validate_tool_definition(definition: dict[str, Any]) -> None: + try: + validate_openapi_schema(definition) + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@admin_router.post("/custom") +def create_custom_tool( + tool_data: CustomToolCreate, + db_session: Session = Depends(get_session), + user: User | None = Depends(current_admin_user), +) -> ToolSnapshot: + _validate_tool_definition(tool_data.definition) + tool = create_tool( + name=tool_data.name, + description=tool_data.description, + openapi_schema=tool_data.definition, + user_id=user.id if user else None, + db_session=db_session, + ) + return ToolSnapshot.from_model(tool) + + +@admin_router.put("/custom/{tool_id}") +def update_custom_tool( + tool_id: int, + tool_data: CustomToolUpdate, + db_session: Session = Depends(get_session), + user: User | None = Depends(current_admin_user), +) -> ToolSnapshot: + if tool_data.definition: + _validate_tool_definition(tool_data.definition) + updated_tool = update_tool( + tool_id=tool_id, + name=tool_data.name, + description=tool_data.description, + openapi_schema=tool_data.definition, + user_id=user.id if user else None, + db_session=db_session, + ) + return ToolSnapshot.from_model(updated_tool) + - @classmethod - def from_model(cls, tool: Tool) -> "ToolSnapshot": - return cls( - id=tool.id, - name=tool.name, - description=tool.description, - in_code_tool_id=tool.in_code_tool_id, - ) +@admin_router.delete("/custom/{tool_id}") +def delete_custom_tool( + tool_id: int, + db_session: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> None: + try: + delete_tool(tool_id, db_session) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + except Exception as e: + # handles case where tool is still used by an Assistant + raise HTTPException(status_code=400, detail=str(e)) + + +class ValidateToolRequest(BaseModel): + definition: dict[str, Any] + + +class ValidateToolResponse(BaseModel): + methods: list[MethodSpec] + + +@admin_router.post("/custom/validate") +def validate_tool( + tool_data: ValidateToolRequest, + _: User | None = Depends(current_admin_user), +) -> ValidateToolResponse: + _validate_tool_definition(tool_data.definition) + method_specs = openapi_to_method_specs(tool_data.definition) + return ValidateToolResponse(methods=method_specs) + + +"""Endpoints for all""" + + +@router.get("/{tool_id}") +def get_custom_tool( + tool_id: int, + db_session: Session = Depends(get_session), + _: User | None = Depends(current_user), +) -> ToolSnapshot: + try: + tool = get_tool_by_id(tool_id, db_session) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + return ToolSnapshot.from_model(tool) @router.get("") @@ -34,5 +134,5 @@ def list_tools( db_session: Session = Depends(get_session), _: User | None = Depends(current_user), ) -> list[ToolSnapshot]: - tools = db_session.execute(select(Tool)).scalars().all() + tools = get_tools(db_session) return [ToolSnapshot.from_model(tool) for tool in tools] diff --git a/backend/danswer/server/features/tool/models.py b/backend/danswer/server/features/tool/models.py new file mode 100644 index 00000000000..feb3ba68269 --- /dev/null +++ b/backend/danswer/server/features/tool/models.py @@ -0,0 +1,23 @@ +from typing import Any + +from pydantic import BaseModel + +from danswer.db.models import Tool + + +class ToolSnapshot(BaseModel): + id: int + name: str + description: str + definition: dict[str, Any] | None + in_code_tool_id: str | None + + @classmethod + def from_model(cls, tool: Tool) -> "ToolSnapshot": + return cls( + id=tool.id, + name=tool.name, + description=tool.description, + definition=tool.openapi_schema, + in_code_tool_id=tool.in_code_tool_id, + ) diff --git a/backend/danswer/server/manage/models.py b/backend/danswer/server/manage/models.py index 8797913f44f..cbc3251bfd3 100644 --- a/backend/danswer/server/manage/models.py +++ b/backend/danswer/server/manage/models.py @@ -14,6 +14,8 @@ from danswer.db.models import SlackBotResponseType from danswer.indexing.models import EmbeddingModelDetail from danswer.server.features.persona.models import PersonaSnapshot +from danswer.server.models import FullUserSnapshot +from danswer.server.models import InvitedUserSnapshot if TYPE_CHECKING: from danswer.db.models import User as UserModel @@ -152,3 +154,10 @@ def from_model( class FullModelVersionResponse(BaseModel): current_model: EmbeddingModelDetail secondary_model: EmbeddingModelDetail | None + + +class AllUsersResponse(BaseModel): + accepted: list[FullUserSnapshot] + invited: list[InvitedUserSnapshot] + accepted_pages: int + invited_pages: int diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index a643f4e752f..3fbd15c7e98 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -1,4 +1,7 @@ +import re + from fastapi import APIRouter +from fastapi import Body from fastapi import Depends from fastapi import HTTPException from fastapi import status @@ -6,28 +9,41 @@ from sqlalchemy import update from sqlalchemy.orm import Session +from danswer.auth.invited_users import get_invited_users +from danswer.auth.invited_users import write_invited_users from danswer.auth.noauth_user import fetch_no_auth_user from danswer.auth.noauth_user import set_no_auth_user_preferences -from danswer.auth.schemas import UserRead from danswer.auth.schemas import UserRole +from danswer.auth.schemas import UserStatus from danswer.auth.users import current_admin_user from danswer.auth.users import current_user from danswer.auth.users import optional_user from danswer.configs.app_configs import AUTH_TYPE +from danswer.configs.app_configs import VALID_EMAIL_DOMAINS from danswer.configs.constants import AuthType from danswer.db.engine import get_session from danswer.db.models import User from danswer.db.users import get_user_by_email from danswer.db.users import list_users from danswer.dynamic_configs.factory import get_dynamic_config_store +from danswer.server.manage.models import AllUsersResponse from danswer.server.manage.models import UserByEmail from danswer.server.manage.models import UserInfo from danswer.server.manage.models import UserRoleResponse +from danswer.server.models import FullUserSnapshot +from danswer.server.models import InvitedUserSnapshot from danswer.server.models import MinimalUserSnapshot +from danswer.utils.logger import setup_logger + +logger = setup_logger() + router = APIRouter() +USERS_PAGE_SIZE = 10 + + @router.patch("/manage/promote-user-to-admin") def promote_admin( user_email: UserByEmail, @@ -69,11 +85,119 @@ async def demote_admin( @router.get("/manage/users") def list_all_users( + q: str, + accepted_page: int, + invited_page: int, _: User | None = Depends(current_admin_user), db_session: Session = Depends(get_session), -) -> list[UserRead]: - users = list_users(db_session) - return [UserRead.from_orm(user) for user in users] +) -> AllUsersResponse: + users = list_users(db_session, q=q) + accepted_emails = {user.email for user in users} + invited_emails = get_invited_users() + if q: + invited_emails = [ + email for email in invited_emails if re.search(r"{}".format(q), email, re.I) + ] + + accepted_count = len(accepted_emails) + invited_count = len(invited_emails) + + return AllUsersResponse( + accepted=[ + FullUserSnapshot( + id=user.id, + email=user.email, + role=user.role, + status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED, + ) + for user in users + ][accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE], + invited=[InvitedUserSnapshot(email=email) for email in invited_emails][ + invited_page * USERS_PAGE_SIZE : (invited_page + 1) * USERS_PAGE_SIZE + ], + accepted_pages=accepted_count // USERS_PAGE_SIZE + 1, + invited_pages=invited_count // USERS_PAGE_SIZE + 1, + ) + + +@router.put("/manage/admin/users") +def bulk_invite_users( + emails: list[str] = Body(..., embed=True), + current_user: User | None = Depends(current_admin_user), +) -> int: + if current_user is None: + raise HTTPException( + status_code=400, detail="Auth is disabled, cannot invite users" + ) + + all_emails = list(set(emails) | set(get_invited_users())) + return write_invited_users(all_emails) + + +@router.patch("/manage/admin/remove-invited-user") +def remove_invited_user( + user_email: UserByEmail, + _: User | None = Depends(current_admin_user), +) -> int: + user_emails = get_invited_users() + remaining_users = [user for user in user_emails if user != user_email.user_email] + return write_invited_users(remaining_users) + + +@router.patch("/manage/admin/deactivate-user") +def deactivate_user( + user_email: UserByEmail, + current_user: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + if current_user is None: + raise HTTPException( + status_code=400, detail="Auth is disabled, cannot deactivate user" + ) + + if current_user.email == user_email.user_email: + raise HTTPException(status_code=400, detail="You cannot deactivate yourself") + + user_to_deactivate = get_user_by_email( + email=user_email.user_email, db_session=db_session + ) + + if not user_to_deactivate: + raise HTTPException(status_code=404, detail="User not found") + + if user_to_deactivate.is_active is False: + logger.warning("{} is already deactivated".format(user_to_deactivate.email)) + + user_to_deactivate.is_active = False + db_session.add(user_to_deactivate) + db_session.commit() + + +@router.patch("/manage/admin/activate-user") +def activate_user( + user_email: UserByEmail, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + user_to_activate = get_user_by_email( + email=user_email.user_email, db_session=db_session + ) + if not user_to_activate: + raise HTTPException(status_code=404, detail="User not found") + + if user_to_activate.is_active is True: + logger.warning("{} is already activated".format(user_to_activate.email)) + + user_to_activate.is_active = True + db_session.add(user_to_activate) + db_session.commit() + + +@router.get("/manage/admin/valid-domains") +def get_valid_domains( + _: User | None = Depends(current_admin_user), +) -> list[str]: + return VALID_EMAIL_DOMAINS """Endpoints for all""" diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index 21349ae0747..fa70189f11c 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -6,6 +6,9 @@ from pydantic import BaseModel from pydantic.generics import GenericModel +from danswer.auth.schemas import UserRole +from danswer.auth.schemas import UserStatus + DataT = TypeVar("DataT") @@ -29,5 +32,16 @@ class MinimalUserSnapshot(BaseModel): email: str +class FullUserSnapshot(BaseModel): + id: UUID + email: str + role: UserRole + status: UserStatus + + +class InvitedUserSnapshot(BaseModel): + email: str + + class DisplayPriorityRequest(BaseModel): display_priority_map: dict[int, int] diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index d20a4b11101..834453e6e2d 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -129,6 +129,8 @@ def get_chat_session( # we already did a permission check above with the call to # `get_chat_session_by_id`, so we can skip it here skip_permission_check=True, + # we need the tool call objs anyways, so just fetch them in a single call + prefetch_tool_calls=True, ) return ChatSessionDetailResponse( diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 44e8ab84624..09561bf24f8 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -17,6 +17,7 @@ from danswer.search.models import RetrievalDetails from danswer.search.models import SearchDoc from danswer.search.models import Tag +from danswer.tools.models import ToolCallFinalResult class SourceTag(Tag): @@ -176,6 +177,7 @@ class ChatMessageDetail(BaseModel): # Dict mapping citation number to db_doc_id citations: dict[int, int] | None files: list[FileDescriptor] + tool_calls: list[ToolCallFinalResult] def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore initial_dict = super().dict(*args, **kwargs) # type: ignore diff --git a/backend/danswer/tools/custom/custom_tool.py b/backend/danswer/tools/custom/custom_tool.py new file mode 100644 index 00000000000..ea232fc5a74 --- /dev/null +++ b/backend/danswer/tools/custom/custom_tool.py @@ -0,0 +1,233 @@ +import json +from collections.abc import Generator +from typing import Any +from typing import cast + +import requests +from langchain_core.messages import HumanMessage +from langchain_core.messages import SystemMessage +from pydantic import BaseModel + +from danswer.dynamic_configs.interface import JSON_ro +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.interfaces import LLM +from danswer.tools.custom.custom_tool_prompts import ( + SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT, +) +from danswer.tools.custom.custom_tool_prompts import SHOULD_USE_CUSTOM_TOOL_USER_PROMPT +from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_SYSTEM_PROMPT +from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_USER_PROMPT +from danswer.tools.custom.custom_tool_prompts import USE_TOOL +from danswer.tools.custom.openapi_parsing import MethodSpec +from danswer.tools.custom.openapi_parsing import openapi_to_method_specs +from danswer.tools.custom.openapi_parsing import openapi_to_url +from danswer.tools.custom.openapi_parsing import REQUEST_BODY +from danswer.tools.custom.openapi_parsing import validate_openapi_schema +from danswer.tools.tool import Tool +from danswer.tools.tool import ToolResponse +from danswer.utils.logger import setup_logger + +logger = setup_logger() + +CUSTOM_TOOL_RESPONSE_ID = "custom_tool_response" + + +class CustomToolCallSummary(BaseModel): + tool_name: str + tool_result: dict + + +class CustomTool(Tool): + def __init__(self, method_spec: MethodSpec, base_url: str) -> None: + self._base_url = base_url + self._method_spec = method_spec + self._tool_definition = self._method_spec.to_tool_definition() + + self._name = self._method_spec.name + self.description = self._method_spec.summary + + def name(self) -> str: + return self._name + + """For LLMs which support explicit tool calling""" + + def tool_definition(self) -> dict: + return self._tool_definition + + def build_tool_message_content( + self, *args: ToolResponse + ) -> str | list[str | dict[str, Any]]: + response = cast(CustomToolCallSummary, args[0].response) + return json.dumps(response.tool_result) + + """For LLMs which do NOT support explicit tool calling""" + + def get_args_for_non_tool_calling_llm( + self, + query: str, + history: list[PreviousMessage], + llm: LLM, + force_run: bool = False, + ) -> dict[str, Any] | None: + if not force_run: + should_use_result = llm.invoke( + [ + SystemMessage(content=SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT), + HumanMessage( + content=SHOULD_USE_CUSTOM_TOOL_USER_PROMPT.format( + history=history, + query=query, + tool_name=self.name(), + tool_description=self.description, + ) + ), + ] + ) + if cast(str, should_use_result.content).strip() != USE_TOOL: + return None + + args_result = llm.invoke( + [ + SystemMessage(content=TOOL_ARG_SYSTEM_PROMPT), + HumanMessage( + content=TOOL_ARG_USER_PROMPT.format( + history=history, + query=query, + tool_name=self.name(), + tool_description=self.description, + tool_args=self.tool_definition()["function"]["parameters"], + ) + ), + ] + ) + args_result_str = cast(str, args_result.content) + + try: + return json.loads(args_result_str.strip()) + except json.JSONDecodeError: + pass + + # try removing ``` + try: + return json.loads(args_result_str.strip("```")) + except json.JSONDecodeError: + pass + + # try removing ```json + try: + return json.loads(args_result_str.strip("```").strip("json")) + except json.JSONDecodeError: + pass + + # pretend like nothing happened if not parse-able + logger.error( + f"Failed to parse args for '{self.name()}' tool. Recieved: {args_result_str}" + ) + return None + + """Actual execution of the tool""" + + def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: + request_body = kwargs.get(REQUEST_BODY) + + path_params = {} + for path_param_schema in self._method_spec.get_path_param_schemas(): + path_params[path_param_schema["name"]] = kwargs[path_param_schema["name"]] + + query_params = {} + for query_param_schema in self._method_spec.get_query_param_schemas(): + if query_param_schema["name"] in kwargs: + query_params[query_param_schema["name"]] = kwargs[ + query_param_schema["name"] + ] + + url = self._method_spec.build_url(self._base_url, path_params, query_params) + method = self._method_spec.method + + response = requests.request(method, url, json=request_body) + + yield ToolResponse( + id=CUSTOM_TOOL_RESPONSE_ID, + response=CustomToolCallSummary( + tool_name=self._name, tool_result=response.json() + ), + ) + + def final_result(self, *args: ToolResponse) -> JSON_ro: + return cast(CustomToolCallSummary, args[0].response).tool_result + + +def build_custom_tools_from_openapi_schema( + openapi_schema: dict[str, Any] +) -> list[CustomTool]: + url = openapi_to_url(openapi_schema) + method_specs = openapi_to_method_specs(openapi_schema) + return [CustomTool(method_spec, url) for method_spec in method_specs] + + +if __name__ == "__main__": + import openai + + openapi_schema = { + "openapi": "3.0.0", + "info": { + "version": "1.0.0", + "title": "Assistants API", + "description": "An API for managing assistants", + }, + "servers": [ + {"url": "http://localhost:8080"}, + ], + "paths": { + "/assistant/{assistant_id}": { + "get": { + "summary": "Get a specific Assistant", + "operationId": "getAssistant", + "parameters": [ + { + "name": "assistant_id", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + }, + "post": { + "summary": "Create a new Assistant", + "operationId": "createAssistant", + "parameters": [ + { + "name": "assistant_id", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "requestBody": { + "required": True, + "content": {"application/json": {"schema": {"type": "object"}}}, + }, + }, + } + }, + } + validate_openapi_schema(openapi_schema) + + tools = build_custom_tools_from_openapi_schema(openapi_schema) + + openai_client = openai.OpenAI() + response = openai_client.chat.completions.create( + model="gpt-4o", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Can you fetch assistant with ID 10"}, + ], + tools=[tool.tool_definition() for tool in tools], # type: ignore + ) + choice = response.choices[0] + if choice.message.tool_calls: + print(choice.message.tool_calls) + for tool_response in tools[0].run( + **json.loads(choice.message.tool_calls[0].function.arguments) + ): + print(tool_response) diff --git a/backend/danswer/tools/custom/custom_tool_prompt_builder.py b/backend/danswer/tools/custom/custom_tool_prompt_builder.py new file mode 100644 index 00000000000..8016363acc9 --- /dev/null +++ b/backend/danswer/tools/custom/custom_tool_prompt_builder.py @@ -0,0 +1,21 @@ +from typing import cast + +from danswer.tools.custom.custom_tool import CustomToolCallSummary +from danswer.tools.models import ToolResponse + + +def build_user_message_for_custom_tool_for_non_tool_calling_llm( + query: str, + tool_name: str, + *args: ToolResponse, +) -> str: + tool_run_summary = cast(CustomToolCallSummary, args[0].response).tool_result + return f""" +Here's the result from the {tool_name} tool: + +{tool_run_summary} + +Now respond to the following: + +{query} +""".strip() diff --git a/backend/danswer/tools/custom/custom_tool_prompts.py b/backend/danswer/tools/custom/custom_tool_prompts.py new file mode 100644 index 00000000000..14e8b007ef0 --- /dev/null +++ b/backend/danswer/tools/custom/custom_tool_prompts.py @@ -0,0 +1,57 @@ +from danswer.prompts.constants import GENERAL_SEP_PAT + +DONT_USE_TOOL = "Don't use tool" +USE_TOOL = "Use tool" + + +"""Prompts to determine if we should use a custom tool or not.""" + + +SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT = ( + "You are a large language model whose only job is to determine if the system should call an " + "external tool to be able to answer the user's last message." +).strip() + +SHOULD_USE_CUSTOM_TOOL_USER_PROMPT = f""" +Given the conversation history and a follow up query, determine if the system should use the \ +'{{tool_name}}' tool to answer the user's query. The '{{tool_name}}' tool is a tool defined as: '{{tool_description}}'. + +Respond with "{USE_TOOL}" if you think the tool would be helpful in respnding to the users query. +Respond with "{DONT_USE_TOOL}" otherwise. + +Conversation History: +{GENERAL_SEP_PAT} +{{history}} +{GENERAL_SEP_PAT} + +If you are at all unsure, respond with {DONT_USE_TOOL}. +Respond with EXACTLY and ONLY "{DONT_USE_TOOL}" or "{USE_TOOL}" + +Follow up input: +{{query}} +""".strip() + + +"""Prompts to figure out the arguments to pass to a custom tool.""" + + +TOOL_ARG_SYSTEM_PROMPT = ( + "You are a large language model whose only job is to determine the arguments to pass to an " + "external tool." +).strip() + + +TOOL_ARG_USER_PROMPT = f""" +Given the following conversation and a follow up input, generate a \ +dictionary of arguments to pass to the '{{tool_name}}' tool. \ +The '{{tool_name}}' tool is a tool defined as: '{{tool_description}}'. \ +The expected arguments are: {{tool_args}}. + +Conversation: +{{history}} + +Follow up input: +{{query}} + +Respond with ONLY and EXACTLY a JSON object specifying the values of the arguments to pass to the tool. +""".strip() # noqa: F541 diff --git a/backend/danswer/tools/custom/openapi_parsing.py b/backend/danswer/tools/custom/openapi_parsing.py new file mode 100644 index 00000000000..40ed5544d8b --- /dev/null +++ b/backend/danswer/tools/custom/openapi_parsing.py @@ -0,0 +1,225 @@ +from typing import Any +from typing import cast + +from openai import BaseModel + +REQUEST_BODY = "requestBody" + + +class PathSpec(BaseModel): + path: str + methods: dict[str, Any] + + +class MethodSpec(BaseModel): + name: str + summary: str + path: str + method: str + spec: dict[str, Any] + + def get_request_body_schema(self) -> dict[str, Any]: + content = self.spec.get("requestBody", {}).get("content", {}) + if "application/json" in content: + return content["application/json"].get("schema") + + if content: + raise ValueError( + f"Unsupported content type: '{list(content.keys())[0]}'. " + f"Only 'application/json' is supported." + ) + + return {} + + def get_query_param_schemas(self) -> list[dict[str, Any]]: + return [ + param + for param in self.spec.get("parameters", []) + if "schema" in param and "in" in param and param["in"] == "query" + ] + + def get_path_param_schemas(self) -> list[dict[str, Any]]: + return [ + param + for param in self.spec.get("parameters", []) + if "schema" in param and "in" in param and param["in"] == "path" + ] + + def build_url( + self, base_url: str, path_params: dict[str, str], query_params: dict[str, str] + ) -> str: + url = f"{base_url}{self.path}" + try: + url = url.format(**path_params) + except KeyError as e: + raise ValueError(f"Missing path parameter: {e}") + if query_params: + url += "?" + for param, value in query_params.items(): + url += f"{param}={value}&" + url = url[:-1] + return url + + def to_tool_definition(self) -> dict[str, Any]: + tool_definition: Any = { + "type": "function", + "function": { + "name": self.name, + "description": self.summary, + "parameters": {"type": "object", "properties": {}}, + }, + } + + request_body_schema = self.get_request_body_schema() + if request_body_schema: + tool_definition["function"]["parameters"]["properties"][ + REQUEST_BODY + ] = request_body_schema + + query_param_schemas = self.get_query_param_schemas() + if query_param_schemas: + tool_definition["function"]["parameters"]["properties"].update( + {param["name"]: param["schema"] for param in query_param_schemas} + ) + + path_param_schemas = self.get_path_param_schemas() + if path_param_schemas: + tool_definition["function"]["parameters"]["properties"].update( + {param["name"]: param["schema"] for param in path_param_schemas} + ) + return tool_definition + + def validate_spec(self) -> None: + # Validate url construction + path_param_schemas = self.get_path_param_schemas() + dummy_path_dict = {param["name"]: "value" for param in path_param_schemas} + query_param_schemas = self.get_query_param_schemas() + dummy_query_dict = {param["name"]: "value" for param in query_param_schemas} + self.build_url("", dummy_path_dict, dummy_query_dict) + + # Make sure request body doesn't throw an exception + self.get_request_body_schema() + + # Ensure the method is valid + if not self.method: + raise ValueError("HTTP method is not specified.") + if self.method.upper() not in ["GET", "POST", "PUT", "DELETE", "PATCH"]: + raise ValueError(f"HTTP method '{self.method}' is not supported.") + + +"""Path-level utils""" + + +def openapi_to_path_specs(openapi_spec: dict[str, Any]) -> list[PathSpec]: + path_specs = [] + + for path, methods in openapi_spec.get("paths", {}).items(): + path_specs.append(PathSpec(path=path, methods=methods)) + + return path_specs + + +"""Method-level utils""" + + +def openapi_to_method_specs(openapi_spec: dict[str, Any]) -> list[MethodSpec]: + path_specs = openapi_to_path_specs(openapi_spec) + + method_specs = [] + for path_spec in path_specs: + for method_name, method in path_spec.methods.items(): + name = method.get("operationId") + if not name: + raise ValueError( + f"Operation ID is not specified for {method_name.upper()} {path_spec.path}" + ) + + summary = method.get("summary") or method.get("description") + if not summary: + raise ValueError( + f"Summary is not specified for {method_name.upper()} {path_spec.path}" + ) + + method_specs.append( + MethodSpec( + name=name, + summary=summary, + path=path_spec.path, + method=method_name, + spec=method, + ) + ) + + if not method_specs: + raise ValueError("No methods found in OpenAPI schema") + + return method_specs + + +def openapi_to_url(openapi_schema: dict[str, dict | str]) -> str: + """ + Extract URLs from the servers section of an OpenAPI schema. + + Args: + openapi_schema (Dict[str, Union[Dict, str, List]]): The OpenAPI schema in dictionary format. + + Returns: + List[str]: A list of base URLs. + """ + urls: list[str] = [] + + servers = cast(list[dict[str, Any]], openapi_schema.get("servers", [])) + for server in servers: + url = server.get("url") + if url: + urls.append(url) + + if len(urls) != 1: + raise ValueError( + f"Expected exactly one URL in OpenAPI schema, but found {urls}" + ) + + return urls[0] + + +def validate_openapi_schema(schema: dict[str, Any]) -> None: + """ + Validate the given JSON schema as an OpenAPI schema. + + Parameters: + - schema (dict): The JSON schema to validate. + + Returns: + - bool: True if the schema is valid, False otherwise. + """ + + # check basic structure + if "info" not in schema: + raise ValueError("`info` section is required in OpenAPI schema") + + info = schema["info"] + if "title" not in info: + raise ValueError("`title` is required in `info` section of OpenAPI schema") + if "description" not in info: + raise ValueError( + "`description` is required in `info` section of OpenAPI schema" + ) + + if "openapi" not in schema: + raise ValueError( + "`openapi` field which specifies OpenAPI schema version is required" + ) + openapi_version = schema["openapi"] + if not openapi_version.startswith("3."): + raise ValueError(f"OpenAPI version '{openapi_version}' is not supported") + + if "paths" not in schema: + raise ValueError("`paths` section is required in OpenAPI schema") + + url = openapi_to_url(schema) + if not url: + raise ValueError("OpenAPI schema does not contain a valid URL in `servers`") + + method_specs = openapi_to_method_specs(schema) + for method_spec in method_specs: + method_spec.validate_spec() diff --git a/backend/danswer/tools/factory.py b/backend/danswer/tools/factory.py deleted file mode 100644 index 197bdd6619a..00000000000 --- a/backend/danswer/tools/factory.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Type - -from sqlalchemy.orm import Session - -from danswer.db.models import Tool as ToolDBModel -from danswer.tools.built_in_tools import get_built_in_tool_by_id -from danswer.tools.tool import Tool - - -def get_tool_cls(tool: ToolDBModel, db_session: Session) -> Type[Tool]: - # Currently only support built-in tools - return get_built_in_tool_by_id(tool.id, db_session) diff --git a/backend/danswer/tools/images/image_generation_tool.py b/backend/danswer/tools/images/image_generation_tool.py index da66271322f..22aa40993b6 100644 --- a/backend/danswer/tools/images/image_generation_tool.py +++ b/backend/danswer/tools/images/image_generation_tool.py @@ -8,6 +8,7 @@ from danswer.chat.chat_utils import combine_message_chain from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF +from danswer.dynamic_configs.interface import JSON_ro from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM from danswer.llm.utils import build_content_with_imgs @@ -53,6 +54,8 @@ class ImageGenerationResponse(BaseModel): class ImageGenerationTool(Tool): + NAME = "run_image_generation" + def __init__( self, api_key: str, model: str = "dall-e-3", num_imgs: int = 2 ) -> None: @@ -60,16 +63,14 @@ def __init__( self.model = model self.num_imgs = num_imgs - @classmethod def name(self) -> str: - return "run_image_generation" + return self.NAME - @classmethod - def tool_definition(cls) -> dict: + def tool_definition(self) -> dict: return { "type": "function", "function": { - "name": cls.name(), + "name": self.name(), "description": "Generate an image from a prompt", "parameters": { "type": "object", @@ -162,3 +163,12 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: id=IMAGE_GENERATION_RESPONSE_ID, response=results, ) + + def final_result(self, *args: ToolResponse) -> JSON_ro: + image_generation_responses = cast( + list[ImageGenerationResponse], args[0].response + ) + return [ + image_generation_response.dict() + for image_generation_response in image_generation_responses + ] diff --git a/backend/danswer/tools/models.py b/backend/danswer/tools/models.py new file mode 100644 index 00000000000..53940dcea49 --- /dev/null +++ b/backend/danswer/tools/models.py @@ -0,0 +1,39 @@ +from typing import Any + +from pydantic import BaseModel +from pydantic import root_validator + + +class ToolResponse(BaseModel): + id: str | None = None + response: Any + + +class ToolCallKickoff(BaseModel): + tool_name: str + tool_args: dict[str, Any] + + +class ToolRunnerResponse(BaseModel): + tool_run_kickoff: ToolCallKickoff | None = None + tool_response: ToolResponse | None = None + tool_message_content: str | list[str | dict[str, Any]] | None = None + + @root_validator + def validate_tool_runner_response( + cls, values: dict[str, ToolResponse | str] + ) -> dict[str, ToolResponse | str]: + fields = ["tool_response", "tool_message_content", "tool_run_kickoff"] + provided = sum(1 for field in fields if values.get(field) is not None) + + if provided != 1: + raise ValueError( + "Exactly one of 'tool_response', 'tool_message_content', " + "or 'tool_run_kickoff' must be provided" + ) + + return values + + +class ToolCallFinalResult(ToolCallKickoff): + tool_result: Any # we would like to use JSON_ro, but can't due to its recursive nature diff --git a/backend/danswer/tools/search/search_tool.py b/backend/danswer/tools/search/search_tool.py index 30ec47d1664..b0b45bd8f40 100644 --- a/backend/danswer/tools/search/search_tool.py +++ b/backend/danswer/tools/search/search_tool.py @@ -10,6 +10,7 @@ from danswer.chat.models import LlmDoc from danswer.db.models import Persona from danswer.db.models import User +from danswer.dynamic_configs.interface import JSON_ro from danswer.llm.answering.doc_pruning import prune_documents from danswer.llm.answering.models import DocumentPruningConfig from danswer.llm.answering.models import PreviousMessage @@ -55,6 +56,8 @@ class SearchResponseSummary(BaseModel): class SearchTool(Tool): + NAME = "run_search" + def __init__( self, db_session: Session, @@ -87,18 +90,16 @@ def __init__( self.bypass_acl = bypass_acl self.db_session = db_session - @classmethod - def name(cls) -> str: - return "run_search" + def name(self) -> str: + return self.NAME """For explicit tool calling""" - @classmethod - def tool_definition(cls) -> dict: + def tool_definition(self) -> dict: return { "type": "function", "function": { - "name": cls.name(), + "name": self.name(), "description": search_tool_description, "parameters": { "type": "object", @@ -241,3 +242,13 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: document_pruning_config=self.pruning_config, ) yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS, response=final_context_documents) + + def final_result(self, *args: ToolResponse) -> JSON_ro: + final_docs = cast( + list[LlmDoc], + next(arg.response for arg in args if arg.id == FINAL_CONTEXT_DOCUMENTS), + ) + # NOTE: need to do this json.loads(doc.json()) stuff because there are some + # subfields that are not serializable by default (datetime) + # this forces pydantic to make them JSON serializable for us + return [json.loads(doc.json()) for doc in final_docs] diff --git a/backend/danswer/tools/tool.py b/backend/danswer/tools/tool.py index dd443757e67..e335a049838 100644 --- a/backend/danswer/tools/tool.py +++ b/backend/danswer/tools/tool.py @@ -2,26 +2,19 @@ from collections.abc import Generator from typing import Any -from pydantic import BaseModel - +from danswer.dynamic_configs.interface import JSON_ro from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM - - -class ToolResponse(BaseModel): - id: str | None = None - response: Any +from danswer.tools.models import ToolResponse class Tool(abc.ABC): - @classmethod @abc.abstractmethod def name(self) -> str: raise NotImplementedError """For LLMs which support explicit tool calling""" - @classmethod @abc.abstractmethod def tool_definition(self) -> dict: raise NotImplementedError @@ -49,3 +42,11 @@ def get_args_for_non_tool_calling_llm( @abc.abstractmethod def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: raise NotImplementedError + + @abc.abstractmethod + def final_result(self, *args: ToolResponse) -> JSON_ro: + """ + This is the "final summary" result of the tool. + It is the result that will be stored in the database. + """ + raise NotImplementedError diff --git a/backend/danswer/tools/tool_runner.py b/backend/danswer/tools/tool_runner.py index 46f247b06dc..a4367d865d5 100644 --- a/backend/danswer/tools/tool_runner.py +++ b/backend/danswer/tools/tool_runner.py @@ -1,42 +1,15 @@ from collections.abc import Generator from typing import Any -from pydantic import BaseModel -from pydantic import root_validator - from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM +from danswer.tools.models import ToolCallFinalResult +from danswer.tools.models import ToolCallKickoff from danswer.tools.tool import Tool from danswer.tools.tool import ToolResponse from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel -class ToolRunKickoff(BaseModel): - tool_name: str - tool_args: dict[str, Any] - - -class ToolRunnerResponse(BaseModel): - tool_run_kickoff: ToolRunKickoff | None = None - tool_response: ToolResponse | None = None - tool_message_content: str | list[str | dict[str, Any]] | None = None - - @root_validator - def validate_tool_runner_response( - cls, values: dict[str, ToolResponse | str] - ) -> dict[str, ToolResponse | str]: - fields = ["tool_response", "tool_message_content", "tool_run_kickoff"] - provided = sum(1 for field in fields if values.get(field) is not None) - - if provided != 1: - raise ValueError( - "Exactly one of 'tool_response', 'tool_message_content', " - "or 'tool_run_kickoff' must be provided" - ) - - return values - - class ToolRunner: def __init__(self, tool: Tool, args: dict[str, Any]): self.tool = tool @@ -44,8 +17,8 @@ def __init__(self, tool: Tool, args: dict[str, Any]): self._tool_responses: list[ToolResponse] | None = None - def kickoff(self) -> ToolRunKickoff: - return ToolRunKickoff(tool_name=self.tool.name(), tool_args=self.args) + def kickoff(self) -> ToolCallKickoff: + return ToolCallKickoff(tool_name=self.tool.name(), tool_args=self.args) def tool_responses(self) -> Generator[ToolResponse, None, None]: if self._tool_responses is not None: @@ -62,6 +35,13 @@ def tool_message_content(self) -> str | list[str | dict[str, Any]]: tool_responses = list(self.tool_responses()) return self.tool.build_tool_message_content(*tool_responses) + def tool_final_result(self) -> ToolCallFinalResult: + return ToolCallFinalResult( + tool_name=self.tool.name(), + tool_args=self.args, + tool_result=self.tool.final_result(*self.tool_responses()), + ) + def check_which_tools_should_run_for_non_tool_calling_llm( tools: list[Tool], query: str, history: list[PreviousMessage], llm: LLM diff --git a/backend/danswer/tools/utils.py b/backend/danswer/tools/utils.py index 831021cdab3..7fb2156df59 100644 --- a/backend/danswer/tools/utils.py +++ b/backend/danswer/tools/utils.py @@ -1,5 +1,4 @@ import json -from typing import Type from tiktoken import Encoding @@ -17,15 +16,13 @@ def explicit_tool_calling_supported(model_provider: str, model_name: str) -> boo return False -def compute_tool_tokens( - tool: Tool | Type[Tool], llm_tokenizer: Encoding | None = None -) -> int: +def compute_tool_tokens(tool: Tool, llm_tokenizer: Encoding | None = None) -> int: if not llm_tokenizer: llm_tokenizer = get_default_llm_tokenizer() return len(llm_tokenizer.encode(json.dumps(tool.tool_definition()))) def compute_all_tool_tokens( - tools: list[Tool] | list[Type[Tool]], llm_tokenizer: Encoding | None = None + tools: list[Tool], llm_tokenizer: Encoding | None = None ) -> int: return sum(compute_tool_tokens(tool, llm_tokenizer) for tool in tools) diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 0eb2f22ecad..598b0e6de17 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -26,6 +26,7 @@ httpx[http2]==0.23.3 httpx-oauth==0.11.2 huggingface-hub==0.20.1 jira==3.5.1 +jsonref==1.1.0 langchain==0.1.17 langchain-community==0.0.36 langchain-core==0.1.50 diff --git a/backend/scripts/restart_containers.sh b/backend/scripts/restart_containers.sh index bfd3bd74572..c60d1905eb5 100755 --- a/backend/scripts/restart_containers.sh +++ b/backend/scripts/restart_containers.sh @@ -1,23 +1,40 @@ #!/bin/bash +# Usage of the script with optional volume arguments +# ./restart_containers.sh [vespa_volume] [postgres_volume] + +VESPA_VOLUME=${1:-""} # Default is empty if not provided +POSTGRES_VOLUME=${2:-""} # Default is empty if not provided + # Stop and remove the existing containers echo "Stopping and removing existing containers..." docker stop danswer_postgres danswer_vespa docker rm danswer_postgres danswer_vespa -# Start the PostgreSQL container +# Start the PostgreSQL container with optional volume echo "Starting PostgreSQL container..." -docker run -p 5432:5432 --name danswer_postgres -e POSTGRES_PASSWORD=password -d postgres +if [[ -n "$POSTGRES_VOLUME" ]]; then + docker run -p 5432:5432 --name danswer_postgres -e POSTGRES_PASSWORD=password -d -v $POSTGRES_VOLUME:/var/lib/postgresql/data postgres +else + docker run -p 5432:5432 --name danswer_postgres -e POSTGRES_PASSWORD=password -d postgres +fi -# Start the Vespa container +# Start the Vespa container with optional volume echo "Starting Vespa container..." -docker run --detach --name danswer_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 vespaengine/vespa:8 +if [[ -n "$VESPA_VOLUME" ]]; then + docker run --detach --name danswer_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 -v $VESPA_VOLUME:/opt/vespa/var vespaengine/vespa:8 +else + docker run --detach --name danswer_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 vespaengine/vespa:8 +fi # Ensure alembic runs in the correct directory SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" PARENT_DIR="$(dirname "$SCRIPT_DIR")" cd "$PARENT_DIR" +# Give Postgres a second to start +sleep 1 + # Run Alembic upgrade echo "Running Alembic migration..." alembic upgrade head diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index c7175faa1f5..cec7cac25a9 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -159,7 +159,6 @@ services: - GONG_CONNECTOR_START_TIME=${GONG_CONNECTOR_START_TIME:-} - NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-} - GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-} - - DISABLE_DOCUMENT_CLEANUP=${DISABLE_DOCUMENT_CLEANUP:-} # Danswer SlackBot Configs - DANSWER_BOT_SLACK_APP_TOKEN=${DANSWER_BOT_SLACK_APP_TOKEN:-} - DANSWER_BOT_SLACK_BOT_TOKEN=${DANSWER_BOT_SLACK_BOT_TOKEN:-} diff --git a/deployment/docker_compose/docker-compose.gpu-dev.yml b/deployment/docker_compose/docker-compose.gpu-dev.yml index 89dd673f91b..cce0f43028b 100644 --- a/deployment/docker_compose/docker-compose.gpu-dev.yml +++ b/deployment/docker_compose/docker-compose.gpu-dev.yml @@ -159,7 +159,6 @@ services: - GONG_CONNECTOR_START_TIME=${GONG_CONNECTOR_START_TIME:-} - NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-} - GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-} - - DISABLE_DOCUMENT_CLEANUP=${DISABLE_DOCUMENT_CLEANUP:-} # Danswer SlackBot Configs - DANSWER_BOT_SLACK_APP_TOKEN=${DANSWER_BOT_SLACK_APP_TOKEN:-} - DANSWER_BOT_SLACK_BOT_TOKEN=${DANSWER_BOT_SLACK_BOT_TOKEN:-} diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index d6427a67929..10b0ad586aa 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -147,40 +147,56 @@ export function AssistantEditor({ const imageGenerationTool = providerSupportingImageGenerationExists ? findImageGenerationTool(tools) : undefined; + const customTools = tools.filter( + (tool) => + tool.in_code_tool_id !== searchTool?.in_code_tool_id && + tool.in_code_tool_id !== imageGenerationTool?.in_code_tool_id + ); + + const availableTools = [ + ...customTools, + ...(searchTool ? [searchTool] : []), + ...(imageGenerationTool ? [imageGenerationTool] : []), + ]; + const enabledToolsMap: { [key: number]: boolean } = {}; + availableTools.forEach((tool) => { + enabledToolsMap[tool.id] = personaCurrentToolIds.includes(tool.id); + }); + + const initialValues = { + name: existingPersona?.name ?? "", + description: existingPersona?.description ?? "", + system_prompt: existingPrompt?.system_prompt ?? "", + task_prompt: existingPrompt?.task_prompt ?? "", + is_public: existingPersona?.is_public ?? defaultPublic, + document_set_ids: + existingPersona?.document_sets?.map((documentSet) => documentSet.id) ?? + ([] as number[]), + num_chunks: existingPersona?.num_chunks ?? null, + include_citations: existingPersona?.prompts[0]?.include_citations ?? true, + llm_relevance_filter: existingPersona?.llm_relevance_filter ?? false, + llm_model_provider_override: + existingPersona?.llm_model_provider_override ?? null, + llm_model_version_override: + existingPersona?.llm_model_version_override ?? null, + starter_messages: existingPersona?.starter_messages ?? [], + enabled_tools_map: enabledToolsMap, + // search_tool_enabled: existingPersona + // ? personaCurrentToolIds.includes(searchTool!.id) + // : ccPairs.length > 0, + // image_generation_tool_enabled: imageGenerationTool + // ? personaCurrentToolIds.includes(imageGenerationTool.id) + // : false, + // EE Only + groups: existingPersona?.groups ?? [], + }; return (
{popup} documentSet.id - ) ?? ([] as number[]), - num_chunks: existingPersona?.num_chunks ?? null, - include_citations: - existingPersona?.prompts[0]?.include_citations ?? true, - llm_relevance_filter: existingPersona?.llm_relevance_filter ?? false, - llm_model_provider_override: - existingPersona?.llm_model_provider_override ?? null, - llm_model_version_override: - existingPersona?.llm_model_version_override ?? null, - starter_messages: existingPersona?.starter_messages ?? [], - // EE Only - groups: existingPersona?.groups ?? [], - search_tool_enabled: existingPersona - ? personaCurrentToolIds.includes(searchTool!.id) - : ccPairs.length > 0, - image_generation_tool_enabled: imageGenerationTool - ? personaCurrentToolIds.includes(imageGenerationTool.id) - : false, - }} + initialValues={initialValues} validationSchema={Yup.object() .shape({ name: Yup.string().required("Must give the Assistant a name!"), @@ -205,8 +221,6 @@ export function AssistantEditor({ ), // EE Only groups: Yup.array().of(Yup.number()), - search_tool_enabled: Yup.boolean().required(), - image_generation_tool_enabled: Yup.boolean().required(), }) .test( "system-prompt-or-task-prompt", @@ -251,30 +265,36 @@ export function AssistantEditor({ formikHelpers.setSubmitting(true); - const tools = []; - if (values.search_tool_enabled && ccPairs.length > 0) { - tools.push(searchTool!.id); - } - if ( - values.image_generation_tool_enabled && - imageGenerationTool && - checkLLMSupportsImageInput( - providerDisplayNameToProviderName.get( - values.llm_model_provider_override || "" - ) || - defaultProviderName || - "", - values.llm_model_version_override || defaultModelName || "" - ) - ) { - tools.push(imageGenerationTool.id); + let enabledTools = Object.keys(values.enabled_tools_map) + .map((toolId) => Number(toolId)) + .filter((toolId) => values.enabled_tools_map[toolId]); + const searchToolEnabled = searchTool + ? enabledTools.includes(searchTool.id) + : false; + const imageGenerationToolEnabled = imageGenerationTool + ? enabledTools.includes(imageGenerationTool.id) + : false; + + if (imageGenerationToolEnabled) { + if ( + !checkLLMSupportsImageInput( + providerDisplayNameToProviderName.get( + values.llm_model_provider_override || "" + ) || + defaultProviderName || + "", + values.llm_model_version_override || defaultModelName || "" + ) + ) { + enabledTools = enabledTools.filter( + (toolId) => toolId !== imageGenerationTool!.id + ); + } } // if disable_retrieval is set, set num_chunks to 0 // to tell the backend to not fetch any documents - const numChunks = values.search_tool_enabled - ? values.num_chunks || 10 - : 0; + const numChunks = searchToolEnabled ? values.num_chunks || 10 : 0; // don't set groups if marked as public const groups = values.is_public ? [] : values.groups; @@ -290,7 +310,7 @@ export function AssistantEditor({ users: user && !checkUserIsNoAuthUser(user.id) ? [user.id] : undefined, groups, - tool_ids: tools, + tool_ids: enabledTools, }); } else { [promptResponse, personaResponse] = await createPersona({ @@ -299,7 +319,7 @@ export function AssistantEditor({ users: user && !checkUserIsNoAuthUser(user.id) ? [user.id] : undefined, groups, - tool_ids: tools, + tool_ids: enabledTools, }); } @@ -351,357 +371,386 @@ export function AssistantEditor({ } }} > - {({ isSubmitting, values, setFieldValue }) => ( -
-
- - <> - - - - - { - setFieldValue("system_prompt", e.target.value); - triggerFinalPromptUpdate( - e.target.value, - values.task_prompt, - values.search_tool_enabled - ); - }} - error={finalPromptError} - /> - - { + function toggleToolInValues(toolId: number) { + const updatedEnabledToolsMap = { + ...values.enabled_tools_map, + [toolId]: !values.enabled_tools_map[toolId], + }; + setFieldValue("enabled_tools_map", updatedEnabledToolsMap); + } + + function searchToolEnabled() { + return searchTool && values.enabled_tools_map[searchTool.id] + ? true + : false; + } + + return ( + +
+ + <> + + + + + { + setFieldValue("system_prompt", e.target.value); + triggerFinalPromptUpdate( + e.target.value, + values.task_prompt, + searchToolEnabled() + ); + }} + error={finalPromptError} + /> + + { - setFieldValue("task_prompt", e.target.value); - triggerFinalPromptUpdate( - values.system_prompt, - e.target.value, - values.search_tool_enabled - ); - }} - error={finalPromptError} - /> - - - - {finalPrompt ? ( -
-                      {finalPrompt}
-                    
- ) : ( - "-" - )} - -
- - + onChange={(e) => { + setFieldValue("task_prompt", e.target.value); + triggerFinalPromptUpdate( + values.system_prompt, + e.target.value, + searchToolEnabled() + ); + }} + error={finalPromptError} + /> + + + + {finalPrompt ? ( +
+                        {finalPrompt}
+                      
+ ) : ( + "-" + )} + + + + + + + <> + {ccPairs.length > 0 && searchTool && ( + <> + { + setFieldValue("num_chunks", null); + toggleToolInValues(searchTool.id); + }} + /> - - <> - {ccPairs.length > 0 && ( - <> - { - setFieldValue("num_chunks", null); - setFieldValue( - "search_tool_enabled", - e.target.checked - ); - }} - /> - - {values.search_tool_enabled && ( -
- {ccPairs.length > 0 && ( - <> - + {searchToolEnabled() && ( +
+ {ccPairs.length > 0 && ( + <> + -
- - <> - Select which{" "} - {!user || user.role === "admin" ? ( - - Document Sets - - ) : ( - "Document Sets" - )}{" "} - that this Assistant should search through. - If none are specified, the Assistant will - search through all available documents in - order to try and respond to queries. - - -
+
+ + <> + Select which{" "} + {!user || user.role === "admin" ? ( + + Document Sets + + ) : ( + "Document Sets" + )}{" "} + that this Assistant should search through. + If none are specified, the Assistant will + search through all available documents in + order to try and respond to queries. + + +
- {documentSets.length > 0 ? ( - ( -
-
- {documentSets.map((documentSet) => { - const ind = - values.document_set_ids.indexOf( - documentSet.id + {documentSets.length > 0 ? ( + ( +
+
+ {documentSets.map((documentSet) => { + const ind = + values.document_set_ids.indexOf( + documentSet.id + ); + let isSelected = ind !== -1; + return ( + { + if (isSelected) { + arrayHelpers.remove(ind); + } else { + arrayHelpers.push( + documentSet.id + ); + } + }} + /> ); - let isSelected = ind !== -1; - return ( - { - if (isSelected) { - arrayHelpers.remove(ind); - } else { - arrayHelpers.push( - documentSet.id - ); - } - }} - /> - ); - })} + })} +
-
- )} - /> - ) : ( - - No Document Sets available.{" "} - {user?.role !== "admin" && ( - <> - If this functionality would be useful, - reach out to the administrators of Danswer - for assistance. - - )} - - )} + )} + /> + ) : ( + + No Document Sets available.{" "} + {user?.role !== "admin" && ( + <> + If this functionality would be useful, + reach out to the administrators of + Danswer for assistance. + + )} + + )} + + <> + + How many chunks should we feed into the + LLM when generating the final response? + Each chunk is ~400 words long. +
+ } + onChange={(e) => { + const value = e.target.value; + // Allow only integer values + if ( + value === "" || + /^[0-9]+$/.test(value) + ) { + setFieldValue("num_chunks", value); + } + }} + /> - <> - - How many chunks should we feed into the - LLM when generating the final response? - Each chunk is ~400 words long. -
- } - onChange={(e) => { - const value = e.target.value; - // Allow only integer values - if ( - value === "" || - /^[0-9]+$/.test(value) - ) { - setFieldValue("num_chunks", value); + + + - - - - - - + + + /> + - - )} -
- )} - - )} - - {imageGenerationTool && - checkLLMSupportsImageInput( - providerDisplayNameToProviderName.get( - values.llm_model_provider_override || "" - ) || - defaultProviderName || - "", - values.llm_model_version_override || - defaultModelName || - "" - ) && ( - { - setFieldValue( - "image_generation_tool_enabled", - e.target.checked - ); - }} - /> + )} +
+ )} + )} - -
- + {imageGenerationTool && + checkLLMSupportsImageInput( + providerDisplayNameToProviderName.get( + values.llm_model_provider_override || "" + ) || + defaultProviderName || + "", + values.llm_model_version_override || + defaultModelName || + "" + ) && ( + { + toggleToolInValues(imageGenerationTool.id); + }} + /> + )} - {llmProviders.length > 0 && ( - <> - - <> - - Pick which LLM to use for this Assistant. If left as - Default, will use{" "} - {defaultModelName} - . -
-
- For more information on the different LLMs, checkout the{" "} - - OpenAI docs - - . -
- -
-
- LLM Provider - ({ - name: llmProvider.name, - value: llmProvider.name, - }))} - includeDefault={true} - onSelect={(selected) => { - if ( - selected !== values.llm_model_provider_override - ) { - setFieldValue( - "llm_model_version_override", - null - ); - } - setFieldValue( - "llm_model_provider_override", - selected - ); + {customTools.length > 0 && ( + <> + {customTools.map((tool) => ( + { + toggleToolInValues(tool.id); }} /> -
- - {values.llm_model_provider_override && ( -
- Model + ))} + + )} + + + + + + {llmProviders.length > 0 && ( + <> + + <> + + Pick which LLM to use for this Assistant. If left as + Default, will use{" "} + {defaultModelName} + . +
+
+ For more information on the different LLMs, checkout + the{" "} + + OpenAI docs + + . +
+ +
+
+ LLM Provider ({ + name: llmProvider.name, + value: llmProvider.name, + }))} + includeDefault={true} + onSelect={(selected) => { + if ( + selected !== values.llm_model_provider_override - ) || [] - } - maxHeight="max-h-72" + ) { + setFieldValue( + "llm_model_version_override", + null + ); + } + setFieldValue( + "llm_model_provider_override", + selected + ); + }} />
- )} -
- -
- - - - )} - - - <> -
- - Starter Messages help guide users to use this Assistant. - They are shown to the user as clickable options when they - select this Assistant. When selected, the specified - message is sent to the LLM as the initial user message. - -
- - ) => ( -
- {values.starter_messages && - values.starter_messages.length > 0 && - values.starter_messages.map((_, index) => { - return ( -
-
-
-
- - - Shows up as the "title" for - this Starter Message. For example, - "Write an email". - - + Model + +
+ )} +
+ + + + + + )} + + + <> +
+ + Starter Messages help guide users to use this Assistant. + They are shown to the user as clickable options when + they select this Assistant. When selected, the specified + message is sent to the LLM as the initial user message. + +
+ + + ) => ( +
+ {values.starter_messages && + values.starter_messages.length > 0 && + values.starter_messages.map((_, index) => { + return ( +
+
+
+
+ + + Shows up as the "title" for + this Starter Message. For example, + "Write an email". + + - -
+ autoComplete="off" + /> + +
-
- - - A description which tells the user what - they might want to use this Starter - Message for. For example "to a - client about a new feature" - - + + + A description which tells the user + what they might want to use this + Starter Message for. For example + "to a client about a new + feature" + + - -
+ autoComplete="off" + /> + +
-
- - - The actual message to be sent as the - initial user message if a user selects - this starter prompt. For example, - "Write me an email to a client - about a new billing feature we just - released." - - + + + The actual message to be sent as the + initial user message if a user selects + this starter prompt. For example, + "Write me an email to a client + about a new billing feature we just + released." + + - + +
+
+
+ + arrayHelpers.remove(index) + } />
-
- arrayHelpers.remove(index)} - /> -
-
- ); - })} - - -
- )} - /> - -
+ ); + })} + + +
+ )} + /> + + - + - {EE_ENABLED && userGroups && (!user || user.role === "admin") && ( - <> - + {EE_ENABLED && + userGroups && + (!user || user.role === "admin") && ( <> - - - {userGroups && - userGroups.length > 0 && - !values.is_public && ( -
- - Select which User Groups should have access to - this Assistant. - -
- {userGroups.map((userGroup) => { - const isSelected = values.groups.includes( - userGroup.id - ); - return ( - { - if (isSelected) { - setFieldValue( - "groups", - values.groups.filter( - (id) => id !== userGroup.id - ) - ); - } else { - setFieldValue("groups", [ - ...values.groups, - userGroup.id, - ]); - } - }} - > -
- -
- {userGroup.name} -
-
-
- ); - })} -
-
- )} + + <> + + + {userGroups && + userGroups.length > 0 && + !values.is_public && ( +
+ + Select which User Groups should have access to + this Assistant. + +
+ {userGroups.map((userGroup) => { + const isSelected = values.groups.includes( + userGroup.id + ); + return ( + { + if (isSelected) { + setFieldValue( + "groups", + values.groups.filter( + (id) => id !== userGroup.id + ) + ); + } else { + setFieldValue("groups", [ + ...values.groups, + userGroup.id, + ]); + } + }} + > +
+ +
+ {userGroup.name} +
+
+
+ ); + })} +
+
+ )} + +
+ -
- - - )} - -
- + )} + +
+ +
-
- - )} + + ); + }}
); diff --git a/web/src/app/admin/connectors/file/page.tsx b/web/src/app/admin/connectors/file/page.tsx index 3e4af0a85d4..b7e93e01c85 100644 --- a/web/src/app/admin/connectors/file/page.tsx +++ b/web/src/app/admin/connectors/file/page.tsx @@ -122,6 +122,7 @@ const Main = () => { file_locations: filePaths, }, refresh_freq: null, + prune_freq: 0, disabled: false, }); if (connectorErrorMsg || !connector) { diff --git a/web/src/app/admin/connectors/gmail/page.tsx b/web/src/app/admin/connectors/gmail/page.tsx index e81db40047c..f0800d293e4 100644 --- a/web/src/app/admin/connectors/gmail/page.tsx +++ b/web/src/app/admin/connectors/gmail/page.tsx @@ -141,9 +141,16 @@ const Main = () => { const { popup, setPopup } = usePopup(); + const appCredentialSuccessfullyFetched = + appCredentialData || + (isAppCredentialError && isAppCredentialError.status === 404); + const serviceAccountKeySuccessfullyFetched = + serviceAccountKeyData || + (isServiceAccountKeyError && isServiceAccountKeyError.status === 404); + if ( - (!appCredentialData && isAppCredentialLoading) || - (!serviceAccountKeyData && isServiceAccountKeyLoading) || + (!appCredentialSuccessfullyFetched && isAppCredentialLoading) || + (!serviceAccountKeySuccessfullyFetched && isServiceAccountKeyLoading) || (!connectorIndexingStatuses && isConnectorIndexingStatusesLoading) || (!credentialsData && isCredentialsLoading) ) { @@ -170,7 +177,10 @@ const Main = () => { ); } - if (isAppCredentialError || isServiceAccountKeyError) { + if ( + !appCredentialSuccessfullyFetched || + !serviceAccountKeySuccessfullyFetched + ) { return (
diff --git a/web/src/app/admin/connectors/google-sites/page.tsx b/web/src/app/admin/connectors/google-sites/page.tsx index 45ea4bcd1d9..20728633a59 100644 --- a/web/src/app/admin/connectors/google-sites/page.tsx +++ b/web/src/app/admin/connectors/google-sites/page.tsx @@ -114,6 +114,7 @@ export default function GoogleSites() { zip_path: filePaths[0], }, refresh_freq: null, + prune_freq: 0, disabled: false, }); if (connectorErrorMsg || !connector) { diff --git a/web/src/app/admin/connectors/salesforce/page.tsx b/web/src/app/admin/connectors/salesforce/page.tsx index 020d9a3d2bd..8771b14f94b 100644 --- a/web/src/app/admin/connectors/salesforce/page.tsx +++ b/web/src/app/admin/connectors/salesforce/page.tsx @@ -221,13 +221,32 @@ const MainSection = () => { // formBody={<>} formBodyBuilder={TextArrayFieldBuilder({ name: "requested_objects", - label: "requested_objects:", - subtext: - "Optionally, specify the Salesforce object type you would like us to index by. For example, specifying the object " + - "'Lead' will cause us to generate a document based on each Lead. " + - "These documents would contain all the information for this Lead, including its child objects (E.g. associated contacts, companies, etc.). " + - "If no requested objects are specified, we will default to indexing by 'Account'." + - "Make sure to use the singular form of the object (E.g. Opportunity instead of Opportunities)", + label: "Specify Salesforce objects to organize by:", + subtext: ( + <> +
+ Specify the Salesforce object types you want us to index.{" "} +
+
+ Click + + {" "} + here{" "} + + for an example of how Danswer uses the objects.
+
+ If unsure, don't specify any objects and Danswer will + default to indexing by 'Account'. +
+
+ Hint: Use the singular form of the object name (e.g., + 'Opportunity' instead of 'Opportunities'). + + ), })} validationSchema={Yup.object().shape({ requested_objects: Yup.array() diff --git a/web/src/app/admin/connectors/web/page.tsx b/web/src/app/admin/connectors/web/page.tsx index 547a164499a..410d187920e 100644 --- a/web/src/app/admin/connectors/web/page.tsx +++ b/web/src/app/admin/connectors/web/page.tsx @@ -118,6 +118,7 @@ export default function Web() { web_connector_type: undefined, }} refreshFreq={60 * 60 * 24} // 1 day + pruneFreq={0} // Don't prune /> diff --git a/web/src/app/admin/tools/ToolEditor.tsx b/web/src/app/admin/tools/ToolEditor.tsx new file mode 100644 index 00000000000..89046d21f1f --- /dev/null +++ b/web/src/app/admin/tools/ToolEditor.tsx @@ -0,0 +1,261 @@ +"use client"; + +import { useState, useEffect, useCallback } from "react"; +import { useRouter } from "next/navigation"; +import { Formik, Form, Field, ErrorMessage } from "formik"; +import * as Yup from "yup"; +import { MethodSpec, ToolSnapshot } from "@/lib/tools/interfaces"; +import { TextFormField } from "@/components/admin/connectors/Field"; +import { Button, Divider } from "@tremor/react"; +import { + createCustomTool, + updateCustomTool, + validateToolDefinition, +} from "@/lib/tools/edit"; +import { usePopup } from "@/components/admin/connectors/Popup"; +import debounce from "lodash/debounce"; + +function parseJsonWithTrailingCommas(jsonString: string) { + // Regular expression to remove trailing commas before } or ] + let cleanedJsonString = jsonString.replace(/,\s*([}\]])/g, "$1"); + // Replace True with true, False with false, and None with null + cleanedJsonString = cleanedJsonString + .replace(/\bTrue\b/g, "true") + .replace(/\bFalse\b/g, "false") + .replace(/\bNone\b/g, "null"); + // Now parse the cleaned JSON string + return JSON.parse(cleanedJsonString); +} + +function prettifyDefinition(definition: any) { + return JSON.stringify(definition, null, 2); +} + +function ToolForm({ + existingTool, + values, + setFieldValue, + isSubmitting, + definitionErrorState, + methodSpecsState, +}: { + existingTool?: ToolSnapshot; + values: ToolFormValues; + setFieldValue: (field: string, value: string) => void; + isSubmitting: boolean; + definitionErrorState: [ + string | null, + React.Dispatch>, + ]; + methodSpecsState: [ + MethodSpec[] | null, + React.Dispatch>, + ]; +}) { + const [definitionError, setDefinitionError] = definitionErrorState; + const [methodSpecs, setMethodSpecs] = methodSpecsState; + + const debouncedValidateDefinition = useCallback( + debounce(async (definition: string) => { + try { + const parsedDefinition = parseJsonWithTrailingCommas(definition); + const response = await validateToolDefinition({ + definition: parsedDefinition, + }); + if (response.error) { + setMethodSpecs(null); + setDefinitionError(response.error); + } else { + setMethodSpecs(response.data); + setDefinitionError(null); + } + } catch (error) { + console.log(error); + setMethodSpecs(null); + setDefinitionError("Invalid JSON format"); + } + }, 300), + [] + ); + + useEffect(() => { + if (values.definition) { + debouncedValidateDefinition(values.definition); + } + }, [values.definition, debouncedValidateDefinition]); + + return ( +
+
+ + +
+ {definitionError && ( +
{definitionError}
+ )} + + + {methodSpecs && methodSpecs.length > 0 && ( +
+

Available methods

+
+ + + + + + + + + + + {methodSpecs?.map((method: MethodSpec, index: number) => ( + + + + + + + ))} + +
NameSummaryMethodPath
{method.name}{method.summary} + {method.method.toUpperCase()} + {method.path}
+
+
+ )} + + +
+ +
+ + ); +} + +interface ToolFormValues { + definition: string; +} + +const ToolSchema = Yup.object().shape({ + definition: Yup.string().required("Tool definition is required"), +}); + +export function ToolEditor({ tool }: { tool?: ToolSnapshot }) { + const router = useRouter(); + const { popup, setPopup } = usePopup(); + const [definitionError, setDefinitionError] = useState(null); + const [methodSpecs, setMethodSpecs] = useState(null); + + const prettifiedDefinition = tool?.definition + ? prettifyDefinition(tool.definition) + : ""; + + return ( +
+ {popup} + { + let definition: any; + try { + definition = parseJsonWithTrailingCommas(values.definition); + } catch (error) { + setDefinitionError("Invalid JSON in tool definition"); + return; + } + + const name = definition?.info?.title; + const description = definition?.info?.description; + const toolData = { + name: name, + description: description || "", + definition: definition, + }; + let response; + if (tool) { + response = await updateCustomTool(tool.id, toolData); + } else { + response = await createCustomTool(toolData); + } + if (response.error) { + setPopup({ + message: "Failed to create tool - " + response.error, + type: "error", + }); + return; + } + router.push(`/admin/tools?u=${Date.now()}`); + }} + > + {({ isSubmitting, values, setFieldValue }) => { + return ( + + ); + }} + +
+ ); +} diff --git a/web/src/app/admin/tools/ToolsTable.tsx b/web/src/app/admin/tools/ToolsTable.tsx new file mode 100644 index 00000000000..8017a2431a0 --- /dev/null +++ b/web/src/app/admin/tools/ToolsTable.tsx @@ -0,0 +1,107 @@ +"use client"; + +import { + Text, + Table, + TableHead, + TableRow, + TableHeaderCell, + TableBody, + TableCell, +} from "@tremor/react"; +import { ToolSnapshot } from "@/lib/tools/interfaces"; +import { useRouter } from "next/navigation"; +import { usePopup } from "@/components/admin/connectors/Popup"; +import { FiCheckCircle, FiEdit, FiXCircle } from "react-icons/fi"; +import { TrashIcon } from "@/components/icons/icons"; +import { deleteCustomTool } from "@/lib/tools/edit"; + +export function ToolsTable({ tools }: { tools: ToolSnapshot[] }) { + const router = useRouter(); + const { popup, setPopup } = usePopup(); + + const sortedTools = [...tools]; + sortedTools.sort((a, b) => a.id - b.id); + + return ( +
+ {popup} + + + + + Name + Description + Built In? + Delete + + + + {sortedTools.map((tool) => ( + + +
+ {tool.in_code_tool_id === null && ( + + router.push( + `/admin/tools/edit/${tool.id}?u=${Date.now()}` + ) + } + /> + )} +

+ {tool.name} +

+
+
+ + {tool.description} + + + {tool.in_code_tool_id === null ? ( + + + No + + ) : ( + + + Yes + + )} + + +
+ {tool.in_code_tool_id === null ? ( +
+
{ + const response = await deleteCustomTool(tool.id); + if (response.data) { + router.refresh(); + } else { + setPopup({ + message: `Failed to delete tool - ${response.error}`, + type: "error", + }); + } + }} + > + +
+
+ ) : ( + "-" + )} +
+
+
+ ))} +
+
+
+ ); +} diff --git a/web/src/app/admin/tools/edit/[toolId]/DeleteToolButton.tsx b/web/src/app/admin/tools/edit/[toolId]/DeleteToolButton.tsx new file mode 100644 index 00000000000..c02e141b54a --- /dev/null +++ b/web/src/app/admin/tools/edit/[toolId]/DeleteToolButton.tsx @@ -0,0 +1,28 @@ +"use client"; + +import { Button } from "@tremor/react"; +import { FiTrash } from "react-icons/fi"; +import { deleteCustomTool } from "@/lib/tools/edit"; +import { useRouter } from "next/navigation"; + +export function DeleteToolButton({ toolId }: { toolId: number }) { + const router = useRouter(); + + return ( + + ); +} diff --git a/web/src/app/admin/tools/edit/[toolId]/page.tsx b/web/src/app/admin/tools/edit/[toolId]/page.tsx new file mode 100644 index 00000000000..8dd54be46b3 --- /dev/null +++ b/web/src/app/admin/tools/edit/[toolId]/page.tsx @@ -0,0 +1,55 @@ +import { ErrorCallout } from "@/components/ErrorCallout"; +import { Card, Text, Title } from "@tremor/react"; +import { ToolEditor } from "@/app/admin/tools/ToolEditor"; +import { fetchToolByIdSS } from "@/lib/tools/fetchTools"; +import { DeleteToolButton } from "./DeleteToolButton"; +import { FiTool } from "react-icons/fi"; +import { AdminPageTitle } from "@/components/admin/Title"; +import { BackButton } from "@/components/BackButton"; + +export default async function Page({ params }: { params: { toolId: string } }) { + const tool = await fetchToolByIdSS(params.toolId); + + let body; + if (!tool) { + body = ( +
+ +
+ ); + } else { + body = ( +
+
+
+ + + + + Delete Tool + Click the button below to permanently delete this tool. +
+ +
+
+
+
+ ); + } + + return ( +
+ + + } + /> + + {body} +
+ ); +} diff --git a/web/src/app/admin/tools/new/page.tsx b/web/src/app/admin/tools/new/page.tsx new file mode 100644 index 00000000000..5d1723f96ac --- /dev/null +++ b/web/src/app/admin/tools/new/page.tsx @@ -0,0 +1,24 @@ +"use client"; + +import { ToolEditor } from "@/app/admin/tools/ToolEditor"; +import { BackButton } from "@/components/BackButton"; +import { AdminPageTitle } from "@/components/admin/Title"; +import { Card } from "@tremor/react"; +import { FiTool } from "react-icons/fi"; + +export default function NewToolPage() { + return ( +
+ + + } + /> + + + + +
+ ); +} diff --git a/web/src/app/admin/tools/page.tsx b/web/src/app/admin/tools/page.tsx new file mode 100644 index 00000000000..7b9edf7abe0 --- /dev/null +++ b/web/src/app/admin/tools/page.tsx @@ -0,0 +1,68 @@ +import { ToolsTable } from "./ToolsTable"; +import { ToolSnapshot } from "@/lib/tools/interfaces"; +import { FiPlusSquare, FiTool } from "react-icons/fi"; +import Link from "next/link"; +import { Divider, Text, Title } from "@tremor/react"; +import { fetchSS } from "@/lib/utilsSS"; +import { ErrorCallout } from "@/components/ErrorCallout"; +import { AdminPageTitle } from "@/components/admin/Title"; + +export default async function Page() { + const toolResponse = await fetchSS("/tool"); + + if (!toolResponse.ok) { + return ( + + ); + } + + const tools = (await toolResponse.json()) as ToolSnapshot[]; + + return ( +
+ } + title="Tools" + /> + + + Tools allow assistants to retrieve information or take actions. + + +
+ + + Create a Tool + +
+ + New Tool +
+ + + + + Existing Tools + +
+
+ ); +} diff --git a/web/src/app/admin/users/page.tsx b/web/src/app/admin/users/page.tsx index acc4c5b5fc3..db0c6c03fa0 100644 --- a/web/src/app/admin/users/page.tsx +++ b/web/src/app/admin/users/page.tsx @@ -1,37 +1,89 @@ "use client"; +import InvitedUserTable from "@/components/admin/users/InvitedUserTable"; +import SignedUpUserTable from "@/components/admin/users/SignedUpUserTable"; +import { SearchBar } from "@/components/search/SearchBar"; +import { useState } from "react"; +import { FiPlusSquare } from "react-icons/fi"; +import { Modal } from "@/components/Modal"; -import { - Table, - TableHead, - TableRow, - TableHeaderCell, - TableBody, - TableCell, - Button, -} from "@tremor/react"; +import { Button, Text } from "@tremor/react"; import { LoadingAnimation } from "@/components/Loading"; import { AdminPageTitle } from "@/components/admin/Title"; -import { usePopup } from "@/components/admin/connectors/Popup"; +import { usePopup, PopupSpec } from "@/components/admin/connectors/Popup"; import { UsersIcon } from "@/components/icons/icons"; import { errorHandlingFetcher } from "@/lib/fetcher"; -import { User } from "@/lib/types"; +import { type User, UserStatus } from "@/lib/types"; import useSWR, { mutate } from "swr"; import { ErrorCallout } from "@/components/ErrorCallout"; +import { HidableSection } from "@/app/admin/assistants/HidableSection"; +import BulkAdd from "@/components/admin/users/BulkAdd"; -const UsersTable = () => { - const { popup, setPopup } = usePopup(); +const ValidDomainsDisplay = ({ validDomains }: { validDomains: string[] }) => { + if (!validDomains.length) { + return ( +
+ No invited users. Anyone can sign up with a valid email address. To + restrict access you can: +
+ (1) Invite users above. Once a user has been invited, only emails that + have explicitly been invited will be able to sign-up. +
+
+ (2) Set the{" "} + VALID_EMAIL_DOMAINS{" "} + environment variable to a comma separated list of email domains. This + will restrict access to users with email addresses from these domains. +
+
+ ); + } + + return ( +
+ No invited users. Anyone with an email address with any of the following + domains can sign up: {validDomains.join(", ")}. +
+ To further restrict access you can invite users above. Once a user has + been invited, only emails that have explicitly been invited will be able + to sign-up. +
+
+ ); +}; +interface UsersResponse { + accepted: User[]; + invited: User[]; + accepted_pages: number; + invited_pages: number; +} + +const UsersTables = ({ + q, + setPopup, +}: { + q: string; + setPopup: (spec: PopupSpec) => void; +}) => { + const [invitedPage, setInvitedPage] = useState(1); + const [acceptedPage, setAcceptedPage] = useState(1); + const { data, isLoading, mutate, error } = useSWR( + `/api/manage/users?q=${encodeURI(q)}&accepted_page=${ + acceptedPage - 1 + }&invited_page=${invitedPage - 1}`, + errorHandlingFetcher + ); const { - data: users, - isLoading, - error, - } = useSWR("/api/manage/users", errorHandlingFetcher); + data: validDomains, + isLoading: isLoadingDomains, + error: domainsError, + } = useSWR("/api/manage/admin/valid-domains", errorHandlingFetcher); - if (isLoading) { + if (isLoading || isLoadingDomains) { return ; } - if (error || !users) { + if (error || !data) { return ( { ); } + if (domainsError || !validDomains) { + return ( + + ); + } + + const { accepted, invited, accepted_pages, invited_pages } = data; + + // remove users that are already accepted + const finalInvited = invited.filter( + (user) => !accepted.map((u) => u.email).includes(user.email) + ); + + return ( + <> + + {invited.length > 0 ? ( + finalInvited.length > 0 ? ( + + ) : ( +
+ To invite additional teammates, use the Invite Users button + above! +
+ ) + ) : ( + + )} +
+ + + ); +}; + +const SearchableTables = () => { + const { popup, setPopup } = usePopup(); + const [query, setQuery] = useState(""); + const [q, setQ] = useState(""); + return (
{popup} - - - - Email - Role - -
-
Actions
-
-
-
-
- - {users.map((user) => ( - - {user.email} - - {user.role === "admin" ? "Admin" : "User"} - - -
- {user.role !== "admin" && ( - - )} - {user.role === "admin" && ( - - )} -
-
-
- ))} -
-
+
+
+ +
+ setQ(query)} + /> +
+
+ +
); }; +const AddUserButton = ({ + setPopup, +}: { + setPopup: (spec: PopupSpec) => void; +}) => { + const [modal, setModal] = useState(false); + const onSuccess = () => { + mutate( + (key) => typeof key === "string" && key.startsWith("/api/manage/users") + ); + setModal(false); + setPopup({ + message: "Users invited!", + type: "success", + }); + }; + const onFailure = async (res: Response) => { + const error = (await res.json()).detail; + setPopup({ + message: `Failed to invite users - ${error}`, + type: "error", + }); + }; + return ( + <> + + {modal && ( + setModal(false)}> +
+ + Add the email addresses to import, separated by whitespaces. + + +
+
+ )} + + ); +}; + const Page = () => { return (
} /> - - +
); }; diff --git a/web/src/app/auth/error/page.tsx b/web/src/app/auth/error/page.tsx index 16d33f49c72..4f288cd205f 100644 --- a/web/src/app/auth/error/page.tsx +++ b/web/src/app/auth/error/page.tsx @@ -1,6 +1,21 @@ +"use client"; + +import { Button } from "@tremor/react"; +import Link from "next/link"; +import { FiLogIn } from "react-icons/fi"; + const Page = () => { return ( -
Unable to login, please try again and/or contact an administrator
+
+
+ Unable to login, please try again and/or contact an administrator. +
+ + + +
); }; diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 047b115c2d3..92e4384f912 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -12,7 +12,8 @@ import { Message, RetrievalType, StreamingError, - ToolRunKickoff, + ToolCallFinalResult, + ToolCallMetadata, } from "./interfaces"; import { ChatSidebar } from "./sessionSidebar/ChatSidebar"; import { Persona } from "../admin/assistants/interfaces"; @@ -265,6 +266,7 @@ export function ChatPage({ message: "", type: "system", files: [], + toolCalls: [], parentMessageId: null, childrenMessageIds: [firstMessageId], latestChildMessageId: firstMessageId, @@ -307,7 +309,6 @@ export function ChatPage({ return newCompleteMessageMap; }; const messageHistory = buildLatestMessageChain(completeMessageMap); - const [currentTool, setCurrentTool] = useState(null); const [isStreaming, setIsStreaming] = useState(false); // uploaded files @@ -535,6 +536,7 @@ export function ChatPage({ message: currMessage, type: "user", files: currentMessageFiles, + toolCalls: [], parentMessageId: parentMessage?.messageId || null, }, ]; @@ -569,6 +571,7 @@ export function ChatPage({ let aiMessageImages: FileDescriptor[] | null = null; let error: string | null = null; let finalMessage: BackendMessage | null = null; + let toolCalls: ToolCallMetadata[] = []; try { const lastSuccessfulMessageId = getLastSuccessfulMessageId(currMessageHistory); @@ -627,7 +630,13 @@ export function ChatPage({ } ); } else if (Object.hasOwn(packet, "tool_name")) { - setCurrentTool((packet as ToolRunKickoff).tool_name); + toolCalls = [ + { + tool_name: (packet as ToolCallMetadata).tool_name, + tool_args: (packet as ToolCallMetadata).tool_args, + tool_result: (packet as ToolCallMetadata).tool_result, + }, + ]; } else if (Object.hasOwn(packet, "error")) { error = (packet as StreamingError).error; } else if (Object.hasOwn(packet, "message_id")) { @@ -657,6 +666,7 @@ export function ChatPage({ message: currMessage, type: "user", files: currentMessageFiles, + toolCalls: [], parentMessageId: parentMessage?.messageId || null, childrenMessageIds: [newAssistantMessageId], latestChildMessageId: newAssistantMessageId, @@ -670,6 +680,7 @@ export function ChatPage({ documents: finalMessage?.context_docs?.top_documents || documents, citations: finalMessage?.citations || {}, files: finalMessage?.files || aiMessageImages || [], + toolCalls: finalMessage?.tool_calls || toolCalls, parentMessageId: newUserMessageId, }, ]); @@ -687,6 +698,7 @@ export function ChatPage({ message: currMessage, type: "user", files: currentMessageFiles, + toolCalls: [], parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID, }, { @@ -694,6 +706,7 @@ export function ChatPage({ message: errorMsg, type: "error", files: aiMessageImages || [], + toolCalls: [], parentMessageId: TEMP_USER_MESSAGE_ID, }, ], @@ -1031,7 +1044,7 @@ export function ChatPage({ citedDocuments={getCitedDocumentsFromMessage( message )} - currentTool={currentTool} + toolCall={message.toolCalls[0]} isComplete={ i !== messageHistory.length - 1 || !isStreaming diff --git a/web/src/app/chat/interfaces.ts b/web/src/app/chat/interfaces.ts index 9f3d647340e..52121ed264b 100644 --- a/web/src/app/chat/interfaces.ts +++ b/web/src/app/chat/interfaces.ts @@ -34,6 +34,18 @@ export interface FileDescriptor { isUploading?: boolean; } +export interface ToolCallMetadata { + tool_name: string; + tool_args: Record; + tool_result?: Record; +} + +export interface ToolCallFinalResult { + tool_name: string; + tool_args: Record; + tool_result: Record; +} + export interface ChatSession { id: number; name: string; @@ -52,6 +64,7 @@ export interface Message { documents?: DanswerDocument[] | null; citations?: CitationMap; files: FileDescriptor[]; + toolCalls: ToolCallMetadata[]; // for rebuilding the message tree parentMessageId: number | null; childrenMessageIds?: number[]; @@ -79,6 +92,7 @@ export interface BackendMessage { time_sent: string; citations: CitationMap; files: FileDescriptor[]; + tool_calls: ToolCallFinalResult[]; } export interface DocumentsResponse { @@ -90,11 +104,6 @@ export interface ImageGenerationDisplay { file_ids: string[]; } -export interface ToolRunKickoff { - tool_name: string; - tool_args: Record; -} - export interface StreamingError { error: string; } diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index 9d65e64ce41..b17413d9304 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -15,7 +15,7 @@ import { Message, RetrievalType, StreamingError, - ToolRunKickoff, + ToolCallMetadata, } from "./interfaces"; import { Persona } from "../admin/assistants/interfaces"; import { ReadonlyURLSearchParams } from "next/navigation"; @@ -138,7 +138,7 @@ export async function* sendMessage({ | DocumentsResponse | BackendMessage | ImageGenerationDisplay - | ToolRunKickoff + | ToolCallMetadata | StreamingError >(sendMessageResponse); } @@ -384,6 +384,7 @@ export function processRawChatHistory( citations: messageInfo?.citations || {}, } : {}), + toolCalls: messageInfo.tool_calls, parentMessageId: messageInfo.parent_message, childrenMessageIds: [], latestChildMessageId: messageInfo.latest_child_message, diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index b188c128ea7..3a7b9d6d8b4 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -9,6 +9,7 @@ import { FiEdit2, FiChevronRight, FiChevronLeft, + FiTool, } from "react-icons/fi"; import { FeedbackType } from "../types"; import { useEffect, useRef, useState } from "react"; @@ -20,9 +21,12 @@ import { ThreeDots } from "react-loader-spinner"; import { SkippedSearch } from "./SkippedSearch"; import remarkGfm from "remark-gfm"; import { CopyButton } from "@/components/CopyButton"; -import { ChatFileType, FileDescriptor } from "../interfaces"; -import { IMAGE_GENERATION_TOOL_NAME } from "../tools/constants"; -import { ToolRunningAnimation } from "../tools/ToolRunningAnimation"; +import { ChatFileType, FileDescriptor, ToolCallMetadata } from "../interfaces"; +import { + IMAGE_GENERATION_TOOL_NAME, + SEARCH_TOOL_NAME, +} from "../tools/constants"; +import { ToolRunDisplay } from "../tools/ToolRunningAnimation"; import { Hoverable } from "@/components/Hoverable"; import { DocumentPreview } from "../files/documents/DocumentPreview"; import { InMessageImage } from "../files/images/InMessageImage"; @@ -35,6 +39,11 @@ import Prism from "prismjs"; import "prismjs/themes/prism-tomorrow.css"; import "./custom-code-styles.css"; +const TOOLS_WITH_CUSTOM_HANDLING = [ + SEARCH_TOOL_NAME, + IMAGE_GENERATION_TOOL_NAME, +]; + function FileDisplay({ files }: { files: FileDescriptor[] }) { const imageFiles = files.filter((file) => file.type === ChatFileType.IMAGE); const nonImgFiles = files.filter((file) => file.type !== ChatFileType.IMAGE); @@ -77,7 +86,7 @@ export const AIMessage = ({ query, personaName, citedDocuments, - currentTool, + toolCall, isComplete, hasDocs, handleFeedback, @@ -93,7 +102,7 @@ export const AIMessage = ({ query?: string; personaName?: string; citedDocuments?: [string, DanswerDocument][] | null; - currentTool?: string | null; + toolCall?: ToolCallMetadata; isComplete?: boolean; hasDocs?: boolean; handleFeedback?: (feedbackType: FeedbackType) => void; @@ -133,28 +142,23 @@ export const AIMessage = ({ content = trimIncompleteCodeSection(content); } - const loader = - currentTool === IMAGE_GENERATION_TOOL_NAME ? ( -
- } - /> -
- ) : ( -
- -
- ); + const shouldShowLoader = + !toolCall || + (toolCall.tool_name === SEARCH_TOOL_NAME && query === undefined); + const defaultLoader = shouldShowLoader ? ( +
+ +
+ ) : undefined; return (
@@ -189,28 +193,61 @@ export const AIMessage = ({
- {query !== undefined && - handleShowRetrieved !== undefined && - isCurrentlyShowingRetrieved !== undefined && - !retrievalDisabled && ( -
- + {query !== undefined && + handleShowRetrieved !== undefined && + isCurrentlyShowingRetrieved !== undefined && + !retrievalDisabled && ( +
+ +
+ )} + {handleForceSearch && + content && + query === undefined && + !hasDocs && + !retrievalDisabled && ( +
+ +
+ )} + + )} + + {toolCall && + !TOOLS_WITH_CUSTOM_HANDLING.includes(toolCall.tool_name) && ( +
+ } + isRunning={!toolCall.tool_result || !content} />
)} - {handleForceSearch && - content && - query === undefined && - !hasDocs && - !retrievalDisabled && ( -
- + + {toolCall && + toolCall.tool_name === IMAGE_GENERATION_TOOL_NAME && + !toolCall.tool_result && ( +
+ } + isRunning={!toolCall.tool_result} + />
)} @@ -260,7 +297,7 @@ export const AIMessage = ({ )} ) : isComplete ? null : ( - loader + defaultLoader )} {citedDocuments && citedDocuments.length > 0 && (
diff --git a/web/src/app/chat/tools/ToolRunningAnimation.tsx b/web/src/app/chat/tools/ToolRunningAnimation.tsx index bd0414295fe..139c9e92151 100644 --- a/web/src/app/chat/tools/ToolRunningAnimation.tsx +++ b/web/src/app/chat/tools/ToolRunningAnimation.tsx @@ -1,16 +1,18 @@ import { LoadingAnimation } from "@/components/Loading"; -export function ToolRunningAnimation({ +export function ToolRunDisplay({ toolName, toolLogo, + isRunning, }: { toolName: string; - toolLogo: JSX.Element; + toolLogo?: JSX.Element; + isRunning: boolean; }) { return ( -
+
{toolLogo} - + {isRunning ? : toolName}
); } diff --git a/web/src/components/PageSelector.tsx b/web/src/components/PageSelector.tsx index 1e443e3c887..cf81b8f3ccc 100644 --- a/web/src/components/PageSelector.tsx +++ b/web/src/components/PageSelector.tsx @@ -79,7 +79,7 @@ const PageLink = ({
); -interface PageSelectorProps { +export interface PageSelectorProps { currentPage: number; totalPages: number; onPageChange: (newPage: number) => void; diff --git a/web/src/components/admin/Layout.tsx b/web/src/components/admin/Layout.tsx index e403b00a168..1411d6c14b6 100644 --- a/web/src/components/admin/Layout.tsx +++ b/web/src/components/admin/Layout.tsx @@ -2,15 +2,12 @@ import { Header } from "@/components/header/Header"; import { AdminSidebar } from "@/components/admin/connectors/AdminSidebar"; import { NotebookIcon, - KeyIcon, UsersIcon, ThumbsUpIcon, BookmarkIcon, - CPUIcon, ZoomInIcon, RobotIcon, ConnectorIcon, - SlackIcon, } from "@/components/icons/icons"; import { User } from "@/lib/types"; import { @@ -19,13 +16,7 @@ import { getCurrentUserSS, } from "@/lib/userSS"; import { redirect } from "next/navigation"; -import { - FiCpu, - FiLayers, - FiPackage, - FiSettings, - FiSlack, -} from "react-icons/fi"; +import { FiCpu, FiPackage, FiSettings, FiSlack, FiTool } from "react-icons/fi"; export async function Layout({ children }: { children: React.ReactNode }) { const tasks = [getAuthTypeMetadataSS(), getCurrentUserSS()]; @@ -142,6 +133,15 @@ export async function Layout({ children }: { children: React.ReactNode }) { ), link: "/admin/bot", }, + { + name: ( +
+ +
Tools
+
+ ), + link: "/admin/tools", + }, ], }, { diff --git a/web/src/components/admin/connectors/ConnectorForm.tsx b/web/src/components/admin/connectors/ConnectorForm.tsx index c91b22bbf88..db04e482756 100644 --- a/web/src/components/admin/connectors/ConnectorForm.tsx +++ b/web/src/components/admin/connectors/ConnectorForm.tsx @@ -70,6 +70,7 @@ interface BaseProps { responseJson: Connector | undefined ) => void; refreshFreq?: number; + pruneFreq?: number; // If specified, then we will create an empty credential and associate // the connector with it. If credentialId is specified, then this will be ignored shouldCreateEmptyCredentialForConnector?: boolean; @@ -91,6 +92,7 @@ export function ConnectorForm({ validationSchema, initialValues, refreshFreq, + pruneFreq, onSubmit, shouldCreateEmptyCredentialForConnector, }: ConnectorFormProps): JSX.Element { @@ -144,6 +146,7 @@ export function ConnectorForm({ input_type: inputType, connector_specific_config: connectorConfig, refresh_freq: refreshFreq || 0, + prune_freq: pruneFreq ?? null, disabled: false, }); @@ -281,6 +284,7 @@ export function UpdateConnectorForm({ input_type: existingConnector.input_type, connector_specific_config: values, refresh_freq: existingConnector.refresh_freq, + prune_freq: existingConnector.prune_freq, disabled: false, }, existingConnector.id diff --git a/web/src/components/admin/connectors/Field.tsx b/web/src/components/admin/connectors/Field.tsx index 88e482d77e8..563a5263396 100644 --- a/web/src/components/admin/connectors/Field.tsx +++ b/web/src/components/admin/connectors/Field.tsx @@ -43,6 +43,10 @@ export function TextFormField({ disabled = false, autoCompleteDisabled = true, error, + defaultHeight, + isCode = false, + fontSize, + hideError, }: { name: string; label: string; @@ -54,7 +58,16 @@ export function TextFormField({ disabled?: boolean; autoCompleteDisabled?: boolean; error?: string; + defaultHeight?: string; + isCode?: boolean; + fontSize?: "text-sm" | "text-base" | "text-lg"; + hideError?: boolean; }) { + let heightString = defaultHeight || ""; + if (isTextArea && !heightString) { + heightString = "h-28"; + } + return (
@@ -64,18 +77,19 @@ export function TextFormField({ type={type} name={name} id={name} - className={ - ` - border - border-border - rounded - w-full - py-2 - px-3 - mt-1 - ${isTextArea ? " h-28" : ""} - ` + (disabled ? " bg-background-strong" : " bg-background-emphasis") - } + className={` + border + border-border + rounded + w-full + py-2 + px-3 + mt-1 + ${heightString} + ${fontSize} + ${disabled ? " bg-background-strong" : " bg-background-emphasis"} + ${isCode ? " font-mono" : ""} + `} disabled={disabled} placeholder={placeholder} autoComplete={autoCompleteDisabled ? "off" : undefined} @@ -84,11 +98,13 @@ export function TextFormField({ {error ? ( {error} ) : ( - + !hideError && ( + + ) )}
); diff --git a/web/src/components/admin/connectors/Popup.tsx b/web/src/components/admin/connectors/Popup.tsx index f3ab7221965..adfc0665c25 100644 --- a/web/src/components/admin/connectors/Popup.tsx +++ b/web/src/components/admin/connectors/Popup.tsx @@ -7,8 +7,8 @@ export interface PopupSpec { export const Popup: React.FC = ({ message, type }) => (
{message} diff --git a/web/src/components/admin/users/BulkAdd.tsx b/web/src/components/admin/users/BulkAdd.tsx new file mode 100644 index 00000000000..858067c4a4b --- /dev/null +++ b/web/src/components/admin/users/BulkAdd.tsx @@ -0,0 +1,87 @@ +"use client"; + +import { withFormik, FormikProps, FormikErrors, Form, Field } from "formik"; + +import { Button } from "@tremor/react"; + +const WHITESPACE_SPLIT = /\s+/; +const EMAIL_REGEX = /[^@]+@[^.]+\.[^.]/; + +const addUsers = async (url: string, { arg }: { arg: Array }) => { + return await fetch(url, { + method: "PUT", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ emails: arg }), + }); +}; + +interface FormProps { + onSuccess: () => void; + onFailure: (res: Response) => void; +} + +interface FormValues { + emails: string; +} + +const AddUserFormRenderer = ({ + touched, + errors, + isSubmitting, +}: FormikProps) => ( +
+
+ + {touched.emails && errors.emails && ( +
{errors.emails}
+ )} + +
+
+); + +const AddUserForm = withFormik({ + mapPropsToValues: (props) => { + return { + emails: "", + }; + }, + validate: (values: FormValues): FormikErrors => { + const emails = values.emails.trim().split(WHITESPACE_SPLIT); + if (!emails.some(Boolean)) { + return { emails: "Required" }; + } + for (let email of emails) { + if (!email.match(EMAIL_REGEX)) { + return { emails: `${email} is not a valid email` }; + } + } + return {}; + }, + handleSubmit: async (values: FormValues, formikBag) => { + const emails = values.emails.trim().split(WHITESPACE_SPLIT); + await addUsers("/api/manage/admin/users", { arg: emails }).then((res) => { + if (res.ok) { + formikBag.props.onSuccess(); + } else { + formikBag.props.onFailure(res); + } + }); + }, +})(AddUserFormRenderer); + +const BulkAdd = ({ onSuccess, onFailure }: FormProps) => { + return ; +}; + +export default BulkAdd; diff --git a/web/src/components/admin/users/CenteredPageSelector.tsx b/web/src/components/admin/users/CenteredPageSelector.tsx new file mode 100644 index 00000000000..2779872e3fd --- /dev/null +++ b/web/src/components/admin/users/CenteredPageSelector.tsx @@ -0,0 +1,20 @@ +import { + PageSelector, + type PageSelectorProps as Props, +} from "@/components/PageSelector"; + +const CenteredPageSelector = ({ + currentPage, + totalPages, + onPageChange, +}: Props) => ( +
+ +
+); + +export default CenteredPageSelector; diff --git a/web/src/components/admin/users/InvitedUserTable.tsx b/web/src/components/admin/users/InvitedUserTable.tsx new file mode 100644 index 00000000000..6bc74213a79 --- /dev/null +++ b/web/src/components/admin/users/InvitedUserTable.tsx @@ -0,0 +1,109 @@ +import { PopupSpec } from "@/components/admin/connectors/Popup"; +import { HidableSection } from "@/app/admin/assistants/HidableSection"; +import { + Table, + TableHead, + TableRow, + TableHeaderCell, + TableBody, + TableCell, + Button, +} from "@tremor/react"; +import userMutationFetcher from "@/lib/admin/users/userMutationFetcher"; +import CenteredPageSelector from "./CenteredPageSelector"; +import { type PageSelectorProps } from "@/components/PageSelector"; +import useSWR from "swr"; +import { type User, UserStatus } from "@/lib/types"; +import useSWRMutation from "swr/mutation"; + +interface Props { + users: Array; + setPopup: (spec: PopupSpec) => void; + mutate: () => void; +} + +const RemoveUserButton = ({ + user, + onSuccess, + onError, +}: { + user: User; + onSuccess: () => void; + onError: (message: string) => void; +}) => { + const { trigger } = useSWRMutation( + "/api/manage/admin/remove-invited-user", + userMutationFetcher, + { onSuccess, onError } + ); + return ( + + ); +}; + +const InvitedUserTable = ({ + users, + setPopup, + currentPage, + totalPages, + onPageChange, + mutate, +}: Props & PageSelectorProps) => { + if (!users.length) return null; + + const onRemovalSuccess = () => { + mutate(); + setPopup({ + message: "User uninvited!", + type: "success", + }); + }; + const onRemovalError = (errorMsg: string) => { + setPopup({ + message: `Unable to uninvite user - ${errorMsg}`, + type: "error", + }); + }; + + return ( + <> + + + + Email + +
Actions
+
+
+
+ + {users.map((user) => ( + + {user.email} + +
+ +
+
+
+ ))} +
+
+ {totalPages > 1 ? ( + + ) : null} + + ); +}; + +export default InvitedUserTable; diff --git a/web/src/components/admin/users/SignedUpUserTable.tsx b/web/src/components/admin/users/SignedUpUserTable.tsx new file mode 100644 index 00000000000..731703fed7c --- /dev/null +++ b/web/src/components/admin/users/SignedUpUserTable.tsx @@ -0,0 +1,187 @@ +import { type User, UserStatus } from "@/lib/types"; +import CenteredPageSelector from "./CenteredPageSelector"; +import { type PageSelectorProps } from "@/components/PageSelector"; +import { HidableSection } from "@/app/admin/assistants/HidableSection"; +import { PopupSpec } from "@/components/admin/connectors/Popup"; +import userMutationFetcher from "@/lib/admin/users/userMutationFetcher"; +import useSWRMutation from "swr/mutation"; +import { + Table, + TableHead, + TableRow, + TableHeaderCell, + TableBody, + TableCell, + Button, +} from "@tremor/react"; + +interface Props { + users: Array; + setPopup: (spec: PopupSpec) => void; + mutate: () => void; +} + +const PromoterButton = ({ + user, + promote, + onSuccess, + onError, +}: { + user: User; + promote: boolean; + onSuccess: () => void; + onError: (message: string) => void; +}) => { + const { trigger, isMutating } = useSWRMutation( + promote + ? "/api/manage/promote-user-to-admin" + : "/api/manage/demote-admin-to-basic", + userMutationFetcher, + { onSuccess, onError } + ); + return ( + + ); +}; + +const DeactivaterButton = ({ + user, + deactivate, + setPopup, + mutate, +}: { + user: User; + deactivate: boolean; + setPopup: (spec: PopupSpec) => void; + mutate: () => void; +}) => { + const { trigger, isMutating } = useSWRMutation( + deactivate + ? "/api/manage/admin/deactivate-user" + : "/api/manage/admin/activate-user", + userMutationFetcher, + { + onSuccess: () => { + mutate(); + setPopup({ + message: `User ${deactivate ? "deactivated" : "activated"}!`, + type: "success", + }); + }, + onError: (errorMsg) => setPopup({ message: errorMsg, type: "error" }), + } + ); + return ( + + ); +}; + +const SignedUpUserTable = ({ + users, + setPopup, + currentPage, + totalPages, + onPageChange, + mutate, +}: Props & PageSelectorProps) => { + if (!users.length) return null; + + const onSuccess = (message: string) => { + mutate(); + setPopup({ + message, + type: "success", + }); + }; + const onError = (message: string) => { + setPopup({ + message, + type: "error", + }); + }; + const onPromotionSuccess = () => { + onSuccess("User promoted to admin user!"); + }; + const onPromotionError = (errorMsg: string) => { + onError(`Unable to promote user - ${errorMsg}`); + }; + const onDemotionSuccess = () => { + onSuccess("Admin demoted to basic user!"); + }; + const onDemotionError = (errorMsg: string) => { + onError(`Unable to demote admin - ${errorMsg}`); + }; + + return ( + + <> + {totalPages > 1 ? ( + + ) : null} + + + + Email + Role + Status + +
+
Actions
+
+
+
+
+ + {users.map((user) => ( + + {user.email} + + {user.role === "admin" ? "Admin" : "User"} + + + {user.status === "live" ? "Active" : "Inactive"} + + +
+ + +
+
+
+ ))} +
+
+ +
+ ); +}; + +export default SignedUpUserTable; diff --git a/web/src/lib/admin/users/userMutationFetcher.ts b/web/src/lib/admin/users/userMutationFetcher.ts new file mode 100644 index 00000000000..87e5fa81f14 --- /dev/null +++ b/web/src/lib/admin/users/userMutationFetcher.ts @@ -0,0 +1,20 @@ +const userMutationFetcher = async ( + url: string, + { arg }: { arg: { user_email: string } } +) => { + return fetch(url, { + method: "PATCH", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + user_email: arg.user_email, + }), + }).then(async (res) => { + if (res.ok) return res.json(); + const errorDetail = (await res.json()).detail; + throw Error(errorDetail); + }); +}; + +export default userMutationFetcher; diff --git a/web/src/lib/tools/edit.ts b/web/src/lib/tools/edit.ts new file mode 100644 index 00000000000..841870a93e7 --- /dev/null +++ b/web/src/lib/tools/edit.ts @@ -0,0 +1,111 @@ +import { MethodSpec, ToolSnapshot } from "./interfaces"; + +interface ApiResponse { + data: T | null; + error: string | null; +} + +export async function createCustomTool(toolData: { + name: string; + description?: string; + definition: Record; +}): Promise> { + try { + const response = await fetch("/api/admin/tool/custom", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(toolData), + }); + + if (!response.ok) { + const errorDetail = (await response.json()).detail; + return { data: null, error: `Failed to create tool: ${errorDetail}` }; + } + + const tool: ToolSnapshot = await response.json(); + return { data: tool, error: null }; + } catch (error) { + console.error("Error creating tool:", error); + return { data: null, error: "Error creating tool" }; + } +} + +export async function updateCustomTool( + toolId: number, + toolData: { + name?: string; + description?: string; + definition?: Record; + } +): Promise> { + try { + const response = await fetch(`/api/admin/tool/custom/${toolId}`, { + method: "PUT", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(toolData), + }); + + if (!response.ok) { + const errorDetail = (await response.json()).detail; + return { data: null, error: `Failed to update tool: ${errorDetail}` }; + } + + const updatedTool: ToolSnapshot = await response.json(); + return { data: updatedTool, error: null }; + } catch (error) { + console.error("Error updating tool:", error); + return { data: null, error: "Error updating tool" }; + } +} + +export async function deleteCustomTool( + toolId: number +): Promise> { + try { + const response = await fetch(`/api/admin/tool/custom/${toolId}`, { + method: "DELETE", + headers: { + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + const errorDetail = (await response.json()).detail; + return { data: false, error: `Failed to delete tool: ${errorDetail}` }; + } + + return { data: true, error: null }; + } catch (error) { + console.error("Error deleting tool:", error); + return { data: false, error: "Error deleting tool" }; + } +} + +export async function validateToolDefinition(toolData: { + definition: Record; +}): Promise> { + try { + const response = await fetch(`/api/admin/tool/custom/validate`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(toolData), + }); + + if (!response.ok) { + const errorDetail = (await response.json()).detail; + return { data: null, error: errorDetail }; + } + + const responseJson = await response.json(); + return { data: responseJson.methods, error: null }; + } catch (error) { + console.error("Error validating tool:", error); + return { data: null, error: "Unexpected error validating tool definition" }; + } +} diff --git a/web/src/lib/tools/fetchTools.ts b/web/src/lib/tools/fetchTools.ts index 51969c6db7f..3ea6cd31f73 100644 --- a/web/src/lib/tools/fetchTools.ts +++ b/web/src/lib/tools/fetchTools.ts @@ -14,3 +14,21 @@ export async function fetchToolsSS(): Promise { return null; } } + +export async function fetchToolByIdSS( + toolId: string +): Promise { + try { + const response = await fetchSS(`/tool/${toolId}`); + if (!response.ok) { + throw new Error( + `Failed to fetch tool with ID ${toolId}: ${await response.text()}` + ); + } + const tool: ToolSnapshot = await response.json(); + return tool; + } catch (error) { + console.error(`Error fetching tool with ID ${toolId}:`, error); + return null; + } +} diff --git a/web/src/lib/tools/interfaces.ts b/web/src/lib/tools/interfaces.ts index f8882e6bfdb..bcb5df50a2a 100644 --- a/web/src/lib/tools/interfaces.ts +++ b/web/src/lib/tools/interfaces.ts @@ -2,5 +2,21 @@ export interface ToolSnapshot { id: number; name: string; description: string; + + // only specified for Custom Tools. OpenAPI schema which represents + // the tool's API. + definition: Record | null; + + // only specified for Custom Tools. ID of the tool in the codebase. in_code_tool_id: string | null; } + +export interface MethodSpec { + /* Defines a single method that is part of a custom tool. Each method maps to a single + action that the LLM can choose to take. */ + name: string; + summary: string; + path: string; + method: string; + spec: Record; +} diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index cef5c0a46fc..138765c5c45 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -4,6 +4,12 @@ export interface UserPreferences { chosen_assistants: number[] | null; } +export enum UserStatus { + live = "live", + invited = "invited", + deactivated = "deactivated", +} + export interface User { id: string; email: string; @@ -12,6 +18,7 @@ export interface User { is_verified: string; role: "basic" | "admin"; preferences: UserPreferences; + status: UserStatus; } export interface MinimalUserSnapshot { @@ -78,6 +85,7 @@ export interface ConnectorBase { source: ValidSources; connector_specific_config: T; refresh_freq: number | null; + prune_freq: number | null; disabled: boolean; }