From 71d4fb98d3031173a17e19531aba34501cce734e Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 31 Oct 2024 19:25:00 -0700 Subject: [PATCH] Refactored Google Drive Connector + Permission Syncing (#2945) * refactoring changes * everything working for service account * works with service account * combined scopes * copy change * oauth prep * Works for oauth and service account credentials * mypy * merge fixes * Refactor Google Drive connector * finished backend * auth changes * if its stupid but it works, its not stupid * npm run dev fixes * addressed change requests * string fix * minor fixes and cleanup * spacing cleanup * Update connector.py * everything done * testing! * Delete backend/tests/daily/connectors/google_drive/file_generator.py * cleaned up --------- Co-authored-by: Chris Weaver <25087905+Weves@users.noreply.github.com> --- .../workflows/pr-python-connector-tests.yml | 3 + backend/danswer/configs/app_configs.py | 3 - .../connectors/confluence/connector.py | 7 +- .../connectors/google_drive/connector.py | 775 ++++++------------ .../connectors/google_drive/connector_auth.py | 125 +-- .../connectors/google_drive/constants.py | 41 +- .../connectors/google_drive/doc_conversion.py | 115 +++ .../connectors/google_drive/file_retrieval.py | 192 +++++ .../connectors/google_drive/google_utils.py | 35 + .../danswer/connectors/google_drive/models.py | 18 + backend/danswer/connectors/interfaces.py | 6 +- .../connectors/salesforce/connector.py | 6 +- backend/danswer/connectors/slack/connector.py | 6 +- backend/danswer/db/credentials.py | 6 +- backend/danswer/server/documents/connector.py | 34 +- backend/danswer/server/documents/models.py | 6 +- .../ee/danswer/background/celery/apps/beat.py | 4 +- .../google_drive/doc_sync.py | 225 +++-- .../google_drive/group_sync.py | 134 +-- .../external_permissions/permission_sync.py | 2 + .../daily/connectors/google_drive/conftest.py | 98 +++ .../daily/connectors/google_drive/helpers.py | 164 ++++ .../google_drive/test_google_drive_oauth.py | 246 ++++++ .../test_google_drive_service_acct.py | 257 ++++++ .../test_google_drive_slim_docs.py | 174 ++++ .../[connector]/AddConnectorPage.tsx | 6 + .../pages/ConnectorInput/ListInput.tsx | 16 +- .../pages/DynamicConnectorCreationForm.tsx | 116 ++- .../[connector]/pages/gdrive/Credential.tsx | 34 +- .../pages/gdrive/GoogleDrivePage.tsx | 7 +- .../admin/connectors/AccessTypeForm.tsx | 37 +- .../admin/connectors/ConnectorTitle.tsx | 15 - .../lib/connectors/AutoSyncOptionFields.tsx | 33 +- web/src/lib/connectors/connectors.tsx | 78 +- web/src/lib/connectors/credentials.ts | 5 +- 35 files changed, 2045 insertions(+), 984 deletions(-) create mode 100644 backend/danswer/connectors/google_drive/doc_conversion.py create mode 100644 backend/danswer/connectors/google_drive/file_retrieval.py create mode 100644 backend/danswer/connectors/google_drive/google_utils.py create mode 100644 backend/danswer/connectors/google_drive/models.py create mode 100644 backend/tests/daily/connectors/google_drive/conftest.py create mode 100644 backend/tests/daily/connectors/google_drive/helpers.py create mode 100644 backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py create mode 100644 backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py create mode 100644 backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py diff --git a/.github/workflows/pr-python-connector-tests.yml b/.github/workflows/pr-python-connector-tests.yml index 108012100b3..fa7df201b5e 100644 --- a/.github/workflows/pr-python-connector-tests.yml +++ b/.github/workflows/pr-python-connector-tests.yml @@ -18,6 +18,9 @@ env: # Jira JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} + # Google + GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }} + GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }} jobs: connectors-check: diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 6d6cc6b9639..fb6b4996363 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -251,9 +251,6 @@ # for some connectors ENABLE_EXPENSIVE_EXPERT_CALLS = False -GOOGLE_DRIVE_INCLUDE_SHARED = False -GOOGLE_DRIVE_FOLLOW_SHORTCUTS = False -GOOGLE_DRIVE_ONLY_ORG_PUBLIC = False # TODO these should be available for frontend configuration, via advanced options expandable WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get( diff --git a/backend/danswer/connectors/confluence/connector.py b/backend/danswer/connectors/confluence/connector.py index f0945547e57..9c93f93f99b 100644 --- a/backend/danswer/connectors/confluence/connector.py +++ b/backend/danswer/connectors/confluence/connector.py @@ -17,6 +17,7 @@ from danswer.connectors.interfaces import GenerateSlimDocumentOutput from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import PollConnector +from danswer.connectors.interfaces import SecondsSinceUnixEpoch from danswer.connectors.interfaces import SlimConnector from danswer.connectors.models import BasicExpertInfo from danswer.connectors.models import ConnectorMissingCredentialError @@ -249,7 +250,11 @@ def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput: self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'" return self._fetch_document_batches() - def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput: + def retrieve_all_slim_documents( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> GenerateSlimDocumentOutput: if self.confluence_client is None: raise ConnectorMissingCredentialError("Confluence") diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index c48df5bb741..4ddd51f749f 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -1,556 +1,305 @@ -import io from collections.abc import Iterator -from collections.abc import Sequence -from datetime import datetime -from datetime import timezone -from enum import Enum -from itertools import chain from typing import Any from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore -from googleapiclient import discovery # type: ignore -from googleapiclient.errors import HttpError # type: ignore +from googleapiclient.discovery import build # type: ignore +from googleapiclient.discovery import Resource # type: ignore -from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE -from danswer.configs.app_configs import GOOGLE_DRIVE_FOLLOW_SHORTCUTS -from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED -from danswer.configs.app_configs import GOOGLE_DRIVE_ONLY_ORG_PUBLIC from danswer.configs.app_configs import INDEX_BATCH_SIZE -from danswer.configs.constants import DocumentSource -from danswer.configs.constants import IGNORE_FOR_QA -from danswer.connectors.google_drive.connector_auth import get_google_drive_creds -from danswer.connectors.google_drive.constants import ( - DB_CREDENTIALS_DICT_DELEGATED_USER_KEY, +from danswer.connectors.google_drive.connector_auth import ( + DB_CREDENTIALS_PRIMARY_ADMIN_KEY, ) -from danswer.connectors.google_drive.constants import ( - DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, +from danswer.connectors.google_drive.connector_auth import get_google_drive_creds +from danswer.connectors.google_drive.constants import MISSING_SCOPES_ERROR_STR +from danswer.connectors.google_drive.constants import ONYX_SCOPE_INSTRUCTIONS +from danswer.connectors.google_drive.constants import SCOPE_DOC_URL +from danswer.connectors.google_drive.constants import SLIM_BATCH_SIZE +from danswer.connectors.google_drive.constants import USER_FIELDS +from danswer.connectors.google_drive.doc_conversion import ( + convert_drive_item_to_document, ) +from danswer.connectors.google_drive.file_retrieval import crawl_folders_for_files +from danswer.connectors.google_drive.file_retrieval import get_files_in_my_drive +from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive +from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval +from danswer.connectors.google_drive.models import GoogleDriveFileType from danswer.connectors.interfaces import GenerateDocumentsOutput +from danswer.connectors.interfaces import GenerateSlimDocumentOutput from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import PollConnector from danswer.connectors.interfaces import SecondsSinceUnixEpoch -from danswer.connectors.models import Document -from danswer.connectors.models import Section -from danswer.file_processing.extract_file_text import docx_to_text -from danswer.file_processing.extract_file_text import pptx_to_text -from danswer.file_processing.extract_file_text import read_pdf_file -from danswer.file_processing.unstructured import get_unstructured_api_key -from danswer.file_processing.unstructured import unstructured_to_text -from danswer.utils.batching import batch_generator +from danswer.connectors.interfaces import SlimConnector +from danswer.connectors.models import SlimDocument from danswer.utils.logger import setup_logger -from danswer.utils.retry_wrapper import retry_builder logger = setup_logger() -DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder" -DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut" -UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now - - -class GDriveMimeType(str, Enum): - DOC = "application/vnd.google-apps.document" - SPREADSHEET = "application/vnd.google-apps.spreadsheet" - PDF = "application/pdf" - WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - PPT = "application/vnd.google-apps.presentation" - POWERPOINT = ( - "application/vnd.openxmlformats-officedocument.presentationml.presentation" - ) - PLAIN_TEXT = "text/plain" - MARKDOWN = "text/markdown" - - -GoogleDriveFileType = dict[str, Any] - -# Google Drive APIs are quite flakey and may 500 for an -# extended period of time. Trying to combat here by adding a very -# long retry period (~20 minutes of trying every minute) -add_retries = retry_builder(tries=50, max_delay=30) - - -def _run_drive_file_query( - service: discovery.Resource, - query: str, - continue_on_failure: bool, - include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED, - follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS, - batch_size: int = INDEX_BATCH_SIZE, -) -> Iterator[GoogleDriveFileType]: - next_page_token = "" - while next_page_token is not None: - logger.debug(f"Running Google Drive fetch with query: {query}") - results = add_retries( - lambda: ( - service.files() - .list( - corpora="allDrives" - if include_shared - else "user", # needed to search through shared drives - pageSize=batch_size, - supportsAllDrives=include_shared, - includeItemsFromAllDrives=include_shared, - fields=( - "nextPageToken, files(mimeType, id, name, permissions, " - "modifiedTime, webViewLink, shortcutDetails)" - ), - pageToken=next_page_token, - q=query, - ) - .execute() - ) - )() - next_page_token = results.get("nextPageToken") - files = results["files"] - for file in files: - if follow_shortcuts and "shortcutDetails" in file: - try: - file_shortcut_points_to = add_retries( - lambda: ( - service.files() - .get( - fileId=file["shortcutDetails"]["targetId"], - supportsAllDrives=include_shared, - fields="mimeType, id, name, modifiedTime, webViewLink, permissions, shortcutDetails", - ) - .execute() - ) - )() - yield file_shortcut_points_to - except HttpError: - logger.error( - f"Failed to follow shortcut with details: {file['shortcutDetails']}" - ) - if continue_on_failure: - continue - raise - else: - yield file - - -def _get_folder_id( - service: discovery.Resource, - parent_id: str, - folder_name: str, - include_shared: bool, - follow_shortcuts: bool, -) -> str | None: - """ - Get the ID of a folder given its name and the ID of its parent folder. - """ - query = f"'{parent_id}' in parents and name='{folder_name}' and " - if follow_shortcuts: - query += f"(mimeType='{DRIVE_FOLDER_TYPE}' or mimeType='{DRIVE_SHORTCUT_TYPE}')" - else: - query += f"mimeType='{DRIVE_FOLDER_TYPE}'" - - # TODO: support specifying folder path in shared drive rather than just `My Drive` - results = add_retries( - lambda: ( - service.files() - .list( - q=query, - spaces="drive", - fields="nextPageToken, files(id, name, shortcutDetails)", - supportsAllDrives=include_shared, - includeItemsFromAllDrives=include_shared, - ) - .execute() - ) - )() - items = results.get("files", []) - - folder_id = None - if items: - if follow_shortcuts and "shortcutDetails" in items[0]: - folder_id = items[0]["shortcutDetails"]["targetId"] - else: - folder_id = items[0]["id"] - return folder_id - - -def _get_folders( - service: discovery.Resource, - continue_on_failure: bool, - folder_id: str | None = None, # if specified, only fetches files within this folder - include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED, - follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS, - batch_size: int = INDEX_BATCH_SIZE, -) -> Iterator[GoogleDriveFileType]: - query = f"mimeType = '{DRIVE_FOLDER_TYPE}' " - if follow_shortcuts: - query = "(" + query + f" or mimeType = '{DRIVE_SHORTCUT_TYPE}'" + ") " - - if folder_id: - query += f"and '{folder_id}' in parents " - query = query.rstrip() # remove the trailing space(s) - - for file in _run_drive_file_query( - service=service, - query=query, - continue_on_failure=continue_on_failure, - include_shared=include_shared, - follow_shortcuts=follow_shortcuts, - batch_size=batch_size, - ): - # Need to check this since file may have been a target of a shortcut - # and not necessarily a folder - if file["mimeType"] == DRIVE_FOLDER_TYPE: - yield file - else: - pass - - -def _get_files( - service: discovery.Resource, - continue_on_failure: bool, - time_range_start: SecondsSinceUnixEpoch | None = None, - time_range_end: SecondsSinceUnixEpoch | None = None, - folder_id: str | None = None, # if specified, only fetches files within this folder - include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED, - follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS, - batch_size: int = INDEX_BATCH_SIZE, -) -> Iterator[GoogleDriveFileType]: - query = f"mimeType != '{DRIVE_FOLDER_TYPE}' " - if time_range_start is not None: - time_start = datetime.utcfromtimestamp(time_range_start).isoformat() + "Z" - query += f"and modifiedTime >= '{time_start}' " - if time_range_end is not None: - time_stop = datetime.utcfromtimestamp(time_range_end).isoformat() + "Z" - query += f"and modifiedTime <= '{time_stop}' " - if folder_id: - query += f"and '{folder_id}' in parents " - query = query.rstrip() # remove the trailing space(s) - - files = _run_drive_file_query( - service=service, - query=query, - continue_on_failure=continue_on_failure, - include_shared=include_shared, - follow_shortcuts=follow_shortcuts, - batch_size=batch_size, - ) - - return files - - -def get_all_files_batched( - service: discovery.Resource, - continue_on_failure: bool, - include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED, - follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS, - batch_size: int = INDEX_BATCH_SIZE, - time_range_start: SecondsSinceUnixEpoch | None = None, - time_range_end: SecondsSinceUnixEpoch | None = None, - folder_id: str | None = None, # if specified, only fetches files within this folder - # if True, will fetch files in sub-folders of the specified folder ID. - # Only applies if folder_id is specified. - traverse_subfolders: bool = True, - folder_ids_traversed: list[str] | None = None, -) -> Iterator[list[GoogleDriveFileType]]: - """Gets all files matching the criteria specified by the args from Google Drive - in batches of size `batch_size`. - """ - found_files = _get_files( - service=service, - continue_on_failure=continue_on_failure, - time_range_start=time_range_start, - time_range_end=time_range_end, - folder_id=folder_id, - include_shared=include_shared, - follow_shortcuts=follow_shortcuts, - batch_size=batch_size, - ) - yield from batch_generator( - items=found_files, - batch_size=batch_size, - pre_batch_yield=lambda batch_files: logger.debug( - f"Parseable Documents in batch: {[file['name'] for file in batch_files]}" - ), - ) - - if traverse_subfolders and folder_id is not None: - folder_ids_traversed = folder_ids_traversed or [] - subfolders = _get_folders( - service=service, - folder_id=folder_id, - continue_on_failure=continue_on_failure, - include_shared=include_shared, - follow_shortcuts=follow_shortcuts, - batch_size=batch_size, - ) - for subfolder in subfolders: - if subfolder["id"] not in folder_ids_traversed: - logger.info("Fetching all files in subfolder: " + subfolder["name"]) - folder_ids_traversed.append(subfolder["id"]) - yield from get_all_files_batched( - service=service, - continue_on_failure=continue_on_failure, - include_shared=include_shared, - follow_shortcuts=follow_shortcuts, - batch_size=batch_size, - time_range_start=time_range_start, - time_range_end=time_range_end, - folder_id=subfolder["id"], - traverse_subfolders=traverse_subfolders, - folder_ids_traversed=folder_ids_traversed, - ) - else: - logger.debug( - "Skipping subfolder since already traversed: " + subfolder["name"] - ) - -def extract_text(file: dict[str, str], service: discovery.Resource) -> str: - mime_type = file["mimeType"] - - if mime_type not in set(item.value for item in GDriveMimeType): - # Unsupported file types can still have a title, finding this way is still useful - return UNSUPPORTED_FILE_TYPE_CONTENT - - if mime_type in [ - GDriveMimeType.DOC.value, - GDriveMimeType.PPT.value, - GDriveMimeType.SPREADSHEET.value, - ]: - export_mime_type = ( - "text/plain" - if mime_type != GDriveMimeType.SPREADSHEET.value - else "text/csv" - ) - return ( - service.files() - .export(fileId=file["id"], mimeType=export_mime_type) - .execute() - .decode("utf-8") - ) - elif mime_type in [ - GDriveMimeType.PLAIN_TEXT.value, - GDriveMimeType.MARKDOWN.value, - ]: - return service.files().get_media(fileId=file["id"]).execute().decode("utf-8") - if mime_type in [ - GDriveMimeType.WORD_DOC.value, - GDriveMimeType.POWERPOINT.value, - GDriveMimeType.PDF.value, - ]: - response = service.files().get_media(fileId=file["id"]).execute() - if get_unstructured_api_key(): - return unstructured_to_text( - file=io.BytesIO(response), file_name=file.get("name", file["id"]) - ) +def _extract_str_list_from_comma_str(string: str | None) -> list[str]: + if not string: + return [] + return [s.strip() for s in string.split(",") if s.strip()] - if mime_type == GDriveMimeType.WORD_DOC.value: - return docx_to_text(file=io.BytesIO(response)) - elif mime_type == GDriveMimeType.PDF.value: - text, _ = read_pdf_file(file=io.BytesIO(response)) - return text - elif mime_type == GDriveMimeType.POWERPOINT.value: - return pptx_to_text(file=io.BytesIO(response)) - return UNSUPPORTED_FILE_TYPE_CONTENT +def _extract_ids_from_urls(urls: list[str]) -> list[str]: + return [url.split("/")[-1] for url in urls] -class GoogleDriveConnector(LoadConnector, PollConnector): +class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): def __init__( self, - # optional list of folder paths e.g. "[My Folder/My Subfolder]" - # if specified, will only index files in these folders - folder_paths: list[str] | None = None, + include_shared_drives: bool = True, + shared_drive_urls: str | None = None, + include_my_drives: bool = True, + my_drive_emails: str | None = None, + shared_folder_urls: str | None = None, batch_size: int = INDEX_BATCH_SIZE, - include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED, - follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS, - only_org_public: bool = GOOGLE_DRIVE_ONLY_ORG_PUBLIC, - continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE, + # OLD PARAMETERS + folder_paths: list[str] | None = None, + include_shared: bool | None = None, + follow_shortcuts: bool | None = None, + only_org_public: bool | None = None, + continue_on_failure: bool | None = None, ) -> None: - self.folder_paths = folder_paths or [] + # Check for old input parameters + if ( + folder_paths is not None + or include_shared is not None + or follow_shortcuts is not None + or only_org_public is not None + or continue_on_failure is not None + ): + logger.exception( + "Google Drive connector received old input parameters. " + "Please visit the docs for help with the new setup: " + f"{SCOPE_DOC_URL}" + ) + raise ValueError( + "Google Drive connector received old input parameters. " + "Please visit the docs for help with the new setup: " + f"{SCOPE_DOC_URL}" + ) + + if ( + not include_shared_drives + and not include_my_drives + and not shared_folder_urls + ): + raise ValueError( + "At least one of include_shared_drives, include_my_drives," + " or shared_folder_urls must be true" + ) + self.batch_size = batch_size - self.include_shared = include_shared - self.follow_shortcuts = follow_shortcuts - self.only_org_public = only_org_public - self.continue_on_failure = continue_on_failure + + self.include_shared_drives = include_shared_drives + shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls) + self.shared_drive_ids = _extract_ids_from_urls(shared_drive_url_list) + + self.include_my_drives = include_my_drives + self.my_drive_emails = _extract_str_list_from_comma_str(my_drive_emails) + + shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls) + self.shared_folder_ids = _extract_ids_from_urls(shared_folder_url_list) + + self.primary_admin_email: str | None = None + self.google_domain: str | None = None + self.creds: OAuthCredentials | ServiceAccountCredentials | None = None - @staticmethod - def _process_folder_paths( - service: discovery.Resource, - folder_paths: list[str], - include_shared: bool, - follow_shortcuts: bool, - ) -> list[str]: - """['Folder/Sub Folder'] -> ['']""" - folder_ids: list[str] = [] - for path in folder_paths: - folder_names = path.split("/") - parent_id = "root" - for folder_name in folder_names: - found_parent_id = _get_folder_id( - service=service, - parent_id=parent_id, - folder_name=folder_name, - include_shared=include_shared, - follow_shortcuts=follow_shortcuts, - ) - if found_parent_id is None: - raise ValueError( - ( - f"Folder '{folder_name}' in path '{path}' " - "not found in Google Drive" - ) - ) - parent_id = found_parent_id - folder_ids.append(parent_id) - - return folder_ids + self._TRAVERSED_PARENT_IDS: set[str] = set() + + def _update_traversed_parent_ids(self, folder_id: str) -> None: + self._TRAVERSED_PARENT_IDS.add(folder_id) def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: - """Checks for two different types of credentials. - (1) A credential which holds a token acquired via a user going thorough - the Google OAuth flow. - (2) A credential which holds a service account key JSON file, which - can then be used to impersonate any user in the workspace. - """ - creds, new_creds_dict = get_google_drive_creds(credentials) - self.creds = creds + primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] + self.google_domain = primary_admin_email.split("@")[1] + self.primary_admin_email = primary_admin_email + + self.creds, new_creds_dict = get_google_drive_creds(credentials) return new_creds_dict - def _fetch_docs_from_drive( + def get_google_resource( self, + service_name: str = "drive", + service_version: str = "v3", + user_email: str | None = None, + ) -> Resource: + if isinstance(self.creds, ServiceAccountCredentials): + creds = self.creds.with_subject(user_email or self.primary_admin_email) + service = build(service_name, service_version, credentials=creds) + elif isinstance(self.creds, OAuthCredentials): + service = build(service_name, service_version, credentials=self.creds) + else: + raise PermissionError("No credentials found") + + return service + + def _get_all_user_emails(self) -> list[str]: + admin_service = self.get_google_resource("admin", "directory_v1") + emails = [] + for user in execute_paginated_retrieval( + retrieval_function=admin_service.users().list, + list_key="users", + fields=USER_FIELDS, + domain=self.google_domain, + ): + if email := user.get("primaryEmail"): + emails.append(email) + return emails + + def _fetch_drive_items( + self, + is_slim: bool, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, - ) -> GenerateDocumentsOutput: - if self.creds is None: - raise PermissionError("Not logged into Google Drive") - - service = discovery.build("drive", "v3", credentials=self.creds) - folder_ids: Sequence[str | None] = self._process_folder_paths( - service, self.folder_paths, self.include_shared, self.follow_shortcuts - ) - if not folder_ids: - folder_ids = [None] - - file_batches = chain( - *[ - get_all_files_batched( - service=service, - continue_on_failure=self.continue_on_failure, - include_shared=self.include_shared, - follow_shortcuts=self.follow_shortcuts, - batch_size=self.batch_size, - time_range_start=start, - time_range_end=end, - folder_id=folder_id, - traverse_subfolders=True, + ) -> Iterator[GoogleDriveFileType]: + primary_drive_service = self.get_google_resource() + + if self.include_shared_drives: + shared_drive_urls = self.shared_drive_ids + if not shared_drive_urls: + # if no parent ids are specified, get all shared drives using the admin account + for drive in execute_paginated_retrieval( + retrieval_function=primary_drive_service.drives().list, + list_key="drives", + useDomainAdminAccess=True, + fields="drives(id)", + ): + shared_drive_urls.append(drive["id"]) + + # For each shared drive, retrieve all files + for shared_drive_id in shared_drive_urls: + for file in get_files_in_shared_drive( + service=primary_drive_service, + drive_id=shared_drive_id, + is_slim=is_slim, + cache_folders=bool(self.shared_folder_ids), + update_traversed_ids_func=self._update_traversed_parent_ids, + start=start, + end=end, + ): + yield file + + if self.shared_folder_ids: + # Crawl all the shared parent ids for files + for folder_id in self.shared_folder_ids: + yield from crawl_folders_for_files( + service=primary_drive_service, + parent_id=folder_id, + personal_drive=False, + traversed_parent_ids=self._TRAVERSED_PARENT_IDS, + update_traversed_ids_func=self._update_traversed_parent_ids, + start=start, + end=end, ) - for folder_id in folder_ids - ] - ) - for files_batch in file_batches: - doc_batch = [] - for file in files_batch: - try: - # Skip files that are shortcuts - if file.get("mimeType") == DRIVE_SHORTCUT_TYPE: - logger.info("Ignoring Drive Shortcut Filetype") - continue - - if self.only_org_public: - if "permissions" not in file: - continue - if not any( - permission["type"] == "domain" - for permission in file["permissions"] - ): - continue - try: - text_contents = extract_text(file, service) or "" - except HttpError as e: - reason = ( - e.error_details[0]["reason"] - if e.error_details - else e.reason - ) - message = ( - e.error_details[0]["message"] - if e.error_details - else e.reason - ) - - # these errors don't represent a failure in the connector, but simply files - # that can't / shouldn't be indexed - ERRORS_TO_CONTINUE_ON = [ - "cannotExportFile", - "exportSizeLimitExceeded", - "cannotDownloadFile", - ] - if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON: - logger.warning( - f"Could not export file '{file['name']}' due to '{message}', skipping..." - ) - continue - - raise - - doc_batch.append( - Document( - id=file["webViewLink"], - sections=[ - Section(link=file["webViewLink"], text=text_contents) - ], - source=DocumentSource.GOOGLE_DRIVE, - semantic_identifier=file["name"], - doc_updated_at=datetime.fromisoformat( - file["modifiedTime"] - ).astimezone(timezone.utc), - metadata={} if text_contents else {IGNORE_FOR_QA: "True"}, - additional_info=file.get("id"), - ) - ) - except Exception as e: - if not self.continue_on_failure: - raise e - - logger.exception( - "Ran into exception when pulling a file from Google Drive" - ) - - yield doc_batch + + all_user_emails = [] + # get all personal docs from each users' personal drive + if self.include_my_drives: + if isinstance(self.creds, ServiceAccountCredentials): + all_user_emails = self.my_drive_emails or [] + + # If using service account and no emails specified, fetch all users + if not all_user_emails: + all_user_emails = self._get_all_user_emails() + + elif self.primary_admin_email: + # If using OAuth, only fetch the primary admin email + all_user_emails = [self.primary_admin_email] + + for email in all_user_emails: + logger.info(f"Fetching personal files for user: {email}") + user_drive_service = self.get_google_resource(user_email=email) + + yield from get_files_in_my_drive( + service=user_drive_service, + email=email, + is_slim=is_slim, + start=start, + end=end, + ) + + def _extract_docs_from_google_drive( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> GenerateDocumentsOutput: + doc_batch = [] + for file in self._fetch_drive_items( + is_slim=False, + start=start, + end=end, + ): + user_email = file.get("owners", [{}])[0].get("emailAddress") + service = self.get_google_resource(user_email=user_email) + if doc := convert_drive_item_to_document( + file=file, + service=service, + ): + doc_batch.append(doc) + if len(doc_batch) >= self.batch_size: + yield doc_batch + doc_batch = [] + + yield doc_batch def load_from_state(self) -> GenerateDocumentsOutput: - yield from self._fetch_docs_from_drive() + try: + yield from self._extract_docs_from_google_drive() + except Exception as e: + if MISSING_SCOPES_ERROR_STR in str(e): + raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e + raise e def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: - # need to subtract 10 minutes from start time to account for modifiedTime - # propogation if a document is modified, it takes some time for the API to - # reflect these changes if we do not have an offset, then we may "miss" the - # update when polling - yield from self._fetch_docs_from_drive(start, end) - - -if __name__ == "__main__": - import json - import os - - service_account_json_path = os.environ.get("GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH") - if not service_account_json_path: - raise ValueError( - "Please set GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH environment variable" - ) - with open(service_account_json_path) as f: - creds = json.load(f) - - credentials_dict = { - DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: json.dumps(creds), - } - delegated_user = os.environ.get("GOOGLE_DRIVE_DELEGATED_USER") - if delegated_user: - credentials_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user - - connector = GoogleDriveConnector(include_shared=True, follow_shortcuts=True) - connector.load_credentials(credentials_dict) - document_batch_generator = connector.load_from_state() - for document_batch in document_batch_generator: - print(document_batch) - break + try: + yield from self._extract_docs_from_google_drive(start, end) + except Exception as e: + if MISSING_SCOPES_ERROR_STR in str(e): + raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e + raise e + + def _extract_slim_docs_from_google_drive( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> GenerateSlimDocumentOutput: + slim_batch = [] + for file in self._fetch_drive_items( + is_slim=True, + start=start, + end=end, + ): + slim_batch.append( + SlimDocument( + id=file["webViewLink"], + perm_sync_data={ + "doc_id": file.get("id"), + "permissions": file.get("permissions", []), + "permission_ids": file.get("permissionIds", []), + "name": file.get("name"), + "owner_email": file.get("owners", [{}])[0].get("emailAddress"), + }, + ) + ) + if len(slim_batch) >= SLIM_BATCH_SIZE: + yield slim_batch + slim_batch = [] + yield slim_batch + + def retrieve_all_slim_documents( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> GenerateSlimDocumentOutput: + try: + yield from self._extract_slim_docs_from_google_drive(start, end) + except Exception as e: + if MISSING_SCOPES_ERROR_STR in str(e): + raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e + raise e diff --git a/backend/danswer/connectors/google_drive/connector_auth.py b/backend/danswer/connectors/google_drive/connector_auth.py index 777deae990a..80cbda6772a 100644 --- a/backend/danswer/connectors/google_drive/connector_auth.py +++ b/backend/danswer/connectors/google_drive/connector_auth.py @@ -8,24 +8,16 @@ from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore +from googleapiclient.discovery import build # type: ignore from sqlalchemy.orm import Session -from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.constants import DocumentSource from danswer.configs.constants import KV_CRED_KEY from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY -from danswer.connectors.google_drive.constants import BASE_SCOPES -from danswer.connectors.google_drive.constants import ( - DB_CREDENTIALS_DICT_DELEGATED_USER_KEY, -) -from danswer.connectors.google_drive.constants import ( - DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, -) -from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY -from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES -from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES +from danswer.connectors.google_drive.constants import MISSING_SCOPES_ERROR_STR +from danswer.connectors.google_drive.constants import ONYX_SCOPE_INSTRUCTIONS from danswer.db.credentials import update_credential_json from danswer.db.models import User from danswer.key_value_store.factory import get_kv_store @@ -36,15 +28,14 @@ logger = setup_logger() - -def build_gdrive_scopes() -> list[str]: - base_scopes: list[str] = BASE_SCOPES - permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES - groups_scopes: list[str] = FETCH_GROUPS_SCOPES - - if ENTERPRISE_EDITION_ENABLED: - return base_scopes + permissions_scopes + groups_scopes - return base_scopes + permissions_scopes +GOOGLE_DRIVE_SCOPES = [ + "https://www.googleapis.com/auth/drive.readonly", + "https://www.googleapis.com/auth/drive.metadata.readonly", + "https://www.googleapis.com/auth/admin.directory.group.readonly", + "https://www.googleapis.com/auth/admin.directory.user.readonly", +] +DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens" +DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_drive_primary_admin" def _build_frontend_google_drive_redirect() -> str: @@ -52,7 +43,7 @@ def _build_frontend_google_drive_redirect() -> str: def get_google_drive_creds_for_authorized_user( - token_json_str: str, scopes: list[str] = build_gdrive_scopes() + token_json_str: str, scopes: list[str] ) -> OAuthCredentials | None: creds_json = json.loads(token_json_str) creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes) @@ -72,21 +63,15 @@ def get_google_drive_creds_for_authorized_user( return None -def _get_google_drive_creds_for_service_account( - service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes() -) -> ServiceAccountCredentials | None: - service_account_key = json.loads(service_account_key_json_str) - creds = ServiceAccountCredentials.from_service_account_info( - service_account_key, scopes=scopes - ) - if not creds.valid or not creds.expired: - creds.refresh(Request()) - return creds if creds.valid else None - - def get_google_drive_creds( - credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes() + credentials: dict[str, str], scopes: list[str] = GOOGLE_DRIVE_SCOPES ) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]: + """Checks for two different types of credentials. + (1) A credential which holds a token acquired via a user going thorough + the Google OAuth flow. + (2) A credential which holds a service account key JSON file, which + can then be used to impersonate any user in the workspace. + """ oauth_creds = None service_creds = None new_creds_dict = None @@ -100,26 +85,27 @@ def get_google_drive_creds( # (e.g. the token has been refreshed) new_creds_json_str = oauth_creds.to_json() if oauth_creds else "" if new_creds_json_str != access_token_json_str: - new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str} - - elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials: - service_account_key_json_str = credentials[ - DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY - ] - service_creds = _get_google_drive_creds_for_service_account( - service_account_key_json_str=service_account_key_json_str, - scopes=scopes, + new_creds_dict = { + DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str, + DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[ + DB_CREDENTIALS_PRIMARY_ADMIN_KEY + ], + } + + elif KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY in credentials: + service_account_key_json_str = credentials[KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY] + service_account_key = json.loads(service_account_key_json_str) + + service_creds = ServiceAccountCredentials.from_service_account_info( + service_account_key, scopes=scopes ) - # "Impersonate" a user if one is specified - delegated_user_email = cast( - str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY) - ) - if delegated_user_email: - service_creds = ( - service_creds.with_subject(delegated_user_email) - if service_creds - else None + if not service_creds.valid or not service_creds.expired: + service_creds.refresh(Request()) + + if not service_creds.valid: + raise PermissionError( + "Unable to access Google Drive - service account credentials are invalid." ) creds: ServiceAccountCredentials | OAuthCredentials | None = ( @@ -146,7 +132,7 @@ def get_auth_url(credential_id: int) -> str: credential_json = json.loads(creds_str) flow = InstalledAppFlow.from_client_config( credential_json, - scopes=build_gdrive_scopes(), + scopes=GOOGLE_DRIVE_SCOPES, redirect_uri=_build_frontend_google_drive_redirect(), ) auth_url, _ = flow.authorization_url(prompt="consent") @@ -169,13 +155,34 @@ def update_credential_access_tokens( app_credentials = get_google_app_cred() flow = InstalledAppFlow.from_client_config( app_credentials.model_dump(), - scopes=build_gdrive_scopes(), + scopes=GOOGLE_DRIVE_SCOPES, redirect_uri=_build_frontend_google_drive_redirect(), ) flow.fetch_token(code=auth_code) creds = flow.credentials token_json_str = creds.to_json() - new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str} + + # Get user email from Google API so we know who + # the primary admin is for this connector + try: + admin_service = build("drive", "v3", credentials=creds) + user_info = ( + admin_service.about() + .get( + fields="user(emailAddress)", + ) + .execute() + ) + email = user_info.get("user", {}).get("emailAddress") + except Exception as e: + if MISSING_SCOPES_ERROR_STR in str(e): + raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e + raise e + + new_creds_dict = { + DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str, + DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email, + } if not update_credential_json(credential_id, new_creds_dict, user, db_session): return None @@ -184,15 +191,15 @@ def update_credential_access_tokens( def build_service_account_creds( source: DocumentSource, - delegated_user_email: str | None = None, + primary_admin_email: str | None = None, ) -> CredentialBase: service_account_key = get_service_account_key() credential_dict = { - DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(), + KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY: service_account_key.json(), } - if delegated_user_email: - credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email + if primary_admin_email: + credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = primary_admin_email return CredentialBase( credential_json=credential_dict, diff --git a/backend/danswer/connectors/google_drive/constants.py b/backend/danswer/connectors/google_drive/constants.py index 0cca65c13df..848a21fffe6 100644 --- a/backend/danswer/connectors/google_drive/constants.py +++ b/backend/danswer/connectors/google_drive/constants.py @@ -1,7 +1,36 @@ -DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens" -DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key" -DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user" +UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now +DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder" +DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut" +DRIVE_FILE_TYPE = "application/vnd.google-apps.file" -BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"] -FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"] -FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"] +FILE_FIELDS = ( + "nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, " + "shortcutDetails, owners(emailAddress))" +) +SLIM_FILE_FIELDS = ( + "nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), " + "permissionIds, webViewLink, owners(emailAddress))" +) +FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)" +USER_FIELDS = "nextPageToken, users(primaryEmail)" + +# these errors don't represent a failure in the connector, but simply files +# that can't / shouldn't be indexed +ERRORS_TO_CONTINUE_ON = [ + "cannotExportFile", + "exportSizeLimitExceeded", + "cannotDownloadFile", +] + +# Error message substrings +MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested" + +# Documentation and error messages +SCOPE_DOC_URL = "https://docs.danswer.dev/connectors/google_drive/overview" +ONYX_SCOPE_INSTRUCTIONS = ( + "You have upgraded Danswer without updating the Google Drive scopes. " + f"Please refer to the documentation to learn how to update the scopes: {SCOPE_DOC_URL}" +) + +# Batch sizes +SLIM_BATCH_SIZE = 500 diff --git a/backend/danswer/connectors/google_drive/doc_conversion.py b/backend/danswer/connectors/google_drive/doc_conversion.py new file mode 100644 index 00000000000..688190c2267 --- /dev/null +++ b/backend/danswer/connectors/google_drive/doc_conversion.py @@ -0,0 +1,115 @@ +import io +from datetime import datetime +from datetime import timezone + +from googleapiclient.discovery import Resource # type: ignore +from googleapiclient.errors import HttpError # type: ignore + +from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE +from danswer.configs.constants import DocumentSource +from danswer.configs.constants import IGNORE_FOR_QA +from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE +from danswer.connectors.google_drive.constants import ERRORS_TO_CONTINUE_ON +from danswer.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT +from danswer.connectors.google_drive.models import GDriveMimeType +from danswer.connectors.google_drive.models import GoogleDriveFileType +from danswer.connectors.models import Document +from danswer.connectors.models import Section +from danswer.file_processing.extract_file_text import docx_to_text +from danswer.file_processing.extract_file_text import pptx_to_text +from danswer.file_processing.extract_file_text import read_pdf_file +from danswer.file_processing.unstructured import get_unstructured_api_key +from danswer.file_processing.unstructured import unstructured_to_text +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def _extract_text(file: dict[str, str], service: Resource) -> str: + mime_type = file["mimeType"] + + if mime_type not in set(item.value for item in GDriveMimeType): + # Unsupported file types can still have a title, finding this way is still useful + return UNSUPPORTED_FILE_TYPE_CONTENT + + if mime_type in [ + GDriveMimeType.DOC.value, + GDriveMimeType.PPT.value, + GDriveMimeType.SPREADSHEET.value, + ]: + export_mime_type = ( + "text/plain" + if mime_type != GDriveMimeType.SPREADSHEET.value + else "text/csv" + ) + return ( + service.files() + .export(fileId=file["id"], mimeType=export_mime_type) + .execute() + .decode("utf-8") + ) + elif mime_type in [ + GDriveMimeType.PLAIN_TEXT.value, + GDriveMimeType.MARKDOWN.value, + ]: + return service.files().get_media(fileId=file["id"]).execute().decode("utf-8") + if mime_type in [ + GDriveMimeType.WORD_DOC.value, + GDriveMimeType.POWERPOINT.value, + GDriveMimeType.PDF.value, + ]: + response = service.files().get_media(fileId=file["id"]).execute() + if get_unstructured_api_key(): + return unstructured_to_text( + file=io.BytesIO(response), file_name=file.get("name", file["id"]) + ) + + if mime_type == GDriveMimeType.WORD_DOC.value: + return docx_to_text(file=io.BytesIO(response)) + elif mime_type == GDriveMimeType.PDF.value: + text, _ = read_pdf_file(file=io.BytesIO(response)) + return text + elif mime_type == GDriveMimeType.POWERPOINT.value: + return pptx_to_text(file=io.BytesIO(response)) + + return UNSUPPORTED_FILE_TYPE_CONTENT + + +def convert_drive_item_to_document( + file: GoogleDriveFileType, service: Resource +) -> Document | None: + try: + # Skip files that are shortcuts + if file.get("mimeType") == DRIVE_SHORTCUT_TYPE: + logger.info("Ignoring Drive Shortcut Filetype") + return None + try: + text_contents = _extract_text(file, service) or "" + except HttpError as e: + reason = e.error_details[0]["reason"] if e.error_details else e.reason + message = e.error_details[0]["message"] if e.error_details else e.reason + if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON: + logger.warning( + f"Could not export file '{file['name']}' due to '{message}', skipping..." + ) + return None + + raise + + return Document( + id=file["webViewLink"], + sections=[Section(link=file["webViewLink"], text=text_contents)], + source=DocumentSource.GOOGLE_DRIVE, + semantic_identifier=file["name"], + doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone( + timezone.utc + ), + metadata={} if text_contents else {IGNORE_FOR_QA: "True"}, + additional_info=file.get("id"), + ) + except Exception as e: + if not CONTINUE_ON_CONNECTOR_FAILURE: + raise e + + logger.exception("Ran into exception when pulling a file from Google Drive") + return None diff --git a/backend/danswer/connectors/google_drive/file_retrieval.py b/backend/danswer/connectors/google_drive/file_retrieval.py new file mode 100644 index 00000000000..ea4e7d49466 --- /dev/null +++ b/backend/danswer/connectors/google_drive/file_retrieval.py @@ -0,0 +1,192 @@ +from collections.abc import Callable +from collections.abc import Iterator +from datetime import datetime + +from googleapiclient.discovery import Resource # type: ignore + +from danswer.connectors.google_drive.constants import DRIVE_FOLDER_TYPE +from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE +from danswer.connectors.google_drive.constants import FILE_FIELDS +from danswer.connectors.google_drive.constants import FOLDER_FIELDS +from danswer.connectors.google_drive.constants import SLIM_FILE_FIELDS +from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval +from danswer.connectors.google_drive.models import GoogleDriveFileType +from danswer.connectors.interfaces import SecondsSinceUnixEpoch +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def _generate_time_range_filter( + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, +) -> str: + time_range_filter = "" + if start is not None: + time_start = datetime.utcfromtimestamp(start).isoformat() + "Z" + time_range_filter += f" and modifiedTime >= '{time_start}'" + if end is not None: + time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z" + time_range_filter += f" and modifiedTime <= '{time_stop}'" + return time_range_filter + + +def _get_folders_in_parent( + service: Resource, + parent_id: str | None = None, + personal_drive: bool = False, +) -> Iterator[GoogleDriveFileType]: + # Follow shortcuts to folders + query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')" + query += " and trashed = false" + + if parent_id: + query += f" and '{parent_id}' in parents" + + for file in execute_paginated_retrieval( + retrieval_function=service.files().list, + list_key="files", + corpora="user" if personal_drive else "allDrives", + supportsAllDrives=not personal_drive, + includeItemsFromAllDrives=not personal_drive, + fields=FOLDER_FIELDS, + q=query, + ): + yield file + + +def _get_files_in_parent( + service: Resource, + parent_id: str, + personal_drive: bool, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + is_slim: bool = False, +) -> Iterator[GoogleDriveFileType]: + query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents" + query += " and trashed = false" + query += _generate_time_range_filter(start, end) + + for file in execute_paginated_retrieval( + retrieval_function=service.files().list, + list_key="files", + corpora="user" if personal_drive else "allDrives", + supportsAllDrives=not personal_drive, + includeItemsFromAllDrives=not personal_drive, + fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS, + q=query, + ): + yield file + + +def crawl_folders_for_files( + service: Resource, + parent_id: str, + personal_drive: bool, + traversed_parent_ids: set[str], + update_traversed_ids_func: Callable[[str], None], + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, +) -> Iterator[GoogleDriveFileType]: + """ + This function starts crawling from any folder. It is slower though. + """ + if parent_id in traversed_parent_ids: + print(f"Skipping subfolder since already traversed: {parent_id}") + return + + update_traversed_ids_func(parent_id) + + yield from _get_files_in_parent( + service=service, + personal_drive=personal_drive, + start=start, + end=end, + parent_id=parent_id, + ) + + for subfolder in _get_folders_in_parent( + service=service, + parent_id=parent_id, + personal_drive=personal_drive, + ): + logger.info("Fetching all files in subfolder: " + subfolder["name"]) + yield from crawl_folders_for_files( + service=service, + parent_id=subfolder["id"], + personal_drive=personal_drive, + traversed_parent_ids=traversed_parent_ids, + update_traversed_ids_func=update_traversed_ids_func, + start=start, + end=end, + ) + + +def get_files_in_shared_drive( + service: Resource, + drive_id: str, + is_slim: bool = False, + cache_folders: bool = True, + update_traversed_ids_func: Callable[[str], None] = lambda _: None, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, +) -> Iterator[GoogleDriveFileType]: + # If we know we are going to folder crawl later, we can cache the folders here + if cache_folders: + # Get all folders being queried and add them to the traversed set + query = f"mimeType = '{DRIVE_FOLDER_TYPE}'" + query += " and trashed = false" + for file in execute_paginated_retrieval( + retrieval_function=service.files().list, + list_key="files", + corpora="drive", + driveId=drive_id, + supportsAllDrives=True, + includeItemsFromAllDrives=True, + fields="nextPageToken, files(id)", + q=query, + ): + update_traversed_ids_func(file["id"]) + + # Get all files in the shared drive + query = f"mimeType != '{DRIVE_FOLDER_TYPE}'" + query += " and trashed = false" + query += _generate_time_range_filter(start, end) + for file in execute_paginated_retrieval( + retrieval_function=service.files().list, + list_key="files", + corpora="drive", + driveId=drive_id, + supportsAllDrives=True, + includeItemsFromAllDrives=True, + fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS, + q=query, + ): + yield file + + +def get_files_in_my_drive( + service: Resource, + email: str, + is_slim: bool = False, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, +) -> Iterator[GoogleDriveFileType]: + query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{email}' in owners" + query += " and trashed = false" + query += _generate_time_range_filter(start, end) + for file in execute_paginated_retrieval( + retrieval_function=service.files().list, + list_key="files", + corpora="user", + fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS, + q=query, + ): + yield file + + +# Just in case we need to get the root folder id +def get_root_folder_id(service: Resource) -> str: + # we dont paginate here because there is only one root folder per user + # https://developers.google.com/drive/api/guides/v2-to-v3-reference + return service.files().get(fileId="root", fields="id").execute()["id"] diff --git a/backend/danswer/connectors/google_drive/google_utils.py b/backend/danswer/connectors/google_drive/google_utils.py new file mode 100644 index 00000000000..5f772e5ad63 --- /dev/null +++ b/backend/danswer/connectors/google_drive/google_utils.py @@ -0,0 +1,35 @@ +from collections.abc import Callable +from collections.abc import Iterator +from typing import Any + +from danswer.connectors.google_drive.models import GoogleDriveFileType +from danswer.utils.retry_wrapper import retry_builder + + +# Google Drive APIs are quite flakey and may 500 for an +# extended period of time. Trying to combat here by adding a very +# long retry period (~20 minutes of trying every minute) +add_retries = retry_builder(tries=50, max_delay=30) + + +def execute_paginated_retrieval( + retrieval_function: Callable, + list_key: str, + **kwargs: Any, +) -> Iterator[GoogleDriveFileType]: + """Execute a paginated retrieval from Google Drive API + Args: + retrieval_function: The specific list function to call (e.g., service.files().list) + **kwargs: Arguments to pass to the list function + """ + next_page_token = "" + while next_page_token is not None: + request_kwargs = kwargs.copy() + if next_page_token: + request_kwargs["pageToken"] = next_page_token + + results = add_retries(lambda: retrieval_function(**request_kwargs).execute())() + + next_page_token = results.get("nextPageToken") + for item in results.get(list_key, []): + yield item diff --git a/backend/danswer/connectors/google_drive/models.py b/backend/danswer/connectors/google_drive/models.py new file mode 100644 index 00000000000..5bb06f3c206 --- /dev/null +++ b/backend/danswer/connectors/google_drive/models.py @@ -0,0 +1,18 @@ +from enum import Enum +from typing import Any + + +class GDriveMimeType(str, Enum): + DOC = "application/vnd.google-apps.document" + SPREADSHEET = "application/vnd.google-apps.spreadsheet" + PDF = "application/pdf" + WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + PPT = "application/vnd.google-apps.presentation" + POWERPOINT = ( + "application/vnd.openxmlformats-officedocument.presentationml.presentation" + ) + PLAIN_TEXT = "text/plain" + MARKDOWN = "text/markdown" + + +GoogleDriveFileType = dict[str, Any] diff --git a/backend/danswer/connectors/interfaces.py b/backend/danswer/connectors/interfaces.py index 4734212147e..c53b3de5f2f 100644 --- a/backend/danswer/connectors/interfaces.py +++ b/backend/danswer/connectors/interfaces.py @@ -56,7 +56,11 @@ def poll_source( class SlimConnector(BaseConnector): @abc.abstractmethod - def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput: + def retrieve_all_slim_documents( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> GenerateSlimDocumentOutput: raise NotImplementedError diff --git a/backend/danswer/connectors/salesforce/connector.py b/backend/danswer/connectors/salesforce/connector.py index 78d73d44766..1e0fe9e1d3a 100644 --- a/backend/danswer/connectors/salesforce/connector.py +++ b/backend/danswer/connectors/salesforce/connector.py @@ -251,7 +251,11 @@ def poll_source( end_datetime = datetime.utcfromtimestamp(end) return self._fetch_from_salesforce(start=start_datetime, end=end_datetime) - def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput: + def retrieve_all_slim_documents( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> GenerateSlimDocumentOutput: if self.sf_client is None: raise ConnectorMissingCredentialError("Salesforce") doc_metadata_list: list[SlimDocument] = [] diff --git a/backend/danswer/connectors/slack/connector.py b/backend/danswer/connectors/slack/connector.py index 92b5a0b7558..22ace603bd4 100644 --- a/backend/danswer/connectors/slack/connector.py +++ b/backend/danswer/connectors/slack/connector.py @@ -391,7 +391,11 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None self.client = WebClient(token=bot_token) return None - def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput: + def retrieve_all_slim_documents( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> GenerateSlimDocumentOutput: if self.client is None: raise ConnectorMissingCredentialError("Slack") diff --git a/backend/danswer/db/credentials.py b/backend/danswer/db/credentials.py index 80ebe1b1538..58be604a724 100644 --- a/backend/danswer/db/credentials.py +++ b/backend/danswer/db/credentials.py @@ -10,12 +10,10 @@ from danswer.auth.schemas import UserRole from danswer.configs.constants import DocumentSource +from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY from danswer.connectors.gmail.constants import ( GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, ) -from danswer.connectors.google_drive.constants import ( - DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, -) from danswer.db.models import ConnectorCredentialPair from danswer.db.models import Credential from danswer.db.models import Credential__UserGroup @@ -442,7 +440,7 @@ def delete_google_drive_service_account_credentials( ) -> None: credentials = fetch_credentials(db_session=db_session, user=user) for credential in credentials: - if credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY): + if credential.credential_json.get(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY): db_session.delete(credential) db_session.commit() diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index 1ba0ab13e2c..b0866a826c1 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -9,6 +9,7 @@ from fastapi import Request from fastapi import Response from fastapi import UploadFile +from google.oauth2.credentials import Credentials # type: ignore from pydantic import BaseModel from sqlalchemy.orm import Session @@ -35,6 +36,7 @@ ) from danswer.connectors.gmail.connector_auth import upsert_google_app_gmail_cred from danswer.connectors.google_drive.connector_auth import build_service_account_creds +from danswer.connectors.google_drive.connector_auth import DB_CREDENTIALS_DICT_TOKEN_KEY from danswer.connectors.google_drive.connector_auth import delete_google_app_cred from danswer.connectors.google_drive.connector_auth import delete_service_account_key from danswer.connectors.google_drive.connector_auth import get_auth_url @@ -43,13 +45,13 @@ get_google_drive_creds_for_authorized_user, ) from danswer.connectors.google_drive.connector_auth import get_service_account_key +from danswer.connectors.google_drive.connector_auth import GOOGLE_DRIVE_SCOPES from danswer.connectors.google_drive.connector_auth import ( update_credential_access_tokens, ) from danswer.connectors.google_drive.connector_auth import upsert_google_app_cred from danswer.connectors.google_drive.connector_auth import upsert_service_account_key from danswer.connectors.google_drive.connector_auth import verify_csrf -from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY from danswer.db.connector import create_connector from danswer.db.connector import delete_connector from danswer.db.connector import fetch_connector_by_id @@ -294,7 +296,7 @@ def upsert_service_account_credential( try: credential_base = build_service_account_creds( DocumentSource.GOOGLE_DRIVE, - delegated_user_email=service_account_credential_request.google_drive_delegated_user, + primary_admin_email=service_account_credential_request.google_drive_primary_admin, ) except KvKeyNotFoundError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -320,7 +322,7 @@ def upsert_gmail_service_account_credential( try: credential_base = build_service_account_creds( DocumentSource.GMAIL, - delegated_user_email=service_account_credential_request.gmail_delegated_user, + primary_admin_email=service_account_credential_request.gmail_delegated_user, ) except KvKeyNotFoundError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -348,27 +350,14 @@ def check_drive_tokens( return AuthStatus(authenticated=False) token_json_str = str(db_credentials.credential_json[DB_CREDENTIALS_DICT_TOKEN_KEY]) google_drive_creds = get_google_drive_creds_for_authorized_user( - token_json_str=token_json_str + token_json_str=token_json_str, + scopes=GOOGLE_DRIVE_SCOPES, ) if google_drive_creds is None: return AuthStatus(authenticated=False) return AuthStatus(authenticated=True) -@router.get("/admin/connector/google-drive/authorize/{credential_id}") -def admin_google_drive_auth( - response: Response, credential_id: str, _: User = Depends(current_admin_user) -) -> AuthUrl: - # set a cookie that we can read in the callback (used for `verify_csrf`) - response.set_cookie( - key=_GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME, - value=credential_id, - httponly=True, - max_age=600, - ) - return AuthUrl(auth_url=get_auth_url(credential_id=int(credential_id))) - - @router.post("/admin/connector/file/upload") def upload_files( files: list[UploadFile], @@ -951,10 +940,11 @@ def google_drive_callback( ) credential_id = int(credential_id_cookie) verify_csrf(credential_id, callback.state) - if ( - update_credential_access_tokens(callback.code, credential_id, user, db_session) - is None - ): + + credentials: Credentials | None = update_credential_access_tokens( + callback.code, credential_id, user, db_session + ) + if credentials is None: raise HTTPException( status_code=500, detail="Unable to fetch Google Drive access tokens" ) diff --git a/backend/danswer/server/documents/models.py b/backend/danswer/server/documents/models.py index fcbc0a76a12..e45d6eabff0 100644 --- a/backend/danswer/server/documents/models.py +++ b/backend/danswer/server/documents/models.py @@ -377,16 +377,16 @@ class GoogleServiceAccountKey(BaseModel): class GoogleServiceAccountCredentialRequest(BaseModel): - google_drive_delegated_user: str | None = None # email of user to impersonate + google_drive_primary_admin: str | None = None # email of user to impersonate gmail_delegated_user: str | None = None # email of user to impersonate @model_validator(mode="after") def check_user_delegation(self) -> "GoogleServiceAccountCredentialRequest": - if (self.google_drive_delegated_user is None) == ( + if (self.google_drive_primary_admin is None) == ( self.gmail_delegated_user is None ): raise ValueError( - "Exactly one of google_drive_delegated_user or gmail_delegated_user must be set" + "Exactly one of google_drive_primary_admin or gmail_delegated_user must be set" ) return self diff --git a/backend/ee/danswer/background/celery/apps/beat.py b/backend/ee/danswer/background/celery/apps/beat.py index bee219e2471..980eb5e3214 100644 --- a/backend/ee/danswer/background/celery/apps/beat.py +++ b/backend/ee/danswer/background/celery/apps/beat.py @@ -13,12 +13,12 @@ { "name": "sync-external-doc-permissions", "task": "check_sync_external_doc_permissions_task", - "schedule": timedelta(seconds=5), # TODO: optimize this + "schedule": timedelta(seconds=30), # TODO: optimize this }, { "name": "sync-external-group-permissions", "task": "check_sync_external_group_permissions_task", - "schedule": timedelta(seconds=5), # TODO: optimize this + "schedule": timedelta(seconds=60), # TODO: optimize this }, { "name": "autogenerate_usage_report", diff --git a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py index f1e805d46d7..d1df0cb0846 100644 --- a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py +++ b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py @@ -1,144 +1,119 @@ -from collections.abc import Iterator from datetime import datetime from datetime import timezone from typing import Any -from typing import cast -from googleapiclient.discovery import build # type: ignore -from googleapiclient.errors import HttpError # type: ignore from sqlalchemy.orm import Session from danswer.access.models import ExternalAccess -from danswer.connectors.factory import instantiate_connector -from danswer.connectors.google_drive.connector_auth import ( - get_google_drive_creds, -) -from danswer.connectors.interfaces import PollConnector -from danswer.connectors.models import InputType +from danswer.connectors.google_drive.connector import GoogleDriveConnector +from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval +from danswer.connectors.interfaces import GenerateSlimDocumentOutput +from danswer.connectors.models import SlimDocument from danswer.db.models import ConnectorCredentialPair from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger -from danswer.utils.retry_wrapper import retry_builder from ee.danswer.db.document import upsert_document_external_perms__no_commit -# Google Drive APIs are quite flakey and may 500 for an -# extended period of time. Trying to combat here by adding a very -# long retry period (~20 minutes of trying every minute) -add_retries = retry_builder(tries=5, delay=5, max_delay=30) - - logger = setup_logger() +_PERMISSION_ID_PERMISSION_MAP: dict[str, dict[str, Any]] = {} -def _get_docs_with_additional_info( - db_session: Session, - cc_pair: ConnectorCredentialPair, -) -> dict[str, Any]: - # Get all document ids that need their permissions updated - runnable_connector = instantiate_connector( - db_session=db_session, - source=cc_pair.connector.source, - input_type=InputType.POLL, - connector_specific_config=cc_pair.connector.connector_specific_config, - credential=cc_pair.credential, - ) - - assert isinstance(runnable_connector, PollConnector) +def _get_slim_doc_generator( + cc_pair: ConnectorCredentialPair, + google_drive_connector: GoogleDriveConnector, +) -> GenerateSlimDocumentOutput: current_time = datetime.now(timezone.utc) start_time = ( cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp() if cc_pair.last_time_perm_sync else 0.0 ) - cc_pair.last_time_perm_sync = current_time - doc_batch_generator = runnable_connector.poll_source( + return google_drive_connector.retrieve_all_slim_documents( start=start_time, end=current_time.timestamp() ) - docs_with_additional_info = { - doc.id: doc.additional_info - for doc_batch in doc_batch_generator - for doc in doc_batch - } - - return docs_with_additional_info - - -def _fetch_permissions_paginated( - drive_service: Any, drive_file_id: str -) -> Iterator[dict[str, Any]]: - next_token = None - - # Get paginated permissions for the file id - while True: - try: - permissions_resp: dict[str, Any] = add_retries( - lambda: ( - drive_service.permissions() - .list( - fileId=drive_file_id, - fields="permissions(emailAddress, type, domain)", - supportsAllDrives=True, - pageToken=next_token, - ) - .execute() - ) - )() - except HttpError as e: - if e.resp.status == 404: - logger.warning(f"Document with id {drive_file_id} not found: {e}") - break - elif e.resp.status == 403: - logger.warning( - f"Access denied for retrieving document permissions: {e}" - ) - break - else: - logger.error(f"Failed to fetch permissions: {e}") - raise - - for permission in permissions_resp.get("permissions", []): - yield permission - - next_token = permissions_resp.get("nextPageToken") - if not next_token: - break - - -def _fetch_google_permissions_for_document_id( - db_session: Session, - drive_file_id: str, - credentials_json: dict[str, str], - company_google_domains: list[str], -) -> ExternalAccess: - # Authenticate and construct service - google_drive_creds, _ = get_google_drive_creds( - credentials_json, + +def _fetch_permissions_for_permission_ids( + google_drive_connector: GoogleDriveConnector, + permission_ids: list[str], + permission_info: dict[str, Any], +) -> list[dict[str, Any]]: + doc_id = permission_info.get("doc_id") + if not permission_info or not doc_id: + return [] + + # Check cache first for all permission IDs + permissions = [ + _PERMISSION_ID_PERMISSION_MAP[pid] + for pid in permission_ids + if pid in _PERMISSION_ID_PERMISSION_MAP + ] + + # If we found all permissions in cache, return them + if len(permissions) == len(permission_ids): + return permissions + + owner_email = permission_info.get("owner_email") + drive_service = google_drive_connector.get_google_resource(user_email=owner_email) + + # Otherwise, fetch all permissions and update cache + fetched_permissions = execute_paginated_retrieval( + retrieval_function=drive_service.permissions().list, + list_key="permissions", + fileId=doc_id, + fields="permissions(id, emailAddress, type, domain)", + supportsAllDrives=True, ) - if not google_drive_creds.valid: - raise ValueError("Invalid Google Drive credentials") - drive_service = build("drive", "v3", credentials=google_drive_creds) + permissions_for_doc_id = [] + # Update cache and return all permissions + for permission in fetched_permissions: + permissions_for_doc_id.append(permission) + _PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission + + return permissions_for_doc_id + +def _get_permissions_from_slim_doc( + google_drive_connector: GoogleDriveConnector, + slim_doc: SlimDocument, +) -> ExternalAccess: + permission_info = slim_doc.perm_sync_data or {} + + permissions_list = permission_info.get("permissions", []) + if not permissions_list: + if permission_ids := permission_info.get("permission_ids"): + permissions_list = _fetch_permissions_for_permission_ids( + google_drive_connector=google_drive_connector, + permission_ids=permission_ids, + permission_info=permission_info, + ) + if not permissions_list: + logger.warning(f"No permissions found for document {slim_doc.id}") + return ExternalAccess( + external_user_emails=set(), + external_user_group_ids=set(), + is_public=False, + ) + + company_domain = google_drive_connector.google_domain user_emails: set[str] = set() group_emails: set[str] = set() public = False - for permission in _fetch_permissions_paginated(drive_service, drive_file_id): + for permission in permissions_list: permission_type = permission["type"] if permission_type == "user": user_emails.add(permission["emailAddress"]) elif permission_type == "group": group_emails.add(permission["emailAddress"]) - elif permission_type == "domain": - if permission["domain"] in company_google_domains: + elif permission_type == "domain" and company_domain: + if permission["domain"] == company_domain: public = True elif permission_type == "anyone": public = True - batch_add_non_web_user_if_not_exists__no_commit(db_session, list(user_emails)) - return ExternalAccess( external_user_emails=user_emails, external_user_group_ids=group_emails, @@ -156,32 +131,26 @@ def gdrive_doc_sync( it in postgres so that when it gets created later, the permissions are already populated """ - sync_details = cc_pair.auto_sync_options - if sync_details is None: - logger.error("Sync details not found for Google Drive") - raise ValueError("Sync details not found for Google Drive") - - # Here we run the connector to grab all the ids - # this may grab ids before they are indexed but that is fine because - # we create a document in postgres to hold the permissions info - # until the indexing job has a chance to run - docs_with_additional_info = _get_docs_with_additional_info( - db_session=db_session, - cc_pair=cc_pair, + google_drive_connector = GoogleDriveConnector( + **cc_pair.connector.connector_specific_config ) - - for doc_id, doc_additional_info in docs_with_additional_info.items(): - ext_access = _fetch_google_permissions_for_document_id( - db_session=db_session, - drive_file_id=doc_additional_info, - credentials_json=cc_pair.credential.credential_json, - company_google_domains=[ - cast(dict[str, str], sync_details)["company_domain"] - ], - ) - upsert_document_external_perms__no_commit( - db_session=db_session, - doc_id=doc_id, - external_access=ext_access, - source_type=cc_pair.connector.source, - ) + google_drive_connector.load_credentials(cc_pair.credential.credential_json) + + slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector) + + for slim_doc_batch in slim_doc_generator: + for slim_doc in slim_doc_batch: + ext_access = _get_permissions_from_slim_doc( + google_drive_connector=google_drive_connector, + slim_doc=slim_doc, + ) + batch_add_non_web_user_if_not_exists__no_commit( + db_session=db_session, + emails=list(ext_access.external_user_emails), + ) + upsert_document_external_perms__no_commit( + db_session=db_session, + doc_id=slim_doc.id, + external_access=ext_access, + source_type=cc_pair.connector.source, + ) diff --git a/backend/ee/danswer/external_permissions/google_drive/group_sync.py b/backend/ee/danswer/external_permissions/google_drive/group_sync.py index ab0f62a886b..c3afa962392 100644 --- a/backend/ee/danswer/external_permissions/google_drive/group_sync.py +++ b/backend/ee/danswer/external_permissions/google_drive/group_sync.py @@ -1,136 +1,48 @@ -from collections.abc import Iterator -from typing import Any - -from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore -from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore -from googleapiclient.discovery import build # type: ignore -from googleapiclient.errors import HttpError # type: ignore from sqlalchemy.orm import Session -from danswer.connectors.google_drive.connector_auth import ( - get_google_drive_creds, -) -from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES +from danswer.connectors.google_drive.connector import GoogleDriveConnector +from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval from danswer.db.models import ConnectorCredentialPair from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger -from danswer.utils.retry_wrapper import retry_builder from ee.danswer.db.external_perm import ExternalUserGroup from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit logger = setup_logger() -# Google Drive APIs are quite flakey and may 500 for an -# extended period of time. Trying to combat here by adding a very -# long retry period (~20 minutes of trying every minute) -add_retries = retry_builder(tries=5, delay=5, max_delay=30) - - -def _fetch_groups_paginated( - google_drive_creds: ServiceAccountCredentials | OAuthCredentials, - identity_source: str | None = None, - customer_id: str | None = None, -) -> Iterator[dict[str, Any]]: - # Note that Google Drive does not use of update the user_cache as the user email - # comes directly with the call to fetch the groups, therefore this is not a valid - # place to save on requests - if identity_source is None and customer_id is None: - raise ValueError( - "Either identity_source or customer_id must be provided to fetch groups" - ) - - cloud_identity_service = build( - "cloudidentity", "v1", credentials=google_drive_creds - ) - parent = ( - f"identitysources/{identity_source}" - if identity_source - else f"customers/{customer_id}" - ) - - while True: - try: - groups_resp: dict[str, Any] = add_retries( - lambda: (cloud_identity_service.groups().list(parent=parent).execute()) - )() - for group in groups_resp.get("groups", []): - yield group - - next_token = groups_resp.get("nextPageToken") - if not next_token: - break - except HttpError as e: - if e.resp.status == 404 or e.resp.status == 403: - break - logger.error(f"Error fetching groups: {e}") - raise - - -def _fetch_group_members_paginated( - google_drive_creds: ServiceAccountCredentials | OAuthCredentials, - group_name: str, -) -> Iterator[dict[str, Any]]: - cloud_identity_service = build( - "cloudidentity", "v1", credentials=google_drive_creds - ) - next_token = None - while True: - try: - membership_info = add_retries( - lambda: ( - cloud_identity_service.groups() - .memberships() - .searchTransitiveMemberships( - parent=group_name, pageToken=next_token - ) - .execute() - ) - )() - - for member in membership_info.get("memberships", []): - yield member - - next_token = membership_info.get("nextPageToken") - if not next_token: - break - except HttpError as e: - if e.resp.status == 404 or e.resp.status == 403: - break - logger.error(f"Error fetching group members: {e}") - raise - - def gdrive_group_sync( db_session: Session, cc_pair: ConnectorCredentialPair, ) -> None: - sync_details = cc_pair.auto_sync_options - if sync_details is None: - logger.error("Sync details not found for Google Drive") - raise ValueError("Sync details not found for Google Drive") - - google_drive_creds, _ = get_google_drive_creds( - cc_pair.credential.credential_json, - scopes=FETCH_GROUPS_SCOPES, + google_drive_connector = GoogleDriveConnector( + **cc_pair.connector.connector_specific_config ) + google_drive_connector.load_credentials(cc_pair.credential.credential_json) + + admin_service = google_drive_connector.get_google_resource("admin", "directory_v1") danswer_groups: list[ExternalUserGroup] = [] - for group in _fetch_groups_paginated( - google_drive_creds, - identity_source=sync_details.get("identity_source"), - customer_id=sync_details.get("customer_id"), + for group in execute_paginated_retrieval( + admin_service.groups().list, + list_key="groups", + domain=google_drive_connector.google_domain, + fields="groups(email)", ): # The id is the group email - group_email = group["groupKey"]["id"] + group_email = group["email"] + # Gather group member emails group_member_emails: list[str] = [] - for member in _fetch_group_members_paginated(google_drive_creds, group["name"]): - member_keys = member["preferredMemberKey"] - member_emails = [member_key["id"] for member_key in member_keys] - for member_email in member_emails: - group_member_emails.append(member_email) - + for member in execute_paginated_retrieval( + admin_service.members().list, + list_key="members", + groupKey=group_email, + fields="members(email)", + ): + group_member_emails.append(member["email"]) + + # Add group members to DB and get their IDs group_members = batch_add_non_web_user_if_not_exists__no_commit( db_session=db_session, emails=group_member_emails ) diff --git a/backend/ee/danswer/external_permissions/permission_sync.py b/backend/ee/danswer/external_permissions/permission_sync.py index ba5bbbd4921..94a0b4bfa8e 100644 --- a/backend/ee/danswer/external_permissions/permission_sync.py +++ b/backend/ee/danswer/external_permissions/permission_sync.py @@ -59,6 +59,7 @@ def run_external_doc_permission_sync( source_type = cc_pair.connector.source doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type) + last_time_perm_sync = cc_pair.last_time_perm_sync if doc_sync_func is None: raise ValueError( @@ -110,4 +111,5 @@ def run_external_doc_permission_sync( logger.info(f"Successfully synced docs for {source_type}") except Exception: logger.exception("Error Syncing Document Permissions") + cc_pair.last_time_perm_sync = last_time_perm_sync db_session.rollback() diff --git a/backend/tests/daily/connectors/google_drive/conftest.py b/backend/tests/daily/connectors/google_drive/conftest.py new file mode 100644 index 00000000000..0b516d0359c --- /dev/null +++ b/backend/tests/daily/connectors/google_drive/conftest.py @@ -0,0 +1,98 @@ +import json +import os +from collections.abc import Callable + +import pytest + +from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY +from danswer.connectors.google_drive.connector import GoogleDriveConnector +from danswer.connectors.google_drive.connector_auth import DB_CREDENTIALS_DICT_TOKEN_KEY +from danswer.connectors.google_drive.connector_auth import ( + DB_CREDENTIALS_PRIMARY_ADMIN_KEY, +) + + +def load_env_vars(env_file: str = ".env") -> None: + current_dir = os.path.dirname(os.path.abspath(__file__)) + env_path = os.path.join(current_dir, env_file) + try: + with open(env_path, "r") as f: + for line in f: + line = line.strip() + if line and not line.startswith("#"): + key, value = line.split("=", 1) + os.environ[key] = value.strip() + print("Successfully loaded environment variables") + except FileNotFoundError: + print(f"File {env_file} not found") + + +# Load environment variables at the module level +load_env_vars() + + +@pytest.fixture +def google_drive_oauth_connector_factory() -> Callable[..., GoogleDriveConnector]: + def _connector_factory( + primary_admin_email: str = "admin@onyx-test.com", + include_shared_drives: bool = True, + shared_drive_urls: str | None = None, + include_my_drives: bool = True, + my_drive_emails: str | None = None, + shared_folder_urls: str | None = None, + ) -> GoogleDriveConnector: + connector = GoogleDriveConnector( + include_shared_drives=include_shared_drives, + shared_drive_urls=shared_drive_urls, + include_my_drives=include_my_drives, + my_drive_emails=my_drive_emails, + shared_folder_urls=shared_folder_urls, + ) + + json_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"] + refried_json_string = json.loads(json_string) + + credentials_json = { + DB_CREDENTIALS_DICT_TOKEN_KEY: refried_json_string, + DB_CREDENTIALS_PRIMARY_ADMIN_KEY: primary_admin_email, + } + connector.load_credentials(credentials_json) + return connector + + return _connector_factory + + +@pytest.fixture +def google_drive_service_acct_connector_factory() -> ( + Callable[..., GoogleDriveConnector] +): + def _connector_factory( + primary_admin_email: str = "admin@onyx-test.com", + include_shared_drives: bool = True, + shared_drive_urls: str | None = None, + include_my_drives: bool = True, + my_drive_emails: str | None = None, + shared_folder_urls: str | None = None, + ) -> GoogleDriveConnector: + print("Creating GoogleDriveConnector with service account credentials") + connector = GoogleDriveConnector( + include_shared_drives=include_shared_drives, + shared_drive_urls=shared_drive_urls, + include_my_drives=include_my_drives, + my_drive_emails=my_drive_emails, + shared_folder_urls=shared_folder_urls, + ) + + json_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"] + refried_json_string = json.loads(json_string) + + # Load Service Account Credentials + connector.load_credentials( + { + KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY: refried_json_string, + DB_CREDENTIALS_PRIMARY_ADMIN_KEY: primary_admin_email, + } + ) + return connector + + return _connector_factory diff --git a/backend/tests/daily/connectors/google_drive/helpers.py b/backend/tests/daily/connectors/google_drive/helpers.py new file mode 100644 index 00000000000..a1bc8feec38 --- /dev/null +++ b/backend/tests/daily/connectors/google_drive/helpers.py @@ -0,0 +1,164 @@ +from collections.abc import Sequence + +from danswer.connectors.models import Document + +ALL_FILES = list(range(0, 60)) +SHARED_DRIVE_FILES = list(range(20, 25)) + + +_ADMIN_FILE_IDS = list(range(0, 5)) +_TEST_USER_1_FILE_IDS = list(range(5, 10)) +_TEST_USER_2_FILE_IDS = list(range(10, 15)) +_TEST_USER_3_FILE_IDS = list(range(15, 20)) +_SHARED_DRIVE_1_FILE_IDS = list(range(20, 25)) +_FOLDER_1_FILE_IDS = list(range(25, 30)) +_FOLDER_1_1_FILE_IDS = list(range(30, 35)) +_FOLDER_1_2_FILE_IDS = list(range(35, 40)) +_SHARED_DRIVE_2_FILE_IDS = list(range(40, 45)) +_FOLDER_2_FILE_IDS = list(range(45, 50)) +_FOLDER_2_1_FILE_IDS = list(range(50, 55)) +_FOLDER_2_2_FILE_IDS = list(range(55, 60)) + +_PUBLIC_FOLDER_RANGE = _FOLDER_1_2_FILE_IDS +_PUBLIC_FILE_IDS = list(range(55, 57)) +PUBLIC_RANGE = _PUBLIC_FOLDER_RANGE + _PUBLIC_FILE_IDS + +_SHARED_DRIVE_1_URL = "https://drive.google.com/drive/folders/0AC_OJ4BkMd4kUk9PVA" +# Group 1 is given access to this folder +_FOLDER_1_URL = ( + "https://drive.google.com/drive/folders/1d3I7U3vUZMDziF1OQqYRkB8Jp2s_GWUn" +) +_FOLDER_1_1_URL = ( + "https://drive.google.com/drive/folders/1aR33-zwzl_mnRAwH55GgtWTE-4A4yWWI" +) +_FOLDER_1_2_URL = ( + "https://drive.google.com/drive/folders/1IO0X55VhvLXf4mdxzHxuKf4wxrDBB6jq" +) +_SHARED_DRIVE_2_URL = "https://drive.google.com/drive/folders/0ABKspIh7P4f4Uk9PVA" +_FOLDER_2_URL = ( + "https://drive.google.com/drive/folders/1lNpCJ1teu8Se0louwL0oOHK9nEalskof" +) +_FOLDER_2_1_URL = ( + "https://drive.google.com/drive/folders/1XeDOMWwxTDiVr9Ig2gKum3Zq_Wivv6zY" +) +_FOLDER_2_2_URL = ( + "https://drive.google.com/drive/folders/1RKlsexA8h7NHvBAWRbU27MJotic7KXe3" +) + +_ADMIN_EMAIL = "admin@onyx-test.com" +_TEST_USER_1_EMAIL = "test_user_1@onyx-test.com" +_TEST_USER_2_EMAIL = "test_user_2@onyx-test.com" +_TEST_USER_3_EMAIL = "test_user_3@onyx-test.com" + +# Dictionary for ranges +DRIVE_ID_MAPPING: dict[str, list[int]] = { + "ADMIN": _ADMIN_FILE_IDS, + "TEST_USER_1": _TEST_USER_1_FILE_IDS, + "TEST_USER_2": _TEST_USER_2_FILE_IDS, + "TEST_USER_3": _TEST_USER_3_FILE_IDS, + "SHARED_DRIVE_1": _SHARED_DRIVE_1_FILE_IDS, + "FOLDER_1": _FOLDER_1_FILE_IDS, + "FOLDER_1_1": _FOLDER_1_1_FILE_IDS, + "FOLDER_1_2": _FOLDER_1_2_FILE_IDS, + "SHARED_DRIVE_2": _SHARED_DRIVE_2_FILE_IDS, + "FOLDER_2": _FOLDER_2_FILE_IDS, + "FOLDER_2_1": _FOLDER_2_1_FILE_IDS, + "FOLDER_2_2": _FOLDER_2_2_FILE_IDS, +} + +# Dictionary for emails +EMAIL_MAPPING: dict[str, str] = { + "ADMIN": _ADMIN_EMAIL, + "TEST_USER_1": _TEST_USER_1_EMAIL, + "TEST_USER_2": _TEST_USER_2_EMAIL, + "TEST_USER_3": _TEST_USER_3_EMAIL, +} + +# Dictionary for URLs +URL_MAPPING: dict[str, str] = { + "SHARED_DRIVE_1": _SHARED_DRIVE_1_URL, + "FOLDER_1": _FOLDER_1_URL, + "FOLDER_1_1": _FOLDER_1_1_URL, + "FOLDER_1_2": _FOLDER_1_2_URL, + "SHARED_DRIVE_2": _SHARED_DRIVE_2_URL, + "FOLDER_2": _FOLDER_2_URL, + "FOLDER_2_1": _FOLDER_2_1_URL, + "FOLDER_2_2": _FOLDER_2_2_URL, +} + +# Dictionary for access permissions +# All users have access to their own My Drive as well as public files +ACCESS_MAPPING: dict[str, list[int]] = { + # Admin has access to everything in shared + "ADMIN": ( + _ADMIN_FILE_IDS + + _SHARED_DRIVE_1_FILE_IDS + + _FOLDER_1_FILE_IDS + + _FOLDER_1_1_FILE_IDS + + _FOLDER_1_2_FILE_IDS + + _SHARED_DRIVE_2_FILE_IDS + + _FOLDER_2_FILE_IDS + + _FOLDER_2_1_FILE_IDS + + _FOLDER_2_2_FILE_IDS + ), + # This user has access to drive 1 + # This user has redundant access to folder 1 because of group access + # This user has been given individual access to files in Admin's My Drive + "TEST_USER_1": ( + _TEST_USER_1_FILE_IDS + + _SHARED_DRIVE_1_FILE_IDS + + _FOLDER_1_FILE_IDS + + _FOLDER_1_1_FILE_IDS + + _FOLDER_1_2_FILE_IDS + + list(range(0, 2)) + ), + # Group 1 includes this user, giving access to folder 1 + # This user has also been given access to folder 2-1 + # This user has also been given individual access to files in folder 2 + "TEST_USER_2": ( + _TEST_USER_2_FILE_IDS + + _FOLDER_1_FILE_IDS + + _FOLDER_1_1_FILE_IDS + + _FOLDER_1_2_FILE_IDS + + _FOLDER_2_1_FILE_IDS + + list(range(45, 47)) + ), + # This user can only see his own files and public files + "TEST_USER_3": _TEST_USER_3_FILE_IDS, +} + + +file_name_template = "file_{}.txt" +file_text_template = "This is file {}" + + +def print_discrepencies(expected: set[str], retrieved: set[str]) -> None: + if expected != retrieved: + print(expected) + print(retrieved) + print("Extra:") + print(retrieved - expected) + print("Missing:") + print(expected - retrieved) + + +def assert_retrieved_docs_match_expected( + retrieved_docs: list[Document], expected_file_ids: Sequence[int] +) -> None: + expected_file_names = { + file_name_template.format(file_id) for file_id in expected_file_ids + } + expected_file_texts = { + file_text_template.format(file_id) for file_id in expected_file_ids + } + + retrieved_file_names = set([doc.semantic_identifier for doc in retrieved_docs]) + retrieved_texts = set([doc.sections[0].text for doc in retrieved_docs]) + + # Check file names + print_discrepencies(expected_file_names, retrieved_file_names) + assert expected_file_names == retrieved_file_names + + # Check file texts + print_discrepencies(expected_file_texts, retrieved_texts) + assert expected_file_texts == retrieved_texts diff --git a/backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py b/backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py new file mode 100644 index 00000000000..f39b15600b4 --- /dev/null +++ b/backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py @@ -0,0 +1,246 @@ +import time +from collections.abc import Callable +from unittest.mock import MagicMock +from unittest.mock import patch + +from danswer.connectors.google_drive.connector import GoogleDriveConnector +from danswer.connectors.models import Document +from tests.daily.connectors.google_drive.helpers import ( + assert_retrieved_docs_match_expected, +) +from tests.daily.connectors.google_drive.helpers import DRIVE_ID_MAPPING +from tests.daily.connectors.google_drive.helpers import EMAIL_MAPPING +from tests.daily.connectors.google_drive.helpers import URL_MAPPING + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_include_all( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_include_all") + connector = google_drive_oauth_connector_factory( + include_shared_drives=True, + include_my_drives=True, + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + # Should get everything in shared and admin's My Drive with oauth + expected_file_ids = ( + DRIVE_ID_MAPPING["ADMIN"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + + DRIVE_ID_MAPPING["FOLDER_1"] + + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_2"] + + DRIVE_ID_MAPPING["FOLDER_2"] + + DRIVE_ID_MAPPING["FOLDER_2_1"] + + DRIVE_ID_MAPPING["FOLDER_2_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_include_shared_drives_only( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_include_shared_drives_only") + connector = google_drive_oauth_connector_factory( + include_shared_drives=True, + include_my_drives=False, + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + # Should only get shared drives + expected_file_ids = ( + DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + + DRIVE_ID_MAPPING["FOLDER_1"] + + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_2"] + + DRIVE_ID_MAPPING["FOLDER_2"] + + DRIVE_ID_MAPPING["FOLDER_2_1"] + + DRIVE_ID_MAPPING["FOLDER_2_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_include_my_drives_only( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_include_my_drives_only") + connector = google_drive_oauth_connector_factory( + include_shared_drives=False, + include_my_drives=True, + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + # Should only get everyone's My Drives + expected_file_ids = DRIVE_ID_MAPPING["ADMIN"] + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_drive_one_only( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_drive_one_only") + drive_urls = [ + URL_MAPPING["SHARED_DRIVE_1"], + ] + connector = google_drive_oauth_connector_factory( + include_shared_drives=True, + include_my_drives=False, + shared_drive_urls=",".join([str(url) for url in drive_urls]), + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + # We ignore shared_drive_urls if include_shared_drives is False + expected_file_ids = ( + DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + + DRIVE_ID_MAPPING["FOLDER_1"] + + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_folder_and_shared_drive( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_folder_and_shared_drive") + drive_urls = [URL_MAPPING["SHARED_DRIVE_1"]] + folder_urls = [URL_MAPPING["FOLDER_2"]] + connector = google_drive_oauth_connector_factory( + include_shared_drives=True, + include_my_drives=True, + shared_drive_urls=",".join([str(url) for url in drive_urls]), + shared_folder_urls=",".join([str(url) for url in folder_urls]), + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + # Should + expected_file_ids = ( + DRIVE_ID_MAPPING["ADMIN"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + + DRIVE_ID_MAPPING["FOLDER_1"] + + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + + DRIVE_ID_MAPPING["FOLDER_2"] + + DRIVE_ID_MAPPING["FOLDER_2_1"] + + DRIVE_ID_MAPPING["FOLDER_2_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_folders_only( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_folders_only") + folder_urls = [ + URL_MAPPING["FOLDER_1_1"], + URL_MAPPING["FOLDER_1_2"], + URL_MAPPING["FOLDER_2_1"], + URL_MAPPING["FOLDER_2_2"], + ] + connector = google_drive_oauth_connector_factory( + include_shared_drives=False, + include_my_drives=False, + shared_folder_urls=",".join([str(url) for url in folder_urls]), + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + expected_file_ids = ( + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + + DRIVE_ID_MAPPING["FOLDER_2_1"] + + DRIVE_ID_MAPPING["FOLDER_2_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_specific_emails( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_specific_emails") + my_drive_emails = [ + EMAIL_MAPPING["TEST_USER_1"], + EMAIL_MAPPING["TEST_USER_3"], + ] + connector = google_drive_oauth_connector_factory( + include_shared_drives=False, + include_my_drives=True, + my_drive_emails=",".join([str(email) for email in my_drive_emails]), + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + # No matter who is specified, when using oauth, if include_my_drives is True, + # we will get all the files from the admin's My Drive + expected_file_ids = DRIVE_ID_MAPPING["ADMIN"] + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) diff --git a/backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py b/backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py new file mode 100644 index 00000000000..b36a53b30f6 --- /dev/null +++ b/backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py @@ -0,0 +1,257 @@ +import time +from collections.abc import Callable +from unittest.mock import MagicMock +from unittest.mock import patch + +from danswer.connectors.google_drive.connector import GoogleDriveConnector +from danswer.connectors.models import Document +from tests.daily.connectors.google_drive.helpers import ( + assert_retrieved_docs_match_expected, +) +from tests.daily.connectors.google_drive.helpers import DRIVE_ID_MAPPING +from tests.daily.connectors.google_drive.helpers import EMAIL_MAPPING +from tests.daily.connectors.google_drive.helpers import URL_MAPPING + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_include_all( + mock_get_api_key: MagicMock, + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_include_all") + connector = google_drive_service_acct_connector_factory( + include_shared_drives=True, + include_my_drives=True, + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + # Should get everything + expected_file_ids = ( + DRIVE_ID_MAPPING["ADMIN"] + + DRIVE_ID_MAPPING["TEST_USER_1"] + + DRIVE_ID_MAPPING["TEST_USER_2"] + + DRIVE_ID_MAPPING["TEST_USER_3"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + + DRIVE_ID_MAPPING["FOLDER_1"] + + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_2"] + + DRIVE_ID_MAPPING["FOLDER_2"] + + DRIVE_ID_MAPPING["FOLDER_2_1"] + + DRIVE_ID_MAPPING["FOLDER_2_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_include_shared_drives_only( + mock_get_api_key: MagicMock, + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_include_shared_drives_only") + connector = google_drive_service_acct_connector_factory( + include_shared_drives=True, + include_my_drives=False, + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + # Should only get shared drives + expected_file_ids = ( + DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + + DRIVE_ID_MAPPING["FOLDER_1"] + + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_2"] + + DRIVE_ID_MAPPING["FOLDER_2"] + + DRIVE_ID_MAPPING["FOLDER_2_1"] + + DRIVE_ID_MAPPING["FOLDER_2_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_include_my_drives_only( + mock_get_api_key: MagicMock, + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_include_my_drives_only") + connector = google_drive_service_acct_connector_factory( + include_shared_drives=False, + include_my_drives=True, + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + # Should only get everyone's My Drives + expected_file_ids = ( + DRIVE_ID_MAPPING["ADMIN"] + + DRIVE_ID_MAPPING["TEST_USER_1"] + + DRIVE_ID_MAPPING["TEST_USER_2"] + + DRIVE_ID_MAPPING["TEST_USER_3"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_drive_one_only( + mock_get_api_key: MagicMock, + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_drive_one_only") + urls = [URL_MAPPING["SHARED_DRIVE_1"]] + connector = google_drive_service_acct_connector_factory( + include_shared_drives=True, + include_my_drives=False, + shared_drive_urls=",".join([str(url) for url in urls]), + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + # We ignore shared_drive_urls if include_shared_drives is False + expected_file_ids = ( + DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + + DRIVE_ID_MAPPING["FOLDER_1"] + + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_folder_and_shared_drive( + mock_get_api_key: MagicMock, + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_folder_and_shared_drive") + drive_urls = [ + URL_MAPPING["SHARED_DRIVE_1"], + ] + folder_urls = [URL_MAPPING["FOLDER_2"]] + connector = google_drive_service_acct_connector_factory( + include_shared_drives=True, + include_my_drives=True, + shared_drive_urls=",".join([str(url) for url in drive_urls]), + shared_folder_urls=",".join([str(url) for url in folder_urls]), + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + # Should + expected_file_ids = ( + DRIVE_ID_MAPPING["ADMIN"] + + DRIVE_ID_MAPPING["TEST_USER_1"] + + DRIVE_ID_MAPPING["TEST_USER_2"] + + DRIVE_ID_MAPPING["TEST_USER_3"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + + DRIVE_ID_MAPPING["FOLDER_1"] + + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + + DRIVE_ID_MAPPING["FOLDER_2"] + + DRIVE_ID_MAPPING["FOLDER_2_1"] + + DRIVE_ID_MAPPING["FOLDER_2_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_folders_only( + mock_get_api_key: MagicMock, + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_folders_only") + folder_urls = [ + URL_MAPPING["FOLDER_1_1"], + URL_MAPPING["FOLDER_1_2"], + URL_MAPPING["FOLDER_2_1"], + URL_MAPPING["FOLDER_2_2"], + ] + connector = google_drive_service_acct_connector_factory( + include_shared_drives=False, + include_my_drives=False, + shared_folder_urls=",".join([str(url) for url in folder_urls]), + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + expected_file_ids = ( + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + + DRIVE_ID_MAPPING["FOLDER_2_1"] + + DRIVE_ID_MAPPING["FOLDER_2_2"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_specific_emails( + mock_get_api_key: MagicMock, + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_specific_emails") + my_drive_emails = [ + EMAIL_MAPPING["TEST_USER_1"], + EMAIL_MAPPING["TEST_USER_3"], + ] + connector = google_drive_service_acct_connector_factory( + include_shared_drives=False, + include_my_drives=True, + my_drive_emails=",".join([str(email) for email in my_drive_emails]), + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + expected_file_ids = ( + DRIVE_ID_MAPPING["TEST_USER_1"] + DRIVE_ID_MAPPING["TEST_USER_3"] + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) diff --git a/backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py b/backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py new file mode 100644 index 00000000000..e731c8b27ce --- /dev/null +++ b/backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py @@ -0,0 +1,174 @@ +import time +from collections.abc import Callable +from unittest.mock import MagicMock +from unittest.mock import patch + +from danswer.access.models import ExternalAccess +from danswer.connectors.google_drive.connector import GoogleDriveConnector +from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval +from ee.danswer.external_permissions.google_drive.doc_sync import ( + _get_permissions_from_slim_doc, +) +from tests.daily.connectors.google_drive.helpers import ACCESS_MAPPING +from tests.daily.connectors.google_drive.helpers import DRIVE_ID_MAPPING +from tests.daily.connectors.google_drive.helpers import EMAIL_MAPPING +from tests.daily.connectors.google_drive.helpers import file_name_template +from tests.daily.connectors.google_drive.helpers import print_discrepencies +from tests.daily.connectors.google_drive.helpers import PUBLIC_RANGE + + +def get_keys_available_to_user_from_access_map( + user_email: str, + group_map: dict[str, list[str]], + access_map: dict[str, ExternalAccess], +) -> list[str]: + """ + Extracts the names of the files available to the user from the access map + through their own email or group memberships or public access + """ + group_emails_for_user = [] + for group_email, user_in_group_email_list in group_map.items(): + if user_email in user_in_group_email_list: + group_emails_for_user.append(group_email) + + accessible_file_names_for_user = [] + for file_name, external_access in access_map.items(): + if external_access.is_public: + accessible_file_names_for_user.append(file_name) + elif user_email in external_access.external_user_emails: + accessible_file_names_for_user.append(file_name) + elif any( + group_email in external_access.external_user_group_ids + for group_email in group_emails_for_user + ): + accessible_file_names_for_user.append(file_name) + return accessible_file_names_for_user + + +def assert_correct_access_for_user( + user_email: str, + expected_access_ids: list[int], + group_map: dict[str, list[str]], + retrieved_access_map: dict[str, ExternalAccess], +) -> None: + """ + compares the expected access range of the user to the keys available to the user + retrieved from the source + """ + retrieved_keys_available_to_user = get_keys_available_to_user_from_access_map( + user_email, group_map, retrieved_access_map + ) + retrieved_file_names = set(retrieved_keys_available_to_user) + + # Combine public and user-specific access IDs + all_accessible_ids = expected_access_ids + PUBLIC_RANGE + expected_file_names = {file_name_template.format(i) for i in all_accessible_ids} + + print_discrepencies(expected_file_names, retrieved_file_names) + + assert expected_file_names == retrieved_file_names + + +# This function is supposed to map to the group_sync.py file for the google drive connector +# TODO: Call it directly +def get_group_map(google_drive_connector: GoogleDriveConnector) -> dict[str, list[str]]: + admin_service = google_drive_connector.get_google_resource("admin", "directory_v1") + + group_map: dict[str, list[str]] = {} + for group in execute_paginated_retrieval( + admin_service.groups().list, + list_key="groups", + domain=google_drive_connector.google_domain, + fields="groups(email)", + ): + # The id is the group email + group_email = group["email"] + + # Gather group member emails + group_member_emails: list[str] = [] + for member in execute_paginated_retrieval( + admin_service.members().list, + list_key="members", + groupKey=group_email, + fields="members(email)", + ): + group_member_emails.append(member["email"]) + group_map[group_email] = group_member_emails + return group_map + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_all_permissions( + mock_get_api_key: MagicMock, + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + google_drive_connector = google_drive_service_acct_connector_factory( + include_shared_drives=True, + include_my_drives=True, + ) + + access_map: dict[str, ExternalAccess] = {} + for slim_doc_batch in google_drive_connector.retrieve_all_slim_documents( + 0, time.time() + ): + for slim_doc in slim_doc_batch: + access_map[ + (slim_doc.perm_sync_data or {})["name"] + ] = _get_permissions_from_slim_doc( + google_drive_connector=google_drive_connector, + slim_doc=slim_doc, + ) + + for file_name, external_access in access_map.items(): + print(file_name, external_access) + + expected_file_range = ( + DRIVE_ID_MAPPING["ADMIN"] + + DRIVE_ID_MAPPING["TEST_USER_1"] + + DRIVE_ID_MAPPING["TEST_USER_2"] + + DRIVE_ID_MAPPING["TEST_USER_3"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_1"] + + DRIVE_ID_MAPPING["FOLDER_1"] + + DRIVE_ID_MAPPING["FOLDER_1_1"] + + DRIVE_ID_MAPPING["FOLDER_1_2"] + + DRIVE_ID_MAPPING["SHARED_DRIVE_2"] + + DRIVE_ID_MAPPING["FOLDER_2"] + + DRIVE_ID_MAPPING["FOLDER_2_1"] + + DRIVE_ID_MAPPING["FOLDER_2_2"] + ) + + # Should get everything + assert len(access_map) == len(expected_file_range) + + group_map = get_group_map(google_drive_connector) + + print("groups:\n", group_map) + + assert_correct_access_for_user( + user_email=EMAIL_MAPPING["ADMIN"], + expected_access_ids=ACCESS_MAPPING["ADMIN"], + group_map=group_map, + retrieved_access_map=access_map, + ) + assert_correct_access_for_user( + user_email=EMAIL_MAPPING["TEST_USER_1"], + expected_access_ids=ACCESS_MAPPING["TEST_USER_1"], + group_map=group_map, + retrieved_access_map=access_map, + ) + + assert_correct_access_for_user( + user_email=EMAIL_MAPPING["TEST_USER_2"], + expected_access_ids=ACCESS_MAPPING["TEST_USER_2"], + group_map=group_map, + retrieved_access_map=access_map, + ) + assert_correct_access_for_user( + user_email=EMAIL_MAPPING["TEST_USER_3"], + expected_access_ids=ACCESS_MAPPING["TEST_USER_3"], + group_map=group_map, + retrieved_access_map=access_map, + ) diff --git a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx index a5b6b3bd1ee..7f0e4c664f3 100644 --- a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx +++ b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx @@ -431,6 +431,12 @@ export default function AddConnector({ setSelectedFiles={setSelectedFiles} selectedFiles={selectedFiles} connector={connector} + currentCredential={ + currentCredential || + liveGDriveCredential || + liveGmailCredential || + null + } /> )} diff --git a/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/ListInput.tsx b/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/ListInput.tsx index 956e0c24597..05deec472a6 100644 --- a/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/ListInput.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/ListInput.tsx @@ -4,18 +4,22 @@ import { TextArrayField } from "@/components/admin/connectors/Field"; import { useFormikContext } from "formik"; interface ListInputProps { - field: ListOption; + name: string; + label: string | ((credential: any) => string); + description: string | ((credential: any) => string); } -const ListInput: React.FC = ({ field }) => { +const ListInput: React.FC = ({ name, label, description }) => { const { values } = useFormikContext(); return ( ); }; diff --git a/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx b/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx index 85237df2c7d..a6ac93441e6 100644 --- a/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx @@ -13,6 +13,8 @@ import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle"; import { AccessTypeForm } from "@/components/admin/connectors/AccessTypeForm"; import { AccessTypeGroupSelector } from "@/components/admin/connectors/AccessTypeGroupSelector"; import { ConfigurableSources } from "@/lib/types"; +import { Credential } from "@/lib/connectors/credentials"; +import CollapsibleSection from "@/app/admin/assistants/CollapsibleSection"; export interface DynamicConnectionFormProps { config: ConnectionConfiguration; @@ -20,19 +22,44 @@ export interface DynamicConnectionFormProps { setSelectedFiles: Dispatch>; values: any; connector: ConfigurableSources; + currentCredential: Credential | null; } -const DynamicConnectionForm: FC = ({ - config, +interface RenderFieldProps { + field: any; + values: any; + selectedFiles: File[]; + setSelectedFiles: Dispatch>; + connector: ConfigurableSources; + currentCredential: Credential | null; +} + +const RenderField: FC = ({ + field, + values, selectedFiles, setSelectedFiles, - values, connector, + currentCredential, }) => { - const [showAdvancedOptions, setShowAdvancedOptions] = useState(false); + if ( + field.visibleCondition && + !field.visibleCondition(values, currentCredential) + ) { + return null; + } + + const label = + typeof field.label === "function" + ? field.label(currentCredential) + : field.label; + const description = + typeof field.description === "function" + ? field.description(currentCredential) + : field.description; - const renderField = (field: any) => ( -
+ const fieldContent = ( + <> {field.type === "file" ? ( = ({ ) : field.type === "zip" ? ( ) : field.type === "list" ? ( - + ) : field.type === "select" ? ( ) : field.type === "number" ? ( ) : field.type === "checkbox" ? ( ) : ( )} -
+ ); + if ( + field.visibleCondition && + field.visibleCondition(values, currentCredential) + ) { + return ( + + {fieldContent} + + ); + } else { + return
{fieldContent}
; + } +}; + +const DynamicConnectionForm: FC = ({ + config, + selectedFiles, + setSelectedFiles, + values, + connector, + currentCredential, +}) => { + const [showAdvancedOptions, setShowAdvancedOptions] = useState(false); + return ( <>

{config.description}

@@ -97,7 +149,20 @@ const DynamicConnectionForm: FC = ({ name={"name"} /> - {config.values.map((field) => !field.hidden && renderField(field))} + {config.values.map( + (field) => + !field.hidden && ( + + ) + )} @@ -108,7 +173,18 @@ const DynamicConnectionForm: FC = ({ showAdvancedOptions={showAdvancedOptions} setShowAdvancedOptions={setShowAdvancedOptions} /> - {showAdvancedOptions && config.advanced_values.map(renderField)} + {showAdvancedOptions && + config.advanced_values.map((field) => ( + + ))} )} diff --git a/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx b/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx index 371bbef6dd1..03a73fe23e6 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx @@ -10,6 +10,7 @@ import { GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME } from "@/lib/constants"; import Cookies from "js-cookie"; import { TextFormField } from "@/components/admin/connectors/Field"; import { Form, Formik } from "formik"; +import { User } from "@/lib/types"; import { Button as TremorButton } from "@tremor/react"; import { Credential, @@ -157,6 +158,7 @@ export const DriveJsonUploadSection = ({ isAdmin, }: DriveJsonUploadSectionProps) => { const { mutate } = useSWRConfig(); + const router = useRouter(); if (serviceAccountCredentialData?.service_account_email) { return ( @@ -190,6 +192,7 @@ export const DriveJsonUploadSection = ({ message: "Successfully deleted service account key", type: "success", }); + router.refresh(); } else { const errorMsg = await response.text(); setPopup({ @@ -307,9 +310,10 @@ interface DriveCredentialSectionProps { setPopup: (popupSpec: PopupSpec | null) => void; refreshCredentials: () => void; connectorExists: boolean; + user: User | null; } -export const DriveOAuthSection = ({ +export const DriveAuthSection = ({ googleDrivePublicCredential, googleDriveServiceAccountCredential, serviceAccountKeyData, @@ -317,6 +321,7 @@ export const DriveOAuthSection = ({ setPopup, refreshCredentials, connectorExists, + user, }: DriveCredentialSectionProps) => { const router = useRouter(); @@ -356,23 +361,23 @@ export const DriveOAuthSection = ({ return (

- When using a Google Drive Service Account, you can either have Danswer - act as the service account itself OR you can specify an account for - the service account to impersonate. + When using a Google Drive Service Account, you must specify the email + of the primary admin that you would like the service account to + impersonate.

- If you want to use the service account itself, leave the{" "} - 'User email to impersonate' field blank when - submitting. If you do choose this option, make sure you have shared - the documents you want to index with the service account. + Ideally, this account should be an owner/admin of the Google + Organization that owns the Google Drive(s) you want to index.

{ formikHelpers.setSubmitting(true); @@ -384,8 +389,7 @@ export const DriveOAuthSection = ({ "Content-Type": "application/json", }, body: JSON.stringify({ - google_drive_delegated_user: - values.google_drive_delegated_user, + google_drive_primary_admin: values.google_drive_primary_admin, }), } ); @@ -408,9 +412,9 @@ export const DriveOAuthSection = ({ {({ isSubmitting }) => (
diff --git a/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx b/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx index d8a14db03a1..518f9709d94 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx @@ -12,7 +12,7 @@ import { useConnectorCredentialIndexingStatus, } from "@/lib/hooks"; import { Title } from "@tremor/react"; -import { DriveJsonUploadSection, DriveOAuthSection } from "./Credential"; +import { DriveJsonUploadSection, DriveAuthSection } from "./Credential"; import { Credential, GoogleDriveCredentialJson, @@ -22,7 +22,7 @@ import { GoogleDriveConfig } from "@/lib/connectors/connectors"; import { useUser } from "@/components/user/UserProvider"; const GDriveMain = ({}: {}) => { - const { isLoadingUser, isAdmin } = useUser(); + const { isLoadingUser, isAdmin, user } = useUser(); const { data: appCredentialData, @@ -135,7 +135,7 @@ const GDriveMain = ({}: {}) => { Step 2: Authenticate with Danswer - { appCredentialData={appCredentialData} serviceAccountKeyData={serviceAccountKeyData} connectorExists={googleDriveConnectorIndexingStatuses.length > 0} + user={user} /> )} diff --git a/web/src/components/admin/connectors/AccessTypeForm.tsx b/web/src/components/admin/connectors/AccessTypeForm.tsx index 108b657b568..8993e28cdb3 100644 --- a/web/src/components/admin/connectors/AccessTypeForm.tsx +++ b/web/src/components/admin/connectors/AccessTypeForm.tsx @@ -9,6 +9,7 @@ import { useUser } from "@/components/user/UserProvider"; import { useField } from "formik"; import { AutoSyncOptions } from "./AutoSyncOptions"; import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; +import { useEffect } from "react"; function isValidAutoSyncSource( value: ConfigurableSources @@ -28,6 +29,21 @@ export function AccessTypeForm({ const isAutoSyncSupported = isValidAutoSyncSource(connector); const { isLoadingUser, isAdmin } = useUser(); + useEffect(() => { + if (!isPaidEnterpriseEnabled) { + access_type_helpers.setValue("public"); + } else if (isAutoSyncSupported) { + access_type_helpers.setValue("sync"); + } else { + access_type_helpers.setValue("private"); + } + }, [ + isAutoSyncSupported, + isAdmin, + isPaidEnterpriseEnabled, + access_type_helpers, + ]); + const options = [ { name: "Private", @@ -46,9 +62,9 @@ export function AccessTypeForm({ }); } - if (isAutoSyncSupported && isAdmin) { + if (isAutoSyncSupported && isAdmin && isPaidEnterpriseEnabled) { options.push({ - name: "Auto Sync", + name: "Auto Sync Permissions", value: "sync", description: "We will automatically sync permissions from the source. A document will be searchable in Danswer if and only if the user performing the search has permission to access the document in the source.", @@ -59,12 +75,13 @@ export function AccessTypeForm({ <> {isPaidEnterpriseEnabled && isAdmin && ( <> -
+
+

+ Control who has access to the documents indexed by this connector. +

-

- Control who has access to the documents indexed by this connector. -

+ {access_type.value === "sync" && isAutoSyncSupported && ( -
- -
+ )} )} diff --git a/web/src/components/admin/connectors/ConnectorTitle.tsx b/web/src/components/admin/connectors/ConnectorTitle.tsx index 269c72e905f..6e2da252aec 100644 --- a/web/src/components/admin/connectors/ConnectorTitle.tsx +++ b/web/src/components/admin/connectors/ConnectorTitle.tsx @@ -64,21 +64,6 @@ export const ConnectorTitle = ({ "Jira Project URL", typedConnector.connector_specific_config.jira_project_url ); - } else if (connector.source === "google_drive") { - const typedConnector = connector as Connector; - if ( - typedConnector.connector_specific_config?.folder_paths && - typedConnector.connector_specific_config?.folder_paths.length > 0 - ) { - additionalMetadata.set( - "Folders", - typedConnector.connector_specific_config.folder_paths.join(", ") - ); - } - - if (!isPublic && owner) { - additionalMetadata.set("Owner", owner); - } } else if (connector.source === "slack") { const typedConnector = connector as Connector; if ( diff --git a/web/src/lib/connectors/AutoSyncOptionFields.tsx b/web/src/lib/connectors/AutoSyncOptionFields.tsx index f6866a16991..4a8b44868e6 100644 --- a/web/src/lib/connectors/AutoSyncOptionFields.tsx +++ b/web/src/lib/connectors/AutoSyncOptionFields.tsx @@ -12,37 +12,6 @@ export const autoSyncConfigBySource: Record< > > = { confluence: {}, - google_drive: { - customer_id: { - label: "Google Workspace Customer ID", - subtext: ( - <> - The unique identifier for your Google Workspace account. To find this, - checkout the{" "} - - guide from Google - - . - - ), - }, - company_domain: { - label: "Google Workspace Company Domain", - subtext: ( - <> - The email domain for your Google Workspace account. -
-
- For example, if your email provided through Google Workspace looks - something like chris@danswer.ai, then your company domain is{" "} - danswer.ai - - ), - }, - }, + google_drive: {}, slack: {}, }; diff --git a/web/src/lib/connectors/connectors.tsx b/web/src/lib/connectors/connectors.tsx index 4e7df383599..6e7ef1ad97f 100644 --- a/web/src/lib/connectors/connectors.tsx +++ b/web/src/lib/connectors/connectors.tsx @@ -2,6 +2,7 @@ import * as Yup from "yup"; import { IsPublicGroupSelectorFormType } from "@/components/IsPublicGroupSelector"; import { ConfigurableSources, ValidInputTypes, ValidSources } from "../types"; import { AccessTypeGroupSelectorFormType } from "@/components/admin/connectors/AccessTypeGroupSelector"; +import { Credential } from "@/lib/connectors/credentials"; // Import Credential type export function isLoadState(connector_name: string): boolean { // TODO: centralize connector metadata like this somewhere instead of hardcoding it here @@ -29,12 +30,18 @@ export type StringWithDescription = { }; export interface Option { - label: string; + label: string | ((currentCredential: Credential | null) => string); name: string; - description?: string; + description?: + | string + | ((currentCredential: Credential | null) => string); query?: string; optional?: boolean; hidden?: boolean; + visibleCondition?: ( + values: any, + currentCredential: Credential | null + ) => boolean; } export interface SelectOption extends Option { @@ -204,38 +211,59 @@ export const connectorConfigs: Record< description: "Configure Google Drive connector", values: [ { - type: "list", - query: "Enter folder paths:", - label: "Folder Paths", - name: "folder_paths", + type: "checkbox", + label: "Include shared drives?", + description: + "This will allow Danswer to index everything in your shared drives.", + name: "include_shared_drives", optional: true, + default: true, }, { - type: "checkbox", - query: "Include shared files?", - label: "Include Shared", - name: "include_shared", - optional: false, - default: false, + type: "text", + description: + "Enter a comma separated list of the URLs of the shared drives to index. Leave blank to index all shared drives.", + label: "Shared Drive URLs", + name: "shared_drive_urls", + visibleCondition: (values) => values.include_shared_drives, + optional: true, }, { type: "checkbox", - query: "Follow shortcuts?", - label: "Follow Shortcuts", - name: "follow_shortcuts", - optional: false, - default: false, + label: (currentCredential) => + currentCredential?.credential_json?.google_drive_tokens + ? "Include My Drive?" + : "Include Everyone's My Drive?", + description: (currentCredential) => + currentCredential?.credential_json?.google_drive_tokens + ? "This will allow Danswer to index everything in your My Drive." + : "This will allow Danswer to index everything in everyone's My Drives.", + name: "include_my_drives", + optional: true, + default: true, }, { - type: "checkbox", - query: "Only include organization public files?", - label: "Only Org Public", - name: "only_org_public", - optional: false, - default: false, + type: "text", + description: + "Enter a comma separated list of the emails of the users whose MyDrive you want to index. Leave blank to index all MyDrives.", + label: "My Drive Emails", + name: "my_drive_emails", + visibleCondition: (values, currentCredential) => + values.include_my_drives && + !currentCredential?.credential_json?.google_drive_tokens, + optional: true, + }, + ], + advanced_values: [ + { + type: "text", + description: + "Enter a comma separated list of the URLs of the folders located in Shared Drives to index. The files located in these folders (and all subfolders) will be indexed. Note: This will be in addition to the 'Include Shared Drives' and 'Shared Drive URLs' settings, so leave those blank if you only want to index the folders specified here.", + label: "Folder URLs", + name: "shared_folder_urls", + optional: true, }, ], - advanced_values: [], }, gmail: { description: "Configure Gmail connector", @@ -1025,7 +1053,7 @@ export interface GitlabConfig { } export interface GoogleDriveConfig { - folder_paths?: string[]; + parent_urls?: string[]; include_shared?: boolean; follow_shortcuts?: boolean; only_org_public?: boolean; diff --git a/web/src/lib/connectors/credentials.ts b/web/src/lib/connectors/credentials.ts index 532f8f6de76..73c788d3a96 100644 --- a/web/src/lib/connectors/credentials.ts +++ b/web/src/lib/connectors/credentials.ts @@ -58,6 +58,7 @@ export interface GmailCredentialJson { export interface GoogleDriveCredentialJson { google_drive_tokens: string; + google_drive_primary_admin: string; } export interface GmailServiceAccountCredentialJson { @@ -67,7 +68,7 @@ export interface GmailServiceAccountCredentialJson { export interface GoogleDriveServiceAccountCredentialJson { google_drive_service_account_key: string; - google_drive_delegated_user: string; + google_drive_primary_admin: string; } export interface SlabCredentialJson { @@ -331,7 +332,7 @@ export const credentialDisplayNames: Record = { // Google Drive Service Account google_drive_service_account_key: "Google Drive Service Account Key", - google_drive_delegated_user: "Google Drive Delegated User", + google_drive_primary_admin: "Google Drive Delegated User", // Slab slab_bot_token: "Slab Bot Token",