Skip to content

Commit

Permalink
fresh indexing feature branch (#2790)
Browse files Browse the repository at this point in the history
* fresh indexing feature branch

* cherry pick test

* Revert "cherry pick test"

This reverts commit 2a62422.

* set multitenant so that vespa fields match when indexing

* cleanup pass

* mypy

* pass through env var to control celery indexing concurrency

* comments on task kickoff and some logging improvements

* use get_session_with_tenant

* comment out all of update.py

* rename to RedisConnectorIndexingFenceData

* first check num_indexing_workers

* refactor RedisConnectorIndexingFenceData

* comment out on_worker_process_init

* fix where num_indexing_workers falls back

* remove extra brace
  • Loading branch information
rkuo-danswer authored Oct 18, 2024
1 parent 12cbbe6 commit 6913efe
Show file tree
Hide file tree
Showing 29 changed files with 1,677 additions and 763 deletions.
109 changes: 101 additions & 8 deletions backend/danswer/background/celery/celery_app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import multiprocessing
import time
from datetime import timedelta
from typing import Any
Expand All @@ -12,6 +13,7 @@
from celery import Task
from celery.exceptions import WorkerShutdown
from celery.signals import beat_init
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
Expand All @@ -21,23 +23,32 @@

from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.update import get_all_tenant_ids
from danswer.background.celery.celery_utils import get_all_tenant_ids
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import SqlEngine
from danswer.db.search_settings import get_current_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SENTRY_DSN

logger = setup_logger()
Expand All @@ -62,8 +73,20 @@
) # Load configuration from 'celeryconfig.py'


@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
pass


@signals.task_postrun.connect
def celery_task_postrun(
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
Expand All @@ -80,6 +103,9 @@ def celery_task_postrun(
This function runs after any task completes (both success and failure)
Note that this signal does not fire on a task that failed to complete and is going
to be retried.
This also does not fire if a worker with acks_late=False crashes (which all of our
long running workers are)
"""
if not task:
return
Expand All @@ -101,32 +127,38 @@ def celery_task_postrun(
if task_id.startswith(RedisDocumentSet.PREFIX):
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
if document_set_id is not None:
rds = RedisDocumentSet(document_set_id)
rds = RedisDocumentSet(int(document_set_id))
r.srem(rds.taskset_key, task_id)
return

if task_id.startswith(RedisUserGroup.PREFIX):
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
if usergroup_id is not None:
rug = RedisUserGroup(usergroup_id)
rug = RedisUserGroup(int(usergroup_id))
r.srem(rug.taskset_key, task_id)
return

if task_id.startswith(RedisConnectorDeletion.PREFIX):
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcd = RedisConnectorDeletion(cc_pair_id)
rcd = RedisConnectorDeletion(int(cc_pair_id))
r.srem(rcd.taskset_key, task_id)
return

if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcp = RedisConnectorPruning(cc_pair_id)
rcp = RedisConnectorPruning(int(cc_pair_id))
r.srem(rcp.taskset_key, task_id)
return


@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
"""The first signal sent on celery worker startup"""
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn


@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
Expand All @@ -135,6 +167,9 @@ def on_beat_init(sender: Any, **kwargs: Any) -> None:

@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")

# decide some initial startup settings based on the celery worker's hostname
# (set at the command line)
hostname = sender.hostname
Expand All @@ -144,6 +179,30 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
elif hostname.startswith("heavy"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
elif hostname.startswith("indexing"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)

# TODO: why is this necessary for the indexer to do?
with get_session_with_tenant(tenant_id) as db_session:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)

# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed

if search_settings.provider_type is None:
logger.notice("Running a first inference to warm up embedding model")
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)

warm_up_bi_encoder(
embedding_model=embedding_model,
)
logger.notice("First inference complete.")
else:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
Expand Down Expand Up @@ -234,6 +293,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:

sender.primary_worker_lock = lock

# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)

Expand Down Expand Up @@ -270,6 +331,31 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)

for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key)

for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)

for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)

for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)


# @worker_process_init.connect
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:
# """This only runs inside child processes when the worker is in pool=prefork mode.
# This may be technically unnecessary since we're finding prefork pools to be
# unstable and currently aren't planning on using them."""
# logger.info("worker_process_init signal received.")
# SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
# SqlEngine.init_engine(pool_size=5, max_overflow=0)

# # https://stackoverflow.com/questions/43944787/sqlalchemy-celery-with-scoped-session-error
# SqlEngine.get_engine().dispose(close=False)


@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -318,7 +404,7 @@ def on_setup_logging(
# TODO: could unhardcode format and colorize and accept these as options from
# celery's config

# reformats celery's worker logger
# reformats the root logger
root_logger = logging.getLogger()

root_handler = logging.StreamHandler() # Set up a handler for the root logger
Expand Down Expand Up @@ -441,6 +527,7 @@ def stop(self, worker: Any) -> None:
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.indexing",
"danswer.background.celery.tasks.periodic",
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.shared",
Expand All @@ -467,9 +554,15 @@ def stop(self, worker: Any) -> None:
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_prune_task_2",
"task": "check_for_pruning",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
Expand Down
Loading

0 comments on commit 6913efe

Please sign in to comment.