diff --git a/backend/alembic/env.py b/backend/alembic/env.py index 98ab55034ce..8a944689de4 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -7,6 +7,7 @@ from sqlalchemy import pool from sqlalchemy.engine import Connection from sqlalchemy.ext.asyncio import create_async_engine +from celery.backends.database.session import ResultModelBase # type: ignore # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -21,7 +22,7 @@ # for 'autogenerate' support # from myapp import mymodel # target_metadata = mymodel.Base.metadata -target_metadata = Base.metadata +target_metadata = [Base.metadata, ResultModelBase.metadata] # other values from the config, defined by the needs of env.py, # can be acquired: @@ -44,7 +45,7 @@ def run_migrations_offline() -> None: url = build_connection_string() context.configure( url=url, - target_metadata=target_metadata, + target_metadata=target_metadata, # type: ignore literal_binds=True, dialect_opts={"paramstyle": "named"}, ) @@ -54,7 +55,7 @@ def run_migrations_offline() -> None: def do_run_migrations(connection: Connection) -> None: - context.configure(connection=connection, target_metadata=target_metadata) + context.configure(connection=connection, target_metadata=target_metadata) # type: ignore with context.begin_transaction(): context.run_migrations() diff --git a/backend/alembic/versions/78dbe7e38469_task_tracking.py b/backend/alembic/versions/78dbe7e38469_task_tracking.py new file mode 100644 index 00000000000..33eac0c39f2 --- /dev/null +++ b/backend/alembic/versions/78dbe7e38469_task_tracking.py @@ -0,0 +1,48 @@ +"""Task Tracking + +Revision ID: 78dbe7e38469 +Revises: 7ccea01261f6 +Create Date: 2023-10-15 23:40:50.593262 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "78dbe7e38469" +down_revision = "7ccea01261f6" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "task_queue_jobs", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("task_id", sa.String(), nullable=False), + sa.Column("task_name", sa.String(), nullable=False), + sa.Column( + "status", + sa.Enum( + "PENDING", + "STARTED", + "SUCCESS", + "FAILURE", + name="taskstatus", + native_enum=False, + ), + nullable=False, + ), + sa.Column("start_time", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "register_time", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + ) + + +def downgrade() -> None: + op.drop_table("task_queue_jobs") diff --git a/backend/danswer/background/celery/celery.py b/backend/danswer/background/celery/celery.py index 47de830b1f2..8d232cf3cde 100644 --- a/backend/danswer/background/celery/celery.py +++ b/backend/danswer/background/celery/celery.py @@ -1,9 +1,39 @@ -from celery import Celery +import os +from datetime import timedelta +from pathlib import Path +from typing import cast -from danswer.background.connector_deletion import cleanup_connector_credential_pair +from celery import Celery # type: ignore +from celery.result import AsyncResult +from sqlalchemy.orm import Session + +from danswer.background.connector_deletion import _delete_connector_credential_pair +from danswer.background.task_utils import name_document_set_sync_task +from danswer.configs.app_configs import FILE_CONNECTOR_TMP_STORAGE_PATH +from danswer.configs.app_configs import JOB_TIMEOUT +from danswer.connectors.file.utils import file_age_in_hours +from danswer.datastores.document_index import get_default_document_index +from danswer.datastores.interfaces import DocumentIndex +from danswer.datastores.interfaces import UpdateRequest +from danswer.db.connector_credential_pair import get_connector_credential_pair +from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed +from danswer.db.document import prepare_to_modify_documents +from danswer.db.document_set import delete_document_set +from danswer.db.document_set import fetch_document_sets +from danswer.db.document_set import fetch_document_sets_for_documents +from danswer.db.document_set import fetch_documents_for_document_set +from danswer.db.document_set import get_document_set_by_id +from danswer.db.document_set import mark_document_set_as_synced from danswer.db.engine import build_connection_string +from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import SYNC_DB_API -from danswer.document_set.document_set import sync_document_set +from danswer.db.models import DocumentSet +from danswer.db.tasks import check_live_task_not_timed_out +from danswer.db.tasks import get_latest_task +from danswer.db.tasks import mark_task_finished +from danswer.db.tasks import mark_task_start +from danswer.db.tasks import register_task +from danswer.utils.batching import batch_generator from danswer.utils.logger import setup_logger logger = setup_logger() @@ -13,17 +43,193 @@ celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url) -@celery_app.task(soft_time_limit=60 * 60 * 6) # 6 hour time limit +_ExistingTaskCache: dict[int, AsyncResult] = {} +_SYNC_BATCH_SIZE = 1000 + + +##### +# Tasks that need to be run in job queue, registered via APIs +# +# If imports from this module are needed, use local imports to avoid circular importing +##### +@celery_app.task(soft_time_limit=JOB_TIMEOUT) def cleanup_connector_credential_pair_task( - connector_id: int, credential_id: int + connector_id: int, + credential_id: int, ) -> int: - return cleanup_connector_credential_pair(connector_id, credential_id) + """Connector deletion task. This is run as an async task because it is a somewhat slow job. + Needs to potentially update a large number of Postgres and Vespa docs, including deleting them + or updating the ACL""" + engine = get_sqlalchemy_engine() + with Session(engine) as db_session: + # validate that the connector / credential pair is deletable + cc_pair = get_connector_credential_pair( + db_session=db_session, + connector_id=connector_id, + credential_id=credential_id, + ) + if not cc_pair or not check_deletion_attempt_is_allowed( + connector_credential_pair=cc_pair + ): + raise ValueError( + "Cannot run deletion attempt - connector_credential_pair is not deletable. " + "This is likely because there is an ongoing / planned indexing attempt OR the " + "connector is not disabled." + ) + try: + # The bulk of the work is in here, updates Postgres and Vespa + return _delete_connector_credential_pair( + db_session=db_session, + document_index=get_default_document_index(), + cc_pair=cc_pair, + ) + except Exception as e: + logger.exception(f"Failed to run connector_deletion due to {e}") + raise e -@celery_app.task(soft_time_limit=60 * 60 * 6) # 6 hour time limit + +@celery_app.task(soft_time_limit=JOB_TIMEOUT) def sync_document_set_task(document_set_id: int) -> None: - try: - return sync_document_set(document_set_id=document_set_id) - except Exception: - logger.exception("Failed to sync document set %s", document_set_id) - raise + """For document sets marked as not up to date, sync the state from postgres + into the datastore. Also handles deletions.""" + + def _sync_document_batch( + document_ids: list[str], document_index: DocumentIndex + ) -> None: + logger.debug(f"Syncing document sets for: {document_ids}") + # begin a transaction, release lock at the end + with Session(get_sqlalchemy_engine()) as db_session: + # acquires a lock on the documents so that no other process can modify them + prepare_to_modify_documents( + db_session=db_session, document_ids=document_ids + ) + + # get current state of document sets for these documents + document_set_map = { + document_id: document_sets + for document_id, document_sets in fetch_document_sets_for_documents( + document_ids=document_ids, db_session=db_session + ) + } + + # update Vespa + document_index.update( + update_requests=[ + UpdateRequest( + document_ids=[document_id], + document_sets=set(document_set_map.get(document_id, [])), + ) + for document_id in document_ids + ] + ) + + with Session(get_sqlalchemy_engine()) as db_session: + task_name = name_document_set_sync_task(document_set_id) + mark_task_start(task_name, db_session) + + try: + document_index = get_default_document_index() + documents_to_update = fetch_documents_for_document_set( + document_set_id=document_set_id, + db_session=db_session, + current_only=False, + ) + for document_batch in batch_generator( + documents_to_update, _SYNC_BATCH_SIZE + ): + _sync_document_batch( + document_ids=[document.id for document in document_batch], + document_index=document_index, + ) + + # if there are no connectors, then delete the document set. Otherwise, just + # mark it as successfully synced. + document_set = cast( + DocumentSet, + get_document_set_by_id( + db_session=db_session, document_set_id=document_set_id + ), + ) # casting since we "know" a document set with this ID exists + if not document_set.connector_credential_pairs: + delete_document_set( + document_set_row=document_set, db_session=db_session + ) + logger.info( + f"Successfully deleted document set with ID: '{document_set_id}'!" + ) + else: + mark_document_set_as_synced( + document_set_id=document_set_id, db_session=db_session + ) + logger.info(f"Document set sync for '{document_set_id}' complete!") + + except Exception: + logger.exception("Failed to sync document set %s", document_set_id) + mark_task_finished(task_name, db_session, success=False) + raise + + mark_task_finished(task_name, db_session) + + +##### +# Periodic Tasks +##### +@celery_app.task( + name="check_for_document_sets_sync_task", + soft_time_limit=JOB_TIMEOUT, +) +def check_for_document_sets_sync_task() -> None: + """Runs periodically to check if any document sets are out of sync + Creates a task to sync the set if needed""" + with Session(get_sqlalchemy_engine()) as db_session: + # check if any document sets are not synced + document_set_info = fetch_document_sets( + db_session=db_session, include_outdated=True + ) + for document_set, _ in document_set_info: + if not document_set.is_up_to_date: + task_name = name_document_set_sync_task(document_set.id) + latest_sync = get_latest_task(task_name, db_session) + + if latest_sync and check_live_task_not_timed_out( + latest_sync, db_session + ): + logger.info( + f"Document set '{document_set.id}' is already syncing. Skipping." + ) + continue + + logger.info(f"Document set {document_set.id} syncing now!") + task = sync_document_set_task.apply_async( + kwargs=dict(document_set_id=document_set.id), + ) + register_task(task.id, task_name, db_session) + + +@celery_app.task(name="clean_old_temp_files_task", soft_time_limit=JOB_TIMEOUT) +def clean_old_temp_files_task( + age_threshold_in_hours: float | int = 24 * 7, # 1 week, + base_path: Path | str = FILE_CONNECTOR_TMP_STORAGE_PATH, +) -> None: + """Files added via the File connector need to be deleted after ingestion + Currently handled async of the indexing job""" + os.makedirs(base_path, exist_ok=True) + for file in os.listdir(base_path): + if file_age_in_hours(file) > age_threshold_in_hours: + os.remove(Path(base_path) / file) + + +##### +# Celery Beat (Periodic Tasks) Settings +##### +celery_app.conf.beat_schedule = { + "check-for-document-set-sync": { + "task": "check_for_document_sets_sync_task", + "schedule": timedelta(seconds=5), + }, + "clean-old-temp-files": { + "task": "clean_old_temp_files_task", + "schedule": timedelta(minutes=30), + }, +} diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py index c58624065f9..516b8a2ddb4 100644 --- a/backend/danswer/background/celery/celery_utils.py +++ b/backend/danswer/background/celery/celery_utils.py @@ -1,11 +1,15 @@ import json +from typing import cast from celery.result import AsyncResult from sqlalchemy import text from sqlalchemy.orm import Session from danswer.background.celery.celery import celery_app +from danswer.background.task_utils import name_cc_cleanup_task from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.models import DeletionStatus +from danswer.server.models import DeletionAttemptSnapshot def get_celery_task(task_id: str) -> AsyncResult: @@ -35,3 +39,37 @@ def get_celery_task_status(task_id: str) -> str | None: return task.status return None + + +def get_deletion_status( + connector_id: int, credential_id: int +) -> DeletionAttemptSnapshot | None: + cleanup_task_id = name_cc_cleanup_task( + connector_id=connector_id, credential_id=credential_id + ) + deletion_task = get_celery_task(task_id=cleanup_task_id) + deletion_task_status = get_celery_task_status(task_id=cleanup_task_id) + + deletion_status = None + error_msg = None + num_docs_deleted = 0 + if deletion_task_status == "SUCCESS": + deletion_status = DeletionStatus.SUCCESS + num_docs_deleted = cast(int, deletion_task.get(propagate=False)) + elif deletion_task_status == "FAILURE": + deletion_status = DeletionStatus.FAILED + error_msg = deletion_task.get(propagate=False) + elif deletion_task_status == "STARTED" or deletion_task_status == "PENDING": + deletion_status = DeletionStatus.IN_PROGRESS + + return ( + DeletionAttemptSnapshot( + connector_id=connector_id, + credential_id=credential_id, + status=deletion_status, + error_msg=str(error_msg), + num_docs_deleted=num_docs_deleted, + ) + if deletion_status + else None + ) diff --git a/backend/danswer/background/celery/deletion_utils.py b/backend/danswer/background/celery/deletion_utils.py deleted file mode 100644 index c6d056022bb..00000000000 --- a/backend/danswer/background/celery/deletion_utils.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import cast - -from danswer.background.celery.celery_utils import get_celery_task -from danswer.background.celery.celery_utils import get_celery_task_status -from danswer.background.connector_deletion import get_cleanup_task_id -from danswer.db.models import DeletionStatus -from danswer.server.models import DeletionAttemptSnapshot - - -def get_deletion_status( - connector_id: int, credential_id: int -) -> DeletionAttemptSnapshot | None: - cleanup_task_id = get_cleanup_task_id( - connector_id=connector_id, credential_id=credential_id - ) - deletion_task = get_celery_task(task_id=cleanup_task_id) - deletion_task_status = get_celery_task_status(task_id=cleanup_task_id) - - deletion_status = None - error_msg = None - num_docs_deleted = 0 - if deletion_task_status == "SUCCESS": - deletion_status = DeletionStatus.SUCCESS - num_docs_deleted = cast(int, deletion_task.get(propagate=False)) - elif deletion_task_status == "FAILURE": - deletion_status = DeletionStatus.FAILED - error_msg = deletion_task.get(propagate=False) - elif deletion_task_status == "STARTED" or deletion_task_status == "PENDING": - deletion_status = DeletionStatus.IN_PROGRESS - - return ( - DeletionAttemptSnapshot( - connector_id=connector_id, - credential_id=credential_id, - status=deletion_status, - error_msg=str(error_msg), - num_docs_deleted=num_docs_deleted, - ) - if deletion_status - else None - ) diff --git a/backend/danswer/background/connector_deletion.py b/backend/danswer/background/connector_deletion.py index f61553bb635..6aea6285019 100644 --- a/backend/danswer/background/connector_deletion.py +++ b/backend/danswer/background/connector_deletion.py @@ -17,15 +17,12 @@ from sqlalchemy.orm import Session from danswer.access.access import get_access_for_documents -from danswer.datastores.document_index import get_default_document_index from danswer.datastores.interfaces import DocumentIndex from danswer.datastores.interfaces import UpdateRequest from danswer.db.connector import fetch_connector_by_id from danswer.db.connector_credential_pair import ( delete_connector_credential_pair__no_commit, ) -from danswer.db.connector_credential_pair import get_connector_credential_pair -from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed from danswer.db.document import delete_document_by_connector_credential_pair from danswer.db.document import delete_documents_complete from danswer.db.document import get_document_connector_cnts @@ -211,39 +208,3 @@ def _delete_connector_credential_pair( f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs." ) return num_docs_deleted - - -def cleanup_connector_credential_pair( - connector_id: int, - credential_id: int, -) -> int: - engine = get_sqlalchemy_engine() - with Session(engine) as db_session: - # validate that the connector / credential pair is deletable - cc_pair = get_connector_credential_pair( - db_session=db_session, - connector_id=connector_id, - credential_id=credential_id, - ) - if not cc_pair or not check_deletion_attempt_is_allowed( - connector_credential_pair=cc_pair - ): - raise ValueError( - "Cannot run deletion attempt - connector_credential_pair is not deletable. " - "This is likely because there is an ongoing / planned indexing attempt OR the " - "connector is not disabled." - ) - - try: - return _delete_connector_credential_pair( - db_session=db_session, - document_index=get_default_document_index(), - cc_pair=cc_pair, - ) - except Exception as e: - logger.exception(f"Failed to run connector_deletion due to {e}") - raise e - - -def get_cleanup_task_id(connector_id: int, credential_id: int) -> str: - return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}" diff --git a/backend/danswer/background/document_set_sync_script.py b/backend/danswer/background/document_set_sync_script.py deleted file mode 100644 index 7761f1f5fbc..00000000000 --- a/backend/danswer/background/document_set_sync_script.py +++ /dev/null @@ -1,55 +0,0 @@ -from celery.result import AsyncResult -from sqlalchemy.orm import Session - -from danswer.background.celery.celery import sync_document_set_task -from danswer.background.utils import interval_run_job -from danswer.db.document_set import ( - fetch_document_sets, -) -from danswer.db.engine import get_sqlalchemy_engine -from danswer.utils.logger import setup_logger - -logger = setup_logger() - - -_ExistingTaskCache: dict[int, AsyncResult] = {} - - -def _document_sync_loop() -> None: - # cleanup tasks - existing_tasks = list(_ExistingTaskCache.items()) - for document_set_id, task in existing_tasks: - if task.ready(): - logger.info( - f"Document set '{document_set_id}' is complete with status " - f"{task.status}. Cleaning up." - ) - del _ExistingTaskCache[document_set_id] - - # kick off new tasks - with Session(get_sqlalchemy_engine()) as db_session: - # check if any document sets are not synced - document_set_info = fetch_document_sets( - db_session=db_session, include_outdated=True - ) - for document_set, _ in document_set_info: - if not document_set.is_up_to_date: - if document_set.id in _ExistingTaskCache: - logger.info( - f"Document set '{document_set.id}' is already syncing. Skipping." - ) - continue - - logger.info( - f"Document set {document_set.id} is not synced. Syncing now!" - ) - task = sync_document_set_task.apply_async( - kwargs=dict(document_set_id=document_set.id), - ) - _ExistingTaskCache[document_set.id] = task - - -if __name__ == "__main__": - interval_run_job( - job=_document_sync_loop, delay=5, emit_job_start_log=False - ) # run every 5 seconds diff --git a/backend/danswer/background/file_deletion.py b/backend/danswer/background/file_deletion.py deleted file mode 100644 index fe050d940e0..00000000000 --- a/backend/danswer/background/file_deletion.py +++ /dev/null @@ -1,6 +0,0 @@ -from danswer.background.utils import interval_run_job -from danswer.connectors.file.utils import clean_old_temp_files - - -if __name__ == "__main__": - interval_run_job(clean_old_temp_files, 30 * 60) # run every 30 minutes diff --git a/backend/danswer/background/task_utils.py b/backend/danswer/background/task_utils.py new file mode 100644 index 00000000000..ec7f79f5667 --- /dev/null +++ b/backend/danswer/background/task_utils.py @@ -0,0 +1,6 @@ +def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str: + return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}" + + +def name_document_set_sync_task(document_set_id: int) -> str: + return f"sync_doc_set_{document_set_id}" diff --git a/backend/danswer/background/utils.py b/backend/danswer/background/utils.py deleted file mode 100644 index b822e957193..00000000000 --- a/backend/danswer/background/utils.py +++ /dev/null @@ -1,24 +0,0 @@ -import time -from collections.abc import Callable -from typing import Any - -from danswer.utils.logger import setup_logger - - -logger = setup_logger() - - -def interval_run_job( - job: Callable[[], Any], delay: int | float, emit_job_start_log: bool = True -) -> None: - while True: - start = time.time() - if emit_job_start_log: - logger.info(f"Running '{job.__name__}', current time: {time.ctime(start)}") - try: - job() - except Exception as e: - logger.exception(f"Failed to run update due to {e}") - sleep_time = delay - (time.time() - start) - if sleep_time > 0: - time.sleep(sleep_time) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 8ba4a65c544..f43f4bb5545 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -211,7 +211,7 @@ # fairly large amount of memory in order to increase substantially, since # each worker loads the embedding models into memory. NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1) - +JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default # Logs every model prompt and output, mostly used for development or exploration purposes LOG_ALL_MODEL_INTERACTIONS = ( os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true" diff --git a/backend/danswer/connectors/file/utils.py b/backend/danswer/connectors/file/utils.py index 0e70e3b40d5..5bd3641babd 100644 --- a/backend/danswer/connectors/file/utils.py +++ b/backend/danswer/connectors/file/utils.py @@ -8,7 +8,6 @@ from danswer.configs.app_configs import FILE_CONNECTOR_TMP_STORAGE_PATH -_FILE_AGE_CLEANUP_THRESHOLD_HOURS = 24 * 7 # 1 week _VALID_FILE_EXTENSIONS = [".txt", ".zip", ".pdf"] @@ -53,13 +52,3 @@ def write_temp_files( def file_age_in_hours(filepath: str | Path) -> float: return (time.time() - os.path.getmtime(filepath)) / (60 * 60) - - -def clean_old_temp_files( - age_threshold_in_hours: float | int = _FILE_AGE_CLEANUP_THRESHOLD_HOURS, - base_path: Path | str = FILE_CONNECTOR_TMP_STORAGE_PATH, -) -> None: - os.makedirs(base_path, exist_ok=True) - for file in os.listdir(base_path): - if file_age_in_hours(file) > age_threshold_in_hours: - os.remove(Path(base_path) / file) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 40ea0ac6e0f..124aa9b2a5e 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -52,6 +52,14 @@ class DeletionStatus(str, PyEnum): FAILED = "failed" +# Consistent with Celery task statuses +class TaskStatus(str, PyEnum): + PENDING = "PENDING" + STARTED = "STARTED" + SUCCESS = "SUCCESS" + FAILURE = "FAILURE" + + class Base(DeclarativeBase): pass @@ -566,3 +574,22 @@ class SlackBotConfig(Base): ) persona: Mapped[Persona | None] = relationship("Persona") + + +class TaskQueueState(Base): + # Currently refers to Celery Tasks + __tablename__ = "task_queue_jobs" + + id: Mapped[int] = mapped_column(primary_key=True) + # Celery task id + task_id: Mapped[str] = mapped_column(String) + # For any job type, this would be the same + task_name: Mapped[str] = mapped_column(String) + # Note that if the task dies, this won't necessarily be marked FAILED correctly + status: Mapped[TaskStatus] = mapped_column(Enum(TaskStatus)) + start_time: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True) + ) + register_time: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) diff --git a/backend/danswer/db/tasks.py b/backend/danswer/db/tasks.py new file mode 100644 index 00000000000..a12f988613c --- /dev/null +++ b/backend/danswer/db/tasks.py @@ -0,0 +1,85 @@ +from sqlalchemy import desc +from sqlalchemy import func +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.configs.app_configs import JOB_TIMEOUT +from danswer.db.engine import get_db_current_time +from danswer.db.models import TaskQueueState +from danswer.db.models import TaskStatus + + +def get_latest_task( + task_name: str, + db_session: Session, +) -> TaskQueueState | None: + stmt = ( + select(TaskQueueState) + .where(TaskQueueState.task_name == task_name) + .order_by(desc(TaskQueueState.id)) + .limit(1) + ) + + result = db_session.execute(stmt) + latest_task = result.scalars().first() + + return latest_task + + +def register_task( + task_id: str, + task_name: str, + db_session: Session, +) -> TaskQueueState: + new_task = TaskQueueState( + task_id=task_id, task_name=task_name, status=TaskStatus.PENDING + ) + + db_session.add(new_task) + db_session.commit() + + return new_task + + +def mark_task_start( + task_name: str, + db_session: Session, +) -> None: + task = get_latest_task(task_name, db_session) + if not task: + raise ValueError(f"No task found with name {task_name}") + + task.start_time = func.now() # type: ignore + db_session.commit() + + +def mark_task_finished( + task_name: str, + db_session: Session, + success: bool = True, +) -> None: + latest_task = get_latest_task(task_name, db_session) + if latest_task is None: + raise ValueError(f"tasks for {task_name} do not exist") + + latest_task.status = TaskStatus.SUCCESS if success else TaskStatus.FAILURE + db_session.commit() + + +def check_live_task_not_timed_out( + task: TaskQueueState, + db_session: Session, + timeout: int = JOB_TIMEOUT, +) -> bool: + # We only care for live tasks to not create new periodic tasks + if task.status in [TaskStatus.SUCCESS, TaskStatus.FAILURE]: + return False + + current_db_time = get_db_current_time(db_session=db_session) + + last_update_time = task.register_time + if task.start_time: + last_update_time = max(task.register_time, task.start_time) + + time_elapsed = current_db_time - last_update_time + return time_elapsed.total_seconds() < timeout diff --git a/backend/danswer/document_set/document_set.py b/backend/danswer/document_set/document_set.py deleted file mode 100644 index e18d995b755..00000000000 --- a/backend/danswer/document_set/document_set.py +++ /dev/null @@ -1,84 +0,0 @@ -from typing import cast - -from sqlalchemy.orm import Session - -from danswer.datastores.document_index import get_default_document_index -from danswer.datastores.interfaces import DocumentIndex -from danswer.datastores.interfaces import UpdateRequest -from danswer.db.document import prepare_to_modify_documents -from danswer.db.document_set import delete_document_set -from danswer.db.document_set import fetch_document_sets_for_documents -from danswer.db.document_set import fetch_documents_for_document_set -from danswer.db.document_set import get_document_set_by_id -from danswer.db.document_set import mark_document_set_as_synced -from danswer.db.engine import get_sqlalchemy_engine -from danswer.db.models import DocumentSet -from danswer.utils.batching import batch_generator -from danswer.utils.logger import setup_logger - -logger = setup_logger() - -_SYNC_BATCH_SIZE = 1000 - - -def _sync_document_batch( - document_ids: list[str], document_index: DocumentIndex -) -> None: - logger.debug(f"Syncing document sets for: {document_ids}") - # begin a transaction, release lock at the end - with Session(get_sqlalchemy_engine()) as db_session: - # acquires a lock on the documents so that no other process can modify them - prepare_to_modify_documents(db_session=db_session, document_ids=document_ids) - - # get current state of document sets for these documents - document_set_map = { - document_id: document_sets - for document_id, document_sets in fetch_document_sets_for_documents( - document_ids=document_ids, db_session=db_session - ) - } - - # update Vespa - document_index.update( - update_requests=[ - UpdateRequest( - document_ids=[document_id], - document_sets=set(document_set_map.get(document_id, [])), - ) - for document_id in document_ids - ] - ) - - -def sync_document_set(document_set_id: int) -> None: - document_index = get_default_document_index() - with Session(get_sqlalchemy_engine()) as db_session: - documents_to_update = fetch_documents_for_document_set( - document_set_id=document_set_id, - db_session=db_session, - current_only=False, - ) - for document_batch in batch_generator(documents_to_update, _SYNC_BATCH_SIZE): - _sync_document_batch( - document_ids=[document.id for document in document_batch], - document_index=document_index, - ) - - # if there are no connectors, then delete the document set. Otherwise, just - # mark it as successfully synced. - document_set = cast( - DocumentSet, - get_document_set_by_id( - db_session=db_session, document_set_id=document_set_id - ), - ) # casting since we "know" a document set with this ID exists - if not document_set.connector_credential_pairs: - delete_document_set(document_set_row=document_set, db_session=db_session) - logger.info( - f"Successfully deleted document set with ID: '{document_set_id}'!" - ) - else: - mark_document_set_as_synced( - document_set_id=document_set_id, db_session=db_session - ) - logger.info(f"Document set sync for '{document_set_id}' complete!") diff --git a/backend/danswer/server/manage.py b/backend/danswer/server/manage.py index e51878b8c59..03d072341ab 100644 --- a/backend/danswer/server/manage.py +++ b/backend/danswer/server/manage.py @@ -14,11 +14,8 @@ from danswer.auth.users import current_admin_user from danswer.auth.users import current_user -from danswer.background.celery.celery import cleanup_connector_credential_pair_task -from danswer.background.celery.deletion_utils import get_deletion_status -from danswer.background.connector_deletion import ( - get_cleanup_task_id, -) +from danswer.background.celery.celery_utils import get_deletion_status +from danswer.background.task_utils import name_cc_cleanup_task from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY @@ -536,6 +533,8 @@ def create_deletion_attempt_for_connector_id( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> None: + from danswer.background.celery.celery import cleanup_connector_credential_pair_task + connector_id = connector_credential_pair_identifier.connector_id credential_id = connector_credential_pair_identifier.credential_id @@ -559,7 +558,7 @@ def create_deletion_attempt_for_connector_id( "no ongoing / planned indexing attempts.", ) - task_id = get_cleanup_task_id( + task_id = name_cc_cleanup_task( connector_id=connector_id, credential_id=credential_id ) cleanup_connector_credential_pair_task.apply_async( diff --git a/backend/danswer/utils/logger.py b/backend/danswer/utils/logger.py index 86173971d8e..c4dd59742b3 100644 --- a/backend/danswer/utils/logger.py +++ b/backend/danswer/utils/logger.py @@ -1,4 +1,5 @@ import logging +import os from collections.abc import MutableMapping from typing import Any @@ -52,7 +53,9 @@ def process( def setup_logger( - name: str = __name__, log_level: int = get_log_level_from_str() + name: str = __name__, + log_level: int = get_log_level_from_str(), + logfile_name: str | None = None, ) -> logging.LoggerAdapter: logger = logging.getLogger(name) @@ -73,4 +76,12 @@ def setup_logger( logger.addHandler(handler) + if logfile_name: + is_containerized = os.path.exists("/.dockerenv") + file_name_template = ( + "/var/log/{name}.log" if is_containerized else "./log/{name}.log" + ) + file_handler = logging.FileHandler(file_name_template.format(name=logfile_name)) + logger.addHandler(file_handler) + return _IndexAttemptLoggingAdapter(logger) diff --git a/backend/scripts/dev_run_celery.py b/backend/scripts/dev_run_celery.py new file mode 100644 index 00000000000..a83ccee09c6 --- /dev/null +++ b/backend/scripts/dev_run_celery.py @@ -0,0 +1,52 @@ +import subprocess +import threading + + +def monitor_process(process_name: str, process: subprocess.Popen) -> None: + assert process.stdout is not None + + while True: + output = process.stdout.readline() + + if output: + print(f"{process_name}: {output.strip()}") + + if process.poll() is not None: + break + + +def run_celery() -> None: + cmd_worker = [ + "celery", + "-A", + "danswer.background.celery", + "worker", + "--loglevel=INFO", + "--concurrency=1", + ] + cmd_beat = ["celery", "-A", "danswer.background.celery", "beat", "--loglevel=INFO"] + + # Redirect stderr to stdout for both processes + worker_process = subprocess.Popen( + cmd_worker, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True + ) + beat_process = subprocess.Popen( + cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True + ) + + # Monitor outputs using threads + worker_thread = threading.Thread( + target=monitor_process, args=("WORKER", worker_process) + ) + beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process)) + + worker_thread.start() + beat_thread.start() + + # Wait for threads to finish + worker_thread.join() + beat_thread.join() + + +if __name__ == "__main__": + run_celery() diff --git a/backend/supervisord.conf b/backend/supervisord.conf index b28fa002455..3613f5f5062 100644 --- a/backend/supervisord.conf +++ b/backend/supervisord.conf @@ -10,28 +10,23 @@ stdout_logfile_maxbytes=52428800 redirect_stderr=true autorestart=true -[program:celery] -command=celery -A danswer.background.celery worker --loglevel=INFO -stdout_logfile=/var/log/celery.log +# Background jobs that must be run async due to long time to completion +[program:celery_worker] +command=celery -A danswer.background.celery worker --loglevel=INFO --logfile=/var/log/celery_worker.log +stdout_logfile=/var/log/celery_worker_supervisor.log stdout_logfile_maxbytes=52428800 redirect_stderr=true autorestart=true -[program:file_deletion] -command=python danswer/background/file_deletion.py -stdout_logfile=/var/log/file_deletion.log +# Job scheduler for periodic tasks +[program:celery_beat] +command=celery -A danswer.background.celery beat --loglevel=INFO --logfile=/var/log/celery_beat.log +stdout_logfile=/var/log/celery_beat_supervisor.log stdout_logfile_maxbytes=52428800 redirect_stderr=true autorestart=true -[program:document_set_sync] -command=python danswer/background/document_set_sync_script.py -stdout_logfile=/var/log/document_set_sync.log -stdout_logfile_maxbytes=52428800 -redirect_stderr=true -autorestart=true - -# Listens for slack messages and responds with answers +# Listens for Slack messages and responds with answers # for all channels that the DanswerBot has been added to. # If not setup, this will just fail 5 times and then stop. # More details on setup here: https://docs.danswer.dev/slack_bot_setup @@ -44,9 +39,9 @@ autorestart=true startretries=5 startsecs=60 -# pushes all logs from the above programs to stdout +# Pushes all logs from the above programs to stdout [program:log-redirect-handler] -command=tail -qF /var/log/update.log /var/log/celery.log /var/log/file_deletion.log /var/log/slack_bot_listener.log /var/log/document_set_sync.log +command=tail -qF /var/log/update.log /var/log/celery_worker.log /var/log/celery_worker_supervisor.log /var/log/celery_beat.log /var/log/celery_beat_supervisor.log /var/log/slack_bot_listener.log stdout_logfile=/dev/stdout stdout_logfile_maxbytes=0 redirect_stderr=true