diff --git a/backend/danswer/background/celery/apps/primary.py b/backend/danswer/background/celery/apps/primary.py index af8ee833d27..d427c7ee0b8 100644 --- a/backend/danswer/background/celery/apps/primary.py +++ b/backend/danswer/background/celery/apps/primary.py @@ -14,7 +14,9 @@ import danswer.background.celery.apps.app_base as app_base from danswer.background.celery.apps.app_base import task_logger from danswer.background.celery.celery_utils import celery_is_worker_primary -from danswer.background.celery.tasks.vespa.tasks import get_unfenced_index_attempt_ids +from danswer.background.celery.tasks.indexing.tasks import ( + get_unfenced_index_attempt_ids, +) from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT from danswer.configs.constants import DanswerRedisLocks from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME diff --git a/backend/danswer/background/celery/tasks/indexing/tasks.py b/backend/danswer/background/celery/tasks/indexing/tasks.py index 3bcb650e7c7..4e86ccd9e5d 100644 --- a/backend/danswer/background/celery/tasks/indexing/tasks.py +++ b/backend/danswer/background/celery/tasks/indexing/tasks.py @@ -3,6 +3,7 @@ from http import HTTPStatus from time import sleep +import redis import sentry_sdk from celery import Celery from celery import shared_task @@ -33,6 +34,8 @@ from danswer.db.enums import IndexingStatus from danswer.db.enums import IndexModelStatus from danswer.db.index_attempt import create_index_attempt +from danswer.db.index_attempt import delete_index_attempt +from danswer.db.index_attempt import get_all_index_attempts_by_status from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import get_last_attempt_for_cc_pair from danswer.db.index_attempt import mark_attempt_failed @@ -45,6 +48,7 @@ from danswer.natural_language_processing.search_nlp_models import EmbeddingModel from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder from danswer.redis.redis_connector import RedisConnector +from danswer.redis.redis_connector_index import RedisConnectorIndex from danswer.redis.redis_connector_index import RedisConnectorIndexPayload from danswer.redis.redis_pool import get_redis_client from danswer.utils.logger import setup_logger @@ -97,6 +101,54 @@ def progress(self, amount: int) -> None: self.redis_client.incrby(self.generator_progress_key, amount) +def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]: + """Gets a list of unfenced index attempts. Should not be possible, so we'd typically + want to clean them up. + + Unfenced = attempt not in terminal state and fence does not exist. + """ + unfenced_attempts: list[int] = [] + + # inner/outer/inner double check pattern to avoid race conditions when checking for + # bad state + # inner = index_attempt in non terminal state + # outer = r.fence_key down + + # check the db for index attempts in a non terminal state + attempts: list[IndexAttempt] = [] + attempts.extend( + get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session) + ) + attempts.extend( + get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session) + ) + + for attempt in attempts: + fence_key = RedisConnectorIndex.fence_key_with_ids( + attempt.connector_credential_pair_id, attempt.search_settings_id + ) + + # if the fence is down / doesn't exist, possible error but not confirmed + if r.exists(fence_key): + continue + + # Between the time the attempts are first looked up and the time we see the fence down, + # the attempt may have completed and taken down the fence normally. + + # We need to double check that the index attempt is still in a non terminal state + # and matches the original state, which confirms we are really in a bad state. + attempt_2 = get_index_attempt(db_session, attempt.id) + if not attempt_2: + continue + + if attempt.status != attempt_2.status: + continue + + unfenced_attempts.append(attempt.id) + + return unfenced_attempts + + @shared_task( name="check_for_indexing", soft_time_limit=300, @@ -107,7 +159,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: r = get_redis_client(tenant_id=tenant_id) - lock_beat = r.lock( + lock_beat: RedisLock = r.lock( DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK, timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, ) @@ -117,6 +169,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: if not lock_beat.acquire(blocking=False): return None + # check for search settings swap with get_session_with_tenant(tenant_id=tenant_id) as db_session: old_search_settings = check_index_swap(db_session=db_session) current_search_settings = get_current_search_settings(db_session) @@ -135,13 +188,18 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: embedding_model=embedding_model, ) + # gather cc_pair_ids cc_pair_ids: list[int] = [] with get_session_with_tenant(tenant_id) as db_session: + lock_beat.reacquire() cc_pairs = fetch_connector_credential_pairs(db_session) for cc_pair_entry in cc_pairs: cc_pair_ids.append(cc_pair_entry.id) + # kick off index attempts for cc_pair_id in cc_pair_ids: + lock_beat.reacquire() + redis_connector = RedisConnector(tenant_id, cc_pair_id) with get_session_with_tenant(tenant_id) as db_session: # Get the primary search settings @@ -198,6 +256,29 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: f"search_settings={search_settings_instance.id} " ) tasks_created += 1 + + # Fail any index attempts in the DB that don't have fences + # This shouldn't ever happen! + with get_session_with_tenant(tenant_id) as db_session: + unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r) + for attempt_id in unfenced_attempt_ids: + lock_beat.reacquire() + + attempt = get_index_attempt(db_session, attempt_id) + if not attempt: + continue + + failure_reason = ( + f"Unfenced index attempt found in DB: " + f"index_attempt={attempt.id} " + f"cc_pair={attempt.connector_credential_pair_id} " + f"search_settings={attempt.search_settings_id}" + ) + task_logger.error(failure_reason) + mark_attempt_failed( + attempt.id, db_session, failure_reason=failure_reason + ) + except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." @@ -207,6 +288,11 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: finally: if lock_beat.owned(): lock_beat.release() + else: + task_logger.error( + "check_for_indexing - Lock not owned on completion: " + f"tenant={tenant_id}" + ) return tasks_created @@ -311,10 +397,11 @@ def try_creating_indexing_task( """ LOCK_TIMEOUT = 30 + index_attempt_id: int | None = None # we need to serialize any attempt to trigger indexing since it can be triggered # either via celery beat or manually (API call) - lock = r.lock( + lock: RedisLock = r.lock( DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task", timeout=LOCK_TIMEOUT, ) @@ -365,6 +452,8 @@ def try_creating_indexing_task( custom_task_id = redis_connector_index.generate_generator_task_id() + # when the task is sent, we have yet to finish setting up the fence + # therefore, the task must contain code that blocks until the fence is ready result = celery_app.send_task( "connector_indexing_proxy_task", kwargs=dict( @@ -385,13 +474,16 @@ def try_creating_indexing_task( payload.celery_task_id = result.id redis_connector_index.set_fence(payload) except Exception: - redis_connector_index.set_fence(None) task_logger.exception( - f"Unexpected exception: " + f"try_creating_indexing_task - Unexpected exception: " f"tenant={tenant_id} " f"cc_pair={cc_pair.id} " f"search_settings={search_settings.id}" ) + + if index_attempt_id is not None: + delete_index_attempt(db_session, index_attempt_id) + redis_connector_index.set_fence(None) return None finally: if lock.owned(): @@ -409,7 +501,7 @@ def connector_indexing_proxy_task( ) -> None: """celery tasks are forked, but forking is unstable. This proxies work to a spawned task.""" task_logger.info( - f"Indexing proxy - starting: attempt={index_attempt_id} " + f"Indexing watchdog - starting: attempt={index_attempt_id} " f"tenant={tenant_id} " f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id}" @@ -417,7 +509,7 @@ def connector_indexing_proxy_task( client = SimpleJobClient() job = client.submit( - connector_indexing_task, + connector_indexing_task_wrapper, index_attempt_id, cc_pair_id, search_settings_id, @@ -428,7 +520,7 @@ def connector_indexing_proxy_task( if not job: task_logger.info( - f"Indexing proxy - spawn failed: attempt={index_attempt_id} " + f"Indexing watchdog - spawn failed: attempt={index_attempt_id} " f"tenant={tenant_id} " f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id}" @@ -436,7 +528,7 @@ def connector_indexing_proxy_task( return task_logger.info( - f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} " + f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} " f"tenant={tenant_id} " f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id}" @@ -460,7 +552,7 @@ def connector_indexing_proxy_task( if job.status == "error": task_logger.error( - f"Indexing proxy - spawned task exceptioned: " + f"Indexing watchdog - spawned task exceptioned: " f"attempt={index_attempt_id} " f"tenant={tenant_id} " f"cc_pair={cc_pair_id} " @@ -472,7 +564,7 @@ def connector_indexing_proxy_task( break task_logger.info( - f"Indexing proxy - finished: attempt={index_attempt_id} " + f"Indexing watchdog - finished: attempt={index_attempt_id} " f"tenant={tenant_id} " f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id}" @@ -480,6 +572,38 @@ def connector_indexing_proxy_task( return +def connector_indexing_task_wrapper( + index_attempt_id: int, + cc_pair_id: int, + search_settings_id: int, + tenant_id: str | None, + is_ee: bool, +) -> int | None: + """Just wraps connector_indexing_task so we can log any exceptions before + re-raising it.""" + result: int | None = None + + try: + result = connector_indexing_task( + index_attempt_id, + cc_pair_id, + search_settings_id, + tenant_id, + is_ee, + ) + except: + logger.exception( + f"connector_indexing_task exceptioned: " + f"tenant={tenant_id} " + f"index_attempt={index_attempt_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id}" + ) + raise + + return result + + def connector_indexing_task( index_attempt_id: int, cc_pair_id: int, @@ -534,6 +658,7 @@ def connector_indexing_task( if redis_connector.delete.fenced: raise RuntimeError( f"Indexing will not start because connector deletion is in progress: " + f"attempt={index_attempt_id} " f"cc_pair={cc_pair_id} " f"fence={redis_connector.delete.fence_key}" ) @@ -541,18 +666,18 @@ def connector_indexing_task( if redis_connector.stop.fenced: raise RuntimeError( f"Indexing will not start because a connector stop signal was detected: " + f"attempt={index_attempt_id} " f"cc_pair={cc_pair_id} " f"fence={redis_connector.stop.fence_key}" ) while True: - # wait for the fence to come up - if not redis_connector_index.fenced: + if not redis_connector_index.fenced: # The fence must exist raise ValueError( f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}" ) - payload = redis_connector_index.payload + payload = redis_connector_index.payload # The payload must exist if not payload: raise ValueError("connector_indexing_task: payload invalid or not found") @@ -575,7 +700,7 @@ def connector_indexing_task( ) break - lock = r.lock( + lock: RedisLock = r.lock( redis_connector_index.generator_lock_key, timeout=CELERY_INDEXING_LOCK_TIMEOUT, ) @@ -584,7 +709,7 @@ def connector_indexing_task( if not acquired: logger.warning( f"Indexing task already running, exiting...: " - f"cc_pair={cc_pair_id} search_settings={search_settings_id}" + f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}" ) return None diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index 12a1fe30d0e..ec7f52bc03c 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -5,7 +5,6 @@ from typing import cast import httpx -import redis from celery import Celery from celery import shared_task from celery import Task @@ -47,13 +46,10 @@ from danswer.db.document_set import get_document_set_by_id from danswer.db.document_set import mark_document_set_as_synced from danswer.db.engine import get_session_with_tenant -from danswer.db.enums import IndexingStatus from danswer.db.index_attempt import delete_index_attempts -from danswer.db.index_attempt import get_all_index_attempts_by_status from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import mark_attempt_failed from danswer.db.models import DocumentSet -from danswer.db.models import IndexAttempt from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import VespaDocumentFields @@ -649,20 +645,26 @@ def monitor_ccpair_indexing_taskset( # the task is still setting up return - # Read result state BEFORE generator_complete_key to avoid a race condition # never use any blocking methods on the result from inside a task! result: AsyncResult = AsyncResult(payload.celery_task_id) - result_state = result.state + # inner/outer/inner double check pattern to avoid race conditions when checking for + # bad state + + # inner = get_completion / generator_complete not signaled + # outer = result.state in READY state status_int = redis_connector_index.get_completion() - if status_int is None: # completion signal not set ... check for errors - # If we get here, and then the task both sets the completion signal and finishes, - # we will incorrectly abort the task. We must check result state, then check - # get_completion again to avoid the race condition. - if result_state in READY_STATES: + if status_int is None: # inner signal not set ... possible error + result_state = result.state + if ( + result_state in READY_STATES + ): # outer signal in terminal state ... possible error + # Now double check! if redis_connector_index.get_completion() is None: - # IF the task state is READY, THEN generator_complete should be set - # if it isn't, then the worker crashed + # inner signal still not set (and cannot change when outer result_state is READY) + # Task is finished but generator complete isn't set. + # We have a problem! Worker may have crashed. + msg = ( f"Connector indexing aborted or exceptioned: " f"attempt={payload.index_attempt_id} " @@ -697,37 +699,6 @@ def monitor_ccpair_indexing_taskset( redis_connector_index.reset() -def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]: - """Gets a list of unfenced index attempts. Should not be possible, so we'd typically - want to clean them up. - - Unfenced = attempt not in terminal state and fence does not exist. - """ - unfenced_attempts: list[int] = [] - - # do some cleanup before clearing fences - # check the db for any outstanding index attempts - attempts: list[IndexAttempt] = [] - attempts.extend( - get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session) - ) - attempts.extend( - get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session) - ) - - for attempt in attempts: - # if attempts exist in the db but we don't detect them in redis, mark them as failed - fence_key = RedisConnectorIndex.fence_key_with_ids( - attempt.connector_credential_pair_id, attempt.search_settings_id - ) - if r.exists(fence_key): - continue - - unfenced_attempts.append(attempt.id) - - return unfenced_attempts - - @shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True) def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: """This is a celery beat task that monitors and finalizes metadata sync tasksets. @@ -779,25 +750,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: f"permissions_sync={n_permissions_sync} " ) - # Fail any index attempts in the DB that don't have fences - with get_session_with_tenant(tenant_id) as db_session: - unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r) - for attempt_id in unfenced_attempt_ids: - attempt = get_index_attempt(db_session, attempt_id) - if not attempt: - continue - - failure_reason = ( - f"Unfenced index attempt found in DB: " - f"index_attempt={attempt.id} " - f"cc_pair={attempt.connector_credential_pair_id} " - f"search_settings={attempt.search_settings_id}" - ) - task_logger.warning(failure_reason) - mark_attempt_failed( - attempt.id, db_session, failure_reason=failure_reason - ) - lock_beat.reacquire() if r.exists(RedisConnectorCredentialPair.get_fence_key()): monitor_connector_taskset(r) diff --git a/backend/danswer/db/index_attempt.py b/backend/danswer/db/index_attempt.py index b9c3d9d4ca2..c0d28060ad5 100644 --- a/backend/danswer/db/index_attempt.py +++ b/backend/danswer/db/index_attempt.py @@ -67,6 +67,13 @@ def create_index_attempt( return new_attempt.id +def delete_index_attempt(db_session: Session, index_attempt_id: int) -> None: + index_attempt = get_index_attempt(db_session, index_attempt_id) + if index_attempt: + db_session.delete(index_attempt) + db_session.commit() + + def mock_successful_index_attempt( connector_credential_pair_id: int, search_settings_id: int,