Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/kill indexing #3213

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -37,7 +36,7 @@ class TaskDependencyError(RuntimeError):
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)

lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
Expand All @@ -60,7 +59,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
self.app, cc_pair_id, db_session, lock_beat, tenant_id
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
Expand All @@ -86,7 +85,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
app: Celery,
cc_pair_id: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock

from danswer.access.models import DocExternalAccess
from danswer.background.celery.apps.app_base import task_logger
Expand All @@ -27,7 +28,7 @@
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import doc_permission_sync_ctx
Expand Down Expand Up @@ -138,7 +139,7 @@ def try_creating_permissions_sync_task(

LOCK_TIMEOUT = 30

lock = r.lock(
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
)
Expand All @@ -162,7 +163,7 @@ def try_creating_permissions_sync_task(

custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"

app.send_task(
result = app.send_task(
"connector_permission_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
Expand All @@ -174,8 +175,8 @@ def try_creating_permissions_sync_task(
)

# set a basic fence to start
payload = RedisConnectorPermissionSyncData(
started=None,
payload = RedisConnectorPermissionSyncPayload(
started=None, celery_task_id=result.id
)

redis_connector.permissions.set_fence(payload)
Expand Down Expand Up @@ -245,9 +246,11 @@ def connector_permission_sync_generator_task(

logger.info(f"Syncing docs for {source_type}")

payload = RedisConnectorPermissionSyncData(
started=datetime.now(timezone.utc),
)
payload = redis_connector.permissions.payload
if not payload:
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")

payload.started = datetime.now(timezone.utc)
redis_connector.permissions.set_fence(payload)

document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock

from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
Expand All @@ -24,6 +25,9 @@
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_ext_group_sync import (
RedisConnectorExternalGroupSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
Expand Down Expand Up @@ -107,7 +111,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
cc_pair_ids_to_sync.append(cc_pair.id)

for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_permissions_sync_task(
tasks_created = try_creating_external_group_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
Expand All @@ -125,7 +129,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
lock_beat.release()


def try_creating_permissions_sync_task(
def try_creating_external_group_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
Expand Down Expand Up @@ -156,7 +160,7 @@ def try_creating_permissions_sync_task(

custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"

_ = app.send_task(
result = app.send_task(
"connector_external_group_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
Expand All @@ -166,8 +170,13 @@ def try_creating_permissions_sync_task(
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
# set a basic fence to start
redis_connector.external_group_sync.set_fence(True)

payload = RedisConnectorExternalGroupSyncPayload(
started=datetime.now(timezone.utc),
celery_task_id=result.id,
)

redis_connector.external_group_sync.set_fence(payload)

except Exception:
task_logger.exception(
Expand Down Expand Up @@ -203,7 +212,7 @@ def connector_external_group_sync_generator_task(

r = get_redis_client(tenant_id=tenant_id)

lock = r.lock(
lock: RedisLock = r.lock(
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
Expand Down Expand Up @@ -249,7 +258,6 @@ def connector_external_group_sync_generator_task(
)

mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)

except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
Expand All @@ -260,6 +268,6 @@ def connector_external_group_sync_generator_task(
raise e
finally:
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
redis_connector.external_group_sync.set_fence(False)
redis_connector.external_group_sync.set_fence(None)
if lock.owned():
lock.release()
42 changes: 28 additions & 14 deletions backend/danswer/background/celery/tasks/indexing/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_active_search_settings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
Expand Down Expand Up @@ -205,17 +205,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:

redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]

# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)

for search_settings_instance in search_settings:
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
for search_settings_instance in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
Expand All @@ -235,7 +228,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
secondary_index_building=len(search_settings) > 1,
secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
):
continue
Expand Down Expand Up @@ -495,8 +488,11 @@ def try_creating_indexing_task(
return index_attempt_id


@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
@shared_task(
name="connector_indexing_proxy_task", bind=True, acks_late=False, track_started=True
)
def connector_indexing_proxy_task(
self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
Expand All @@ -509,6 +505,10 @@ def connector_indexing_proxy_task(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)

if not self.request.id:
task_logger.error("self.request.id is None!")

client = SimpleJobClient()

job = client.submit(
Expand Down Expand Up @@ -537,9 +537,23 @@ def connector_indexing_proxy_task(
f"search_settings={search_settings_id}"
)

redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)

while True:
sleep(10)

if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing proxy - termination signal detected: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
)
job.cancel()
break

# do nothing for ongoing jobs that haven't been stopped
if not job.done():
with get_session_with_tenant(tenant_id) as db_session:
Expand Down
11 changes: 5 additions & 6 deletions backend/danswer/background/celery/tasks/vespa/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
Expand Down Expand Up @@ -588,17 +588,15 @@ def monitor_ccpair_permissions_taskset(
if remaining > 0:
return

payload: RedisConnectorPermissionSyncData | None = (
payload: RedisConnectorPermissionSyncPayload | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None

mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")

redis_connector.permissions.taskset_clear()
redis_connector.permissions.generator_clear()
redis_connector.permissions.set_fence(None)
redis_connector.permissions.reset()


def monitor_ccpair_indexing_taskset(
Expand Down Expand Up @@ -692,6 +690,7 @@ def monitor_ccpair_indexing_taskset(
task_logger.info(
f"Connector indexing finished: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
Expand Down Expand Up @@ -724,7 +723,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:

# print current queue lengths
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r)
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
)
Expand Down
4 changes: 3 additions & 1 deletion backend/danswer/background/celery/versioned_apps/primary.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery

from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable

set_is_ee_based_on_env_variable()
app = fetch_versioned_implementation(
app: Celery = fetch_versioned_implementation(
"danswer.background.celery.apps.primary", "celery_app"
)
19 changes: 19 additions & 0 deletions backend/danswer/db/search_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,25 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
return latest_settings


def get_active_search_settings(db_session: Session) -> list[SearchSettings]:
"""Returns active search settings. The first entry will always be the current search
settings. If there are new search settings that are being migrated to, those will be
the second entry."""
search_settings_list: list[SearchSettings] = []

# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings_list.append(primary_search_settings)

# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings_list.append(secondary_search_settings)

return search_settings_list


def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
query = select(SearchSettings).order_by(SearchSettings.id.desc())
result = db_session.execute(query)
Expand Down
Loading
Loading