Skip to content

Commit

Permalink
reworked drive+confluence frontend and implied backend changes (#3143)
Browse files Browse the repository at this point in the history
* reworked drive+confluence frontend and implied backend changes

* fixed oauth admin tests

* fixed service account tests

* frontend cleanup

* copy change

* details!

* added key

* so good

* whoops!

* fixed mnore treljsertjoslijt

* has issue with boolean form

* should be done
  • Loading branch information
hagen-danswer authored Nov 16, 2024
1 parent 259fc04 commit 6e83fe3
Show file tree
Hide file tree
Showing 16 changed files with 1,102 additions and 405 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pr-python-connector-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ env:
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
# Google
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1 }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
Expand Down
8 changes: 4 additions & 4 deletions backend/danswer/connectors/confluence/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ def __init__(
if cql_query:
# if a cql_query is provided, we will use it to fetch the pages
cql_page_query = cql_query
elif space:
# if no cql_query is provided, we will use the space to fetch the pages
cql_page_query += f" and space='{quote(space)}'"
elif page_id:
# if a cql_query is not provided, we will use the page_id to fetch the page
if index_recursively:
cql_page_query += f" and ancestor='{page_id}'"
else:
# if neither a space nor a cql_query is provided, we will use the page_id to fetch the page
cql_page_query += f" and id='{page_id}'"
elif space:
# if no cql_query or page_id is provided, we will use the space to fetch the pages
cql_page_query += f" and space='{quote(space)}'"

self.cql_page_query = cql_page_query
self.cql_time_filter = ""
Expand Down
194 changes: 146 additions & 48 deletions backend/danswer/connectors/google_drive/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
convert_drive_item_to_document,
)
from danswer.connectors.google_drive.file_retrieval import crawl_folders_for_files
from danswer.connectors.google_drive.file_retrieval import get_all_files_for_oauth
from danswer.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from danswer.connectors.google_drive.models import GoogleDriveFileType
Expand Down Expand Up @@ -82,12 +83,31 @@ def _process_files_batch(
yield doc_batch


def _clean_requested_drive_ids(
requested_drive_ids: set[str],
requested_folder_ids: set[str],
all_drive_ids_available: set[str],
) -> tuple[set[str], set[str]]:
invalid_requested_drive_ids = requested_drive_ids - all_drive_ids_available
filtered_folder_ids = requested_folder_ids - all_drive_ids_available
if invalid_requested_drive_ids:
logger.warning(
f"Some shared drive IDs were not found. IDs: {invalid_requested_drive_ids}"
)
logger.warning("Checking for folder access instead...")
filtered_folder_ids.update(invalid_requested_drive_ids)

valid_requested_drive_ids = requested_drive_ids - invalid_requested_drive_ids
return valid_requested_drive_ids, filtered_folder_ids


class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
include_shared_drives: bool = True,
include_shared_drives: bool = False,
include_my_drives: bool = False,
include_files_shared_with_me: bool = False,
shared_drive_urls: str | None = None,
include_my_drives: bool = True,
my_drive_emails: str | None = None,
shared_folder_urls: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
Expand Down Expand Up @@ -120,22 +140,36 @@ def __init__(
if (
not include_shared_drives
and not include_my_drives
and not include_files_shared_with_me
and not shared_folder_urls
and not my_drive_emails
and not shared_drive_urls
):
raise ValueError(
"At least one of include_shared_drives, include_my_drives,"
" or shared_folder_urls must be true"
"Nothing to index. Please specify at least one of the following: "
"include_shared_drives, include_my_drives, include_files_shared_with_me, "
"shared_folder_urls, or my_drive_emails"
)

self.batch_size = batch_size

self.include_shared_drives = include_shared_drives
specific_requests_made = False
if bool(shared_drive_urls) or bool(my_drive_emails) or bool(shared_folder_urls):
specific_requests_made = True

self.include_files_shared_with_me = (
False if specific_requests_made else include_files_shared_with_me
)
self.include_my_drives = False if specific_requests_made else include_my_drives
self.include_shared_drives = (
False if specific_requests_made else include_shared_drives
)

shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls)
self._requested_shared_drive_ids = set(
_extract_ids_from_urls(shared_drive_url_list)
)

self.include_my_drives = include_my_drives
self._requested_my_drive_emails = set(
_extract_str_list_from_comma_str(my_drive_emails)
)
Expand Down Expand Up @@ -225,26 +259,20 @@ def _get_all_drive_ids(self) -> set[str]:
creds=self.creds,
user_email=self.primary_admin_email,
)
is_service_account = isinstance(self.creds, ServiceAccountCredentials)
all_drive_ids = set()
# We don't want to fail if we're using OAuth because you can
# access your my drive as a non admin user in an org still
ignore_fetch_failure = isinstance(self.creds, OAuthCredentials)
for drive in execute_paginated_retrieval(
retrieval_function=primary_drive_service.drives().list,
list_key="drives",
continue_on_404_or_403=ignore_fetch_failure,
useDomainAdminAccess=True,
useDomainAdminAccess=is_service_account,
fields="drives(id)",
):
all_drive_ids.add(drive["id"])

if not all_drive_ids:
logger.warning(
"No drives found. This is likely because oauth user "
"is not an admin and cannot view all drive IDs. "
"Continuing with only the shared drive IDs specified in the config."
"No drives found even though we are indexing shared drives was requested."
)
all_drive_ids = set(self._requested_shared_drive_ids)

