Skip to content

Commit

Permalink
backport double check
Browse files Browse the repository at this point in the history
  • Loading branch information
LostVector committed Nov 22, 2024
1 parent b11a3ad commit ac236a9
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 79 deletions.
4 changes: 3 additions & 1 deletion backend/danswer/background/celery/apps/primary.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import danswer.background.celery.apps.app_base as app_base
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.celery.tasks.vespa.tasks import get_unfenced_index_attempt_ids
from danswer.background.celery.tasks.indexing.tasks import (
get_unfenced_index_attempt_ids,
)
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
Expand Down
155 changes: 140 additions & 15 deletions backend/danswer/background/celery/tasks/indexing/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from http import HTTPStatus
from time import sleep

import redis
import sentry_sdk
from celery import Celery
from celery import shared_task
Expand Down Expand Up @@ -33,6 +34,8 @@
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import delete_index_attempt
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_failed
Expand All @@ -45,6 +48,7 @@
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_index import RedisConnectorIndexPayload
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
Expand Down Expand Up @@ -97,6 +101,54 @@ def progress(self, amount: int) -> None:
self.redis_client.incrby(self.generator_progress_key, amount)


def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []

# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = index_attempt in non terminal state
# outer = r.fence_key down

# check the db for index attempts in a non terminal state
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)

for attempt in attempts:
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)

# if the fence is down / doesn't exist, possible error but not confirmed
if r.exists(fence_key):
continue

# Between the time the attempts are first looked up and the time we see the fence down,
# the attempt may have completed and taken down the fence normally.

# We need to double check that the index attempt is still in a non terminal state
# and matches the original state, which confirms we are really in a bad state.
attempt_2 = get_index_attempt(db_session, attempt.id)
if not attempt_2:
continue

if attempt.status != attempt_2.status:
continue

unfenced_attempts.append(attempt.id)

return unfenced_attempts


@shared_task(
name="check_for_indexing",
soft_time_limit=300,
Expand All @@ -107,7 +159,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:

r = get_redis_client(tenant_id=tenant_id)

lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
Expand All @@ -117,6 +169,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
if not lock_beat.acquire(blocking=False):
return None

# check for search settings swap
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
Expand All @@ -135,13 +188,18 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
embedding_model=embedding_model,
)

# gather cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)

# kick off index attempts
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()

redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
Expand Down Expand Up @@ -198,6 +256,29 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
f"search_settings={search_settings_instance.id} "
)
tasks_created += 1

# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
lock_beat.reacquire()

attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue

failure_reason = (
f"Unfenced index attempt found in DB: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
task_logger.error(failure_reason)
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)

except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
Expand All @@ -207,6 +288,11 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_indexing - Lock not owned on completion: "
f"tenant={tenant_id}"
)

return tasks_created

Expand Down Expand Up @@ -311,10 +397,11 @@ def try_creating_indexing_task(
"""

LOCK_TIMEOUT = 30
index_attempt_id: int | None = None

# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock = r.lock(
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
Expand Down Expand Up @@ -365,6 +452,8 @@ def try_creating_indexing_task(

custom_task_id = redis_connector_index.generate_generator_task_id()

# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
"connector_indexing_proxy_task",
kwargs=dict(
Expand All @@ -385,13 +474,16 @@ def try_creating_indexing_task(
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
except Exception:
redis_connector_index.set_fence(None)
task_logger.exception(
f"Unexpected exception: "
f"try_creating_indexing_task - Unexpected exception: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id}"
)

if index_attempt_id is not None:
delete_index_attempt(db_session, index_attempt_id)
redis_connector_index.set_fence(None)
return None
finally:
if lock.owned():
Expand All @@ -409,15 +501,15 @@ def connector_indexing_proxy_task(
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
task_logger.info(
f"Indexing proxy - starting: attempt={index_attempt_id} "
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
client = SimpleJobClient()

job = client.submit(
connector_indexing_task,
connector_indexing_task_wrapper,
index_attempt_id,
cc_pair_id,
search_settings_id,
Expand All @@ -428,15 +520,15 @@ def connector_indexing_proxy_task(

if not job:
task_logger.info(
f"Indexing proxy - spawn failed: attempt={index_attempt_id} "
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return

task_logger.info(
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
Expand All @@ -460,7 +552,7 @@ def connector_indexing_proxy_task(

if job.status == "error":
task_logger.error(
f"Indexing proxy - spawned task exceptioned: "
f"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
Expand All @@ -472,14 +564,46 @@ def connector_indexing_proxy_task(
break

task_logger.info(
f"Indexing proxy - finished: attempt={index_attempt_id} "
f"Indexing watchdog - finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return


def connector_indexing_task_wrapper(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
) -> int | None:
"""Just wraps connector_indexing_task so we can log any exceptions before
re-raising it."""
result: int | None = None

try:
result = connector_indexing_task(
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
is_ee,
)
except:
logger.exception(
f"connector_indexing_task exceptioned: "
f"tenant={tenant_id} "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise

return result


def connector_indexing_task(
index_attempt_id: int,
cc_pair_id: int,
Expand Down Expand Up @@ -534,25 +658,26 @@ def connector_indexing_task(
if redis_connector.delete.fenced:
raise RuntimeError(
f"Indexing will not start because connector deletion is in progress: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}"
)

if redis_connector.stop.fenced:
raise RuntimeError(
f"Indexing will not start because a connector stop signal was detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}"
)

while True:
# wait for the fence to come up
if not redis_connector_index.fenced:
if not redis_connector_index.fenced: # The fence must exist
raise ValueError(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
)

payload = redis_connector_index.payload
payload = redis_connector_index.payload # The payload must exist
if not payload:
raise ValueError("connector_indexing_task: payload invalid or not found")

Expand All @@ -575,7 +700,7 @@ def connector_indexing_task(
)
break

lock = r.lock(
lock: RedisLock = r.lock(
redis_connector_index.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
)
Expand All @@ -584,7 +709,7 @@ def connector_indexing_task(
if not acquired:
logger.warning(
f"Indexing task already running, exiting...: "
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
)
return None

Expand Down
Loading

0 comments on commit ac236a9

Please sign in to comment.