diff --git a/app.py b/app.py index 7bfdf36..d8b0108 100644 --- a/app.py +++ b/app.py @@ -17,7 +17,7 @@ async def main(): # Sidebar for content type selection st.sidebar.title("Audiocast Info") - init_session_state() + session_id = init_session_state() if st.session_state.content_category: st.sidebar.subheader( @@ -32,9 +32,9 @@ async def main(): uichat = st.empty() if not st.session_state.user_specification: with uichat.container(): - await chatui(uichat) + await chatui(session_id, uichat) else: - await audioui(uichat) + await audioui(session_id, uichat) if __name__ == "__main__": diff --git a/pages/audiocast.py b/pages/audiocast.py index 1666a9c..cfc7a6f 100644 --- a/pages/audiocast.py +++ b/pages/audiocast.py @@ -1,9 +1,11 @@ import asyncio from pathlib import Path +import pyperclip import streamlit as st -from src.utils.main_utils import get_audiocast_uri +from src.env_var import APP_URL +from src.utils.main_utils import get_audiocast def navigate_to_home(): @@ -14,27 +16,52 @@ def navigate_to_home(): async def render_audiocast_page(): st.set_page_config(page_title="Audiora | Share Page", page_icon="🎧") - audiocast_id = st.query_params.get("uuid") + session_id = st.query_params.get("session_id") - if audiocast_id: + if session_id: # Display audiocast content - st.title("🎧 Audiocast Player") - st.write(f"Playing audiocast: {audiocast_id}") + st.title("🎧 Audiora") + st.subheader("Share Page ") + + st.markdown(f"#### Viewing audiocast: {session_id}") try: with st.spinner("Loading audiocast..."): - audio_path = get_audiocast_uri(audiocast_id) - st.audio(audio_path) + audiocast = get_audiocast(session_id) + + # Audio player + st.audio(audiocast["url"]) + + # Transcript + with st.expander("Show Transcript"): + st.write(audiocast["script"]) + + # Metadata + st.sidebar.subheader("Audiocast Source") + st.sidebar.markdown(audiocast["source_content"]) + + share_url = f"{APP_URL}/audiocast?session_id={session_id}" + st.text_input("Share this audiocast:", share_url) + + share_col, restart_row = st.columns(2, vertical_alignment="bottom") + + with share_col: + if st.button("Copy Share link", use_container_width=True): + pyperclip.copy(share_url) + st.session_state.show_copy_success = True + + with restart_row: + if st.button("Create your Audiocast", use_container_width=True): + navigate_to_home() - # TODO: Fetch audiocast metadata from the database - st.subheader("Audiocast Details") - st.write("Created: 2024-03-20") + if audiocast["created_at"]: + st.markdown(f"> Created: {audiocast["created_at"]}") except Exception as e: st.error(f"Error loading audiocast: {str(e)}") else: st.warning( - "Audiocast ID is missing in the URL. Expected URL format: ?uuid=your-audiocast-id" + "Audiocast ID is missing in the URL. Expected URL format: ?session_id=your-audiocast-id" ) st.markdown("---") diff --git a/src/services/firestore_sdk.py b/src/services/firestore_sdk.py new file mode 100644 index 0000000..4ca0e7b --- /dev/null +++ b/src/services/firestore_sdk.py @@ -0,0 +1,58 @@ +import logging +from typing import Dict, Literal + +from firebase_admin.firestore import client, firestore + +firestore_client = client() +server_timestamp = firestore.SERVER_TIMESTAMP +increment = firestore.Increment +arrayUnion = firestore.ArrayUnion +arrayRemove = firestore.ArrayRemove + + +Collection = Literal["audiora_sessions", "audiora_audiocasts"] + +collections: Dict[Collection, Collection] = { + "audiora_sessions": "audiora_sessions", + "audiora_audiocasts": "audiora_audiocasts", +} + + +class DBManager: + def __init__(self, scope: str): + self.logger = logging.getLogger(scope) + + @property + def timestamp(self): + return server_timestamp + + def _get_collection(self, collection: Collection): + return firestore_client.collection(collections[collection]) + + def _create_document(self, collection: Collection, data: Dict): + return self._get_collection(collection).add( + {**data, "created_at": self.timestamp, "updated_at": self.timestamp} + ) + + def _set_document(self, collection: Collection, doc_id: str, data: Dict): + return ( + self._get_collection(collection) + .document(doc_id) + .set({**data, "created_at": self.timestamp, "updated_at": self.timestamp}) + ) + + def _update_document(self, collection: Collection, doc_id: str, data: Dict): + return ( + self._get_collection(collection) + .document(doc_id) + .update({**data, "updated_at": self.timestamp}) + ) + + def _delete_document(self, collection: Collection, doc_id: str): + return self._get_collection(collection).document(doc_id).delete() + + def _get_document(self, collection: Collection, doc_id: str): + return self._get_collection(collection).document(doc_id).get() + + def _get_documents(self, collection: Collection): + return self._get_collection(collection).stream() diff --git a/src/services/storage.py b/src/services/storage.py index 1580a9c..1e14bca 100644 --- a/src/services/storage.py +++ b/src/services/storage.py @@ -1,3 +1,4 @@ +import os from dataclasses import dataclass from io import BytesIO from pathlib import Path @@ -5,6 +6,7 @@ from uuid import uuid4 from google.cloud import storage +from pydub import AudioSegment from src.env_var import BUCKET_NAME @@ -70,8 +72,15 @@ def download_from_gcs(self, filename: str): """ blobname = f"{BLOB_BASE_URI}/{filename}" blob = bucket.blob(blobname) - tmp_file_path = f"/tmp/{str(uuid4())}" - blob.download_to_filename(tmp_file_path) + tmp_file_path = f"/tmp/{filename}" + if os.path.exists(tmp_file_path): + try: + audio = AudioSegment.from_file(tmp_file_path) + if audio.duration_seconds > 0: + return tmp_file_path + except Exception: + os.remove(tmp_file_path) + blob.download_to_filename(tmp_file_path) return tmp_file_path diff --git a/src/uis/audioui.py b/src/uis/audioui.py index d3b8188..1680323 100644 --- a/src/uis/audioui.py +++ b/src/uis/audioui.py @@ -5,7 +5,7 @@ from src.utils.render_audiocast import render_audiocast -async def audioui(uichat: DeltaGenerator): +async def audioui(session_id: str, uichat: DeltaGenerator): """ Audiocast interface """ @@ -17,7 +17,7 @@ async def audioui(uichat: DeltaGenerator): summary = st.session_state.user_specification content_category = st.session_state.content_category - await use_audiocast_request(summary, content_category) + await use_audiocast_request(session_id, summary, content_category) else: st.info("Audiocast generation completed!") - render_audiocast() + render_audiocast(session_id) diff --git a/src/uis/chatui.py b/src/uis/chatui.py index 31bf729..02dcfe1 100644 --- a/src/uis/chatui.py +++ b/src/uis/chatui.py @@ -10,7 +10,7 @@ from src.utils.render_chat import render_chat_history -async def chatui(uichat: DeltaGenerator): +async def chatui(session_id: str, uichat: DeltaGenerator): """ Chat interface """ @@ -27,12 +27,13 @@ async def chatui(uichat: DeltaGenerator): content_category = st.session_state.content_category if st.session_state.example_prompt: - handle_example_prompt(content_category) + prompt = st.session_state.example_prompt + handle_example_prompt(session_id, prompt, content_category) if st.session_state.prompt: prompt = st.session_state.prompt st.session_state.prompt = None - ai_message = handle_user_prompt(prompt, content_category) + ai_message = handle_user_prompt(session_id, prompt, content_category) if isinstance(ai_message, str): await evaluate_final_response(ai_message, content_category) diff --git a/src/utils/audio_manager_utils.py b/src/utils/audio_manager_utils.py index 1c73551..fe30fdc 100644 --- a/src/utils/audio_manager_utils.py +++ b/src/utils/audio_manager_utils.py @@ -41,9 +41,7 @@ def __init__(self) -> None: def _create_voice_mapping(self, tags: List[str], voices: List[Any]): """Create mapping of tags to voices""" - available_voices = voices[: len(tags)] - if len(available_voices) < len(tags): - available_voices = list(islice(cycle(voices), len(tags))) + available_voices = list(islice(cycle(voices), len(tags))) return dict(zip(tags, available_voices)) def _prepare_speech_jobs( @@ -120,8 +118,8 @@ def split_content(self, content: str, tags: List[str]) -> List[Tuple[str, str]]: # Regular expression pattern to match Tag0, Tag1, ..., TagN speaker dialogues matches = re.findall(r"<(Speaker\d+)>(.*?)", content, re.DOTALL) return [ - (str(person), " ".join(content.split()).strip()) - for person, content in matches + (str(speaker), " ".join(content_part.split()).strip()) + for speaker, content_part in matches ] @staticmethod diff --git a/src/utils/chat_thread.py b/src/utils/chat_thread.py index 904287a..78b3561 100644 --- a/src/utils/chat_thread.py +++ b/src/utils/chat_thread.py @@ -14,10 +14,14 @@ termination_suffix = "Please click the button below to start generating the audiocast." -def generate_stream_response(prompt: str, content_category: ContentCategory): +def generate_stream_response( + session_id: str, + prompt: str, + content_category: ContentCategory, +): with st.spinner("Generating response..."): response_generator = chat( - st.session_state.chat_session_id, + session_id, SessionChatRequest( message=SessionChatMessage(role="user", content=prompt), content_category=content_category, @@ -27,12 +31,17 @@ def generate_stream_response(prompt: str, content_category: ContentCategory): return response_generator -def handle_example_prompt(content_category: ContentCategory): +def handle_example_prompt( + session_id: str, + prompt: str, + content_category: ContentCategory, +): """Handle selected example prompt""" - prompt = st.session_state.example_prompt with st.chat_message("assistant"): - response_generator = generate_stream_response(prompt, content_category) + response_generator = generate_stream_response( + session_id, prompt, content_category + ) ai_message = st.write_stream(response_generator) st.session_state.example_prompt = None @@ -45,12 +54,20 @@ def handle_example_prompt(content_category: ContentCategory): st.error("Failed to generate AI response. Please try again.") -def handle_user_prompt(prompt: str, content_category: ContentCategory): +def handle_user_prompt( + session_id: str, + prompt: str, + content_category: ContentCategory, +): """ Handle user input prompt """ with st.chat_message("assistant"): - response_generator = generate_stream_response(prompt, content_category) + response_generator = generate_stream_response( + session_id, + prompt, + content_category, + ) ai_message = st.write_stream(response_generator) if ai_message: @@ -110,7 +127,11 @@ def onclick(summary: str): st.rerun() -async def use_audiocast_request(summary: str, content_category: ContentCategory): +async def use_audiocast_request( + session_id: str, + summary: str, + content_category: ContentCategory, +): """ Call audiocast creating workflow @@ -121,7 +142,11 @@ async def use_audiocast_request(summary: str, content_category: ContentCategory) try: with st.spinner("Generating your audiocast..."): audiocast_response = await generate_audiocast( - GenerateAudioCastRequest(summary=summary, category=content_category) + GenerateAudioCastRequest( + sessionId=session_id, + summary=summary, + category=content_category, + ) ) print(f"Generate AudioCast Response: {audiocast_response}") diff --git a/src/utils/chat_utils.py b/src/utils/chat_utils.py index 6942bfb..e627b98 100644 --- a/src/utils/chat_utils.py +++ b/src/utils/chat_utils.py @@ -1,7 +1,8 @@ +import uuid from typing import Dict, List, Literal import streamlit as st -from pydantic import BaseModel +from pydantic import BaseModel, Field ContentCategory = Literal[ "podcast", @@ -49,8 +50,9 @@ class SessionChatMessage(BaseModel): - role: Literal["user", "assistant"] + id: str = Field(default_factory=lambda: str(uuid.uuid4())) content: str + role: Literal["user", "assistant"] class SessionChatRequest(BaseModel): diff --git a/src/utils/content_generator.py b/src/utils/content_generator.py deleted file mode 100644 index 6f289ef..0000000 --- a/src/utils/content_generator.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Dict, List - -from langchain.chains import LLMChain -from langchain.llms import OpenAI -from langchain.prompts import PromptTemplate - - -class ContentGenerator: - def __init__(self): - self.llm = OpenAI(temperature=0.7) - self.prompt_templates = { - "story": PromptTemplate( - input_variables=["query"], - template="""Create an engaging story about {query}. - Make it captivating and suitable for audio narration. - Include vivid descriptions and natural dialogue.""", - ), - "podcast": PromptTemplate( - input_variables=["query"], - template="""Create an informative podcast script about {query}. - Structure it like a professional podcast with clear sections, - engaging facts, and natural transitions.""", - ), - "sermon": PromptTemplate( - input_variables=["query"], - template="""Create an inspiring sermon about {query}. - Include spiritual insights, relevant scriptures, - and practical applications for daily life.""", - ), - "science": PromptTemplate( - input_variables=["query"], - template="""Create an educational scientific explanation about {query}. - Make it engaging and accessible while maintaining accuracy. - Include recent research and fascinating details.""", - ), - } - - def generate_content( - self, query: str, content_category: str, chat_history: List[Dict] - ) -> str: - # Get the appropriate prompt template - prompt_template = self.prompt_templates.get(content_category) - if not prompt_template: - raise ValueError(f"Invalid content type: {content_category}") - - # Create and run the chain - chain = LLMChain(llm=self.llm, prompt=prompt_template) - response = chain.run(query=query) - - return response - - def refine_with_chat_history(self, content: str, chat_history: List[Dict]) -> str: - # Use chat history to refine the content if needed - relevant_context = "\n".join( - [ - f"{msg['role']}: {msg['content']}" - for msg in chat_history[-3:] # Use last 3 messages for context - ] - ) - - refine_prompt = PromptTemplate( - input_variables=["content", "context"], - template="""Given this conversation context: - {context} - - Please refine this content to better match the user's needs: - {content} - - Refined content:""", - ) - - chain = LLMChain(llm=self.llm, prompt=refine_prompt) - refined_content = chain.run(content=content, context=relevant_context) - - return refined_content diff --git a/src/utils/main_utils.py b/src/utils/main_utils.py index ac75136..c9c3537 100644 --- a/src/utils/main_utils.py +++ b/src/utils/main_utils.py @@ -1,13 +1,12 @@ -import uuid -from pathlib import Path -from typing import Dict, List +from datetime import datetime import streamlit as st from pydantic import BaseModel from src.services.storage import StorageManager from src.utils.audio_manager import AudioManager -from src.utils.audio_synthesizer import AudioSynthesizer + +# from src.utils.audio_synthesizer import AudioSynthesizer from src.utils.audiocast_request import AudioScriptMaker, generate_source_content from src.utils.chat_request import chat_request from src.utils.chat_utils import ( @@ -15,54 +14,45 @@ SessionChatRequest, content_categories, ) +from src.utils.session_manager import SessionManager class GenerateAudioCastRequest(BaseModel): + sessionId: str summary: str category: str class GenerateAudioCastResponse(BaseModel): - uuid: str url: str script: str source_content: str - - -# Store chat sessions (in-memory for now, should be moved to a database in production) -chat_sessions: Dict[str, List[SessionChatMessage]] = {} + created_at: str | None def chat(session_id: str, request: SessionChatRequest): - message = request.message content_category = request.content_category - - if session_id not in chat_sessions: - chat_sessions[session_id] = [] - - chat_sessions[session_id].append(message) + db = SessionManager(session_id) + db._add_chat(request.message) def on_finish(text: str): - chat_sessions[session_id].append( - SessionChatMessage(role="assistant", content=text) - ) - # log text and other metadata to database + db._add_chat(SessionChatMessage(role="assistant", content=text)) - generator = chat_request( + return chat_request( content_category=content_category, - previous_messages=chat_sessions[session_id], + previous_messages=db._get_chats(), on_finish=on_finish, ) - return generator - async def generate_audiocast(request: GenerateAudioCastRequest): """ Generate an audiocast based on a summary of user's request """ + session_id = request.sessionId summary = request.summary category = request.category + if category not in content_categories: raise Exception("Invalid content category") @@ -93,37 +83,60 @@ async def generate_audiocast(request: GenerateAudioCastRequest): container.info("Generating audio...") output_file = await AudioManager().generate_speech(audio_script) - container.info("Enhancing audio quality...") - AudioSynthesizer().enhance_audio_minimal(Path(output_file)) + # container.info("Enhancing audio quality...") + # AudioSynthesizer().enhance_audio_minimal(Path(output_file)) print(f"output_file: {output_file}") - # unique ID for the audiocast - uniq_id = str(uuid.uuid4()) - # TODO: Use a background service # STEP 4: Ingest audio file to a storage service (e.g., GCS, S3) with container.container(): try: container.info("Storing a copy of your audiocast...") storage_manager = StorageManager() - storage_manager.upload_audio_to_gcs(output_file, uniq_id) + storage_manager.upload_audio_to_gcs(output_file, session_id) except Exception as e: print(f"Error while storing audiocast: {str(e)}") + db = SessionManager(session_id) + db._update_source(source_content) + db._update_transcript(audio_script) + response = GenerateAudioCastResponse( - uuid=uniq_id, url=output_file, script=audio_script, source_content=source_content, + created_at=datetime.now().strftime("%Y-%m-%d %H:%M"), ) return response.model_dump() -def get_audiocast_uri(uuid: str): +def get_audiocast(session_id: str): """ Get the URI for the audiocast """ storage_manager = StorageManager() - filepath = storage_manager.download_from_gcs(uuid) - return filepath + filepath = storage_manager.download_from_gcs(session_id) + + session_data = SessionManager(session_id).data() + if not session_data: + raise Exception(f"Audiocast not found for session_id: {session_id}") + + metadata = session_data.metadata + source = metadata.source if metadata else "" + transcript = metadata.transcript if metadata else "" + + created_at: str | None = None + if session_data.created_at: + created_at = datetime.fromisoformat(session_data.created_at).strftime( + "%Y-%m-%d %H:%M" + ) + + response = GenerateAudioCastResponse( + url=filepath, + script=transcript, + source_content=source, + created_at=created_at, + ) + + return response.model_dump() diff --git a/src/utils/render_audiocast.py b/src/utils/render_audiocast.py index 5daa902..b7000e0 100644 --- a/src/utils/render_audiocast.py +++ b/src/utils/render_audiocast.py @@ -8,13 +8,13 @@ class GenerateAudiocastDict(TypedDict): - uuid: str url: str script: str source_content: str + created_at: str | None -def render_audiocast(): +def render_audiocast(session_id: str): """ Render the audiocast based on the user's preferences - Display current audiocast if available @@ -33,7 +33,7 @@ def render_audiocast(): st.sidebar.subheader("Audiocast Source") st.sidebar.markdown(current_audiocast["source_content"]) - share_url = f"{APP_URL}/audiocast?uuid={current_audiocast['uuid']}" + share_url = f"{APP_URL}/audiocast?session_id={session_id}" st.text_input("Share this audiocast:", share_url) share_col, restart_row = st.columns(2, vertical_alignment="bottom") diff --git a/src/utils/render_chat.py b/src/utils/render_chat.py index 2569b37..b891136 100644 --- a/src/utils/render_chat.py +++ b/src/utils/render_chat.py @@ -20,7 +20,7 @@ def on_value_change(): with col1: st.selectbox( "Select Content Category", - content_categories, + ["", *content_categories], format_func=lambda x: x.title(), key="selected_content_category", on_change=on_value_change, diff --git a/src/utils/session_manager.py b/src/utils/session_manager.py new file mode 100644 index 0000000..3d9fc94 --- /dev/null +++ b/src/utils/session_manager.py @@ -0,0 +1,112 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, cast + +from src.services.firestore_sdk import ( + Collection, + DBManager, + arrayRemove, + arrayUnion, + collections, +) +from src.utils.chat_utils import SessionChatMessage + + +@dataclass +class ChatMetadata: + source: str + transcript: str + + +@dataclass +class SessionModel: + id: str + chats: List[SessionChatMessage] + metadata: Optional[ChatMetadata] + created_at: Optional[str] = None + + +class SessionManager(DBManager): + collection: Collection = collections["audiora_sessions"] + + def __init__(self, session_id: str): + super().__init__(scope="ChatManager") + + self.doc_id = session_id + session_doc = self._get_document(self.collection, self.doc_id) + # if the collection does not exist, create it + if not session_doc.exists: + payload = SessionModel(id=self.doc_id, chats=[], metadata=None) + self._set_document(self.collection, self.doc_id, payload.__dict__) + + def _update(self, data: Dict): + return self._update_document(self.collection, self.doc_id, data) + + def data(self) -> SessionModel | None: + doc = self._get_document(self.collection, self.doc_id) + + data = doc.to_dict() + if not doc.exists or not data: + return None + + metadata = data["metadata"] or {} + + return SessionModel( + id=data["id"], + chats=data["chats"], + metadata=ChatMetadata( + source=metadata.get("source", ""), + transcript=metadata.get("transcript", ""), + ), + created_at=str(data["created_at"]), + ) + + def _update_source(self, source: str): + return self._update({"metadata.source": source}) + + def _update_transcript(self, transcript: str): + return self._update({"metadata.transcript": transcript}) + + def _add_chat(self, chat: SessionChatMessage): + return self._update_document( + self.collection, self.doc_id, {"chats": arrayUnion([chat.__dict__])} + ) + + def _delete_chat(self, chat_id: str): + doc = self._get_document(self.collection, self.doc_id) + if not doc.exists: + return + + chat_to_remove = [chat for chat in doc.get("chats") if chat.id == chat_id] + self._update_document( + self.collection, + self.doc_id, + {"chats": arrayRemove([chat_to_remove.__dict__])}, + ) + + def _get_chat(self, chat_id: str) -> SessionChatMessage | None: + doc = self._get_document(self.collection, self.doc_id) + if not doc.exists: + return None + + item = [chat for chat in doc.get("chats") if chat.id == chat_id][0] + if item: + return SessionChatMessage( + content=item["content"], + id=item["id"], + role=item["role"], + ) + + def _get_chats(self) -> List[SessionChatMessage]: + doc = self._get_document(self.collection, self.doc_id) + if not doc.exists: + return [] + + chats = cast(Dict, doc.get("chats")) + return [ + SessionChatMessage( + content=chat["content"], + id=chat["id"], + role=chat["role"], + ) + for chat in chats + ] diff --git a/src/utils/session_state.py b/src/utils/session_state.py index 1386e1e..d19164d 100644 --- a/src/utils/session_state.py +++ b/src/utils/session_state.py @@ -33,6 +33,8 @@ def init_session_state(): if "current_audiocast" not in st.session_state: st.session_state.current_audiocast = None + return cast(str, st.session_state.chat_session_id) + def reset_session(): """ diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29