Skip to content

Commit

Permalink
define session status field on the db
Browse files Browse the repository at this point in the history
  • Loading branch information
nwaughachukwuma committed Nov 26, 2024
1 parent 24204a8 commit c85c12f
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 25 deletions.
13 changes: 12 additions & 1 deletion api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from .utils.custom_sources.save_copied_source import CopiedPasteSourceRequest, save_copied_source
from .utils.custom_sources.save_uploaded_sources import UploadedFiles
from .utils.generate_audiocast import GenerateAudioCastRequest, generate_audiocast
from .utils.generate_audiocast import GenerateAudioCastRequest, GenerateAudiocastException, generate_audiocast
from .utils.generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source
from .utils.get_audiocast import get_audiocast
from .utils.get_session_title import GetSessionTitleModel, get_session_title
Expand Down Expand Up @@ -78,6 +78,7 @@ def chat_endpoint(
db._add_chat(request.chatItem)

def on_finish(text: str):
background_tasks.add_task(db._update, {"status": "collating"})
background_tasks.add_task(
db._add_chat,
SessionChatItem(role="assistant", content=text),
Expand All @@ -92,6 +93,16 @@ def on_finish(text: str):
return StreamingResponse(response, media_type="text/event-stream")


@app.exception_handler(GenerateAudiocastException)
async def generate_audiocast_exception_handler(request, exc):
print("generate_audiocast_exception_handler>>>>>>")
print("Request: ", request)

db = SessionManager(exc.session_id, exc.category)
db._update({"status": "failed"})
return HTTPException(status_code=exc.status_code, detail=str(exc.detail))


@app.post("/audiocast/generate", response_model=str)
async def generate_audiocast_endpoint(
request: GenerateAudioCastRequest,
Expand Down
77 changes: 56 additions & 21 deletions api/src/utils/generate_audiocast.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
import asyncio

from fastapi import BackgroundTasks, HTTPException

from src.services.storage import StorageManager
from src.utils.audio_manager import AudioManager, AudioManagerConfig
from src.utils.audiocast_script_maker import AudioScriptMaker
from src.utils.audiocast_utils import GenerateAudioCastRequest
from src.utils.chat_utils import ContentCategory
from src.utils.custom_sources.base_utils import CustomSourceManager
from src.utils.generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source
from src.utils.session_manager import SessionManager
from src.utils.waveform_utils import WaveformUtils

from .audio_manager import AudioManager, AudioManagerConfig
from .audiocast_script_maker import AudioScriptMaker
from .audiocast_utils import GenerateAudioCastRequest
from .chat_utils import ContentCategory
from .custom_sources.base_utils import CustomSourceManager
from .generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source
from .session_manager import SessionManager
from .waveform_utils import WaveformUtils


class GenerateAudiocastException(HTTPException):
name = "GenerateAudiocastException"

def __init__(self, detail: str, session_id: str, category: ContentCategory, status_code=500):
self.detail = detail
self.status_code = status_code
self.session_id = session_id
self.category = category


def compile_custom_sources(session_id: str):
Expand Down Expand Up @@ -39,25 +52,36 @@ def post_generate_audio(


async def generate_audiocast(request: GenerateAudioCastRequest, background_tasks: BackgroundTasks):
"""## Generate audiocast based on a summary of user's request
### Steps:
1. Generate source content
2. Generate audio script
3. Generate audio
4. a) Store audio. b) Store the audio waveform on GCS
5. Update session
"""
"""## Generate audiocast based on a summary of user's request"""
summary = request.summary
category = request.category
session_id = request.sessionId

db = SessionManager(session_id, category)
session_data = SessionManager.data(session_id)

if not session_data:
raise GenerateAudiocastException(
status_code=404,
detail=f"Audiocast data not found for session_id: {session_id}",
session_id=session_id,
category=category,
)

if session_data.status == "completed":
return "Audiocast already generated!"
elif session_data.status == "generating":
print(f"Queueing the current audio generation request>>>>\n\nSessionId: {session_id}")

background_tasks.add_task(generate_audiocast, request, background_tasks)
await asyncio.sleep(5)
return "Audiocast generation already in progress!"

db._update({"status": "generating"})

def update_session_info(info: str):
db._update_info(info)

session_data = SessionManager.data(session_id)
source_content = session_data.metadata.source if session_data and session_data.metadata else None

if not source_content:
Expand All @@ -71,7 +95,12 @@ def update_session_info(info: str):
)

if not source_content:
raise HTTPException(status_code=500, detail="Failed to generate source content")
raise GenerateAudiocastException(
status_code=500,
detail="Failed to generate source content",
session_id=session_id,
category=category,
)

# get custom sources
update_session_info("Checking for custom sources...")
Expand All @@ -83,7 +112,12 @@ def update_session_info(info: str):
audio_script = script_maker.create(provider="gemini")

if not audio_script:
raise HTTPException(status_code=500, detail="Failed to generate audio script")
raise GenerateAudiocastException(
status_code=500,
detail="Failed to generate audio script",
session_id=session_id,
category=category,
)

# Generate audio
update_session_info("Generating audio...")
Expand All @@ -97,6 +131,7 @@ def update_session_info(info: str):
audio_path,
audio_script,
)
db._update({"completed": True})

db._update({"status": "completed"})

return "Audiocast generated successfully!"
7 changes: 5 additions & 2 deletions api/src/utils/session_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, cast
from typing import Callable, Dict, List, Literal, Optional, cast

from src.services.firestore_sdk import (
Collection,
Expand All @@ -19,14 +19,17 @@ class ChatMetadata:
title: Optional[str] = None


SessionStatus = Literal["collating", "completed", "generating", "failed"]


@dataclass
class SessionModel:
id: str
category: ContentCategory
chats: List[SessionChatItem]
metadata: Optional[ChatMetadata]
created_at: Optional[str] = None
completed: Optional[bool] = None
status: Optional[SessionStatus] = None


class SessionManager(DBManager):
Expand Down
2 changes: 1 addition & 1 deletion app/src/lib/utils/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export interface SessionModel {
id: string;
category: ContentCategory;
chats: Array<SessionChatItem>;
completed?: boolean;
status?: 'collating' | 'completed' | 'generating' | 'failed';
metadata?: ChatMetadata;
created_at?: string;
}

0 comments on commit c85c12f

Please sign in to comment.