return all_drive_ids

Expand All @@ -261,14 +289,9 @@ def _impersonate_user_for_retrieval(

# if we are including my drives, try to get the current user's my
# drive if any of the following are true:
# - no specific emails were requested
# - include_my_drives is true
# - the current user's email is in the requested emails
# - we are using OAuth (in which case we assume that is the only email we will try)
if self.include_my_drives and (
not self._requested_my_drive_emails
or user_email in self._requested_my_drive_emails
or isinstance(self.creds, OAuthCredentials)
):
if self.include_my_drives or user_email in self._requested_my_drive_emails:
yield from get_all_files_in_my_drive(
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
Expand Down Expand Up @@ -299,7 +322,7 @@ def _impersonate_user_for_retrieval(
end=end,
)

def _fetch_drive_items(
def _manage_service_account_retrieval(
self,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
Expand All @@ -309,29 +332,16 @@ def _fetch_drive_items(

all_drive_ids: set[str] = self._get_all_drive_ids()

# remove drive ids from the folder ids because they are queried differently
filtered_folder_ids = self._requested_folder_ids - all_drive_ids

# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
invalid_drive_ids = self._requested_shared_drive_ids - all_drive_ids
if invalid_drive_ids:
logger.warning(
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
drive_ids_to_retrieve: set[str] = set()
folder_ids_to_retrieve: set[str] = set()
if self._requested_shared_drive_ids or self._requested_folder_ids:
drive_ids_to_retrieve, folder_ids_to_retrieve = _clean_requested_drive_ids(
requested_drive_ids=self._requested_shared_drive_ids,
requested_folder_ids=self._requested_folder_ids,
all_drive_ids_available=all_drive_ids,
)
logger.warning("Checking for folder access instead...")
filtered_folder_ids.update(invalid_drive_ids)

# If including shared drives, use the requested IDs if provided,
# otherwise use all drive IDs
filtered_drive_ids = set()
if self.include_shared_drives:
if self._requested_shared_drive_ids:
# Remove invalid drive IDs from requested IDs
filtered_drive_ids = (
self._requested_shared_drive_ids - invalid_drive_ids
)
else:
filtered_drive_ids = all_drive_ids
elif self.include_shared_drives:
drive_ids_to_retrieve = all_drive_ids

# Process users in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=10) as executor:
Expand All @@ -340,8 +350,8 @@ def _fetch_drive_items(
self._impersonate_user_for_retrieval,
email,
is_slim,
filtered_drive_ids,
filtered_folder_ids,
drive_ids_to_retrieve,
folder_ids_to_retrieve,
start,
end,
): email
Expand All @@ -353,13 +363,101 @@ def _fetch_drive_items(
yield from future.result()

remaining_folders = (
filtered_drive_ids | filtered_folder_ids
drive_ids_to_retrieve | folder_ids_to_retrieve
) - self._retrieved_ids
if remaining_folders:
logger.warning(
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
)

def _manage_oauth_retrieval(
self,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
drive_service = get_drive_service(self.creds, self.primary_admin_email)

if self.include_files_shared_with_me or self.include_my_drives:
yield from get_all_files_for_oauth(
service=drive_service,
include_files_shared_with_me=self.include_files_shared_with_me,
include_my_drives=self.include_my_drives,
include_shared_drives=self.include_shared_drives,
is_slim=is_slim,
start=start,
end=end,
)

all_requested = (
self.include_files_shared_with_me
and self.include_my_drives
and self.include_shared_drives
)
if all_requested:
# If all 3 are true, we already yielded from get_all_files_for_oauth
return

all_drive_ids = self._get_all_drive_ids()
drive_ids_to_retrieve: set[str] = set()
folder_ids_to_retrieve: set[str] = set()
if self._requested_shared_drive_ids or self._requested_folder_ids:
drive_ids_to_retrieve, folder_ids_to_retrieve = _clean_requested_drive_ids(
requested_drive_ids=self._requested_shared_drive_ids,
requested_folder_ids=self._requested_folder_ids,
all_drive_ids_available=all_drive_ids,
)
elif self.include_shared_drives:
drive_ids_to_retrieve = all_drive_ids

for drive_id in drive_ids_to_retrieve:
yield from get_files_in_shared_drive(
service=drive_service,
drive_id=drive_id,
is_slim=is_slim,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)

# Even if no folders were requested, we still check if any drives were requested
# that could be folders.
remaining_folders = folder_ids_to_retrieve - self._retrieved_ids
for folder_id in remaining_folders:
yield from crawl_folders_for_files(
service=drive_service,
parent_id=folder_id,
traversed_parent_ids=self._retrieved_ids,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)

remaining_folders = (
drive_ids_to_retrieve | folder_ids_to_retrieve
) - self._retrieved_ids
if remaining_folders:
logger.warning(
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
)

def _fetch_drive_items(
self,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
retrieval_method = (
self._manage_service_account_retrieval
if isinstance(self.creds, ServiceAccountCredentials)
else self._manage_oauth_retrieval
)
return retrieval_method(
is_slim=is_slim,
start=start,
end=end,
)

def _extract_docs_from_google_drive(
self,
start: SecondsSinceUnixEpoch | None = None,
Expand Down
Loading

0 comments on commit 6e83fe3

Please sign in to comment.