Skip to content

Commit

Permalink
Merge pull request #19 from nwaughachukwuma/ensure-idempotency-alt
Browse files Browse the repository at this point in the history
use a status field on audiocast session model
  • Loading branch information
nwaughachukwuma authored Nov 29, 2024
2 parents 68b1c47 + f2f5127 commit 745e170
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 78 deletions.
9 changes: 8 additions & 1 deletion api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from .utils.custom_sources.save_copied_source import CopiedPasteSourceRequest, save_copied_source
from .utils.custom_sources.save_uploaded_sources import UploadedFiles
from .utils.generate_audiocast import GenerateAudioCastRequest, generate_audiocast
from .utils.generate_audiocast import GenerateAudioCastRequest, GenerateAudiocastException, generate_audiocast
from .utils.generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source
from .utils.get_audiocast import get_audiocast
from .utils.get_session_title import GetSessionTitleModel, get_session_title
Expand Down Expand Up @@ -78,6 +78,7 @@ def chat_endpoint(
db._add_chat(request.chatItem)

def on_finish(text: str):
background_tasks.add_task(db._update, {"status": "collating"})
background_tasks.add_task(
db._add_chat,
SessionChatItem(role="assistant", content=text),
Expand All @@ -92,6 +93,12 @@ def on_finish(text: str):
return StreamingResponse(response, media_type="text/event-stream")


@app.exception_handler(GenerateAudiocastException)
async def generate_audiocast_exception_handler(_, exc):
SessionManager._update_status(exc.session_id, "failed")
return HTTPException(status_code=exc.status_code, detail=str(exc.detail))


@app.post("/audiocast/generate", response_model=str)
async def generate_audiocast_endpoint(
request: GenerateAudioCastRequest,
Expand Down
72 changes: 51 additions & 21 deletions api/src/utils/generate_audiocast.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
import asyncio
import os

from fastapi import BackgroundTasks, HTTPException

from src.services.storage import StorageManager
from src.utils.audio_manager import AudioManager, AudioManagerConfig
from src.utils.audiocast_script_maker import AudioScriptMaker
from src.utils.audiocast_utils import GenerateAudioCastRequest
from src.utils.chat_utils import ContentCategory
from src.utils.custom_sources.base_utils import CustomSourceManager
from src.utils.generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source
from src.utils.session_manager import SessionManager
from src.utils.waveform_utils import WaveformUtils

from .audio_manager import AudioManager, AudioManagerConfig
from .audiocast_script_maker import AudioScriptMaker
from .audiocast_utils import GenerateAudioCastRequest
from .chat_utils import ContentCategory
from .custom_sources.base_utils import CustomSourceManager
from .generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source
from .session_manager import SessionManager
from .waveform_utils import WaveformUtils


class GenerateAudiocastException(HTTPException):
name = "GenerateAudiocastException"

def __init__(self, detail: str, session_id: str, status_code=500):
self.detail = detail
self.status_code = status_code
self.session_id = session_id


def compile_custom_sources(session_id: str):
Expand Down Expand Up @@ -44,25 +55,40 @@ def post_generate_audio(


async def generate_audiocast(request: GenerateAudioCastRequest, background_tasks: BackgroundTasks):
"""## Generate audiocast based on a summary of user's request
### Steps:
1. Generate source content
2. Generate audio script
3. Generate audio
4. a) Store audio. b) Store the audio waveform on GCS
5. Update session
"""
"""## Generate audiocast based on a summary of user's request"""
summary = request.summary
category = request.category
session_id = request.sessionId

db = SessionManager(session_id, category)
session_data = SessionManager.data(session_id)

if not session_data:
raise GenerateAudiocastException(
status_code=404, detail=f"Audiocast data not found for session_id: {session_id}", session_id=session_id
)

if session_data.status == "completed":
return "Audiocast already generated!"
elif session_data.status == "generating":
print(f"Queueing the current audio generation request>>>>\n\nSessionId: {session_id}")

async def retry_generation():
await asyncio.sleep(10)
try:
await generate_audiocast(request, background_tasks)
except Exception as e:
print(f"Retry failed for session {session_id}: {str(e)}")
SessionManager._update_status(session_id, "failed")

background_tasks.add_task(retry_generation)
return "Audiocast generation in progress. Please wait..."

db._update({"status": "generating"})

def update_session_info(info: str):
db._update_info(info)

session_data = SessionManager.data(session_id)
source_content = session_data.metadata.source if session_data and session_data.metadata else None

if not source_content:
Expand All @@ -76,7 +102,9 @@ def update_session_info(info: str):
)

if not source_content:
raise HTTPException(status_code=500, detail="Failed to generate source content")
raise GenerateAudiocastException(
status_code=500, detail="Failed to generate source content", session_id=session_id
)

# get custom sources
update_session_info("Checking for custom sources...")
Expand All @@ -88,7 +116,9 @@ def update_session_info(info: str):
audio_script = script_maker.create(provider="gemini")

if not audio_script:
raise HTTPException(status_code=500, detail="Failed to generate audio script")
raise GenerateAudiocastException(
status_code=500, detail="Failed to generate audio script", session_id=session_id
)

# Generate audio
update_session_info("Generating audio...")
Expand All @@ -102,6 +132,6 @@ def update_session_info(info: str):
audio_path,
audio_script,
)
db._update({"completed": True})
db._update({"status": "completed"})

return "Audiocast generated successfully!"
3 changes: 3 additions & 0 deletions api/src/utils/get_audiocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,7 @@ def get_audiocast(session_id: str):
detail=f"Audiocast not found for session_id: {session_id}",
)

if session_data.status != "completed":
SessionManager._update_status(session_id, "completed")

return session_data.__dict__
11 changes: 9 additions & 2 deletions api/src/utils/session_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, cast
from typing import Callable, Dict, List, Literal, Optional, cast

from src.services.firestore_sdk import (
Collection,
Expand All @@ -19,14 +19,17 @@ class ChatMetadata:
title: Optional[str] = None


SessionStatus = Literal["collating", "completed", "generating", "failed"]


@dataclass
class SessionModel:
id: str
category: ContentCategory
chats: List[SessionChatItem]
metadata: Optional[ChatMetadata]
created_at: Optional[str] = None
completed: Optional[bool] = None
status: Optional[SessionStatus] = None


class SessionManager(DBManager):
Expand Down Expand Up @@ -150,3 +153,7 @@ def on_snapshot(doc_snapshot, _changes, _read_time):
@staticmethod
def _delete_session(doc_id: str):
return DBManager()._delete_document(collections["audiora_sessions"], doc_id)

@staticmethod
def _update_status(doc_id: str, status: SessionStatus):
return DBManager()._update_document(collections["audiora_sessions"], doc_id, data={"status": status})
12 changes: 3 additions & 9 deletions app/src/lib/components/ChatListActionItems.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
export let summary: string;
export let title: string;
const { session$, audioSource$, fetchingSource$, sessionId$, sessionCompleted$ } =
const { session$, audioSource$, fetchingSource$, sessionId$, sessionCompleted$, updateSessionTitle } =
getSessionContext();
async function ongenerate(summary: string) {
Expand Down Expand Up @@ -79,15 +79,9 @@
async function handleStreamingResponse(res: Response) {
if (!res.ok) return;
for await (const chunk of streamingResponse(res)) {
session$.update((session) => {
if (session) {
if (session.title.toLowerCase() === 'untitled') session.title = '';
session.title += chunk;
}
return session;
});
updateSessionTitle(chunk);
}
}
Expand Down
13 changes: 8 additions & 5 deletions app/src/lib/components/Sidebar.svelte
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
<script lang="ts" context="module">
import { browser } from '$app/environment';
import { SESSION_KEY } from '@/stores/sessionContext.svelte';
const ONE_DAY_MS = 24 * 60 * 60 * 1000;
const last24Hrs = Date.now() - ONE_DAY_MS;
const last7Days = Date.now() - 4 * ONE_DAY_MS;
export function getSessionItems() {
if (!browser) return [];
return Object.entries(localStorage)
.filter(([key]) => key.startsWith(SESSION_KEY))
.map(
Expand Down Expand Up @@ -39,17 +42,17 @@
import { page } from '$app/stores';
import NewAudiocastButton from './NewAudiocastButton.svelte';
import { goto } from '$app/navigation';
import { browser } from '$app/environment';
import { env } from '@env';
const dispatch = createEventDispatcher<{ clickItem: void }>();
const { openSettingsDrawer$ } = getAppContext();
const { session$, refreshSidebar$ } = getSessionContext();
$: sessionItems = browser || $session$ ? getSessionItems() : [];
$: sidebarItems = getSidebarItems(sessionItems);
$: if ($refreshSidebar$) sidebarItems = getSidebarItems(getSessionItems());
$: sidebarItems = getSidebarItems(getSessionItems());
$: if ($refreshSidebar$ || $session$?.title) {
sidebarItems = getSidebarItems(getSessionItems());
}
$: inLast24Hrs = sidebarItems.filter((i) => i.nonce > last24Hrs);
$: inLast7Days = sidebarItems.filter((i) => i.nonce < last24Hrs && i.nonce > last7Days);
Expand Down
10 changes: 10 additions & 0 deletions app/src/lib/stores/sessionContext.svelte.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ export function setSessionContext(sessionId: string) {
session.chats = chats;
return session;
});
},
updateSessionTitle: (chunk: string) => {
session$.update((session) => {
if (session) {
if (session.title.toLowerCase() === 'untitled') session.title = '';
session.title += chunk;
}
return session;
});
return session$;
}
});
}
Expand Down
2 changes: 1 addition & 1 deletion app/src/lib/utils/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export interface SessionModel {
id: string;
category: ContentCategory;
chats: Array<SessionChatItem>;
completed?: boolean;
status?: 'collating' | 'completed' | 'generating' | 'failed';
metadata?: ChatMetadata;
created_at?: string;
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
<script>
import { getSessionContext } from '@/stores/sessionContext.svelte';
import { Button } from '@/components/ui/button';
const { session$, sessionModel$ } = getSessionContext();
const { session$, sessionModel$, customSources$ } = getSessionContext();
$: $sessionModel$;
$: $customSources$;
</script>

<!-- TODO: Use only the DB references -->
Expand Down
Loading

0 comments on commit 745e170

Please sign in to comment.