Skip to content

Commit

Permalink
Celery Beat (#575)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhongsun96 authored Oct 16, 2023
1 parent a7ddb22 commit b5982c1
Show file tree
Hide file tree
Showing 19 changed files with 507 additions and 299 deletions.
7 changes: 4 additions & 3 deletions backend/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from celery.backends.database.session import ResultModelBase # type: ignore

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
Expand All @@ -21,7 +22,7 @@
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata
target_metadata = [Base.metadata, ResultModelBase.metadata]

# other values from the config, defined by the needs of env.py,
# can be acquired:
Expand All @@ -44,7 +45,7 @@ def run_migrations_offline() -> None:
url = build_connection_string()
context.configure(
url=url,
target_metadata=target_metadata,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
Expand All @@ -54,7 +55,7 @@ def run_migrations_offline() -> None:


def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata)
context.configure(connection=connection, target_metadata=target_metadata) # type: ignore

with context.begin_transaction():
context.run_migrations()
Expand Down
48 changes: 48 additions & 0 deletions backend/alembic/versions/78dbe7e38469_task_tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Task Tracking
Revision ID: 78dbe7e38469
Revises: 7ccea01261f6
Create Date: 2023-10-15 23:40:50.593262
"""
from alembic import op
import sqlalchemy as sa

# revision identifiers, used by Alembic.
revision = "78dbe7e38469"
down_revision = "7ccea01261f6"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table(
"task_queue_jobs",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("task_id", sa.String(), nullable=False),
sa.Column("task_name", sa.String(), nullable=False),
sa.Column(
"status",
sa.Enum(
"PENDING",
"STARTED",
"SUCCESS",
"FAILURE",
name="taskstatus",
native_enum=False,
),
nullable=False,
),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"register_time",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
)


def downgrade() -> None:
op.drop_table("task_queue_jobs")
230 changes: 218 additions & 12 deletions backend/danswer/background/celery/celery.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,39 @@
from celery import Celery
import os
from datetime import timedelta
from pathlib import Path
from typing import cast

from danswer.background.connector_deletion import cleanup_connector_credential_pair
from celery import Celery # type: ignore
from celery.result import AsyncResult
from sqlalchemy.orm import Session

from danswer.background.connector_deletion import _delete_connector_credential_pair
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import FILE_CONNECTOR_TMP_STORAGE_PATH
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.connectors.file.utils import file_age_in_hours
from danswer.datastores.document_index import get_default_document_index
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import UpdateRequest
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.document_set import fetch_documents_for_document_set
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import build_connection_string
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import SYNC_DB_API
from danswer.document_set.document_set import sync_document_set
from danswer.db.models import DocumentSet
from danswer.db.tasks import check_live_task_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.db.tasks import mark_task_finished
from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger

logger = setup_logger()
Expand All @@ -13,17 +43,193 @@
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)


@celery_app.task(soft_time_limit=60 * 60 * 6) # 6 hour time limit
_ExistingTaskCache: dict[int, AsyncResult] = {}
_SYNC_BATCH_SIZE = 1000


#####
# Tasks that need to be run in job queue, registered via APIs
#
# If imports from this module are needed, use local imports to avoid circular importing
#####
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def cleanup_connector_credential_pair_task(
connector_id: int, credential_id: int
connector_id: int,
credential_id: int,
) -> int:
return cleanup_connector_credential_pair(connector_id, credential_id)
"""Connector deletion task. This is run as an async task because it is a somewhat slow job.
Needs to potentially update a large number of Postgres and Vespa docs, including deleting them
or updating the ACL"""
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
# validate that the connector / credential pair is deletable
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair or not check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair
):
raise ValueError(
"Cannot run deletion attempt - connector_credential_pair is not deletable. "
"This is likely because there is an ongoing / planned indexing attempt OR the "
"connector is not disabled."
)

try:
# The bulk of the work is in here, updates Postgres and Vespa
return _delete_connector_credential_pair(
db_session=db_session,
document_index=get_default_document_index(),
cc_pair=cc_pair,
)
except Exception as e:
logger.exception(f"Failed to run connector_deletion due to {e}")
raise e

@celery_app.task(soft_time_limit=60 * 60 * 6) # 6 hour time limit

@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_document_set_task(document_set_id: int) -> None:
try:
return sync_document_set(document_set_id=document_set_id)
except Exception:
logger.exception("Failed to sync document set %s", document_set_id)
raise
"""For document sets marked as not up to date, sync the state from postgres
into the datastore. Also handles deletions."""

def _sync_document_batch(
document_ids: list[str], document_index: DocumentIndex
) -> None:
logger.debug(f"Syncing document sets for: {document_ids}")
# begin a transaction, release lock at the end
with Session(get_sqlalchemy_engine()) as db_session:
# acquires a lock on the documents so that no other process can modify them
prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
)

# get current state of document sets for these documents
document_set_map = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=document_ids, db_session=db_session
)
}

# update Vespa
document_index.update(
update_requests=[
UpdateRequest(
document_ids=[document_id],
document_sets=set(document_set_map.get(document_id, [])),
)
for document_id in document_ids
]
)

with Session(get_sqlalchemy_engine()) as db_session:
task_name = name_document_set_sync_task(document_set_id)
mark_task_start(task_name, db_session)

try:
document_index = get_default_document_index()
documents_to_update = fetch_documents_for_document_set(
document_set_id=document_set_id,
db_session=db_session,
current_only=False,
)
for document_batch in batch_generator(
documents_to_update, _SYNC_BATCH_SIZE
):
_sync_document_batch(
document_ids=[document.id for document in document_batch],
document_index=document_index,
)

# if there are no connectors, then delete the document set. Otherwise, just
# mark it as successfully synced.
document_set = cast(
DocumentSet,
get_document_set_by_id(
db_session=db_session, document_set_id=document_set_id
),
) # casting since we "know" a document set with this ID exists
if not document_set.connector_credential_pairs:
delete_document_set(
document_set_row=document_set, db_session=db_session
)
logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
)
else:
mark_document_set_as_synced(
document_set_id=document_set_id, db_session=db_session
)
logger.info(f"Document set sync for '{document_set_id}' complete!")

except Exception:
logger.exception("Failed to sync document set %s", document_set_id)
mark_task_finished(task_name, db_session, success=False)
raise

mark_task_finished(task_name, db_session)


#####
# Periodic Tasks
#####
@celery_app.task(
name="check_for_document_sets_sync_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_document_sets_sync_task() -> None:
"""Runs periodically to check if any document sets are out of sync
Creates a task to sync the set if needed"""
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
if not document_set.is_up_to_date:
task_name = name_document_set_sync_task(document_set.id)
latest_sync = get_latest_task(task_name, db_session)

if latest_sync and check_live_task_not_timed_out(
latest_sync, db_session
):
logger.info(
f"Document set '{document_set.id}' is already syncing. Skipping."
)
continue

logger.info(f"Document set {document_set.id} syncing now!")
task = sync_document_set_task.apply_async(
kwargs=dict(document_set_id=document_set.id),
)
register_task(task.id, task_name, db_session)


@celery_app.task(name="clean_old_temp_files_task", soft_time_limit=JOB_TIMEOUT)
def clean_old_temp_files_task(
age_threshold_in_hours: float | int = 24 * 7, # 1 week,
base_path: Path | str = FILE_CONNECTOR_TMP_STORAGE_PATH,
) -> None:
"""Files added via the File connector need to be deleted after ingestion
Currently handled async of the indexing job"""
os.makedirs(base_path, exist_ok=True)
for file in os.listdir(base_path):
if file_age_in_hours(file) > age_threshold_in_hours:
os.remove(Path(base_path) / file)


#####
# Celery Beat (Periodic Tasks) Settings
#####
celery_app.conf.beat_schedule = {
"check-for-document-set-sync": {
"task": "check_for_document_sets_sync_task",
"schedule": timedelta(seconds=5),
},
"clean-old-temp-files": {
"task": "clean_old_temp_files_task",
"schedule": timedelta(minutes=30),
},
}
38 changes: 38 additions & 0 deletions backend/danswer/background/celery/celery_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import json
from typing import cast

from celery.result import AsyncResult
from sqlalchemy import text
from sqlalchemy.orm import Session

from danswer.background.celery.celery import celery_app
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DeletionStatus
from danswer.server.models import DeletionAttemptSnapshot


def get_celery_task(task_id: str) -> AsyncResult:
Expand Down Expand Up @@ -35,3 +39,37 @@ def get_celery_task_status(task_id: str) -> str | None:
return task.status

return None


def get_deletion_status(
connector_id: int, credential_id: int
) -> DeletionAttemptSnapshot | None:
cleanup_task_id = name_cc_cleanup_task(
connector_id=connector_id, credential_id=credential_id
)
deletion_task = get_celery_task(task_id=cleanup_task_id)
deletion_task_status = get_celery_task_status(task_id=cleanup_task_id)

deletion_status = None
error_msg = None
num_docs_deleted = 0
if deletion_task_status == "SUCCESS":
deletion_status = DeletionStatus.SUCCESS
num_docs_deleted = cast(int, deletion_task.get(propagate=False))
elif deletion_task_status == "FAILURE":
deletion_status = DeletionStatus.FAILED
error_msg = deletion_task.get(propagate=False)
elif deletion_task_status == "STARTED" or deletion_task_status == "PENDING":
deletion_status = DeletionStatus.IN_PROGRESS

return (
DeletionAttemptSnapshot(
connector_id=connector_id,
credential_id=credential_id,
status=deletion_status,
error_msg=str(error_msg),
num_docs_deleted=num_docs_deleted,
)
if deletion_status
else None
)
Loading

1 comment on commit b5982c1

@vercel
Copy link

@vercel vercel bot commented on b5982c1 Oct 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.