From c85c12f49ed28bf0fa06a114008160279431411b Mon Sep 17 00:00:00 2001
From: Chukwuma Nwaugha <nwaughac@gmail.com>
Date: Tue, 26 Nov 2024 18:26:25 +0000
Subject: [PATCH] define session status field on the db

---
 api/src/main.py                     | 13 ++++-
 api/src/utils/generate_audiocast.py | 77 +++++++++++++++++++++--------
 api/src/utils/session_manager.py    |  7 ++-
 app/src/lib/utils/types.ts          |  2 +-
 4 files changed, 74 insertions(+), 25 deletions(-)

diff --git a/api/src/main.py b/api/src/main.py
index 4d88349..53735ab 100644
--- a/api/src/main.py
+++ b/api/src/main.py
@@ -25,7 +25,7 @@
 )
 from .utils.custom_sources.save_copied_source import CopiedPasteSourceRequest, save_copied_source
 from .utils.custom_sources.save_uploaded_sources import UploadedFiles
-from .utils.generate_audiocast import GenerateAudioCastRequest, generate_audiocast
+from .utils.generate_audiocast import GenerateAudioCastRequest, GenerateAudiocastException, generate_audiocast
 from .utils.generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source
 from .utils.get_audiocast import get_audiocast
 from .utils.get_session_title import GetSessionTitleModel, get_session_title
