Skip to content

Commit

Permalink
define category as both a data model field and a class property for s…
Browse files Browse the repository at this point in the history
…ession_manager
  • Loading branch information
nwaughachukwuma committed Nov 12, 2024
1 parent 4495303 commit 5ee3e29
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 36 deletions.
6 changes: 3 additions & 3 deletions api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ async def root():
@app.post("/chat/{session_id}", response_model=Generator[str, Any, None])
async def chat_endpoint(session_id: str, request: SessionChatRequest, background_tasks: BackgroundTasks):
"""Chat endpoint"""
content_category = request.contentCategory
db = SessionManager(session_id)
category = request.contentCategory
db = SessionManager(session_id, category)
db._add_chat(request.chatItem)

def on_finish(text: str):
Expand All @@ -71,7 +71,7 @@ def on_finish(text: str):
)

response = chat_request(
content_category=content_category,
content_category=category,
previous_messages=db._get_chats(),
on_finish=on_finish,
)
Expand Down
7 changes: 4 additions & 3 deletions api/src/utils/audiocast_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import TypedDict
from typing import List, TypedDict

from pydantic import BaseModel

from src.utils.chat_utils import ContentCategory
from src.utils.chat_utils import ContentCategory, SessionChatItem


class GenerateAudioCastRequest(BaseModel):
Expand All @@ -12,9 +12,10 @@ class GenerateAudioCastRequest(BaseModel):


class GenerateAudioCastResponse(BaseModel):
url: str
script: str
source_content: str
chats: List[SessionChatItem]
title: str | None
created_at: str | None


Expand Down
14 changes: 12 additions & 2 deletions api/src/utils/generate_audiocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def generate_audiocast(request: GenerateAudioCastRequest, background_tasks
background_tasks,
)

db = SessionManager(session_id)
db = SessionManager(session_id, category)

def update_session_info(info: str):
background_tasks.add_task(db._update_info, info)
Expand Down Expand Up @@ -78,9 +78,19 @@ def _run_on_background():

background_tasks.add_task(_run_on_background)

session_data = SessionManager.data(session_id)
if not session_data:
raise HTTPException(
status_code=404,
detail=f"Failed to get audiocast from the DB for session_id: {session_id}",
)

title = session_data.metadata.title if session_data.metadata and session_data.metadata.title else "Untitled"

return GenerateAudioCastResponse(
url=audio_path,
script=audio_script,
source_content=source_content,
created_at=datetime.now().strftime("%Y-%m-%d %H:%M"),
chats=session_data.chats,
title=title,
)
20 changes: 5 additions & 15 deletions api/src/utils/get_audiocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,15 @@

from fastapi import HTTPException

from src.services.storage import StorageManager
from src.utils.generate_audiocast import (
GenerateAudioCastResponse,
)
from src.utils.generate_audiocast import GenerateAudioCastResponse
from src.utils.session_manager import SessionManager


def get_audiocast(session_id: str):
"""
Get audiocast based on session id
"""
try:
storage_manager = StorageManager()
filepath = storage_manager.download_from_gcs(session_id)
except Exception as e:
raise HTTPException(
status_code=404,
detail=f"Audiocast asset not found for session_id: {session_id}",
)

session_data = SessionManager(session_id).data()
session_data = SessionManager.data(session_id)
if not session_data:
raise HTTPException(
status_code=404,
Expand All @@ -32,14 +20,16 @@ def get_audiocast(session_id: str):
metadata = session_data.metadata
source = metadata.source if metadata else ""
transcript = metadata.transcript if metadata else ""
title = metadata.title if metadata and metadata.title else "Untitled"

created_at = None
if session_data.created_at:
created_at = datetime.fromisoformat(session_data.created_at).strftime("%Y-%m-%d %H:%M")

return GenerateAudioCastResponse(
url=filepath,
script=transcript,
source_content=source,
created_at=created_at,
chats=session_data.chats,
title=title,
)
5 changes: 1 addition & 4 deletions api/src/utils/get_audiocast_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,12 @@ async def get_audiocast_source(request: GetAudiocastSourceModel, background_task

@use_cache_manager(cache_key)
async def _handler():
db = SessionManager(session_id)

def update_session_info(info: str):
db = SessionManager(session_id, category)
background_tasks.add_task(db._update_info, info)

update_session_info("Generating source content...")

source_content = generate_source_content(category, summary)

return source_content

return await _handler()
2 changes: 1 addition & 1 deletion api/src/utils/get_session_title.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def get_session_title(request: GetSessionTitleModel, background_tasks: Bac

def _on_finish(title: str):
print(f"Generated title: {title}")
db = SessionManager(session_id)
db = SessionManager(session_id, request.category)
background_tasks.add_task(db._update_title, title)

response = generate_content(
Expand Down
30 changes: 22 additions & 8 deletions api/src/utils/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
arrayUnion,
collections,
)
from src.utils.chat_utils import SessionChatItem
from src.utils.chat_utils import ContentCategory, SessionChatItem


@dataclass
Expand All @@ -22,30 +22,41 @@ class ChatMetadata:
@dataclass
class SessionModel:
id: str
category: ContentCategory
chats: List[SessionChatItem]
metadata: Optional[ChatMetadata]
created_at: Optional[str] = None


class SessionManager(DBManager):
collection: Collection = collections["audiora_sessions"]
category: ContentCategory

def __init__(self, session_id: str):
def __init__(self, session_id: str, category: ContentCategory):
super().__init__()

self.doc_id = session_id
session_doc = self._get_document(self.collection, self.doc_id)
# if the collection does not exist, create it
if not session_doc.exists:
payload = SessionModel(id=self.doc_id, chats=[], metadata=None)
self.category = category
self._init_document()

def _init_document(self):
"""if the collection does not exist, create it"""
try:
session_doc = self._get_document(self.collection, self.doc_id)
except Exception:
session_doc = None

if not session_doc or not session_doc.exists:
payload = SessionModel(id=self.doc_id, chats=[], metadata=None, category=self.category)
self._set_document(self.collection, self.doc_id, payload.__dict__)

def _update(self, data: Dict):
return self._update_document(self.collection, self.doc_id, data)

def data(self) -> SessionModel | None:
@classmethod
def data(cls, doc_id: str) -> SessionModel | None:
"""Get session data"""
doc = self._get_document(self.collection, self.doc_id)
doc = DBManager()._get_document(collections["audiora_sessions"], doc_id)

data = doc.to_dict()
if not doc.exists or not data:
Expand All @@ -55,10 +66,13 @@ def data(self) -> SessionModel | None:

return SessionModel(
id=data["id"],
category=data["category"],
chats=data["chats"],
metadata=ChatMetadata(
source=metadata.get("source", ""),
transcript=metadata.get("transcript", ""),
info=metadata.get("info"),
title=metadata.get("title"),
),
created_at=str(data["created_at"]),
)
Expand Down

0 comments on commit 5ee3e29

Please sign in to comment.