Skip to content

Commit

Permalink
save attachments in background tasks as link custom sources
Browse files Browse the repository at this point in the history
  • Loading branch information
nwaughachukwuma committed Dec 6, 2024
1 parent d10c2a5 commit 85087e7
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 17 deletions.
23 changes: 8 additions & 15 deletions api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@
GetCustomSourcesRequest,
generate_custom_source,
)
from .utils.custom_sources.manage_attachments import ManageAttachments
from .utils.custom_sources.save_copied_source import CopiedPasteSourceRequest, save_copied_source
from .utils.custom_sources.save_uploaded_sources import UploadedFiles
from .utils.decorators import use_cache_manager
from .utils.detect_content_category import DetectContentCategoryRequest, detect_content_category
from .utils.generate_audiocast import GenerateAudioCastRequest, GenerateAudiocastException, generate_audiocast
from .utils.generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source
from .utils.get_audiocast import get_audiocast
from .utils.get_session_title import GetSessionTitleModel, get_session_title
from .utils.make_seed import get_hash
from .utils.session_manager import SessionManager, SessionModel
from .utils.summarize_custom_sources import SummarizeCustomSourcesRequest, summarize_custom_sources

Expand Down Expand Up @@ -81,29 +80,23 @@ async def chat_endpoint(
"""Chat endpoint"""
category = request.contentCategory
attachments = request.attachments
db = SessionManager(session_id, category)

sources_summary: str | None = None
if attachments:
attachments.sort(key=lambda x: x.lower())

@use_cache_manager(get_hash(attachments))
async def _get_attachments_summary():
summary = await summarize_custom_sources(attachments)
db._update_source(summary)
return summary

sources_summary = await _get_attachments_summary()

db = SessionManager(session_id, category)
db._add_chat(request.chatItem)

attachment_manager = ManageAttachments(session_id)
sources_summary = await attachment_manager.get_attachments_summary(db, attachments)

def on_finish(text: str):
background_tasks.add_task(db._update, {"status": "collating"})
background_tasks.add_task(
db._add_chat,
SessionChatItem(role="assistant", content=text),
)

if attachments:
background_tasks.add_task(attachment_manager.store_attachments, attachments)

response = chat_request(
content_category=category,
previous_messages=db._get_chats(),
Expand Down
12 changes: 12 additions & 0 deletions api/src/utils/custom_sources/base_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,15 @@ def _get_custom_sources(self) -> list[CustomSourceModelDict]:

def _delete_custom_source(self, source_id: str):
return self._get_doc_ref(source_id).delete()

def _get_custom_source_by_url(self, url: str):
self._check_document()
try:
session_ref = self._get_collection(self.collection).document(self.doc_id)
docs = session_ref.collection(self.sub_collection).where("url", "==", url).get()
for doc in docs:
if doc.exists:
return cast(CustomSourceModel, self._safe_to_dict(doc.to_dict()))
except Exception as e:
print(f"Error getting custom sources for Session: {self.doc_id}", e)
return None
8 changes: 6 additions & 2 deletions api/src/utils/custom_sources/generate_url_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class DeleteCustomSourcesRequest(BaseModel):
sourceId: str


def generate_custom_source(request: GenerateCustomSourceRequest, background_tasks: BackgroundTasks):
def generate_custom_source(request: GenerateCustomSourceRequest, background_tasks: BackgroundTasks | None = None):
extractor = ExtractURLContent()
content = extractor._extract(request.url)

Expand All @@ -33,5 +33,9 @@ def save_to_firestore():
manager = CustomSourceManager(request.sessionId)
manager._set_custom_source(custom_source)

background_tasks.add_task(save_to_firestore)
if background_tasks:
background_tasks.add_task(save_to_firestore)
else:
save_to_firestore()

return content.model_dump()
47 changes: 47 additions & 0 deletions api/src/utils/custom_sources/manage_attachments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import asyncio

from src.utils.decorators import use_cache_manager
from src.utils.make_seed import get_hash
from src.utils.session_manager import SessionManager
from src.utils.summarize_custom_sources import summarize_custom_sources

from .base_utils import CustomSourceManager
from .generate_url_source import GenerateCustomSourceRequest, generate_custom_source


class ManageAttachments:
def __init__(self, session_id: str):
self.session_id = session_id

async def get_attachments_summary(self, db: SessionManager, attachments: list[str] | None):
"""
Manage custom sources uploaded by the user
"""
sources_summary: str | None = None
if attachments:
attachments.sort(key=lambda x: x.lower())

@use_cache_manager(get_hash(attachments))
async def handler():
summary = await summarize_custom_sources(attachments)
db._update_source(summary)
return summary

sources_summary = await handler()

return sources_summary

async def store_attachments(self, attachments: list[str]):
"""
Store attachments as custom sources of type links
"""
manager = CustomSourceManager(self.session_id)

async def _handler(url: str):
custom_source = manager._get_custom_source_by_url(url)
if not custom_source:
request = GenerateCustomSourceRequest(url=url, sessionId=self.session_id)
generate_custom_source(request)

await asyncio.gather(*[_handler(url) for url in attachments])
return True

0 comments on commit 85087e7

Please sign in to comment.