diff --git a/.github/workflows/pr-Integration-tests.yml b/.github/workflows/pr-Integration-tests.yml index 283618ff993..a51dc743f33 100644 --- a/.github/workflows/pr-Integration-tests.yml +++ b/.github/workflows/pr-Integration-tests.yml @@ -197,7 +197,8 @@ jobs: -e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \ -e TEST_WEB_HOSTNAME=test-runner \ danswer/danswer-integration:test \ - /app/tests/integration/tests + /app/tests/integration/tests \ + /app/tests/integration/connector_job_tests continue-on-error: true id: run_tests diff --git a/.vscode/launch.template.jsonc b/.vscode/launch.template.jsonc index 87875907cd5..1f1faed097d 100644 --- a/.vscode/launch.template.jsonc +++ b/.vscode/launch.template.jsonc @@ -203,7 +203,7 @@ "--loglevel=INFO", "--hostname=light@%n", "-Q", - "vespa_metadata_sync,connector_deletion", + "vespa_metadata_sync,connector_deletion,doc_permissions_upsert", ], "presentation": { "group": "2", @@ -232,7 +232,7 @@ "--loglevel=INFO", "--hostname=heavy@%n", "-Q", - "connector_pruning", + "connector_pruning,connector_doc_permissions_sync,connector_external_group_sync", ], "presentation": { "group": "2", diff --git a/backend/alembic/versions/2daa494a0851_add_group_sync_time.py b/backend/alembic/versions/2daa494a0851_add_group_sync_time.py new file mode 100644 index 00000000000..c8a98f7693e --- /dev/null +++ b/backend/alembic/versions/2daa494a0851_add_group_sync_time.py @@ -0,0 +1,30 @@ +"""add-group-sync-time + +Revision ID: 2daa494a0851 +Revises: c0fd6e4da83a +Create Date: 2024-11-11 10:57:22.991157 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "2daa494a0851" +down_revision = "c0fd6e4da83a" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "connector_credential_pair", + sa.Column( + "last_time_external_group_sync", + sa.DateTime(timezone=True), + nullable=True, + ), + ) + + +def downgrade() -> None: + op.drop_column("connector_credential_pair", "last_time_external_group_sync") diff --git a/backend/danswer/access/models.py b/backend/danswer/access/models.py index 46b9c0efa93..126648eb41e 100644 --- a/backend/danswer/access/models.py +++ b/backend/danswer/access/models.py @@ -16,6 +16,41 @@ class ExternalAccess: is_public: bool +@dataclass(frozen=True) +class DocExternalAccess: + external_access: ExternalAccess + # The document ID + doc_id: str + + def to_dict(self) -> dict: + return { + "external_access": { + "external_user_emails": list(self.external_access.external_user_emails), + "external_user_group_ids": list( + self.external_access.external_user_group_ids + ), + "is_public": self.external_access.is_public, + }, + "doc_id": self.doc_id, + } + + @classmethod + def from_dict(cls, data: dict) -> "DocExternalAccess": + external_access = ExternalAccess( + external_user_emails=set( + data["external_access"].get("external_user_emails", []) + ), + external_user_group_ids=set( + data["external_access"].get("external_user_group_ids", []) + ), + is_public=data["external_access"]["is_public"], + ) + return cls( + external_access=external_access, + doc_id=data["doc_id"], + ) + + @dataclass(frozen=True) class DocumentAccess(ExternalAccess): # User emails for Danswer users, None indicates admin diff --git a/backend/danswer/background/celery/apps/app_base.py b/backend/danswer/background/celery/apps/app_base.py index 79e2e9739ae..d041ce0d2bc 100644 --- a/backend/danswer/background/celery/apps/app_base.py +++ b/backend/danswer/background/celery/apps/app_base.py @@ -24,6 +24,8 @@ from danswer.redis.redis_connector import RedisConnector from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair from danswer.redis.redis_connector_delete import RedisConnectorDelete +from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync +from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync from danswer.redis.redis_connector_prune import RedisConnectorPrune from danswer.redis.redis_document_set import RedisDocumentSet from danswer.redis.redis_pool import get_redis_client @@ -136,6 +138,22 @@ def on_task_postrun( RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r) return + if task_id.startswith(RedisConnectorPermissionSync.SUBTASK_PREFIX): + cc_pair_id = RedisConnector.get_id_from_task_id(task_id) + if cc_pair_id is not None: + RedisConnectorPermissionSync.remove_from_taskset( + int(cc_pair_id), task_id, r + ) + return + + if task_id.startswith(RedisConnectorExternalGroupSync.SUBTASK_PREFIX): + cc_pair_id = RedisConnector.get_id_from_task_id(task_id) + if cc_pair_id is not None: + RedisConnectorExternalGroupSync.remove_from_taskset( + int(cc_pair_id), task_id, r + ) + return + def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: """The first signal sent on celery worker startup""" diff --git a/backend/danswer/background/celery/apps/heavy.py b/backend/danswer/background/celery/apps/heavy.py index 3f8263267e0..714c91ee421 100644 --- a/backend/danswer/background/celery/apps/heavy.py +++ b/backend/danswer/background/celery/apps/heavy.py @@ -91,5 +91,7 @@ def on_setup_logging( celery_app.autodiscover_tasks( [ "danswer.background.celery.tasks.pruning", + "danswer.background.celery.tasks.doc_permission_syncing", + "danswer.background.celery.tasks.external_group_syncing", ] ) diff --git a/backend/danswer/background/celery/apps/light.py b/backend/danswer/background/celery/apps/light.py index 354257e9a98..17292743f9d 100644 --- a/backend/danswer/background/celery/apps/light.py +++ b/backend/danswer/background/celery/apps/light.py @@ -92,5 +92,6 @@ def on_setup_logging( "danswer.background.celery.tasks.shared", "danswer.background.celery.tasks.vespa", "danswer.background.celery.tasks.connector_deletion", + "danswer.background.celery.tasks.doc_permission_syncing", ] ) diff --git a/backend/danswer/background/celery/apps/primary.py b/backend/danswer/background/celery/apps/primary.py index 14d2b006bb8..5e4145e16d7 100644 --- a/backend/danswer/background/celery/apps/primary.py +++ b/backend/danswer/background/celery/apps/primary.py @@ -20,6 +20,8 @@ from danswer.db.engine import SqlEngine from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair from danswer.redis.redis_connector_delete import RedisConnectorDelete +from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync +from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync from danswer.redis.redis_connector_index import RedisConnectorIndex from danswer.redis.redis_connector_prune import RedisConnectorPrune from danswer.redis.redis_connector_stop import RedisConnectorStop @@ -134,6 +136,10 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None: RedisConnectorStop.reset_all(r) + RedisConnectorPermissionSync.reset_all(r) + + RedisConnectorExternalGroupSync.reset_all(r) + @worker_ready.connect def on_worker_ready(sender: Any, **kwargs: Any) -> None: @@ -233,6 +239,8 @@ def stop(self, worker: Any) -> None: "danswer.background.celery.tasks.connector_deletion", "danswer.background.celery.tasks.indexing", "danswer.background.celery.tasks.periodic", + "danswer.background.celery.tasks.doc_permission_syncing", + "danswer.background.celery.tasks.external_group_syncing", "danswer.background.celery.tasks.pruning", "danswer.background.celery.tasks.shared", "danswer.background.celery.tasks.vespa", diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py index d0df7af02d7..c8a125d11b1 100644 --- a/backend/danswer/background/celery/celery_utils.py +++ b/backend/danswer/background/celery/celery_utils.py @@ -81,7 +81,7 @@ def extract_ids_from_runnable_connector( callback: RunIndexingCallbackInterface | None = None, ) -> set[str]: """ - If the PruneConnector hasnt been implemented for the given connector, just pull + If the SlimConnector hasnt been implemented for the given connector, just pull all docs using the load_from_state and grab out the IDs. Optionally, a callback can be passed to handle the length of each document batch. diff --git a/backend/danswer/background/celery/tasks/beat_schedule.py b/backend/danswer/background/celery/tasks/beat_schedule.py index 6a20c6ba5c1..a6dc693d4d2 100644 --- a/backend/danswer/background/celery/tasks/beat_schedule.py +++ b/backend/danswer/background/celery/tasks/beat_schedule.py @@ -41,6 +41,18 @@ "schedule": timedelta(seconds=5), "options": {"priority": DanswerCeleryPriority.HIGH}, }, + { + "name": "check-for-doc-permissions-sync", + "task": "check_for_doc_permissions_sync", + "schedule": timedelta(seconds=30), + "options": {"priority": DanswerCeleryPriority.HIGH}, + }, + { + "name": "check-for-external-group-sync", + "task": "check_for_external_group_sync", + "schedule": timedelta(seconds=20), + "options": {"priority": DanswerCeleryPriority.HIGH}, + }, ] diff --git a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py index 360481015bb..85d27b2ba16 100644 --- a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py @@ -143,6 +143,12 @@ def try_generate_document_cc_pair_cleanup_tasks( f"cc_pair={cc_pair_id}" ) + if redis_connector.permissions.fenced: + raise TaskDependencyError( + f"Connector deletion - Delayed (permissions in progress): " + f"cc_pair={cc_pair_id}" + ) + # add tasks to celery and build up the task set to monitor in redis redis_connector.delete.taskset_clear() diff --git a/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py b/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py new file mode 100644 index 00000000000..e4a715f425a --- /dev/null +++ b/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py @@ -0,0 +1,321 @@ +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from uuid import uuid4 + +from celery import Celery +from celery import shared_task +from celery import Task +from celery.exceptions import SoftTimeLimitExceeded +from redis import Redis + +from danswer.access.models import DocExternalAccess +from danswer.background.celery.apps.app_base import task_logger +from danswer.configs.app_configs import JOB_TIMEOUT +from danswer.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT +from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT +from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX +from danswer.configs.constants import DanswerCeleryPriority +from danswer.configs.constants import DanswerCeleryQueues +from danswer.configs.constants import DanswerRedisLocks +from danswer.configs.constants import DocumentSource +from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id +from danswer.db.engine import get_session_with_tenant +from danswer.db.enums import AccessType +from danswer.db.enums import ConnectorCredentialPairStatus +from danswer.db.models import ConnectorCredentialPair +from danswer.db.users import batch_add_non_web_user_if_not_exists +from danswer.redis.redis_connector import RedisConnector +from danswer.redis.redis_connector_doc_perm_sync import ( + RedisConnectorPermissionSyncData, +) +from danswer.redis.redis_pool import get_redis_client +from danswer.utils.logger import doc_permission_sync_ctx +from danswer.utils.logger import setup_logger +from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs +from ee.danswer.db.document import upsert_document_external_perms +from ee.danswer.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS +from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP + +logger = setup_logger() + + +DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES = 3 + + +# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT +LIGHT_SOFT_TIME_LIMIT = 105 +LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15 + + +def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool: + """Returns boolean indicating if external doc permissions sync is due.""" + + if cc_pair.access_type != AccessType.SYNC: + return False + + # skip doc permissions sync if not active + if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE: + return False + + if cc_pair.status == ConnectorCredentialPairStatus.DELETING: + return False + + # If the last sync is None, it has never been run so we run the sync + last_perm_sync = cc_pair.last_time_perm_sync + if last_perm_sync is None: + return True + + source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source) + + # If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync. + if not source_sync_period: + return True + + # If the last sync is greater than the full fetch period, we run the sync + next_sync = last_perm_sync + timedelta(seconds=source_sync_period) + if datetime.now(timezone.utc) >= next_sync: + return True + + return False + + +@shared_task( + name="check_for_doc_permissions_sync", + soft_time_limit=JOB_TIMEOUT, + bind=True, +) +def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None: + r = get_redis_client(tenant_id=tenant_id) + + lock_beat = r.lock( + DanswerRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK, + timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, + ) + + try: + # these tasks should never overlap + if not lock_beat.acquire(blocking=False): + return + + # get all cc pairs that need to be synced + cc_pair_ids_to_sync: list[int] = [] + with get_session_with_tenant(tenant_id) as db_session: + cc_pairs = get_all_auto_sync_cc_pairs(db_session) + + for cc_pair in cc_pairs: + if _is_external_doc_permissions_sync_due(cc_pair): + cc_pair_ids_to_sync.append(cc_pair.id) + + for cc_pair_id in cc_pair_ids_to_sync: + tasks_created = try_creating_permissions_sync_task( + self.app, cc_pair_id, r, tenant_id + ) + if not tasks_created: + continue + + task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}") + except SoftTimeLimitExceeded: + task_logger.info( + "Soft time limit exceeded, task is being terminated gracefully." + ) + except Exception: + task_logger.exception(f"Unexpected exception: tenant={tenant_id}") + finally: + if lock_beat.owned(): + lock_beat.release() + + +def try_creating_permissions_sync_task( + app: Celery, + cc_pair_id: int, + r: Redis, + tenant_id: str | None, +) -> int | None: + """Returns an int if syncing is needed. The int represents the number of sync tasks generated. + Returns None if no syncing is required.""" + redis_connector = RedisConnector(tenant_id, cc_pair_id) + + LOCK_TIMEOUT = 30 + + lock = r.lock( + DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks", + timeout=LOCK_TIMEOUT, + ) + + acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2) + if not acquired: + return None + + try: + if redis_connector.permissions.fenced: + return None + + if redis_connector.delete.fenced: + return None + + if redis_connector.prune.fenced: + return None + + redis_connector.permissions.generator_clear() + redis_connector.permissions.taskset_clear() + + custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}" + + app.send_task( + "connector_permission_sync_generator_task", + kwargs=dict( + cc_pair_id=cc_pair_id, + tenant_id=tenant_id, + ), + queue=DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, + task_id=custom_task_id, + priority=DanswerCeleryPriority.HIGH, + ) + + # set a basic fence to start + payload = RedisConnectorPermissionSyncData( + started=None, + ) + + redis_connector.permissions.set_fence(payload) + except Exception: + task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}") + return None + finally: + if lock.owned(): + lock.release() + + return 1 + + +@shared_task( + name="connector_permission_sync_generator_task", + acks_late=False, + soft_time_limit=JOB_TIMEOUT, + track_started=True, + trail=False, + bind=True, +) +def connector_permission_sync_generator_task( + self: Task, + cc_pair_id: int, + tenant_id: str | None, +) -> None: + """ + Permission sync task that handles document permission syncing for a given connector credential pair + This task assumes that the task has already been properly fenced + """ + + doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get() + doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id + doc_permission_sync_ctx_dict["request_id"] = self.request.id + doc_permission_sync_ctx.set(doc_permission_sync_ctx_dict) + + redis_connector = RedisConnector(tenant_id, cc_pair_id) + + r = get_redis_client(tenant_id=tenant_id) + + lock = r.lock( + DanswerRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX + + f"_{redis_connector.id}", + timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT, + ) + + acquired = lock.acquire(blocking=False) + if not acquired: + task_logger.warning( + f"Permission sync task already running, exiting...: cc_pair={cc_pair_id}" + ) + return None + + try: + with get_session_with_tenant(tenant_id) as db_session: + cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + if cc_pair is None: + raise ValueError( + f"No connector credential pair found for id: {cc_pair_id}" + ) + + source_type = cc_pair.connector.source + + doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type) + if doc_sync_func is None: + raise ValueError(f"No doc sync func found for {source_type}") + + logger.info(f"Syncing docs for {source_type}") + + payload = RedisConnectorPermissionSyncData( + started=datetime.now(timezone.utc), + ) + redis_connector.permissions.set_fence(payload) + + document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair) + + task_logger.info( + f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}" + ) + tasks_generated = redis_connector.permissions.generate_tasks( + self.app, lock, document_external_accesses, source_type + ) + if tasks_generated is None: + return None + + task_logger.info( + f"RedisConnector.permissions.generate_tasks finished. " + f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}" + ) + + redis_connector.permissions.generator_complete = tasks_generated + + except Exception as e: + task_logger.exception(f"Failed to run permission sync: cc_pair={cc_pair_id}") + + redis_connector.permissions.generator_clear() + redis_connector.permissions.taskset_clear() + redis_connector.permissions.set_fence(None) + raise e + finally: + if lock.owned(): + lock.release() + + +@shared_task( + name="update_external_document_permissions_task", + soft_time_limit=LIGHT_SOFT_TIME_LIMIT, + time_limit=LIGHT_TIME_LIMIT, + max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES, + bind=True, +) +def update_external_document_permissions_task( + self: Task, + tenant_id: str | None, + serialized_doc_external_access: dict, + source_string: str, +) -> bool: + document_external_access = DocExternalAccess.from_dict( + serialized_doc_external_access + ) + doc_id = document_external_access.doc_id + external_access = document_external_access.external_access + try: + with get_session_with_tenant(tenant_id) as db_session: + # Then we build the update requests to update vespa + batch_add_non_web_user_if_not_exists( + db_session=db_session, + emails=list(external_access.external_user_emails), + ) + upsert_document_external_perms( + db_session=db_session, + doc_id=doc_id, + external_access=external_access, + source_type=DocumentSource(source_string), + ) + + logger.debug( + f"Successfully synced postgres document permissions for {doc_id}" + ) + return True + except Exception: + logger.exception("Error Syncing Document Permissions") + return False diff --git a/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py b/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py new file mode 100644 index 00000000000..4f7451faf76 --- /dev/null +++ b/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py @@ -0,0 +1,265 @@ +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from uuid import uuid4 + +from celery import Celery +from celery import shared_task +from celery import Task +from celery.exceptions import SoftTimeLimitExceeded +from redis import Redis + +from danswer.background.celery.apps.app_base import task_logger +from danswer.configs.app_configs import JOB_TIMEOUT +from danswer.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT +from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT +from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX +from danswer.configs.constants import DanswerCeleryPriority +from danswer.configs.constants import DanswerCeleryQueues +from danswer.configs.constants import DanswerRedisLocks +from danswer.db.connector import mark_cc_pair_as_external_group_synced +from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id +from danswer.db.engine import get_session_with_tenant +from danswer.db.enums import AccessType +from danswer.db.enums import ConnectorCredentialPairStatus +from danswer.db.models import ConnectorCredentialPair +from danswer.redis.redis_connector import RedisConnector +from danswer.redis.redis_pool import get_redis_client +from danswer.utils.logger import setup_logger +from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs +from ee.danswer.db.external_perm import ExternalUserGroup +from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair +from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIOD +from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP + +logger = setup_logger() + + +EXTERNAL_GROUPS_UPDATE_MAX_RETRIES = 3 + + +# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT +LIGHT_SOFT_TIME_LIMIT = 105 +LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15 + + +def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool: + """Returns boolean indicating if external group sync is due.""" + + if cc_pair.access_type != AccessType.SYNC: + return False + + # skip pruning if not active + if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE: + return False + + if cc_pair.status == ConnectorCredentialPairStatus.DELETING: + return False + + # If there is not group sync function for the connector, we don't run the sync + # This is fine because all sources dont necessarily have a concept of groups + if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source): + return False + + # If the last sync is None, it has never been run so we run the sync + last_ext_group_sync = cc_pair.last_time_external_group_sync + if last_ext_group_sync is None: + return True + + source_sync_period = EXTERNAL_GROUP_SYNC_PERIOD + + # If EXTERNAL_GROUP_SYNC_PERIOD is None, we always run the sync. + if not source_sync_period: + return True + + # If the last sync is greater than the full fetch period, we run the sync + next_sync = last_ext_group_sync + timedelta(seconds=source_sync_period) + if datetime.now(timezone.utc) >= next_sync: + return True + + return False + + +@shared_task( + name="check_for_external_group_sync", + soft_time_limit=JOB_TIMEOUT, + bind=True, +) +def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None: + r = get_redis_client(tenant_id=tenant_id) + + lock_beat = r.lock( + DanswerRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK, + timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, + ) + + try: + # these tasks should never overlap + if not lock_beat.acquire(blocking=False): + return + + cc_pair_ids_to_sync: list[int] = [] + with get_session_with_tenant(tenant_id) as db_session: + cc_pairs = get_all_auto_sync_cc_pairs(db_session) + + for cc_pair in cc_pairs: + if _is_external_group_sync_due(cc_pair): + cc_pair_ids_to_sync.append(cc_pair.id) + + for cc_pair_id in cc_pair_ids_to_sync: + tasks_created = try_creating_permissions_sync_task( + self.app, cc_pair_id, r, tenant_id + ) + if not tasks_created: + continue + + task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}") + except SoftTimeLimitExceeded: + task_logger.info( + "Soft time limit exceeded, task is being terminated gracefully." + ) + except Exception: + task_logger.exception(f"Unexpected exception: tenant={tenant_id}") + finally: + if lock_beat.owned(): + lock_beat.release() + + +def try_creating_permissions_sync_task( + app: Celery, + cc_pair_id: int, + r: Redis, + tenant_id: str | None, +) -> int | None: + """Returns an int if syncing is needed. The int represents the number of sync tasks generated. + Returns None if no syncing is required.""" + redis_connector = RedisConnector(tenant_id, cc_pair_id) + + LOCK_TIMEOUT = 30 + + lock = r.lock( + DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks", + timeout=LOCK_TIMEOUT, + ) + + acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2) + if not acquired: + return None + + try: + # Dont kick off a new sync if the previous one is still running + if redis_connector.external_group_sync.fenced: + return None + + redis_connector.external_group_sync.generator_clear() + redis_connector.external_group_sync.taskset_clear() + + custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}" + + _ = app.send_task( + "connector_external_group_sync_generator_task", + kwargs=dict( + cc_pair_id=cc_pair_id, + tenant_id=tenant_id, + ), + queue=DanswerCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, + task_id=custom_task_id, + priority=DanswerCeleryPriority.HIGH, + ) + # set a basic fence to start + redis_connector.external_group_sync.set_fence(True) + + except Exception: + task_logger.exception( + f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}" + ) + return None + finally: + if lock.owned(): + lock.release() + + return 1 + + +@shared_task( + name="connector_external_group_sync_generator_task", + acks_late=False, + soft_time_limit=JOB_TIMEOUT, + track_started=True, + trail=False, + bind=True, +) +def connector_external_group_sync_generator_task( + self: Task, + cc_pair_id: int, + tenant_id: str | None, +) -> None: + """ + Permission sync task that handles document permission syncing for a given connector credential pair + This task assumes that the task has already been properly fenced + """ + + redis_connector = RedisConnector(tenant_id, cc_pair_id) + + r = get_redis_client(tenant_id=tenant_id) + + lock = r.lock( + DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX + + f"_{redis_connector.id}", + timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT, + ) + + try: + acquired = lock.acquire(blocking=False) + if not acquired: + task_logger.warning( + f"External group sync task already running, exiting...: cc_pair={cc_pair_id}" + ) + return None + + with get_session_with_tenant(tenant_id) as db_session: + cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + if cc_pair is None: + raise ValueError( + f"No connector credential pair found for id: {cc_pair_id}" + ) + + source_type = cc_pair.connector.source + + ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type) + if ext_group_sync_func is None: + raise ValueError(f"No external group sync func found for {source_type}") + + logger.info(f"Syncing docs for {source_type}") + + external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair) + + logger.info( + f"Syncing {len(external_user_groups)} external user groups for {source_type}" + ) + + replace_user__ext_group_for_cc_pair( + db_session=db_session, + cc_pair_id=cc_pair.id, + group_defs=external_user_groups, + source=cc_pair.connector.source, + ) + logger.info( + f"Synced {len(external_user_groups)} external user groups for {source_type}" + ) + + mark_cc_pair_as_external_group_synced(db_session, cc_pair.id) + + except Exception as e: + task_logger.exception( + f"Failed to run external group sync: cc_pair={cc_pair_id}" + ) + + redis_connector.external_group_sync.generator_clear() + redis_connector.external_group_sync.taskset_clear() + raise e + finally: + # we always want to clear the fence after the task is done or failed so it doesn't get stuck + redis_connector.external_group_sync.set_fence(False) + if lock.owned(): + lock.release() diff --git a/backend/danswer/background/celery/tasks/pruning/tasks.py b/backend/danswer/background/celery/tasks/pruning/tasks.py index af80e6b886c..049840c051a 100644 --- a/backend/danswer/background/celery/tasks/pruning/tasks.py +++ b/backend/danswer/background/celery/tasks/pruning/tasks.py @@ -38,6 +38,35 @@ logger = setup_logger() +def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool: + """Returns boolean indicating if pruning is due.""" + + # skip pruning if no prune frequency is set + # pruning can still be forced via the API which will run a pruning task directly + if not cc_pair.connector.prune_freq: + return False + + # skip pruning if not active + if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE: + return False + + # skip pruning if the next scheduled prune time hasn't been reached yet + last_pruned = cc_pair.last_pruned + if not last_pruned: + if not cc_pair.last_successful_index_time: + # if we've never indexed, we can't prune + return False + + # if never pruned, use the last time the connector indexed successfully + last_pruned = cc_pair.last_successful_index_time + + next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq) + if datetime.now(timezone.utc) < next_prune: + return False + + return True + + @shared_task( name="check_for_pruning", soft_time_limit=JOB_TIMEOUT, @@ -69,7 +98,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None: if not cc_pair: continue - if not is_pruning_due(cc_pair, db_session, r): + if not _is_pruning_due(cc_pair): continue tasks_created = try_creating_prune_generator_task( @@ -90,47 +119,6 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None: lock_beat.release() -def is_pruning_due( - cc_pair: ConnectorCredentialPair, - db_session: Session, - r: Redis, -) -> bool: - """Returns an int if pruning is triggered. - The int represents the number of prune tasks generated (in this case, only one - because the task is a long running generator task.) - Returns None if no pruning is triggered (due to not being needed or - other reasons such as simultaneous pruning restrictions. - - Checks for scheduling related conditions, then delegates the rest of the checks to - try_creating_prune_generator_task. - """ - - # skip pruning if no prune frequency is set - # pruning can still be forced via the API which will run a pruning task directly - if not cc_pair.connector.prune_freq: - return False - - # skip pruning if not active - if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE: - return False - - # skip pruning if the next scheduled prune time hasn't been reached yet - last_pruned = cc_pair.last_pruned - if not last_pruned: - if not cc_pair.last_successful_index_time: - # if we've never indexed, we can't prune - return False - - # if never pruned, use the last time the connector indexed successfully - last_pruned = cc_pair.last_successful_index_time - - next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq) - if datetime.now(timezone.utc) < next_prune: - return False - - return True - - def try_creating_prune_generator_task( celery_app: Celery, cc_pair: ConnectorCredentialPair, @@ -166,10 +154,16 @@ def try_creating_prune_generator_task( return None try: - if redis_connector.prune.fenced: # skip pruning if already pruning + # skip pruning if already pruning + if redis_connector.prune.fenced: + return None + + # skip pruning if the cc_pair is deleting + if redis_connector.delete.fenced: return None - if redis_connector.delete.fenced: # skip pruning if the cc_pair is deleting + # skip pruning if doc permissions sync is running + if redis_connector.permissions.fenced: return None db_session.refresh(cc_pair) diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index b01a0eac815..00676d7b117 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -27,6 +27,7 @@ from danswer.configs.constants import DanswerCeleryQueues from danswer.configs.constants import DanswerRedisLocks from danswer.db.connector import fetch_connector_by_id +from danswer.db.connector import mark_cc_pair_as_permissions_synced from danswer.db.connector import mark_ccpair_as_pruned from danswer.db.connector_credential_pair import add_deletion_failure_message from danswer.db.connector_credential_pair import ( @@ -58,6 +59,10 @@ from danswer.redis.redis_connector import RedisConnector from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair from danswer.redis.redis_connector_delete import RedisConnectorDelete +from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync +from danswer.redis.redis_connector_doc_perm_sync import ( + RedisConnectorPermissionSyncData, +) from danswer.redis.redis_connector_index import RedisConnectorIndex from danswer.redis.redis_connector_prune import RedisConnectorPrune from danswer.redis.redis_document_set import RedisDocumentSet @@ -546,6 +551,47 @@ def monitor_ccpair_pruning_taskset( redis_connector.prune.set_fence(False) +def monitor_ccpair_permissions_taskset( + tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session +) -> None: + fence_key = key_bytes.decode("utf-8") + cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key) + if cc_pair_id_str is None: + task_logger.warning( + f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}" + ) + return + + cc_pair_id = int(cc_pair_id_str) + + redis_connector = RedisConnector(tenant_id, cc_pair_id) + if not redis_connector.permissions.fenced: + return + + initial = redis_connector.permissions.generator_complete + if initial is None: + return + + remaining = redis_connector.permissions.get_remaining() + task_logger.info( + f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}" + ) + if remaining > 0: + return + + payload: RedisConnectorPermissionSyncData | None = ( + redis_connector.permissions.payload + ) + start_time: datetime | None = payload.started if payload else None + + mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time) + task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}") + + redis_connector.permissions.taskset_clear() + redis_connector.permissions.generator_clear() + redis_connector.permissions.set_fence(None) + + def monitor_ccpair_indexing_taskset( tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session ) -> None: @@ -668,13 +714,17 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: n_pruning = celery_get_queue_length( DanswerCeleryQueues.CONNECTOR_PRUNING, r_celery ) + n_permissions_sync = celery_get_queue_length( + DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery + ) task_logger.info( f"Queue lengths: celery={n_celery} " f"indexing={n_indexing} " f"sync={n_sync} " f"deletion={n_deletion} " - f"pruning={n_pruning}" + f"pruning={n_pruning} " + f"permissions_sync={n_permissions_sync} " ) # do some cleanup before clearing fences @@ -688,20 +738,22 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session) ) - for a in attempts: + for attempt in attempts: # if attempts exist in the db but we don't detect them in redis, mark them as failed fence_key = RedisConnectorIndex.fence_key_with_ids( - a.connector_credential_pair_id, a.search_settings_id + attempt.connector_credential_pair_id, attempt.search_settings_id ) if not r.exists(fence_key): failure_reason = ( f"Unknown index attempt. Might be left over from a process restart: " - f"index_attempt={a.id} " - f"cc_pair={a.connector_credential_pair_id} " - f"search_settings={a.search_settings_id}" + f"index_attempt={attempt.id} " + f"cc_pair={attempt.connector_credential_pair_id} " + f"search_settings={attempt.search_settings_id}" ) task_logger.warning(failure_reason) - mark_attempt_failed(a.id, db_session, failure_reason=failure_reason) + mark_attempt_failed( + attempt.id, db_session, failure_reason=failure_reason + ) lock_beat.reacquire() if r.exists(RedisConnectorCredentialPair.get_fence_key()): @@ -741,6 +793,12 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: with get_session_with_tenant(tenant_id) as db_session: monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session) + lock_beat.reacquire() + for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"): + lock_beat.reacquire() + with get_session_with_tenant(tenant_id) as db_session: + monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session) + # uncomment for debugging if needed # r_celery = celery_app.broker_connection().channel().client # length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery) diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index 35cb080b903..2cd1fc07e2a 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -33,8 +33,8 @@ from danswer.indexing.embedder import DefaultIndexingEmbedder from danswer.indexing.indexing_heartbeat import IndexingHeartbeat from danswer.indexing.indexing_pipeline import build_indexing_pipeline -from danswer.utils.logger import IndexAttemptSingleton from danswer.utils.logger import setup_logger +from danswer.utils.logger import TaskAttemptSingleton from danswer.utils.variable_functionality import global_version logger = setup_logger() @@ -427,7 +427,7 @@ def run_indexing_entrypoint( # set the indexing attempt ID so that all log messages from this process # will have it added as a prefix - IndexAttemptSingleton.set_cc_and_index_id( + TaskAttemptSingleton.set_cc_and_index_id( index_attempt_id, connector_credential_pair_id ) with get_session_with_tenant(tenant_id) as db_session: diff --git a/backend/danswer/background/task_name_builders.py b/backend/danswer/background/task_name_builders.py deleted file mode 100644 index 3e24f2d2afe..00000000000 --- a/backend/danswer/background/task_name_builders.py +++ /dev/null @@ -1,4 +0,0 @@ -def name_sync_external_doc_permissions_task( - cc_pair_id: int, tenant_id: str | None = None -) -> str: - return f"sync_external_doc_permissions_task__{cc_pair_id}" diff --git a/backend/danswer/background/task_utils.py b/backend/danswer/background/task_utils.py index c1c24bf92a1..f4562892460 100644 --- a/backend/danswer/background/task_utils.py +++ b/backend/danswer/background/task_utils.py @@ -14,15 +14,6 @@ from danswer.db.tasks import register_task -def name_cc_prune_task( - connector_id: int | None = None, credential_id: int | None = None -) -> str: - task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}" - if not connector_id or not credential_id: - task_name = "prune_connector_credential_pair" - return task_name - - T = TypeVar("T", bound=Callable) diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 7b3ea8e81bb..028cbf65d35 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -80,6 +80,10 @@ # if we can get callbacks as object bytes download, we could lower this a lot. CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min +CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 300 # 5 min + +CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min + DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:" @@ -209,9 +213,17 @@ class PostgresAdvisoryLocks(Enum): class DanswerCeleryQueues: + # Light queue VESPA_METADATA_SYNC = "vespa_metadata_sync" + DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert" CONNECTOR_DELETION = "connector_deletion" + + # Heavy queue CONNECTOR_PRUNING = "connector_pruning" + CONNECTOR_DOC_PERMISSIONS_SYNC = "connector_doc_permissions_sync" + CONNECTOR_EXTERNAL_GROUP_SYNC = "connector_external_group_sync" + + # Indexing queue CONNECTOR_INDEXING = "connector_indexing" @@ -221,8 +233,18 @@ class DanswerRedisLocks: CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat" CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat" CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat" + CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = ( + "da_lock:check_connector_doc_permissions_sync_beat" + ) + CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = ( + "da_lock:check_connector_external_group_sync_beat" + ) MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat" + CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = ( + "da_lock:connector_doc_permissions_sync" + ) + CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX = "da_lock:connector_external_group_sync" PRUNING_LOCK_PREFIX = "da_lock:pruning" INDEXING_METADATA_PREFIX = "da_metadata:indexing" diff --git a/backend/danswer/connectors/confluence/connector.py b/backend/danswer/connectors/confluence/connector.py index 9c93f93f99b..6a376e3d55d 100644 --- a/backend/danswer/connectors/confluence/connector.py +++ b/backend/danswer/connectors/confluence/connector.py @@ -146,7 +146,7 @@ def _convert_object_to_document( # The url and the id are the same object_url = build_confluence_document_id( - self.wiki_base, confluence_object["_links"]["webui"] + self.wiki_base, confluence_object["_links"]["webui"], self.is_cloud ) object_text = None @@ -278,7 +278,9 @@ def retrieve_all_slim_documents( doc_metadata_list.append( SlimDocument( id=build_confluence_document_id( - self.wiki_base, page["_links"]["webui"] + self.wiki_base, + page["_links"]["webui"], + self.is_cloud, ), perm_sync_data=perm_sync_data, ) @@ -293,7 +295,9 @@ def retrieve_all_slim_documents( doc_metadata_list.append( SlimDocument( id=build_confluence_document_id( - self.wiki_base, attachment["_links"]["webui"] + self.wiki_base, + attachment["_links"]["webui"], + self.is_cloud, ), perm_sync_data=perm_sync_data, ) diff --git a/backend/danswer/connectors/confluence/utils.py b/backend/danswer/connectors/confluence/utils.py index beb0465be60..1ae59b8ffd5 100644 --- a/backend/danswer/connectors/confluence/utils.py +++ b/backend/danswer/connectors/confluence/utils.py @@ -153,7 +153,9 @@ def attachment_to_content( return extracted_text -def build_confluence_document_id(base_url: str, content_url: str) -> str: +def build_confluence_document_id( + base_url: str, content_url: str, is_cloud: bool +) -> str: """For confluence, the document id is the page url for a page based document or the attachment download url for an attachment based document @@ -164,6 +166,8 @@ def build_confluence_document_id(base_url: str, content_url: str) -> str: Returns: str: The document id """ + if is_cloud and not base_url.endswith("/wiki"): + base_url += "/wiki" return f"{base_url}{content_url}" diff --git a/backend/danswer/connectors/gmail/connector.py b/backend/danswer/connectors/gmail/connector.py index f9ef995d701..170e1219e11 100644 --- a/backend/danswer/connectors/gmail/connector.py +++ b/backend/danswer/connectors/gmail/connector.py @@ -305,6 +305,7 @@ def _fetch_slim_threads( query = _build_time_range_query(time_range_start, time_range_end) doc_batch = [] for user_email in self._get_all_user_emails(): + logger.info(f"Fetching slim threads for user: {user_email}") gmail_service = get_gmail_service(self.creds, user_email) for thread in execute_paginated_retrieval( retrieval_function=gmail_service.users().threads().list, diff --git a/backend/danswer/db/connector.py b/backend/danswer/db/connector.py index 835f74d437c..767a722eec4 100644 --- a/backend/danswer/db/connector.py +++ b/backend/danswer/db/connector.py @@ -282,3 +282,32 @@ def mark_ccpair_as_pruned(cc_pair_id: int, db_session: Session) -> None: cc_pair.last_pruned = datetime.now(timezone.utc) db_session.commit() + + +def mark_cc_pair_as_permissions_synced( + db_session: Session, cc_pair_id: int, start_time: datetime | None +) -> None: + stmt = select(ConnectorCredentialPair).where( + ConnectorCredentialPair.id == cc_pair_id + ) + cc_pair = db_session.scalar(stmt) + if cc_pair is None: + raise ValueError(f"No cc_pair with ID: {cc_pair_id}") + + cc_pair.last_time_perm_sync = start_time + db_session.commit() + + +def mark_cc_pair_as_external_group_synced(db_session: Session, cc_pair_id: int) -> None: + stmt = select(ConnectorCredentialPair).where( + ConnectorCredentialPair.id == cc_pair_id + ) + cc_pair = db_session.scalar(stmt) + if cc_pair is None: + raise ValueError(f"No cc_pair with ID: {cc_pair_id}") + + # The sync time can be marked after it ran because all group syncs + # are run in full, not polling for changes. + # If this changes, we need to update this function. + cc_pair.last_time_external_group_sync = datetime.now(timezone.utc) + db_session.commit() diff --git a/backend/danswer/db/document.py b/backend/danswer/db/document.py index 2e142a2c0b5..1797a68bf73 100644 --- a/backend/danswer/db/document.py +++ b/backend/danswer/db/document.py @@ -19,6 +19,7 @@ from sqlalchemy.sql.expression import null from danswer.configs.constants import DEFAULT_BOOST +from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.feedback import delete_document_feedback_for_documents__no_commit @@ -46,13 +47,21 @@ def count_documents_by_needs_sync(session: Session) -> int: """Get the count of all documents where: 1. last_modified is newer than last_synced 2. last_synced is null (meaning we've never synced) + AND the document has a relationship with a connector/credential pair + + TODO: The documents without a relationship with a connector/credential pair + should be cleaned up somehow eventually. This function executes the query and returns the count of documents matching the criteria.""" count = ( - session.query(func.count()) + session.query(func.count(DbDocument.id.distinct())) .select_from(DbDocument) + .join( + DocumentByConnectorCredentialPair, + DbDocument.id == DocumentByConnectorCredentialPair.id, + ) .filter( or_( DbDocument.last_modified > DbDocument.last_synced, @@ -91,6 +100,22 @@ def construct_document_select_for_connector_credential_pair_by_needs_sync( return stmt +def get_all_documents_needing_vespa_sync_for_cc_pair( + db_session: Session, cc_pair_id: int +) -> list[DbDocument]: + cc_pair = get_connector_credential_pair_from_id( + cc_pair_id=cc_pair_id, db_session=db_session + ) + if not cc_pair: + raise ValueError(f"No CC pair found with ID: {cc_pair_id}") + + stmt = construct_document_select_for_connector_credential_pair_by_needs_sync( + cc_pair.connector_id, cc_pair.credential_id + ) + + return list(db_session.scalars(stmt).all()) + + def construct_document_select_for_connector_credential_pair( connector_id: int, credential_id: int | None = None ) -> Select: @@ -104,6 +129,21 @@ def construct_document_select_for_connector_credential_pair( return stmt +def get_documents_for_cc_pair( + db_session: Session, + cc_pair_id: int, +) -> list[DbDocument]: + cc_pair = get_connector_credential_pair_from_id( + cc_pair_id=cc_pair_id, db_session=db_session + ) + if not cc_pair: + raise ValueError(f"No CC pair found with ID: {cc_pair_id}") + stmt = construct_document_select_for_connector_credential_pair( + connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id + ) + return list(db_session.scalars(stmt).all()) + + def get_document_ids_for_connector_credential_pair( db_session: Session, connector_id: int, credential_id: int, limit: int | None = None ) -> list[str]: @@ -268,7 +308,7 @@ def get_access_info_for_documents( return db_session.execute(stmt).all() # type: ignore -def upsert_documents( +def _upsert_documents( db_session: Session, document_metadata_batch: list[DocumentMetadata], initial_boost: int = DEFAULT_BOOST, @@ -306,6 +346,8 @@ def upsert_documents( ] ) + # This does not update the permissions of the document if + # the document already exists. on_conflict_stmt = insert_stmt.on_conflict_do_update( index_elements=["id"], # Conflict target set_={ @@ -322,7 +364,7 @@ def upsert_documents( db_session.commit() -def upsert_document_by_connector_credential_pair( +def _upsert_document_by_connector_credential_pair( db_session: Session, document_metadata_batch: list[DocumentMetadata] ) -> None: """NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause.""" @@ -404,8 +446,8 @@ def upsert_documents_complete( db_session: Session, document_metadata_batch: list[DocumentMetadata], ) -> None: - upsert_documents(db_session, document_metadata_batch) - upsert_document_by_connector_credential_pair(db_session, document_metadata_batch) + _upsert_documents(db_session, document_metadata_batch) + _upsert_document_by_connector_credential_pair(db_session, document_metadata_batch) logger.info( f"Upserted {len(document_metadata_batch)} document store entries into DB" ) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index ce8e21c52e8..f0fd61b0145 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -420,6 +420,9 @@ class ConnectorCredentialPair(Base): last_time_perm_sync: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) + last_time_external_group_sync: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) # Time finished, not used for calculating backend jobs which uses time started (created) last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), default=None diff --git a/backend/danswer/db/users.py b/backend/danswer/db/users.py index 1ff21b71006..2c173f04443 100644 --- a/backend/danswer/db/users.py +++ b/backend/danswer/db/users.py @@ -97,3 +97,18 @@ def batch_add_non_web_user_if_not_exists__no_commit( db_session.flush() # generate ids return found_users + new_users + + +def batch_add_non_web_user_if_not_exists( + db_session: Session, emails: list[str] +) -> list[User]: + found_users, missing_user_emails = get_users_by_emails(db_session, emails) + + new_users: list[User] = [] + for email in missing_user_emails: + new_users.append(_generate_non_web_user(email=email)) + + db_session.add_all(new_users) + db_session.commit() + + return found_users + new_users diff --git a/backend/danswer/indexing/indexing_pipeline.py b/backend/danswer/indexing/indexing_pipeline.py index 507956ff40f..d2f25d5d377 100644 --- a/backend/danswer/indexing/indexing_pipeline.py +++ b/backend/danswer/indexing/indexing_pipeline.py @@ -56,7 +56,7 @@ def __call__( ... -def upsert_documents_in_db( +def _upsert_documents_in_db( documents: list[Document], index_attempt_metadata: IndexAttemptMetadata, db_session: Session, @@ -243,7 +243,7 @@ def index_doc_batch_prepare( # Create records in the source of truth about these documents, # does not include doc_updated_at which is also used to indicate a successful update - upsert_documents_in_db( + _upsert_documents_in_db( documents=documents, index_attempt_metadata=index_attempt_metadata, db_session=db_session, diff --git a/backend/danswer/redis/redis_connector.py b/backend/danswer/redis/redis_connector.py index df61f986ede..8b52a2fd811 100644 --- a/backend/danswer/redis/redis_connector.py +++ b/backend/danswer/redis/redis_connector.py @@ -1,6 +1,8 @@ import redis from danswer.redis.redis_connector_delete import RedisConnectorDelete +from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync +from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync from danswer.redis.redis_connector_index import RedisConnectorIndex from danswer.redis.redis_connector_prune import RedisConnectorPrune from danswer.redis.redis_connector_stop import RedisConnectorStop @@ -19,6 +21,10 @@ def __init__(self, tenant_id: str | None, id: int) -> None: self.stop = RedisConnectorStop(tenant_id, id, self.redis) self.prune = RedisConnectorPrune(tenant_id, id, self.redis) self.delete = RedisConnectorDelete(tenant_id, id, self.redis) + self.permissions = RedisConnectorPermissionSync(tenant_id, id, self.redis) + self.external_group_sync = RedisConnectorExternalGroupSync( + tenant_id, id, self.redis + ) def new_index(self, search_settings_id: int) -> RedisConnectorIndex: return RedisConnectorIndex( diff --git a/backend/danswer/redis/redis_connector_doc_perm_sync.py b/backend/danswer/redis/redis_connector_doc_perm_sync.py new file mode 100644 index 00000000000..357523cf85c --- /dev/null +++ b/backend/danswer/redis/redis_connector_doc_perm_sync.py @@ -0,0 +1,187 @@ +import time +from datetime import datetime +from typing import cast +from uuid import uuid4 + +import redis +from celery import Celery +from pydantic import BaseModel + +from danswer.access.models import DocExternalAccess +from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT +from danswer.configs.constants import DanswerCeleryPriority +from danswer.configs.constants import DanswerCeleryQueues + + +class RedisConnectorPermissionSyncData(BaseModel): + started: datetime | None + + +class RedisConnectorPermissionSync: + """Manages interactions with redis for doc permission sync tasks. Should only be accessed + through RedisConnector.""" + + PREFIX = "connectordocpermissionsync" + + FENCE_PREFIX = f"{PREFIX}_fence" + + # phase 1 - geneartor task and progress signals + GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorpermissions+generator + GENERATOR_PROGRESS_PREFIX = ( + PREFIX + "_generator_progress" + ) # connectorpermissions_generator_progress + GENERATOR_COMPLETE_PREFIX = ( + PREFIX + "_generator_complete" + ) # connectorpermissions_generator_complete + + TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpermissions_taskset + SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpermissions+sub + + def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None: + self.tenant_id: str | None = tenant_id + self.id = id + self.redis = redis + + self.fence_key: str = f"{self.FENCE_PREFIX}_{id}" + self.generator_task_key = f"{self.GENERATORTASK_PREFIX}_{id}" + self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}" + self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}" + + self.taskset_key = f"{self.TASKSET_PREFIX}_{id}" + + self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}" + + def taskset_clear(self) -> None: + self.redis.delete(self.taskset_key) + + def generator_clear(self) -> None: + self.redis.delete(self.generator_progress_key) + self.redis.delete(self.generator_complete_key) + + def get_remaining(self) -> int: + remaining = cast(int, self.redis.scard(self.taskset_key)) + return remaining + + def get_active_task_count(self) -> int: + """Count of active permission sync tasks""" + count = 0 + for _ in self.redis.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"): + count += 1 + return count + + @property + def fenced(self) -> bool: + if self.redis.exists(self.fence_key): + return True + + return False + + @property + def payload(self) -> RedisConnectorPermissionSyncData | None: + # read related data and evaluate/print task progress + fence_bytes = cast(bytes, self.redis.get(self.fence_key)) + if fence_bytes is None: + return None + + fence_str = fence_bytes.decode("utf-8") + payload = RedisConnectorPermissionSyncData.model_validate_json( + cast(str, fence_str) + ) + + return payload + + def set_fence( + self, + payload: RedisConnectorPermissionSyncData | None, + ) -> None: + if not payload: + self.redis.delete(self.fence_key) + return + + self.redis.set(self.fence_key, payload.model_dump_json()) + + @property + def generator_complete(self) -> int | None: + """the fence payload is an int representing the starting number of + permission sync tasks to be processed ... just after the generator completes.""" + fence_bytes = self.redis.get(self.generator_complete_key) + if fence_bytes is None: + return None + + if fence_bytes == b"None": + return None + + fence_int = int(cast(bytes, fence_bytes).decode()) + return fence_int + + @generator_complete.setter + def generator_complete(self, payload: int | None) -> None: + """Set the payload to an int to set the fence, otherwise if None it will + be deleted""" + if payload is None: + self.redis.delete(self.generator_complete_key) + return + + self.redis.set(self.generator_complete_key, payload) + + def generate_tasks( + self, + celery_app: Celery, + lock: redis.lock.Lock | None, + new_permissions: list[DocExternalAccess], + source_string: str, + ) -> int | None: + last_lock_time = time.monotonic() + async_results = [] + + # Create a task for each document permission sync + for doc_perm in new_permissions: + current_time = time.monotonic() + if lock and current_time - last_lock_time >= ( + CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 + ): + lock.reacquire() + last_lock_time = current_time + # Add task for document permissions sync + custom_task_id = f"{self.subtask_prefix}_{uuid4()}" + self.redis.sadd(self.taskset_key, custom_task_id) + + result = celery_app.send_task( + "update_external_document_permissions_task", + kwargs=dict( + tenant_id=self.tenant_id, + serialized_doc_external_access=doc_perm.to_dict(), + source_string=source_string, + ), + queue=DanswerCeleryQueues.DOC_PERMISSIONS_UPSERT, + task_id=custom_task_id, + priority=DanswerCeleryPriority.MEDIUM, + ) + async_results.append(result) + + return len(async_results) + + @staticmethod + def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None: + taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}" + r.srem(taskset_key, task_id) + return + + @staticmethod + def reset_all(r: redis.Redis) -> None: + """Deletes all redis values for all connectors""" + for key in r.scan_iter(RedisConnectorPermissionSync.TASKSET_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter( + RedisConnectorPermissionSync.GENERATOR_COMPLETE_PREFIX + "*" + ): + r.delete(key) + + for key in r.scan_iter( + RedisConnectorPermissionSync.GENERATOR_PROGRESS_PREFIX + "*" + ): + r.delete(key) + + for key in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"): + r.delete(key) diff --git a/backend/danswer/redis/redis_connector_ext_group_sync.py b/backend/danswer/redis/redis_connector_ext_group_sync.py new file mode 100644 index 00000000000..dadc00d0b0a --- /dev/null +++ b/backend/danswer/redis/redis_connector_ext_group_sync.py @@ -0,0 +1,133 @@ +from typing import cast + +import redis +from celery import Celery +from sqlalchemy.orm import Session + + +class RedisConnectorExternalGroupSync: + """Manages interactions with redis for external group syncing tasks. Should only be accessed + through RedisConnector.""" + + PREFIX = "connectorexternalgroupsync" + + FENCE_PREFIX = f"{PREFIX}_fence" + + # phase 1 - geneartor task and progress signals + GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorexternalgroupsync+generator + GENERATOR_PROGRESS_PREFIX = ( + PREFIX + "_generator_progress" + ) # connectorexternalgroupsync_generator_progress + GENERATOR_COMPLETE_PREFIX = ( + PREFIX + "_generator_complete" + ) # connectorexternalgroupsync_generator_complete + + TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorexternalgroupsync_taskset + SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorexternalgroupsync+sub + + def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None: + self.tenant_id: str | None = tenant_id + self.id = id + self.redis = redis + + self.fence_key: str = f"{self.FENCE_PREFIX}_{id}" + self.generator_task_key = f"{self.GENERATORTASK_PREFIX}_{id}" + self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}" + self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}" + + self.taskset_key = f"{self.TASKSET_PREFIX}_{id}" + + self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}" + + def taskset_clear(self) -> None: + self.redis.delete(self.taskset_key) + + def generator_clear(self) -> None: + self.redis.delete(self.generator_progress_key) + self.redis.delete(self.generator_complete_key) + + def get_remaining(self) -> int: + # todo: move into fence + remaining = cast(int, self.redis.scard(self.taskset_key)) + return remaining + + def get_active_task_count(self) -> int: + """Count of active external group syncing tasks""" + count = 0 + for _ in self.redis.scan_iter( + RedisConnectorExternalGroupSync.FENCE_PREFIX + "*" + ): + count += 1 + return count + + @property + def fenced(self) -> bool: + if self.redis.exists(self.fence_key): + return True + + return False + + def set_fence(self, value: bool) -> None: + if not value: + self.redis.delete(self.fence_key) + return + + self.redis.set(self.fence_key, 0) + + @property + def generator_complete(self) -> int | None: + """the fence payload is an int representing the starting number of + external group syncing tasks to be processed ... just after the generator completes. + """ + fence_bytes = self.redis.get(self.generator_complete_key) + if fence_bytes is None: + return None + + if fence_bytes == b"None": + return None + + fence_int = int(cast(bytes, fence_bytes).decode()) + return fence_int + + @generator_complete.setter + def generator_complete(self, payload: int | None) -> None: + """Set the payload to an int to set the fence, otherwise if None it will + be deleted""" + if payload is None: + self.redis.delete(self.generator_complete_key) + return + + self.redis.set(self.generator_complete_key, payload) + + def generate_tasks( + self, + celery_app: Celery, + db_session: Session, + lock: redis.lock.Lock | None, + ) -> int | None: + pass + + @staticmethod + def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None: + taskset_key = f"{RedisConnectorExternalGroupSync.TASKSET_PREFIX}_{id}" + r.srem(taskset_key, task_id) + return + + @staticmethod + def reset_all(r: redis.Redis) -> None: + """Deletes all redis values for all connectors""" + for key in r.scan_iter(RedisConnectorExternalGroupSync.TASKSET_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter( + RedisConnectorExternalGroupSync.GENERATOR_COMPLETE_PREFIX + "*" + ): + r.delete(key) + + for key in r.scan_iter( + RedisConnectorExternalGroupSync.GENERATOR_PROGRESS_PREFIX + "*" + ): + r.delete(key) + + for key in r.scan_iter(RedisConnectorExternalGroupSync.FENCE_PREFIX + "*"): + r.delete(key) diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index cc64ac563f6..4b38eec7f71 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -12,13 +12,13 @@ from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot +from danswer.background.celery.tasks.doc_permission_syncing.tasks import ( + try_creating_permissions_sync_task, +) from danswer.background.celery.tasks.pruning.tasks import ( try_creating_prune_generator_task, ) from danswer.background.celery.versioned_apps.primary import app as primary_app -from danswer.background.task_name_builders import ( - name_sync_external_doc_permissions_task, -) from danswer.db.connector_credential_pair import add_credential_to_connector from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id from danswer.db.connector_credential_pair import remove_credential_from_connector @@ -26,6 +26,7 @@ update_connector_credential_pair_from_id, ) from danswer.db.document import get_document_counts_for_cc_pairs +from danswer.db.document import get_documents_for_cc_pair from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR from danswer.db.engine import get_current_tenant_id from danswer.db.engine import get_session @@ -38,15 +39,13 @@ from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id from danswer.db.models import User from danswer.db.search_settings import get_current_search_settings -from danswer.db.tasks import check_task_is_live_and_not_timed_out -from danswer.db.tasks import get_latest_task from danswer.redis.redis_connector import RedisConnector from danswer.redis.redis_pool import get_redis_client from danswer.server.documents.models import CCPairFullInfo from danswer.server.documents.models import CCStatusUpdateRequest -from danswer.server.documents.models import CeleryTaskStatus from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorCredentialPairMetadata +from danswer.server.documents.models import DocumentSyncStatus from danswer.server.documents.models import PaginatedIndexAttempts from danswer.server.models import StatusResponse from danswer.utils.logger import setup_logger @@ -288,12 +287,12 @@ def prune_cc_pair( ) -@router.get("/admin/cc-pair/{cc_pair_id}/sync") +@router.get("/admin/cc-pair/{cc_pair_id}/sync-permissions") def get_cc_pair_latest_sync( cc_pair_id: int, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), -) -> CeleryTaskStatus: +) -> datetime | None: cc_pair = get_connector_credential_pair_from_id( cc_pair_id=cc_pair_id, db_session=db_session, @@ -303,34 +302,20 @@ def get_cc_pair_latest_sync( if not cc_pair: raise HTTPException( status_code=400, - detail="Connection not found for current user's permissions", - ) - - # look up the last sync task for this connector (if it exists) - sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id) - last_sync_task = get_latest_task(sync_task_name, db_session) - if not last_sync_task: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail="No sync task found.", + detail="cc_pair not found for current user's permissions", ) - return CeleryTaskStatus( - id=last_sync_task.task_id, - name=last_sync_task.task_name, - status=last_sync_task.status, - start_time=last_sync_task.start_time, - register_time=last_sync_task.register_time, - ) + return cc_pair.last_time_perm_sync -@router.post("/admin/cc-pair/{cc_pair_id}/sync") +@router.post("/admin/cc-pair/{cc_pair_id}/sync-permissions") def sync_cc_pair( cc_pair_id: int, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), ) -> StatusResponse[list[int]]: - # avoiding circular refs + """Triggers permissions sync on a particular cc_pair immediately""" cc_pair = get_connector_credential_pair_from_id( cc_pair_id=cc_pair_id, @@ -344,35 +329,47 @@ def sync_cc_pair( detail="Connection not found for current user's permissions", ) - sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id) - last_sync_task = get_latest_task(sync_task_name, db_session) + r = get_redis_client(tenant_id=tenant_id) - if last_sync_task and check_task_is_live_and_not_timed_out( - last_sync_task, db_session - ): + redis_connector = RedisConnector(tenant_id, cc_pair_id) + if redis_connector.permissions.fenced: raise HTTPException( status_code=HTTPStatus.CONFLICT, - detail="Sync task already in progress.", + detail="Doc permissions sync task already in progress.", ) - logger.info(f"Syncing the {cc_pair.connector.name} connector.") - sync_external_doc_permissions_task = fetch_ee_implementation_or_noop( - "danswer.background.celery.apps.primary", - "sync_external_doc_permissions_task", - None, + logger.info( + f"Doc permissions sync cc_pair={cc_pair_id} " + f"connector_id={cc_pair.connector_id} " + f"credential_id={cc_pair.credential_id} " + f"{cc_pair.connector.name} connector." ) - - if sync_external_doc_permissions_task: - sync_external_doc_permissions_task.apply_async( - kwargs=dict( - cc_pair_id=cc_pair_id, tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get() - ), + tasks_created = try_creating_permissions_sync_task( + primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get() + ) + if not tasks_created: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Doc permissions sync task creation failed.", ) return StatusResponse( success=True, - message="Successfully created the sync task.", + message="Successfully created the doc permissions sync task.", + ) + + +@router.get("/admin/cc-pair/{cc_pair_id}/get-docs-sync-status") +def get_docs_sync_status( + cc_pair_id: int, + _: User = Depends(current_curator_or_admin_user), + db_session: Session = Depends(get_session), +) -> list[DocumentSyncStatus]: + all_docs_for_cc_pair = get_documents_for_cc_pair( + db_session=db_session, + cc_pair_id=cc_pair_id, ) + return [DocumentSyncStatus.from_model(doc) for doc in all_docs_for_cc_pair] @router.put("/connector/{connector_id}/credential/{credential_id}") diff --git a/backend/danswer/server/documents/models.py b/backend/danswer/server/documents/models.py index fbd7c1e59e4..a541ae92c48 100644 --- a/backend/danswer/server/documents/models.py +++ b/backend/danswer/server/documents/models.py @@ -14,6 +14,7 @@ from danswer.db.models import Connector from danswer.db.models import ConnectorCredentialPair from danswer.db.models import Credential +from danswer.db.models import Document as DbDocument from danswer.db.models import IndexAttempt from danswer.db.models import IndexAttemptError as DbIndexAttemptError from danswer.db.models import IndexingStatus @@ -21,6 +22,20 @@ from danswer.server.utils import mask_credential_dict +class DocumentSyncStatus(BaseModel): + doc_id: str + last_synced: datetime | None + last_modified: datetime | None + + @classmethod + def from_model(cls, doc: DbDocument) -> "DocumentSyncStatus": + return DocumentSyncStatus( + doc_id=doc.id, + last_synced=doc.last_synced, + last_modified=doc.last_modified, + ) + + class DocumentInfo(BaseModel): num_chunks: int num_tokens: int diff --git a/backend/danswer/utils/logger.py b/backend/danswer/utils/logger.py index bd784513898..b9c335be9c6 100644 --- a/backend/danswer/utils/logger.py +++ b/backend/danswer/utils/logger.py @@ -21,8 +21,12 @@ "pruning_ctx", default=dict() ) +doc_permission_sync_ctx: contextvars.ContextVar[ + dict[str, Any] +] = contextvars.ContextVar("doc_permission_sync_ctx", default=dict()) -class IndexAttemptSingleton: + +class TaskAttemptSingleton: """Used to tell if this process is an indexing job, and if so what is the unique identifier for this indexing attempt. For things like the API server, main background job (scheduler), etc. this will not be used.""" @@ -66,9 +70,10 @@ def process( ) -> tuple[str, MutableMapping[str, Any]]: # If this is an indexing job, add the attempt ID to the log message # This helps filter the logs for this specific indexing - index_attempt_id = IndexAttemptSingleton.get_index_attempt_id() - cc_pair_id = IndexAttemptSingleton.get_connector_credential_pair_id() + index_attempt_id = TaskAttemptSingleton.get_index_attempt_id() + cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id() + doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get() pruning_ctx_dict = pruning_ctx.get() if len(pruning_ctx_dict) > 0: if "request_id" in pruning_ctx_dict: @@ -76,6 +81,9 @@ def process( if "cc_pair_id" in pruning_ctx_dict: msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}" + elif len(doc_permission_sync_ctx_dict) > 0: + if "request_id" in doc_permission_sync_ctx_dict: + msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}" else: if index_attempt_id is not None: msg = f"[Index Attempt: {index_attempt_id}] {msg}" diff --git a/backend/ee/danswer/background/celery/apps/primary.py b/backend/ee/danswer/background/celery/apps/primary.py index 98dd66be0a3..21644228484 100644 --- a/backend/ee/danswer/background/celery/apps/primary.py +++ b/backend/ee/danswer/background/celery/apps/primary.py @@ -1,32 +1,12 @@ from danswer.background.celery.apps.primary import celery_app -from danswer.background.task_name_builders import ( - name_sync_external_doc_permissions_task, -) from danswer.background.task_utils import build_celery_task_wrapper from danswer.configs.app_configs import JOB_TIMEOUT from danswer.db.chat import delete_chat_sessions_older_than from danswer.db.engine import get_session_with_tenant from danswer.server.settings.store import load_settings from danswer.utils.logger import setup_logger -from danswer.utils.variable_functionality import global_version from ee.danswer.background.celery_utils import should_perform_chat_ttl_check -from ee.danswer.background.celery_utils import ( - should_perform_external_doc_permissions_check, -) -from ee.danswer.background.celery_utils import ( - should_perform_external_group_permissions_check, -) from ee.danswer.background.task_name_builders import name_chat_ttl_task -from ee.danswer.background.task_name_builders import ( - name_sync_external_group_permissions_task, -) -from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs -from ee.danswer.external_permissions.permission_sync import ( - run_external_doc_permission_sync, -) -from ee.danswer.external_permissions.permission_sync import ( - run_external_group_permission_sync, -) from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR @@ -34,25 +14,6 @@ logger = setup_logger() # mark as EE for all tasks in this file -global_version.set_ee() - - -@build_celery_task_wrapper(name_sync_external_doc_permissions_task) -@celery_app.task(soft_time_limit=JOB_TIMEOUT) -def sync_external_doc_permissions_task( - cc_pair_id: int, *, tenant_id: str | None -) -> None: - with get_session_with_tenant(tenant_id) as db_session: - run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id) - - -@build_celery_task_wrapper(name_sync_external_group_permissions_task) -@celery_app.task(soft_time_limit=JOB_TIMEOUT) -def sync_external_group_permissions_task( - cc_pair_id: int, *, tenant_id: str | None -) -> None: - with get_session_with_tenant(tenant_id) as db_session: - run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id) @build_celery_task_wrapper(name_chat_ttl_task) @@ -67,38 +28,6 @@ def perform_ttl_management_task( ##### # Periodic Tasks ##### -@celery_app.task( - name="check_sync_external_doc_permissions_task", - soft_time_limit=JOB_TIMEOUT, -) -def check_sync_external_doc_permissions_task(*, tenant_id: str | None) -> None: - """Runs periodically to sync external permissions""" - with get_session_with_tenant(tenant_id) as db_session: - cc_pairs = get_all_auto_sync_cc_pairs(db_session) - for cc_pair in cc_pairs: - if should_perform_external_doc_permissions_check( - cc_pair=cc_pair, db_session=db_session - ): - sync_external_doc_permissions_task.apply_async( - kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id), - ) - - -@celery_app.task( - name="check_sync_external_group_permissions_task", - soft_time_limit=JOB_TIMEOUT, -) -def check_sync_external_group_permissions_task(*, tenant_id: str | None) -> None: - """Runs periodically to sync external group permissions""" - with get_session_with_tenant(tenant_id) as db_session: - cc_pairs = get_all_auto_sync_cc_pairs(db_session) - for cc_pair in cc_pairs: - if should_perform_external_group_permissions_check( - cc_pair=cc_pair, db_session=db_session - ): - sync_external_group_permissions_task.apply_async( - kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id), - ) @celery_app.task( diff --git a/backend/ee/danswer/background/celery/tasks/beat_schedule.py b/backend/ee/danswer/background/celery/tasks/beat_schedule.py index 05e2b92bde1..86680e60c7f 100644 --- a/backend/ee/danswer/background/celery/tasks/beat_schedule.py +++ b/backend/ee/danswer/background/celery/tasks/beat_schedule.py @@ -6,16 +6,6 @@ ) ee_tasks_to_schedule = [ - { - "name": "sync-external-doc-permissions", - "task": "check_sync_external_doc_permissions_task", - "schedule": timedelta(seconds=30), # TODO: optimize this - }, - { - "name": "sync-external-group-permissions", - "task": "check_sync_external_group_permissions_task", - "schedule": timedelta(seconds=60), # TODO: optimize this - }, { "name": "autogenerate_usage_report", "task": "autogenerate_usage_report_task", diff --git a/backend/ee/danswer/background/celery_utils.py b/backend/ee/danswer/background/celery_utils.py index facad66db23..f6fff26cf41 100644 --- a/backend/ee/danswer/background/celery_utils.py +++ b/backend/ee/danswer/background/celery_utils.py @@ -1,46 +1,13 @@ -from datetime import datetime -from datetime import timezone - from sqlalchemy.orm import Session -from danswer.background.task_name_builders import ( - name_sync_external_doc_permissions_task, -) -from danswer.db.enums import AccessType -from danswer.db.models import ConnectorCredentialPair from danswer.db.tasks import check_task_is_live_and_not_timed_out from danswer.db.tasks import get_latest_task from danswer.utils.logger import setup_logger from ee.danswer.background.task_name_builders import name_chat_ttl_task -from ee.danswer.background.task_name_builders import ( - name_sync_external_group_permissions_task, -) -from ee.danswer.external_permissions.sync_params import PERMISSION_SYNC_PERIODS logger = setup_logger() -def _is_time_to_run_sync(cc_pair: ConnectorCredentialPair) -> bool: - source_sync_period = PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source) - - # If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync. - if not source_sync_period: - return True - - # If the last sync is None, it has never been run so we run the sync - if cc_pair.last_time_perm_sync is None: - return True - - last_sync = cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc) - current_time = datetime.now(timezone.utc) - - # If the last sync is greater than the full fetch period, we run the sync - if (current_time - last_sync).total_seconds() > source_sync_period: - return True - - return False - - def should_perform_chat_ttl_check( retention_limit_days: int | None, db_session: Session ) -> bool: @@ -57,47 +24,3 @@ def should_perform_chat_ttl_check( logger.debug(f"{task_name} is already being performed. Skipping.") return False return True - - -def should_perform_external_doc_permissions_check( - cc_pair: ConnectorCredentialPair, db_session: Session -) -> bool: - if cc_pair.access_type != AccessType.SYNC: - return False - - task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair.id) - - latest_task = get_latest_task(task_name, db_session) - if not latest_task: - return True - - if check_task_is_live_and_not_timed_out(latest_task, db_session): - logger.debug(f"{task_name} is already being performed. Skipping.") - return False - - if not _is_time_to_run_sync(cc_pair): - return False - - return True - - -def should_perform_external_group_permissions_check( - cc_pair: ConnectorCredentialPair, db_session: Session -) -> bool: - if cc_pair.access_type != AccessType.SYNC: - return False - - task_name = name_sync_external_group_permissions_task(cc_pair_id=cc_pair.id) - - latest_task = get_latest_task(task_name, db_session) - if not latest_task: - return True - - if check_task_is_live_and_not_timed_out(latest_task, db_session): - logger.debug(f"{task_name} is already being performed. Skipping.") - return False - - if not _is_time_to_run_sync(cc_pair): - return False - - return True diff --git a/backend/ee/danswer/background/task_name_builders.py b/backend/ee/danswer/background/task_name_builders.py index 39e2cf252ed..c218cdd3b59 100644 --- a/backend/ee/danswer/background/task_name_builders.py +++ b/backend/ee/danswer/background/task_name_builders.py @@ -1,8 +1,2 @@ def name_chat_ttl_task(retention_limit_days: int, tenant_id: str | None = None) -> str: return f"chat_ttl_{retention_limit_days}_days" - - -def name_sync_external_group_permissions_task( - cc_pair_id: int, tenant_id: str | None = None -) -> str: - return f"sync_external_group_permissions_task__{cc_pair_id}" diff --git a/backend/ee/danswer/db/document.py b/backend/ee/danswer/db/document.py index d67bc0e57e7..e061db6c75b 100644 --- a/backend/ee/danswer/db/document.py +++ b/backend/ee/danswer/db/document.py @@ -1,3 +1,6 @@ +from datetime import datetime +from datetime import timezone + from sqlalchemy import select from sqlalchemy.orm import Session @@ -45,3 +48,53 @@ def upsert_document_external_perms__no_commit( document.external_user_emails = list(external_access.external_user_emails) document.external_user_group_ids = prefixed_external_groups document.is_public = external_access.is_public + + +def upsert_document_external_perms( + db_session: Session, + doc_id: str, + external_access: ExternalAccess, + source_type: DocumentSource, +) -> None: + """ + This sets the permissions for a document in postgres. + NOTE: this will replace any existing external access, it will not do a union + """ + document = db_session.scalars( + select(DbDocument).where(DbDocument.id == doc_id) + ).first() + + prefixed_external_groups: set[str] = { + prefix_group_w_source( + ext_group_name=group_id, + source=source_type, + ) + for group_id in external_access.external_user_group_ids + } + + if not document: + # If the document does not exist, still store the external access + # So that if the document is added later, the external access is already stored + # The upsert function in the indexing pipeline does not overwrite the permissions fields + document = DbDocument( + id=doc_id, + semantic_id="", + external_user_emails=external_access.external_user_emails, + external_user_group_ids=prefixed_external_groups, + is_public=external_access.is_public, + ) + db_session.add(document) + db_session.commit() + return + + # If the document exists, we need to check if the external access has changed + if ( + external_access.external_user_emails != set(document.external_user_emails or []) + or prefixed_external_groups != set(document.external_user_group_ids or []) + or external_access.is_public != document.is_public + ): + document.external_user_emails = list(external_access.external_user_emails) + document.external_user_group_ids = list(prefixed_external_groups) + document.is_public = external_access.is_public + document.last_modified = datetime.now(timezone.utc) + db_session.commit() diff --git a/backend/ee/danswer/db/external_perm.py b/backend/ee/danswer/db/external_perm.py index 25881df55d3..1ccb30db5e2 100644 --- a/backend/ee/danswer/db/external_perm.py +++ b/backend/ee/danswer/db/external_perm.py @@ -9,11 +9,12 @@ from danswer.access.utils import prefix_group_w_source from danswer.configs.constants import DocumentSource from danswer.db.models import User__ExternalUserGroupId +from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit class ExternalUserGroup(BaseModel): id: str - user_ids: list[UUID] + user_emails: list[str] def delete_user__ext_group_for_user__no_commit( @@ -38,7 +39,7 @@ def delete_user__ext_group_for_cc_pair__no_commit( ) -def replace_user__ext_group_for_cc_pair__no_commit( +def replace_user__ext_group_for_cc_pair( db_session: Session, cc_pair_id: int, group_defs: list[ExternalUserGroup], @@ -46,24 +47,44 @@ def replace_user__ext_group_for_cc_pair__no_commit( ) -> None: """ This function clears all existing external user group relations for a given cc_pair_id - and replaces them with the new group definitions. + and replaces them with the new group definitions and commits the changes. """ delete_user__ext_group_for_cc_pair__no_commit( db_session=db_session, cc_pair_id=cc_pair_id, ) - new_external_permissions = [ - User__ExternalUserGroupId( - user_id=user_id, - external_user_group_id=prefix_group_w_source(external_group.id, source), - cc_pair_id=cc_pair_id, - ) - for external_group in group_defs - for user_id in external_group.user_ids - ] + # collect all emails from all groups to batch add all users at once for efficiency + all_group_member_emails = set() + for external_group in group_defs: + for user_email in external_group.user_emails: + all_group_member_emails.add(user_email) + + # batch add users if they don't exist and get their ids + all_group_members = batch_add_non_web_user_if_not_exists__no_commit( + db_session=db_session, emails=list(all_group_member_emails) + ) + + # map emails to ids + email_id_map = {user.email: user.id for user in all_group_members} + + # use these ids to create new external user group relations relating group_id to user_ids + new_external_permissions = [] + for external_group in group_defs: + for user_email in external_group.user_emails: + user_id = email_id_map[user_email] + new_external_permissions.append( + User__ExternalUserGroupId( + user_id=user_id, + external_user_group_id=prefix_group_w_source( + external_group.id, source + ), + cc_pair_id=cc_pair_id, + ) + ) db_session.add_all(new_external_permissions) + db_session.commit() def fetch_external_groups_for_user( diff --git a/backend/ee/danswer/external_permissions/confluence/doc_sync.py b/backend/ee/danswer/external_permissions/confluence/doc_sync.py index a7bc898b8b7..18fa27ea428 100644 --- a/backend/ee/danswer/external_permissions/confluence/doc_sync.py +++ b/backend/ee/danswer/external_permissions/confluence/doc_sync.py @@ -4,17 +4,14 @@ """ from typing import Any -from sqlalchemy.orm import Session - +from danswer.access.models import DocExternalAccess from danswer.access.models import ExternalAccess from danswer.connectors.confluence.connector import ConfluenceConnector from danswer.connectors.confluence.onyx_confluence import OnyxConfluence from danswer.connectors.confluence.utils import get_user_email_from_username__server from danswer.connectors.models import SlimDocument from danswer.db.models import ConnectorCredentialPair -from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger -from ee.danswer.db.document import upsert_document_external_perms__no_commit logger = setup_logger() @@ -163,7 +160,13 @@ def _extract_read_access_restrictions( f"Email for user {user['username']} not found in Confluence" ) else: - logger.warning(f"User {user} does not have an email or username") + if user.get("email") is not None: + logger.warning(f"Cant find email for user {user.get('displayName')}") + logger.warning( + "This user needs to make their email accessible in Confluence Settings" + ) + + logger.warning(f"no user email or username for {user}") # Extract the groups with read access read_access_group = read_access_restrictions.get("group", {}) @@ -190,12 +193,12 @@ def _fetch_all_page_restrictions_for_space( confluence_client: OnyxConfluence, slim_docs: list[SlimDocument], space_permissions_by_space_key: dict[str, ExternalAccess], -) -> dict[str, ExternalAccess]: +) -> list[DocExternalAccess]: """ For all pages, if a page has restrictions, then use those restrictions. Otherwise, use the space's restrictions. """ - document_restrictions: dict[str, ExternalAccess] = {} + document_restrictions: list[DocExternalAccess] = [] for slim_doc in slim_docs: if slim_doc.perm_sync_data is None: @@ -207,21 +210,34 @@ def _fetch_all_page_restrictions_for_space( restrictions=slim_doc.perm_sync_data.get("restrictions", {}), ) if restrictions: - document_restrictions[slim_doc.id] = restrictions - else: - space_key = slim_doc.perm_sync_data.get("space_key") - if space_permissions := space_permissions_by_space_key.get(space_key): - document_restrictions[slim_doc.id] = space_permissions - else: - logger.warning(f"No permissions found for document {slim_doc.id}") + document_restrictions.append( + DocExternalAccess( + doc_id=slim_doc.id, + external_access=restrictions, + ) + ) + # If there are restrictions, then we don't need to use the space's restrictions + continue + + space_key = slim_doc.perm_sync_data.get("space_key") + if space_permissions := space_permissions_by_space_key.get(space_key): + # If there are no restrictions, then use the space's restrictions + document_restrictions.append( + DocExternalAccess( + doc_id=slim_doc.id, + external_access=space_permissions, + ) + ) + continue + + logger.warning(f"No permissions found for document {slim_doc.id}") return document_restrictions def confluence_doc_sync( - db_session: Session, cc_pair: ConnectorCredentialPair, -) -> None: +) -> list[DocExternalAccess]: """ Adds the external permissions to the documents in postgres if the document doesn't already exists in postgres, we create @@ -247,20 +263,8 @@ def confluence_doc_sync( for doc_batch in confluence_connector.retrieve_all_slim_documents(): slim_docs.extend(doc_batch) - permissions_by_doc_id = _fetch_all_page_restrictions_for_space( + return _fetch_all_page_restrictions_for_space( confluence_client=confluence_client, slim_docs=slim_docs, space_permissions_by_space_key=space_permissions_by_space_key, ) - - all_emails = set() - for doc_id, page_specific_access in permissions_by_doc_id.items(): - upsert_document_external_perms__no_commit( - db_session=db_session, - doc_id=doc_id, - external_access=page_specific_access, - source_type=cc_pair.connector.source, - ) - all_emails.update(page_specific_access.external_user_emails) - - batch_add_non_web_user_if_not_exists__no_commit(db_session, list(all_emails)) diff --git a/backend/ee/danswer/external_permissions/confluence/group_sync.py b/backend/ee/danswer/external_permissions/confluence/group_sync.py index a55bb777bc5..db5bd9c01cc 100644 --- a/backend/ee/danswer/external_permissions/confluence/group_sync.py +++ b/backend/ee/danswer/external_permissions/confluence/group_sync.py @@ -1,15 +1,11 @@ from typing import Any -from sqlalchemy.orm import Session - from danswer.connectors.confluence.onyx_confluence import OnyxConfluence from danswer.connectors.confluence.utils import build_confluence_client from danswer.connectors.confluence.utils import get_user_email_from_username__server from danswer.db.models import ConnectorCredentialPair -from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger from ee.danswer.db.external_perm import ExternalUserGroup -from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit logger = setup_logger() @@ -40,9 +36,8 @@ def _get_group_members_email_paginated( def confluence_group_sync( - db_session: Session, cc_pair: ConnectorCredentialPair, -) -> None: +) -> list[ExternalUserGroup]: is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False) confluence_client = build_confluence_client( credentials_json=cc_pair.credential.credential_json, @@ -63,20 +58,13 @@ def confluence_group_sync( group_member_emails = _get_group_members_email_paginated( confluence_client, group_name ) - group_members = batch_add_non_web_user_if_not_exists__no_commit( - db_session=db_session, emails=list(group_member_emails) - ) - if group_members: - danswer_groups.append( - ExternalUserGroup( - id=group_name, - user_ids=[user.id for user in group_members], - ) + if not group_member_emails: + continue + danswer_groups.append( + ExternalUserGroup( + id=group_name, + user_emails=list(group_member_emails), ) + ) - replace_user__ext_group_for_cc_pair__no_commit( - db_session=db_session, - cc_pair_id=cc_pair.id, - group_defs=danswer_groups, - source=cc_pair.connector.source, - ) + return danswer_groups diff --git a/backend/ee/danswer/external_permissions/gmail/doc_sync.py b/backend/ee/danswer/external_permissions/gmail/doc_sync.py index 2748443f022..6b72e7ba116 100644 --- a/backend/ee/danswer/external_permissions/gmail/doc_sync.py +++ b/backend/ee/danswer/external_permissions/gmail/doc_sync.py @@ -1,15 +1,12 @@ from datetime import datetime from datetime import timezone -from sqlalchemy.orm import Session - +from danswer.access.models import DocExternalAccess from danswer.access.models import ExternalAccess from danswer.connectors.gmail.connector import GmailConnector from danswer.connectors.interfaces import GenerateSlimDocumentOutput from danswer.db.models import ConnectorCredentialPair -from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger -from ee.danswer.db.document import upsert_document_external_perms__no_commit logger = setup_logger() @@ -31,9 +28,8 @@ def _get_slim_doc_generator( def gmail_doc_sync( - db_session: Session, cc_pair: ConnectorCredentialPair, -) -> None: +) -> list[DocExternalAccess]: """ Adds the external permissions to the documents in postgres if the document doesn't already exists in postgres, we create @@ -45,6 +41,7 @@ def gmail_doc_sync( slim_doc_generator = _get_slim_doc_generator(cc_pair, gmail_connector) + document_external_access: list[DocExternalAccess] = [] for slim_doc_batch in slim_doc_generator: for slim_doc in slim_doc_batch: if slim_doc.perm_sync_data is None: @@ -56,13 +53,11 @@ def gmail_doc_sync( external_user_group_ids=set(), is_public=False, ) - batch_add_non_web_user_if_not_exists__no_commit( - db_session=db_session, - emails=list(ext_access.external_user_emails), - ) - upsert_document_external_perms__no_commit( - db_session=db_session, - doc_id=slim_doc.id, - external_access=ext_access, - source_type=cc_pair.connector.source, + document_external_access.append( + DocExternalAccess( + doc_id=slim_doc.id, + external_access=ext_access, + ) ) + + return document_external_access diff --git a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py index fddb0e72171..2e421ad6c95 100644 --- a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py +++ b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py @@ -2,8 +2,7 @@ from datetime import timezone from typing import Any -from sqlalchemy.orm import Session - +from danswer.access.models import DocExternalAccess from danswer.access.models import ExternalAccess from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval @@ -11,9 +10,7 @@ from danswer.connectors.interfaces import GenerateSlimDocumentOutput from danswer.connectors.models import SlimDocument from danswer.db.models import ConnectorCredentialPair -from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger -from ee.danswer.db.document import upsert_document_external_perms__no_commit logger = setup_logger() @@ -126,9 +123,8 @@ def _get_permissions_from_slim_doc( def gdrive_doc_sync( - db_session: Session, cc_pair: ConnectorCredentialPair, -) -> None: +) -> list[DocExternalAccess]: """ Adds the external permissions to the documents in postgres if the document doesn't already exists in postgres, we create @@ -142,19 +138,17 @@ def gdrive_doc_sync( slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector) + document_external_accesses = [] for slim_doc_batch in slim_doc_generator: for slim_doc in slim_doc_batch: ext_access = _get_permissions_from_slim_doc( google_drive_connector=google_drive_connector, slim_doc=slim_doc, ) - batch_add_non_web_user_if_not_exists__no_commit( - db_session=db_session, - emails=list(ext_access.external_user_emails), - ) - upsert_document_external_perms__no_commit( - db_session=db_session, - doc_id=slim_doc.id, - external_access=ext_access, - source_type=cc_pair.connector.source, + document_external_accesses.append( + DocExternalAccess( + external_access=ext_access, + doc_id=slim_doc.id, + ) ) + return document_external_accesses diff --git a/backend/ee/danswer/external_permissions/google_drive/group_sync.py b/backend/ee/danswer/external_permissions/google_drive/group_sync.py index e9ca40b3dcb..0f421f371b9 100644 --- a/backend/ee/danswer/external_permissions/google_drive/group_sync.py +++ b/backend/ee/danswer/external_permissions/google_drive/group_sync.py @@ -1,21 +1,16 @@ -from sqlalchemy.orm import Session - from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval from danswer.connectors.google_utils.resources import get_admin_service from danswer.db.models import ConnectorCredentialPair -from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger from ee.danswer.db.external_perm import ExternalUserGroup -from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit logger = setup_logger() def gdrive_group_sync( - db_session: Session, cc_pair: ConnectorCredentialPair, -) -> None: +) -> list[ExternalUserGroup]: google_drive_connector = GoogleDriveConnector( **cc_pair.connector.connector_specific_config ) @@ -44,20 +39,14 @@ def gdrive_group_sync( ): group_member_emails.append(member["email"]) - # Add group members to DB and get their IDs - group_members = batch_add_non_web_user_if_not_exists__no_commit( - db_session=db_session, emails=group_member_emails - ) - if group_members: - danswer_groups.append( - ExternalUserGroup( - id=group_email, user_ids=[user.id for user in group_members] - ) + if not group_member_emails: + continue + + danswer_groups.append( + ExternalUserGroup( + id=group_email, + user_emails=list(group_member_emails), ) + ) - replace_user__ext_group_for_cc_pair__no_commit( - db_session=db_session, - cc_pair_id=cc_pair.id, - group_defs=danswer_groups, - source=cc_pair.connector.source, - ) + return danswer_groups diff --git a/backend/ee/danswer/external_permissions/permission_sync.py b/backend/ee/danswer/external_permissions/permission_sync.py deleted file mode 100644 index 94a0b4bfa8e..00000000000 --- a/backend/ee/danswer/external_permissions/permission_sync.py +++ /dev/null @@ -1,115 +0,0 @@ -from datetime import datetime -from datetime import timezone - -from sqlalchemy.orm import Session - -from danswer.access.access import get_access_for_documents -from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id -from danswer.db.document import get_document_ids_for_connector_credential_pair -from danswer.document_index.factory import get_current_primary_default_document_index -from danswer.document_index.interfaces import UpdateRequest -from danswer.utils.logger import setup_logger -from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP -from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP - -logger = setup_logger() - - -def run_external_group_permission_sync( - db_session: Session, - cc_pair_id: int, -) -> None: - cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) - if cc_pair is None: - raise ValueError(f"No connector credential pair found for id: {cc_pair_id}") - - source_type = cc_pair.connector.source - group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type) - - if group_sync_func is None: - # Not all sync connectors support group permissions so this is fine - return - - try: - # This function updates: - # - the user_email <-> external_user_group_id mapping - # in postgres without committing - logger.debug(f"Syncing groups for {source_type}") - if group_sync_func is not None: - group_sync_func( - db_session, - cc_pair, - ) - - # update postgres - db_session.commit() - except Exception: - logger.exception("Error Syncing Group Permissions") - db_session.rollback() - - -def run_external_doc_permission_sync( - db_session: Session, - cc_pair_id: int, -) -> None: - cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) - if cc_pair is None: - raise ValueError(f"No connector credential pair found for id: {cc_pair_id}") - - source_type = cc_pair.connector.source - - doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type) - last_time_perm_sync = cc_pair.last_time_perm_sync - - if doc_sync_func is None: - raise ValueError( - f"No permission sync function found for source type: {source_type}" - ) - - try: - # This function updates: - # - the user_email <-> document mapping - # - the external_user_group_id <-> document mapping - # in postgres without committing - logger.info(f"Syncing docs for {source_type}") - doc_sync_func( - db_session, - cc_pair, - ) - - # Get the document ids for the cc pair - document_ids_for_cc_pair = get_document_ids_for_connector_credential_pair( - db_session=db_session, - connector_id=cc_pair.connector_id, - credential_id=cc_pair.credential_id, - ) - - # This function fetches the updated access for the documents - # and returns a dictionary of document_ids and access - # This is the access we want to update vespa with - docs_access = get_access_for_documents( - document_ids=document_ids_for_cc_pair, - db_session=db_session, - ) - - # Then we build the update requests to update vespa - update_reqs = [ - UpdateRequest(document_ids=[doc_id], access=doc_access) - for doc_id, doc_access in docs_access.items() - ] - - # Don't bother sync-ing secondary, it will be sync-ed after switch anyway - document_index = get_current_primary_default_document_index(db_session) - - # update vespa - document_index.update(update_reqs) - - cc_pair.last_time_perm_sync = datetime.now(timezone.utc) - - # update postgres - db_session.commit() - logger.info(f"Successfully synced docs for {source_type}") - except Exception: - logger.exception("Error Syncing Document Permissions") - cc_pair.last_time_perm_sync = last_time_perm_sync - db_session.rollback() diff --git a/backend/ee/danswer/external_permissions/slack/doc_sync.py b/backend/ee/danswer/external_permissions/slack/doc_sync.py index b5f6e9695db..24c565fc4e5 100644 --- a/backend/ee/danswer/external_permissions/slack/doc_sync.py +++ b/backend/ee/danswer/external_permissions/slack/doc_sync.py @@ -1,16 +1,12 @@ from slack_sdk import WebClient -from sqlalchemy.orm import Session +from danswer.access.models import DocExternalAccess from danswer.access.models import ExternalAccess -from danswer.connectors.factory import instantiate_connector -from danswer.connectors.interfaces import SlimConnector -from danswer.connectors.models import InputType from danswer.connectors.slack.connector import get_channels from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries +from danswer.connectors.slack.connector import SlackPollConnector from danswer.db.models import ConnectorCredentialPair -from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger -from ee.danswer.db.document import upsert_document_external_perms__no_commit from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map @@ -18,22 +14,15 @@ def _get_slack_document_ids_and_channels( - db_session: Session, cc_pair: ConnectorCredentialPair, ) -> dict[str, list[str]]: - # Get all document ids that need their permissions updated - runnable_connector = instantiate_connector( - db_session=db_session, - source=cc_pair.connector.source, - input_type=InputType.SLIM_RETRIEVAL, - connector_specific_config=cc_pair.connector.connector_specific_config, - credential=cc_pair.credential, - ) + slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config) + slack_connector.load_credentials(cc_pair.credential.credential_json) - assert isinstance(runnable_connector, SlimConnector) + slim_doc_generator = slack_connector.retrieve_all_slim_documents() channel_doc_map: dict[str, list[str]] = {} - for doc_metadata_batch in runnable_connector.retrieve_all_slim_documents(): + for doc_metadata_batch in slim_doc_generator: for doc_metadata in doc_metadata_batch: if doc_metadata.perm_sync_data is None: continue @@ -46,13 +35,11 @@ def _get_slack_document_ids_and_channels( def _fetch_workspace_permissions( - db_session: Session, user_id_to_email_map: dict[str, str], ) -> ExternalAccess: user_emails = set() for email in user_id_to_email_map.values(): user_emails.add(email) - batch_add_non_web_user_if_not_exists__no_commit(db_session, list(user_emails)) return ExternalAccess( external_user_emails=user_emails, # No group<->document mapping for slack @@ -63,7 +50,6 @@ def _fetch_workspace_permissions( def _fetch_channel_permissions( - db_session: Session, slack_client: WebClient, workspace_permissions: ExternalAccess, user_id_to_email_map: dict[str, str], @@ -113,9 +99,6 @@ def _fetch_channel_permissions( # If no email is found, we skip the user continue user_id_to_email_map[member_id] = member_email - batch_add_non_web_user_if_not_exists__no_commit( - db_session, [member_email] - ) member_emails.add(member_email) @@ -131,9 +114,8 @@ def _fetch_channel_permissions( def slack_doc_sync( - db_session: Session, cc_pair: ConnectorCredentialPair, -) -> None: +) -> list[DocExternalAccess]: """ Adds the external permissions to the documents in postgres if the document doesn't already exists in postgres, we create @@ -145,19 +127,18 @@ def slack_doc_sync( ) user_id_to_email_map = fetch_user_id_to_email_map(slack_client) channel_doc_map = _get_slack_document_ids_and_channels( - db_session=db_session, cc_pair=cc_pair, ) workspace_permissions = _fetch_workspace_permissions( - db_session=db_session, user_id_to_email_map=user_id_to_email_map, ) channel_permissions = _fetch_channel_permissions( - db_session=db_session, slack_client=slack_client, workspace_permissions=workspace_permissions, user_id_to_email_map=user_id_to_email_map, ) + + document_external_accesses = [] for channel_id, ext_access in channel_permissions.items(): doc_ids = channel_doc_map.get(channel_id) if not doc_ids: @@ -165,9 +146,10 @@ def slack_doc_sync( continue for doc_id in doc_ids: - upsert_document_external_perms__no_commit( - db_session=db_session, - doc_id=doc_id, - external_access=ext_access, - source_type=cc_pair.connector.source, + document_external_accesses.append( + DocExternalAccess( + external_access=ext_access, + doc_id=doc_id, + ) ) + return document_external_accesses diff --git a/backend/ee/danswer/external_permissions/slack/group_sync.py b/backend/ee/danswer/external_permissions/slack/group_sync.py index 80838895219..780f619e464 100644 --- a/backend/ee/danswer/external_permissions/slack/group_sync.py +++ b/backend/ee/danswer/external_permissions/slack/group_sync.py @@ -5,14 +5,11 @@ THERE IS NO USERGROUP <-> DOCUMENT PERMISSION MAPPING """ from slack_sdk import WebClient -from sqlalchemy.orm import Session from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries from danswer.db.models import ConnectorCredentialPair -from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger from ee.danswer.db.external_perm import ExternalUserGroup -from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map logger = setup_logger() @@ -29,7 +26,6 @@ def _get_slack_group_ids( def _get_slack_group_members_email( - db_session: Session, slack_client: WebClient, group_name: str, user_id_to_email_map: dict[str, str], @@ -49,18 +45,14 @@ def _get_slack_group_members_email( # If no email is found, we skip the user continue user_id_to_email_map[member_id] = member_email - batch_add_non_web_user_if_not_exists__no_commit( - db_session, [member_email] - ) group_member_emails.append(member_email) return group_member_emails def slack_group_sync( - db_session: Session, cc_pair: ConnectorCredentialPair, -) -> None: +) -> list[ExternalUserGroup]: slack_client = WebClient( token=cc_pair.credential.credential_json["slack_bot_token"] ) @@ -69,24 +61,13 @@ def slack_group_sync( danswer_groups: list[ExternalUserGroup] = [] for group_name in _get_slack_group_ids(slack_client): group_member_emails = _get_slack_group_members_email( - db_session=db_session, slack_client=slack_client, group_name=group_name, user_id_to_email_map=user_id_to_email_map, ) - group_members = batch_add_non_web_user_if_not_exists__no_commit( - db_session=db_session, emails=group_member_emails + if not group_member_emails: + continue + danswer_groups.append( + ExternalUserGroup(id=group_name, user_emails=group_member_emails) ) - if group_members: - danswer_groups.append( - ExternalUserGroup( - id=group_name, user_ids=[user.id for user in group_members] - ) - ) - - replace_user__ext_group_for_cc_pair__no_commit( - db_session=db_session, - cc_pair_id=cc_pair.id, - group_defs=danswer_groups, - source=cc_pair.connector.source, - ) + return danswer_groups diff --git a/backend/ee/danswer/external_permissions/sync_params.py b/backend/ee/danswer/external_permissions/sync_params.py index 1fd09ca1509..fb81ab35035 100644 --- a/backend/ee/danswer/external_permissions/sync_params.py +++ b/backend/ee/danswer/external_permissions/sync_params.py @@ -1,9 +1,9 @@ from collections.abc import Callable -from sqlalchemy.orm import Session - +from danswer.access.models import DocExternalAccess from danswer.configs.constants import DocumentSource from danswer.db.models import ConnectorCredentialPair +from ee.danswer.db.external_perm import ExternalUserGroup from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_sync from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync from ee.danswer.external_permissions.gmail.doc_sync import gmail_doc_sync @@ -12,12 +12,18 @@ from ee.danswer.external_permissions.slack.doc_sync import slack_doc_sync # Defining the input/output types for the sync functions -SyncFuncType = Callable[ +DocSyncFuncType = Callable[ [ - Session, ConnectorCredentialPair, ], - None, + list[DocExternalAccess], +] + +GroupSyncFuncType = Callable[ + [ + ConnectorCredentialPair, + ], + list[ExternalUserGroup], ] # These functions update: @@ -25,7 +31,7 @@ # - the external_user_group_id <-> document mapping # in postgres without committing # THIS ONE IS NECESSARY FOR AUTO SYNC TO WORK -DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = { +DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = { DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync, DocumentSource.CONFLUENCE: confluence_doc_sync, DocumentSource.SLACK: slack_doc_sync, @@ -36,19 +42,21 @@ # - the user_email <-> external_user_group_id mapping # in postgres without committing # THIS ONE IS OPTIONAL ON AN APP BY APP BASIS -GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = { +GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = { DocumentSource.GOOGLE_DRIVE: gdrive_group_sync, DocumentSource.CONFLUENCE: confluence_group_sync, } # If nothing is specified here, we run the doc_sync every time the celery beat runs -PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = { +DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = { # Polling is not supported so we fetch all doc permissions every 5 minutes DocumentSource.CONFLUENCE: 5 * 60, DocumentSource.SLACK: 5 * 60, } +EXTERNAL_GROUP_SYNC_PERIOD: int = 30 # 30 seconds + def check_if_valid_sync_source(source_type: DocumentSource) -> bool: return source_type in DOC_PERMISSIONS_FUNC_MAP diff --git a/backend/scripts/dev_run_background_jobs.py b/backend/scripts/dev_run_background_jobs.py index 1ca823e0935..6abb5fad8a1 100644 --- a/backend/scripts/dev_run_background_jobs.py +++ b/backend/scripts/dev_run_background_jobs.py @@ -42,7 +42,7 @@ def run_jobs() -> None: "--loglevel=INFO", "--hostname=light@%n", "-Q", - "vespa_metadata_sync,connector_deletion", + "vespa_metadata_sync,connector_deletion,doc_permissions_upsert", ] cmd_worker_heavy = [ @@ -56,7 +56,7 @@ def run_jobs() -> None: "--loglevel=INFO", "--hostname=heavy@%n", "-Q", - "connector_pruning", + "connector_pruning,connector_doc_permissions_sync,connector_external_group_sync", ] cmd_worker_indexing = [ diff --git a/backend/supervisord.conf b/backend/supervisord.conf index 93472161854..c4a431b1e3a 100644 --- a/backend/supervisord.conf +++ b/backend/supervisord.conf @@ -33,7 +33,7 @@ stopasgroup=true command=celery -A danswer.background.celery.versioned_apps.light worker --loglevel=INFO --hostname=light@%%n - -Q vespa_metadata_sync,connector_deletion + -Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert stdout_logfile=/var/log/celery_worker_light.log stdout_logfile_maxbytes=16MB redirect_stderr=true @@ -45,7 +45,7 @@ stopasgroup=true command=celery -A danswer.background.celery.versioned_apps.heavy worker --loglevel=INFO --hostname=heavy@%%n - -Q connector_pruning + -Q connector_pruning,connector_doc_permissions_sync,connector_external_group_sync stdout_logfile=/var/log/celery_worker_heavy.log stdout_logfile_maxbytes=16MB redirect_stderr=true diff --git a/backend/tests/integration/common_utils/managers/cc_pair.py b/backend/tests/integration/common_utils/managers/cc_pair.py index 68469d144b9..d32f7694cb5 100644 --- a/backend/tests/integration/common_utils/managers/cc_pair.py +++ b/backend/tests/integration/common_utils/managers/cc_pair.py @@ -8,11 +8,10 @@ from danswer.connectors.models import InputType from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus -from danswer.db.enums import TaskStatus -from danswer.server.documents.models import CeleryTaskStatus from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorIndexingStatus from danswer.server.documents.models import DocumentSource +from danswer.server.documents.models import DocumentSyncStatus from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.constants import MAX_DELAY @@ -328,56 +327,127 @@ def sync( user_performing_action: DATestUser | None = None, ) -> None: result = requests.post( - url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync", + url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-permissions", headers=user_performing_action.headers if user_performing_action else GENERAL_HEADERS, ) - result.raise_for_status() + # + if result.status_code != 409: + result.raise_for_status() @staticmethod def get_sync_task( cc_pair: DATestCCPair, user_performing_action: DATestUser | None = None, - ) -> CeleryTaskStatus: + ) -> datetime | None: + response = requests.get( + url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-permissions", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + response_str = response.json() + + # If the response itself is a datetime string, parse it + if not isinstance(response_str, str): + return None + + try: + return datetime.fromisoformat(response_str) + except ValueError: + return None + + @staticmethod + def get_doc_sync_statuses( + cc_pair: DATestCCPair, + user_performing_action: DATestUser | None = None, + ) -> list[DocumentSyncStatus]: response = requests.get( - url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync", + url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/get-docs-sync-status", headers=user_performing_action.headers if user_performing_action else GENERAL_HEADERS, ) response.raise_for_status() - return CeleryTaskStatus(**response.json()) + doc_sync_statuses: list[DocumentSyncStatus] = [] + for doc_sync_status in response.json(): + doc_sync_statuses.append( + DocumentSyncStatus( + doc_id=doc_sync_status["doc_id"], + last_synced=datetime.fromisoformat(doc_sync_status["last_synced"]), + last_modified=datetime.fromisoformat( + doc_sync_status["last_modified"] + ), + ) + ) + + return doc_sync_statuses @staticmethod def wait_for_sync( cc_pair: DATestCCPair, after: datetime, timeout: float = MAX_DELAY, + number_of_updated_docs: int = 0, user_performing_action: DATestUser | None = None, ) -> None: """after: The task register time must be after this time.""" start = time.monotonic() while True: - task = CCPairManager.get_sync_task(cc_pair, user_performing_action) - if not task: - raise ValueError("Sync task not found.") + last_synced = CCPairManager.get_sync_task(cc_pair, user_performing_action) + if last_synced and last_synced > after: + print(f"last_synced: {last_synced}") + print(f"sync command start time: {after}") + print(f"permission sync complete: cc_pair={cc_pair.id}") + break - if not task.register_time or task.register_time < after: - raise ValueError("Sync task register time is too early.") + elapsed = time.monotonic() - start + if elapsed > timeout: + raise TimeoutError( + f"Permission sync was not completed within {timeout} seconds" + ) - if task.status == TaskStatus.SUCCESS: - # Sync succeeded - return + print( + f"Waiting for CC sync to complete. elapsed={elapsed:.2f} timeout={timeout}" + ) + time.sleep(5) + + # TODO: remove this sleep, this shouldnt be necessary but + # time.sleep(5) + + print("waiting for vespa sync") + # wait for the vespa sync to complete once the permission sync is complete + start = time.monotonic() + while True: + doc_sync_statuses = CCPairManager.get_doc_sync_statuses( + cc_pair=cc_pair, + user_performing_action=user_performing_action, + ) + synced_docs = 0 + for doc_sync_status in doc_sync_statuses: + if ( + doc_sync_status.last_synced is not None + and doc_sync_status.last_modified is not None + and doc_sync_status.last_synced >= doc_sync_status.last_modified + and doc_sync_status.last_synced >= after + and doc_sync_status.last_modified >= after + ): + synced_docs += 1 + + if synced_docs >= number_of_updated_docs: + print(f"all docs synced: cc_pair={cc_pair.id}") + break elapsed = time.monotonic() - start if elapsed > timeout: raise TimeoutError( - f"CC pair syncing was not completed within {timeout} seconds" + f"Vespa sync was not completed within {timeout} seconds" ) print( - f"Waiting for CC syncing to complete. elapsed={elapsed:.2f} timeout={timeout}" + f"Waiting for vespa sync to complete. elapsed={elapsed:.2f} timeout={timeout}" ) time.sleep(5) diff --git a/backend/tests/integration/connector_job_tests/slack/conftest.py b/backend/tests/integration/connector_job_tests/slack/conftest.py index 03d99737ce7..38b851de809 100644 --- a/backend/tests/integration/connector_job_tests/slack/conftest.py +++ b/backend/tests/integration/connector_job_tests/slack/conftest.py @@ -6,6 +6,10 @@ from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager +# from tests.load_env_vars import load_env_vars + +# load_env_vars() + @pytest.fixture() def slack_test_setup() -> Generator[tuple[dict[str, Any], dict[str, Any]], None, None]: diff --git a/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py b/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py index d64986ea826..cff506eece7 100644 --- a/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py +++ b/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py @@ -96,11 +96,13 @@ def test_slack_permission_sync( private_message = "Sara's favorite number is 346794" # Add messages to channels + print(f"\n Adding public message to channel: {public_message}") SlackManager.add_message_to_channel( slack_client=slack_client, channel=public_channel, message=public_message, ) + print(f"\n Adding private message to channel: {private_message}") SlackManager.add_message_to_channel( slack_client=slack_client, channel=private_channel, @@ -117,7 +119,6 @@ def test_slack_permission_sync( ) # Run permission sync - before = datetime.now(timezone.utc) CCPairManager.sync( cc_pair=cc_pair, user_performing_action=admin_user, @@ -125,26 +126,33 @@ def test_slack_permission_sync( CCPairManager.wait_for_sync( cc_pair=cc_pair, after=before, + number_of_updated_docs=2, user_performing_action=admin_user, ) # Search as admin with access to both channels + print("\nSearching as admin user") danswer_doc_message_strings = DocumentSearchManager.search_documents( query="favorite number", user_performing_action=admin_user, ) + print( + "\n documents retrieved by admin user: ", + danswer_doc_message_strings, + ) # Ensure admin user can see messages from both channels assert public_message in danswer_doc_message_strings assert private_message in danswer_doc_message_strings # Search as test_user_2 with access to only the public channel + print("\n Searching as test_user_2") danswer_doc_message_strings = DocumentSearchManager.search_documents( query="favorite number", user_performing_action=test_user_2, ) print( - "\ntop_documents content before removing from private channel for test_user_2: ", + "\n documents retrieved by test_user_2: ", danswer_doc_message_strings, ) @@ -153,12 +161,13 @@ def test_slack_permission_sync( assert private_message not in danswer_doc_message_strings # Search as test_user_1 with access to both channels + print("\n Searching as test_user_1") danswer_doc_message_strings = DocumentSearchManager.search_documents( query="favorite number", user_performing_action=test_user_1, ) print( - "\ntop_documents content before removing from private channel for test_user_1: ", + "\n documents retrieved by test_user_1 before being removed from private channel: ", danswer_doc_message_strings, ) @@ -167,7 +176,8 @@ def test_slack_permission_sync( assert private_message in danswer_doc_message_strings # ----------------------MAKE THE CHANGES-------------------------- - print("\nRemoving test_user_1 from the private channel") + print("\n Removing test_user_1 from the private channel") + before = datetime.now(timezone.utc) # Remove test_user_1 from the private channel desired_channel_members = [admin_user] SlackManager.set_channel_members( @@ -185,18 +195,20 @@ def test_slack_permission_sync( CCPairManager.wait_for_sync( cc_pair=cc_pair, after=before, + number_of_updated_docs=1, user_performing_action=admin_user, ) # ----------------------------VERIFY THE CHANGES--------------------------- # Ensure test_user_1 can no longer see messages from the private channel # Search as test_user_1 with access to only the public channel + danswer_doc_message_strings = DocumentSearchManager.search_documents( query="favorite number", user_performing_action=test_user_1, ) print( - "\ntop_documents content after removing from private channel for test_user_1: ", + "\n documents retrieved by test_user_1 after being removed from private channel: ", danswer_doc_message_strings, ) diff --git a/deployment/cloud_kubernetes/workers/heavy_worker.yaml b/deployment/cloud_kubernetes/workers/heavy_worker.yaml index 7488b0e9a39..e34c9a5185e 100644 --- a/deployment/cloud_kubernetes/workers/heavy_worker.yaml +++ b/deployment/cloud_kubernetes/workers/heavy_worker.yaml @@ -25,7 +25,7 @@ spec: "--loglevel=INFO", "--hostname=heavy@%n", "-Q", - "connector_pruning", + "connector_pruning,connector_doc_permissions_sync,connector_external_group_sync", ] env: - name: REDIS_PASSWORD diff --git a/deployment/cloud_kubernetes/workers/light_worker.yaml b/deployment/cloud_kubernetes/workers/light_worker.yaml index b16c24a9402..64e8d255a72 100644 --- a/deployment/cloud_kubernetes/workers/light_worker.yaml +++ b/deployment/cloud_kubernetes/workers/light_worker.yaml @@ -25,7 +25,7 @@ spec: "--loglevel=INFO", "--hostname=light@%n", "-Q", - "vespa_metadata_sync,connector_deletion", + "vespa_metadata_sync,connector_deletion,doc_permissions_upsert", "--prefetch-multiplier=1", "--concurrency=10", ]