diff --git a/backend/Dockerfile.model_server b/backend/Dockerfile.model_server index 78e3eaa15ab..073580bab7f 100644 --- a/backend/Dockerfile.model_server +++ b/backend/Dockerfile.model_server @@ -27,16 +27,15 @@ RUN apt-get remove -y --allow-remove-essential perl-base && \ # Download tokenizers, distilbert for the Danswer model # Download model weights # Run Nomic to pull in the custom architecture and have it cached locally -# RUN python -c "from transformers import AutoTokenizer; \ -# from huggingface_hub import snapshot_download; \ -# snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \ -# snapshot_download('nomic-ai/nomic-embed-text-v1'); \ -# from sentence_transformers import SentenceTransformer; \ -# SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);" +RUN python -c "from transformers import AutoTokenizer; \ +AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \ +from huggingface_hub import snapshot_download; \ +snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \ +snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1');" # In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while # running Danswer, don't overwrite it with the built in cache folder -# RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface +RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface WORKDIR /app diff --git a/backend/alembic/versions/401c1ac29467_add_tables_for_ui_based_llm_.py b/backend/alembic/versions/401c1ac29467_add_tables_for_ui_based_llm_.py index 34469b13c17..dcc766fe287 100644 --- a/backend/alembic/versions/401c1ac29467_add_tables_for_ui_based_llm_.py +++ b/backend/alembic/versions/401c1ac29467_add_tables_for_ui_based_llm_.py @@ -5,7 +5,6 @@ Create Date: 2024-04-13 18:07:29.153817 """ - from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql diff --git a/backend/alembic/versions/703313b75876_add_tokenratelimit_tables.py b/backend/alembic/versions/703313b75876_add_tokenratelimit_tables.py index 7dbc63ce345..ed1993efed3 100644 --- a/backend/alembic/versions/703313b75876_add_tokenratelimit_tables.py +++ b/backend/alembic/versions/703313b75876_add_tokenratelimit_tables.py @@ -5,7 +5,6 @@ Create Date: 2024-04-15 01:36:02.952809 """ - import json from typing import cast from alembic import op diff --git a/backend/alembic/versions/ef7da92f7213_add_files_to_chatmessage.py b/backend/alembic/versions/ef7da92f7213_add_files_to_chatmessage.py index 10ec5d93a49..eb04a1b8208 100644 --- a/backend/alembic/versions/ef7da92f7213_add_files_to_chatmessage.py +++ b/backend/alembic/versions/ef7da92f7213_add_files_to_chatmessage.py @@ -5,7 +5,6 @@ Create Date: 2024-04-28 16:59:33.199153 """ - from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql diff --git a/backend/alembic/versions/fad14119fb92_delete_tags_with_wrong_enum.py b/backend/alembic/versions/fad14119fb92_delete_tags_with_wrong_enum.py index ccd9df6c01b..b9c428640eb 100644 --- a/backend/alembic/versions/fad14119fb92_delete_tags_with_wrong_enum.py +++ b/backend/alembic/versions/fad14119fb92_delete_tags_with_wrong_enum.py @@ -5,7 +5,6 @@ Create Date: 2024-04-25 17:05:09.695703 """ - from alembic import op # revision identifiers, used by Alembic. diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 60b6a208113..ac02d125850 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -231,6 +231,7 @@ async def oauth_callback( associate_by_email: bool = False, is_verified_by_default: bool = False, ) -> models.UOAP: + verify_email_in_whitelist(account_email) verify_email_domain(account_email) user = await super().oauth_callback( # type: ignore diff --git a/backend/danswer/background/celery/celery_app.py b/backend/danswer/background/celery/celery_app.py index aa7c53aa3d2..8c43fb2eec3 100644 --- a/backend/danswer/background/celery/celery_app.py +++ b/backend/danswer/background/celery/celery_app.py @@ -1169,6 +1169,77 @@ def on_setup_logging( task_logger.propagate = False +class CeleryTaskPlainFormatter(PlainFormatter): + def format(self, record: logging.LogRecord) -> str: + task = current_task + if task and task.request: + record.__dict__.update(task_id=task.request.id, task_name=task.name) + record.msg = f"[{task.name}({task.request.id})] {record.msg}" + + return super().format(record) + + +class CeleryTaskColoredFormatter(ColoredFormatter): + def format(self, record: logging.LogRecord) -> str: + task = current_task + if task and task.request: + record.__dict__.update(task_id=task.request.id, task_name=task.name) + record.msg = f"[{task.name}({task.request.id})] {record.msg}" + + return super().format(record) + + +@signals.setup_logging.connect +def on_setup_logging( + loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any +) -> None: + # TODO: could unhardcode format and colorize and accept these as options from + # celery's config + + # reformats celery's worker logger + root_logger = logging.getLogger() + + root_handler = logging.StreamHandler() # Set up a handler for the root logger + root_formatter = ColoredFormatter( + "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + root_handler.setFormatter(root_formatter) + root_logger.addHandler(root_handler) # Apply the handler to the root logger + + if logfile: + root_file_handler = logging.FileHandler(logfile) + root_file_formatter = PlainFormatter( + "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + root_file_handler.setFormatter(root_file_formatter) + root_logger.addHandler(root_file_handler) + + root_logger.setLevel(loglevel) + + # reformats celery's task logger + task_formatter = CeleryTaskColoredFormatter( + "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + task_handler = logging.StreamHandler() # Set up a handler for the task logger + task_handler.setFormatter(task_formatter) + task_logger.addHandler(task_handler) # Apply the handler to the task logger + + if logfile: + task_file_handler = logging.FileHandler(logfile) + task_file_formatter = CeleryTaskPlainFormatter( + "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + task_file_handler.setFormatter(task_file_formatter) + task_logger.addHandler(task_file_handler) + + task_logger.setLevel(loglevel) + task_logger.propagate = False + + ##### # Celery Beat (Periodic Tasks) Settings ##### diff --git a/backend/danswer/connectors/google_site/connector.py b/backend/danswer/connectors/google_site/connector.py index 3720ff5f433..9cfcf224e3f 100644 --- a/backend/danswer/connectors/google_site/connector.py +++ b/backend/danswer/connectors/google_site/connector.py @@ -119,11 +119,9 @@ def load_from_state(self) -> GenerateDocumentsOutput: semantic_identifier=title, sections=[ Section( - link=( - (self.base_url.rstrip("/") + "/" + path.lstrip("/")) - if path - else "" - ), + link=(self.base_url.rstrip("/") + "/" + path.lstrip("/")) + if path + else "", text=parsed_html.cleaned_text, ) ], diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 208dbdba3c1..94b5d0123cc 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -14,7 +14,6 @@ from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import NullPool from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS from danswer.configs.app_configs import LOG_POSTGRES_LATENCY @@ -131,7 +130,6 @@ def init_sqlalchemy_engine(app_name: str) -> None: def get_sqlalchemy_engine() -> Engine: - connect_args = {"sslmode": "disable"} global _SYNC_ENGINE if _SYNC_ENGINE is None: connection_string = build_connection_string( @@ -148,7 +146,6 @@ def get_sqlalchemy_engine() -> Engine: def get_sqlalchemy_async_engine() -> AsyncEngine: - connect_args = {"ssl": "disable"} global _ASYNC_ENGINE if _ASYNC_ENGINE is None: # underlying asyncpg cannot accept application_name directly in the connection string diff --git a/backend/danswer/llm/gpt_4_all.py b/backend/danswer/llm/gpt_4_all.py deleted file mode 100644 index c7cf6a61557..00000000000 --- a/backend/danswer/llm/gpt_4_all.py +++ /dev/null @@ -1,77 +0,0 @@ -from collections.abc import Iterator -from typing import Any - -from langchain.schema.language_model import LanguageModelInput - -from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS -from danswer.configs.model_configs import GEN_AI_TEMPERATURE -from danswer.llm.interfaces import LLM -from danswer.llm.utils import convert_lm_input_to_basic_string -from danswer.utils.logger import setup_logger - - -logger = setup_logger() - - -class DummyGPT4All: - """In the case of import failure due to architectural incompatibilities, - this module does not raise exceptions during server startup, - as long as the module isn't actually used""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - raise RuntimeError("GPT4All library not installed.") - - -try: - from gpt4all import GPT4All # type:ignore -except ImportError: - # Setting a low log level because users get scared when they see this - logger.debug( - "GPT4All library not installed. " - "If you wish to run GPT4ALL (in memory) to power Danswer's " - "Generative AI features, please install gpt4all==2.0.2." - ) - GPT4All = DummyGPT4All - - -class DanswerGPT4All(LLM): - """Option to run an LLM locally, however this is significantly slower and - answers tend to be much worse - - NOTE: currently unused, but kept for future reference / if we want to add this back. - """ - - @property - def requires_warm_up(self) -> bool: - """GPT4All models are lazy loaded, load them on server start so that the - first inference isn't extremely delayed""" - return True - - @property - def requires_api_key(self) -> bool: - return False - - def __init__( - self, - timeout: int, - model_version: str, - max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, - temperature: float = GEN_AI_TEMPERATURE, - ): - self.timeout = timeout - self.max_output_tokens = max_output_tokens - self.temperature = temperature - self.gpt4all_model = GPT4All(model_version) - - def log_model_configs(self) -> None: - logger.debug( - f"GPT4All Model: {self.gpt4all_model}, Temperature: {self.temperature}" - ) - - def invoke(self, prompt: LanguageModelInput) -> str: - prompt_basic = convert_lm_input_to_basic_string(prompt) - return self.gpt4all_model.generate(prompt_basic) - - def stream(self, prompt: LanguageModelInput) -> Iterator[str]: - prompt_basic = convert_lm_input_to_basic_string(prompt) - return self.gpt4all_model.generate(prompt_basic, streaming=True) diff --git a/backend/danswer/tools/built_in_tools.py b/backend/danswer/tools/built_in_tools.py index 1bfecef7ce4..99b2ae3bbb6 100644 --- a/backend/danswer/tools/built_in_tools.py +++ b/backend/danswer/tools/built_in_tools.py @@ -146,6 +146,7 @@ def auto_add_search_tool_to_personas(db_session: Session) -> None: db_session.commit() logger.notice("Completed adding SearchTool to relevant Personas.") + _built_in_tools_cache: dict[int, Type[Tool]] | None = None diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index f1bde878b93..5b9d57b9d35 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -10,7 +10,6 @@ dask==2023.8.1 ddtrace==2.6.5 distributed==2023.8.1 fastapi==0.109.2 -fastapi-health==0.4.0 fastapi-users==12.1.3 fastapi-users-db-sqlalchemy==5.0.0 filelock==3.15.4