From 85087e7c00998ebe81ca99c646808441df639d58 Mon Sep 17 00:00:00 2001 From: Chukwuma Nwaugha Date: Fri, 6 Dec 2024 20:59:32 +0000 Subject: [PATCH] save attachments in background tasks as link custom sources --- api/src/main.py | 23 ++++----- api/src/utils/custom_sources/base_utils.py | 12 +++++ .../custom_sources/generate_url_source.py | 8 +++- .../custom_sources/manage_attachments.py | 47 +++++++++++++++++++ 4 files changed, 73 insertions(+), 17 deletions(-) create mode 100644 api/src/utils/custom_sources/manage_attachments.py diff --git a/api/src/main.py b/api/src/main.py index b66cb59..8bb0555 100644 --- a/api/src/main.py +++ b/api/src/main.py @@ -25,15 +25,14 @@ GetCustomSourcesRequest, generate_custom_source, ) +from .utils.custom_sources.manage_attachments import ManageAttachments from .utils.custom_sources.save_copied_source import CopiedPasteSourceRequest, save_copied_source from .utils.custom_sources.save_uploaded_sources import UploadedFiles -from .utils.decorators import use_cache_manager from .utils.detect_content_category import DetectContentCategoryRequest, detect_content_category from .utils.generate_audiocast import GenerateAudioCastRequest, GenerateAudiocastException, generate_audiocast from .utils.generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source from .utils.get_audiocast import get_audiocast from .utils.get_session_title import GetSessionTitleModel, get_session_title -from .utils.make_seed import get_hash from .utils.session_manager import SessionManager, SessionModel from .utils.summarize_custom_sources import SummarizeCustomSourcesRequest, summarize_custom_sources @@ -81,22 +80,13 @@ async def chat_endpoint( """Chat endpoint""" category = request.contentCategory attachments = request.attachments - db = SessionManager(session_id, category) - - sources_summary: str | None = None - if attachments: - attachments.sort(key=lambda x: x.lower()) - - @use_cache_manager(get_hash(attachments)) - async def _get_attachments_summary(): - summary = await summarize_custom_sources(attachments) - db._update_source(summary) - return summary - - sources_summary = await _get_attachments_summary() + db = SessionManager(session_id, category) db._add_chat(request.chatItem) + attachment_manager = ManageAttachments(session_id) + sources_summary = await attachment_manager.get_attachments_summary(db, attachments) + def on_finish(text: str): background_tasks.add_task(db._update, {"status": "collating"}) background_tasks.add_task( @@ -104,6 +94,9 @@ def on_finish(text: str): SessionChatItem(role="assistant", content=text), ) + if attachments: + background_tasks.add_task(attachment_manager.store_attachments, attachments) + response = chat_request( content_category=category, previous_messages=db._get_chats(), diff --git a/api/src/utils/custom_sources/base_utils.py b/api/src/utils/custom_sources/base_utils.py index 6291f6e..255a0dd 100644 --- a/api/src/utils/custom_sources/base_utils.py +++ b/api/src/utils/custom_sources/base_utils.py @@ -103,3 +103,15 @@ def _get_custom_sources(self) -> list[CustomSourceModelDict]: def _delete_custom_source(self, source_id: str): return self._get_doc_ref(source_id).delete() + + def _get_custom_source_by_url(self, url: str): + self._check_document() + try: + session_ref = self._get_collection(self.collection).document(self.doc_id) + docs = session_ref.collection(self.sub_collection).where("url", "==", url).get() + for doc in docs: + if doc.exists: + return cast(CustomSourceModel, self._safe_to_dict(doc.to_dict())) + except Exception as e: + print(f"Error getting custom sources for Session: {self.doc_id}", e) + return None diff --git a/api/src/utils/custom_sources/generate_url_source.py b/api/src/utils/custom_sources/generate_url_source.py index 8a61860..af57e12 100644 --- a/api/src/utils/custom_sources/generate_url_source.py +++ b/api/src/utils/custom_sources/generate_url_source.py @@ -20,7 +20,7 @@ class DeleteCustomSourcesRequest(BaseModel): sourceId: str -def generate_custom_source(request: GenerateCustomSourceRequest, background_tasks: BackgroundTasks): +def generate_custom_source(request: GenerateCustomSourceRequest, background_tasks: BackgroundTasks | None = None): extractor = ExtractURLContent() content = extractor._extract(request.url) @@ -33,5 +33,9 @@ def save_to_firestore(): manager = CustomSourceManager(request.sessionId) manager._set_custom_source(custom_source) - background_tasks.add_task(save_to_firestore) + if background_tasks: + background_tasks.add_task(save_to_firestore) + else: + save_to_firestore() + return content.model_dump() diff --git a/api/src/utils/custom_sources/manage_attachments.py b/api/src/utils/custom_sources/manage_attachments.py new file mode 100644 index 0000000..13d06af --- /dev/null +++ b/api/src/utils/custom_sources/manage_attachments.py @@ -0,0 +1,47 @@ +import asyncio + +from src.utils.decorators import use_cache_manager +from src.utils.make_seed import get_hash +from src.utils.session_manager import SessionManager +from src.utils.summarize_custom_sources import summarize_custom_sources + +from .base_utils import CustomSourceManager +from .generate_url_source import GenerateCustomSourceRequest, generate_custom_source + + +class ManageAttachments: + def __init__(self, session_id: str): + self.session_id = session_id + + async def get_attachments_summary(self, db: SessionManager, attachments: list[str] | None): + """ + Manage custom sources uploaded by the user + """ + sources_summary: str | None = None + if attachments: + attachments.sort(key=lambda x: x.lower()) + + @use_cache_manager(get_hash(attachments)) + async def handler(): + summary = await summarize_custom_sources(attachments) + db._update_source(summary) + return summary + + sources_summary = await handler() + + return sources_summary + + async def store_attachments(self, attachments: list[str]): + """ + Store attachments as custom sources of type links + """ + manager = CustomSourceManager(self.session_id) + + async def _handler(url: str): + custom_source = manager._get_custom_source_by_url(url) + if not custom_source: + request = GenerateCustomSourceRequest(url=url, sessionId=self.session_id) + generate_custom_source(request) + + await asyncio.gather(*[_handler(url) for url in attachments]) + return True