@@ -78,6 +78,7 @@ def chat_endpoint(
     db._add_chat(request.chatItem)
 
     def on_finish(text: str):
+        background_tasks.add_task(db._update, {"status": "collating"})
         background_tasks.add_task(
             db._add_chat,
             SessionChatItem(role="assistant", content=text),
@@ -92,6 +93,16 @@ def on_finish(text: str):
     return StreamingResponse(response, media_type="text/event-stream")
 
 
+@app.exception_handler(GenerateAudiocastException)
+async def generate_audiocast_exception_handler(request, exc):
+    print("generate_audiocast_exception_handler>>>>>>")
+    print("Request: ", request)
+
+    db = SessionManager(exc.session_id, exc.category)
+    db._update({"status": "failed"})
+    return HTTPException(status_code=exc.status_code, detail=str(exc.detail))
+
+
 @app.post("/audiocast/generate", response_model=str)
 async def generate_audiocast_endpoint(
     request: GenerateAudioCastRequest,
diff --git a/api/src/utils/generate_audiocast.py b/api/src/utils/generate_audiocast.py
index 5ba3530..26e35bd 100644
--- a/api/src/utils/generate_audiocast.py
+++ b/api/src/utils/generate_audiocast.py
@@ -1,14 +1,27 @@
+import asyncio
+
 from fastapi import BackgroundTasks, HTTPException
 
 from src.services.storage import StorageManager
-from src.utils.audio_manager import AudioManager, AudioManagerConfig
-from src.utils.audiocast_script_maker import AudioScriptMaker
-from src.utils.audiocast_utils import GenerateAudioCastRequest
-from src.utils.chat_utils import ContentCategory
-from src.utils.custom_sources.base_utils import CustomSourceManager
-from src.utils.generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source
-from src.utils.session_manager import SessionManager
-from src.utils.waveform_utils import WaveformUtils
+
+from .audio_manager import AudioManager, AudioManagerConfig
+from .audiocast_script_maker import AudioScriptMaker
+from .audiocast_utils import GenerateAudioCastRequest
+from .chat_utils import ContentCategory
+from .custom_sources.base_utils import CustomSourceManager
+from .generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source
+from .session_manager import SessionManager
+from .waveform_utils import WaveformUtils
+
+
+class GenerateAudiocastException(HTTPException):
+    name = "GenerateAudiocastException"
+
+    def __init__(self, detail: str, session_id: str, category: ContentCategory, status_code=500):
+        self.detail = detail
+        self.status_code = status_code
+        self.session_id = session_id
+        self.category = category
 
 
 def compile_custom_sources(session_id: str):
@@ -39,25 +52,36 @@ def post_generate_audio(
 
 
 async def generate_audiocast(request: GenerateAudioCastRequest, background_tasks: BackgroundTasks):
-    """## Generate audiocast based on a summary of user's request
-
-    ### Steps:
-    1. Generate source content
-    2. Generate audio script
-    3. Generate audio
-    4. a) Store audio. b) Store the audio waveform on GCS
-    5. Update session
-    """
+    """## Generate audiocast based on a summary of user's request"""
     summary = request.summary
     category = request.category
     session_id = request.sessionId
 
     db = SessionManager(session_id, category)
+    session_data = SessionManager.data(session_id)
+
+    if not session_data:
+        raise GenerateAudiocastException(
+            status_code=404,
+            detail=f"Audiocast data not found for session_id: {session_id}",
+            session_id=session_id,
+            category=category,
+        )
+
+    if session_data.status == "completed":
+        return "Audiocast already generated!"
+    elif session_data.status == "generating":
+        print(f"Queueing the current audio generation request>>>>\n\nSessionId: {session_id}")
+
+        background_tasks.add_task(generate_audiocast, request, background_tasks)
+        await asyncio.sleep(5)
+        return "Audiocast generation already in progress!"
+
+    db._update({"status": "generating"})
 
     def update_session_info(info: str):
         db._update_info(info)
 
-    session_data = SessionManager.data(session_id)
     source_content = session_data.metadata.source if session_data and session_data.metadata else None
 
     if not source_content:
@@ -71,7 +95,12 @@ def update_session_info(info: str):
         )
 
     if not source_content:
-        raise HTTPException(status_code=500, detail="Failed to generate source content")
+        raise GenerateAudiocastException(
+            status_code=500,
+            detail="Failed to generate source content",
+            session_id=session_id,
+            category=category,
+        )
 
     # get custom sources
     update_session_info("Checking for custom sources...")
@@ -83,7 +112,12 @@ def update_session_info(info: str):
     audio_script = script_maker.create(provider="gemini")
 
     if not audio_script:
-        raise HTTPException(status_code=500, detail="Failed to generate audio script")
+        raise GenerateAudiocastException(
+            status_code=500,
+            detail="Failed to generate audio script",
+            session_id=session_id,
+            category=category,
+        )
 
     # Generate audio
     update_session_info("Generating audio...")
@@ -97,6 +131,7 @@ def update_session_info(info: str):
         audio_path,
         audio_script,
     )
-    db._update({"completed": True})
+
+    db._update({"status": "completed"})
 
     return "Audiocast generated successfully!"
diff --git a/api/src/utils/session_manager.py b/api/src/utils/session_manager.py
index 37f57e5..1145207 100644
--- a/api/src/utils/session_manager.py
+++ b/api/src/utils/session_manager.py
@@ -1,5 +1,5 @@
 from dataclasses import dataclass
-from typing import Callable, Dict, List, Optional, cast
+from typing import Callable, Dict, List, Literal, Optional, cast
 
 from src.services.firestore_sdk import (
     Collection,
@@ -19,6 +19,9 @@ class ChatMetadata:
     title: Optional[str] = None
 
 
+SessionStatus = Literal["collating", "completed", "generating", "failed"]
+
+
 @dataclass
 class SessionModel:
     id: str
@@ -26,7 +29,7 @@ class SessionModel:
     chats: List[SessionChatItem]
     metadata: Optional[ChatMetadata]
     created_at: Optional[str] = None
-    completed: Optional[bool] = None
+    status: Optional[SessionStatus] = None
 
 
 class SessionManager(DBManager):
diff --git a/app/src/lib/utils/types.ts b/app/src/lib/utils/types.ts
index af5cd78..5efe00b 100644
--- a/app/src/lib/utils/types.ts
+++ b/app/src/lib/utils/types.ts
@@ -25,7 +25,7 @@ export interface SessionModel {
 	id: string;
 	category: ContentCategory;
 	chats: Array<SessionChatItem>;
-	completed?: boolean;
+	status?: 'collating' | 'completed' | 'generating' | 'failed';
 	metadata?: ChatMetadata;
 	created_at?: string;
 }