diff --git a/api/src/main.py b/api/src/main.py index 4d88349..53735ab 100644 --- a/api/src/main.py +++ b/api/src/main.py @@ -25,7 +25,7 @@ ) from .utils.custom_sources.save_copied_source import CopiedPasteSourceRequest, save_copied_source from .utils.custom_sources.save_uploaded_sources import UploadedFiles -from .utils.generate_audiocast import GenerateAudioCastRequest, generate_audiocast +from .utils.generate_audiocast import GenerateAudioCastRequest, GenerateAudiocastException, generate_audiocast from .utils.generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source from .utils.get_audiocast import get_audiocast from .utils.get_session_title import GetSessionTitleModel, get_session_title @@ -78,6 +78,7 @@ def chat_endpoint( db._add_chat(request.chatItem) def on_finish(text: str): + background_tasks.add_task(db._update, {"status": "collating"}) background_tasks.add_task( db._add_chat, SessionChatItem(role="assistant", content=text), @@ -92,6 +93,16 @@ def on_finish(text: str): return StreamingResponse(response, media_type="text/event-stream") +@app.exception_handler(GenerateAudiocastException) +async def generate_audiocast_exception_handler(request, exc): + print("generate_audiocast_exception_handler>>>>>>") + print("Request: ", request) + + db = SessionManager(exc.session_id, exc.category) + db._update({"status": "failed"}) + return HTTPException(status_code=exc.status_code, detail=str(exc.detail)) + + @app.post("/audiocast/generate", response_model=str) async def generate_audiocast_endpoint( request: GenerateAudioCastRequest, diff --git a/api/src/utils/generate_audiocast.py b/api/src/utils/generate_audiocast.py index 5ba3530..26e35bd 100644 --- a/api/src/utils/generate_audiocast.py +++ b/api/src/utils/generate_audiocast.py @@ -1,14 +1,27 @@ +import asyncio + from fastapi import BackgroundTasks, HTTPException from src.services.storage import StorageManager -from src.utils.audio_manager import AudioManager, AudioManagerConfig -from src.utils.audiocast_script_maker import AudioScriptMaker -from src.utils.audiocast_utils import GenerateAudioCastRequest -from src.utils.chat_utils import ContentCategory -from src.utils.custom_sources.base_utils import CustomSourceManager -from src.utils.generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source -from src.utils.session_manager import SessionManager -from src.utils.waveform_utils import WaveformUtils + +from .audio_manager import AudioManager, AudioManagerConfig +from .audiocast_script_maker import AudioScriptMaker +from .audiocast_utils import GenerateAudioCastRequest +from .chat_utils import ContentCategory +from .custom_sources.base_utils import CustomSourceManager +from .generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source +from .session_manager import SessionManager +from .waveform_utils import WaveformUtils + + +class GenerateAudiocastException(HTTPException): + name = "GenerateAudiocastException" + + def __init__(self, detail: str, session_id: str, category: ContentCategory, status_code=500): + self.detail = detail + self.status_code = status_code + self.session_id = session_id + self.category = category def compile_custom_sources(session_id: str): @@ -39,25 +52,36 @@ def post_generate_audio( async def generate_audiocast(request: GenerateAudioCastRequest, background_tasks: BackgroundTasks): - """## Generate audiocast based on a summary of user's request - - ### Steps: - 1. Generate source content - 2. Generate audio script - 3. Generate audio - 4. a) Store audio. b) Store the audio waveform on GCS - 5. Update session - """ + """## Generate audiocast based on a summary of user's request""" summary = request.summary category = request.category session_id = request.sessionId db = SessionManager(session_id, category) + session_data = SessionManager.data(session_id) + + if not session_data: + raise GenerateAudiocastException( + status_code=404, + detail=f"Audiocast data not found for session_id: {session_id}", + session_id=session_id, + category=category, + ) + + if session_data.status == "completed": + return "Audiocast already generated!" + elif session_data.status == "generating": + print(f"Queueing the current audio generation request>>>>\n\nSessionId: {session_id}") + + background_tasks.add_task(generate_audiocast, request, background_tasks) + await asyncio.sleep(5) + return "Audiocast generation already in progress!" + + db._update({"status": "generating"}) def update_session_info(info: str): db._update_info(info) - session_data = SessionManager.data(session_id) source_content = session_data.metadata.source if session_data and session_data.metadata else None if not source_content: @@ -71,7 +95,12 @@ def update_session_info(info: str): ) if not source_content: - raise HTTPException(status_code=500, detail="Failed to generate source content") + raise GenerateAudiocastException( + status_code=500, + detail="Failed to generate source content", + session_id=session_id, + category=category, + ) # get custom sources update_session_info("Checking for custom sources...") @@ -83,7 +112,12 @@ def update_session_info(info: str): audio_script = script_maker.create(provider="gemini") if not audio_script: - raise HTTPException(status_code=500, detail="Failed to generate audio script") + raise GenerateAudiocastException( + status_code=500, + detail="Failed to generate audio script", + session_id=session_id, + category=category, + ) # Generate audio update_session_info("Generating audio...") @@ -97,6 +131,7 @@ def update_session_info(info: str): audio_path, audio_script, ) - db._update({"completed": True}) + + db._update({"status": "completed"}) return "Audiocast generated successfully!" diff --git a/api/src/utils/session_manager.py b/api/src/utils/session_manager.py index 37f57e5..1145207 100644 --- a/api/src/utils/session_manager.py +++ b/api/src/utils/session_manager.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, cast +from typing import Callable, Dict, List, Literal, Optional, cast from src.services.firestore_sdk import ( Collection, @@ -19,6 +19,9 @@ class ChatMetadata: title: Optional[str] = None +SessionStatus = Literal["collating", "completed", "generating", "failed"] + + @dataclass class SessionModel: id: str @@ -26,7 +29,7 @@ class SessionModel: chats: List[SessionChatItem] metadata: Optional[ChatMetadata] created_at: Optional[str] = None - completed: Optional[bool] = None + status: Optional[SessionStatus] = None class SessionManager(DBManager): diff --git a/app/src/lib/utils/types.ts b/app/src/lib/utils/types.ts index af5cd78..5efe00b 100644 --- a/app/src/lib/utils/types.ts +++ b/app/src/lib/utils/types.ts @@ -25,7 +25,7 @@ export interface SessionModel { id: string; category: ContentCategory; chats: Array; - completed?: boolean; + status?: 'collating' | 'completed' | 'generating' | 'failed'; metadata?: ChatMetadata; created_at?: string; }