Skip to content

Commit

Permalink
Integrate firestore (#4)
Browse files Browse the repository at this point in the history
* add firestore_sdk ad session_manager

* save user chats on firestore

* pass down session_id for a deterministic workflow

* handle conversion of chat object to/fro a dict

* remove references to langchain

* reuse a previously downloaded audiofile if it's processable

* render audiocast metdata on share page

* cleanup

* temp remove audio_enchancement
  • Loading branch information
nwaughachukwuma authored Oct 31, 2024
1 parent 3973781 commit 82db371
Show file tree
Hide file tree
Showing 16 changed files with 322 additions and 150 deletions.
6 changes: 3 additions & 3 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def main():
# Sidebar for content type selection
st.sidebar.title("Audiocast Info")

init_session_state()
session_id = init_session_state()

if st.session_state.content_category:
st.sidebar.subheader(
Expand All @@ -32,9 +32,9 @@ async def main():
uichat = st.empty()
if not st.session_state.user_specification:
with uichat.container():
await chatui(uichat)
await chatui(session_id, uichat)
else:
await audioui(uichat)
await audioui(session_id, uichat)


if __name__ == "__main__":
Expand Down
49 changes: 38 additions & 11 deletions pages/audiocast.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
from pathlib import Path

import pyperclip
import streamlit as st

from src.utils.main_utils import get_audiocast_uri
from src.env_var import APP_URL
from src.utils.main_utils import get_audiocast


def navigate_to_home():
Expand All @@ -14,27 +16,52 @@ def navigate_to_home():
async def render_audiocast_page():
st.set_page_config(page_title="Audiora | Share Page", page_icon="🎧")

audiocast_id = st.query_params.get("uuid")
session_id = st.query_params.get("session_id")

if audiocast_id:
if session_id:
# Display audiocast content
st.title("🎧 Audiocast Player")
st.write(f"Playing audiocast: {audiocast_id}")
st.title("🎧 Audiora")
st.subheader("Share Page ")

st.markdown(f"#### Viewing audiocast: {session_id}")

try:
with st.spinner("Loading audiocast..."):
audio_path = get_audiocast_uri(audiocast_id)
st.audio(audio_path)
audiocast = get_audiocast(session_id)

# Audio player
st.audio(audiocast["url"])

# Transcript
with st.expander("Show Transcript"):
st.write(audiocast["script"])

# Metadata
st.sidebar.subheader("Audiocast Source")
st.sidebar.markdown(audiocast["source_content"])

share_url = f"{APP_URL}/audiocast?session_id={session_id}"
st.text_input("Share this audiocast:", share_url)

share_col, restart_row = st.columns(2, vertical_alignment="bottom")

with share_col:
if st.button("Copy Share link", use_container_width=True):
pyperclip.copy(share_url)
st.session_state.show_copy_success = True

with restart_row:
if st.button("Create your Audiocast", use_container_width=True):
navigate_to_home()

# TODO: Fetch audiocast metadata from the database
st.subheader("Audiocast Details")
st.write("Created: 2024-03-20")
if audiocast["created_at"]:
st.markdown(f"> Created: {audiocast["created_at"]}")

except Exception as e:
st.error(f"Error loading audiocast: {str(e)}")
else:
st.warning(
"Audiocast ID is missing in the URL. Expected URL format: ?uuid=your-audiocast-id"
"Audiocast ID is missing in the URL. Expected URL format: ?session_id=your-audiocast-id"
)

st.markdown("---")
Expand Down
58 changes: 58 additions & 0 deletions src/services/firestore_sdk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import logging
from typing import Dict, Literal

from firebase_admin.firestore import client, firestore

firestore_client = client()
server_timestamp = firestore.SERVER_TIMESTAMP
increment = firestore.Increment
arrayUnion = firestore.ArrayUnion
arrayRemove = firestore.ArrayRemove


Collection = Literal["audiora_sessions", "audiora_audiocasts"]

collections: Dict[Collection, Collection] = {
"audiora_sessions": "audiora_sessions",
"audiora_audiocasts": "audiora_audiocasts",
}


class DBManager:
def __init__(self, scope: str):
self.logger = logging.getLogger(scope)

@property
def timestamp(self):
return server_timestamp

def _get_collection(self, collection: Collection):
return firestore_client.collection(collections[collection])

def _create_document(self, collection: Collection, data: Dict):
return self._get_collection(collection).add(
{**data, "created_at": self.timestamp, "updated_at": self.timestamp}
)

def _set_document(self, collection: Collection, doc_id: str, data: Dict):
return (
self._get_collection(collection)
.document(doc_id)
.set({**data, "created_at": self.timestamp, "updated_at": self.timestamp})
)

def _update_document(self, collection: Collection, doc_id: str, data: Dict):
return (
self._get_collection(collection)
.document(doc_id)
.update({**data, "updated_at": self.timestamp})
)

def _delete_document(self, collection: Collection, doc_id: str):
return self._get_collection(collection).document(doc_id).delete()

def _get_document(self, collection: Collection, doc_id: str):
return self._get_collection(collection).document(doc_id).get()

def _get_documents(self, collection: Collection):
return self._get_collection(collection).stream()
13 changes: 11 additions & 2 deletions src/services/storage.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import Any, Dict
from uuid import uuid4

from google.cloud import storage
from pydub import AudioSegment

from src.env_var import BUCKET_NAME

Expand Down Expand Up @@ -70,8 +72,15 @@ def download_from_gcs(self, filename: str):
"""
blobname = f"{BLOB_BASE_URI}/{filename}"
blob = bucket.blob(blobname)
tmp_file_path = f"/tmp/{str(uuid4())}"

blob.download_to_filename(tmp_file_path)
tmp_file_path = f"/tmp/{filename}"
if os.path.exists(tmp_file_path):
try:
audio = AudioSegment.from_file(tmp_file_path)
if audio.duration_seconds > 0:
return tmp_file_path
except Exception:
os.remove(tmp_file_path)

blob.download_to_filename(tmp_file_path)
return tmp_file_path
6 changes: 3 additions & 3 deletions src/uis/audioui.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from src.utils.render_audiocast import render_audiocast


async def audioui(uichat: DeltaGenerator):
async def audioui(session_id: str, uichat: DeltaGenerator):
"""
Audiocast interface
"""
Expand All @@ -17,7 +17,7 @@ async def audioui(uichat: DeltaGenerator):

summary = st.session_state.user_specification
content_category = st.session_state.content_category
await use_audiocast_request(summary, content_category)
await use_audiocast_request(session_id, summary, content_category)
else:
st.info("Audiocast generation completed!")
render_audiocast()
render_audiocast(session_id)
7 changes: 4 additions & 3 deletions src/uis/chatui.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from src.utils.render_chat import render_chat_history


async def chatui(uichat: DeltaGenerator):
async def chatui(session_id: str, uichat: DeltaGenerator):
"""
Chat interface
"""
Expand All @@ -27,12 +27,13 @@ async def chatui(uichat: DeltaGenerator):
content_category = st.session_state.content_category

if st.session_state.example_prompt:
handle_example_prompt(content_category)
prompt = st.session_state.example_prompt
handle_example_prompt(session_id, prompt, content_category)

if st.session_state.prompt:
prompt = st.session_state.prompt
st.session_state.prompt = None
ai_message = handle_user_prompt(prompt, content_category)
ai_message = handle_user_prompt(session_id, prompt, content_category)

if isinstance(ai_message, str):
await evaluate_final_response(ai_message, content_category)
Expand Down
8 changes: 3 additions & 5 deletions src/utils/audio_manager_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def __init__(self) -> None:

def _create_voice_mapping(self, tags: List[str], voices: List[Any]):
"""Create mapping of tags to voices"""
available_voices = voices[: len(tags)]
if len(available_voices) < len(tags):
available_voices = list(islice(cycle(voices), len(tags)))
available_voices = list(islice(cycle(voices), len(tags)))
return dict(zip(tags, available_voices))

def _prepare_speech_jobs(
Expand Down Expand Up @@ -120,8 +118,8 @@ def split_content(self, content: str, tags: List[str]) -> List[Tuple[str, str]]:
# Regular expression pattern to match Tag0, Tag1, ..., TagN speaker dialogues
matches = re.findall(r"<(Speaker\d+)>(.*?)</Speaker\d+>", content, re.DOTALL)
return [
(str(person), " ".join(content.split()).strip())
for person, content in matches
(str(speaker), " ".join(content_part.split()).strip())
for speaker, content_part in matches
]

@staticmethod
Expand Down
43 changes: 34 additions & 9 deletions src/utils/chat_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
termination_suffix = "Please click the button below to start generating the audiocast."


def generate_stream_response(prompt: str, content_category: ContentCategory):
def generate_stream_response(
session_id: str,
prompt: str,
content_category: ContentCategory,
):
with st.spinner("Generating response..."):
response_generator = chat(
st.session_state.chat_session_id,
session_id,
SessionChatRequest(
message=SessionChatMessage(role="user", content=prompt),
content_category=content_category,
Expand All @@ -27,12 +31,17 @@ def generate_stream_response(prompt: str, content_category: ContentCategory):
return response_generator


def handle_example_prompt(content_category: ContentCategory):
def handle_example_prompt(
session_id: str,
prompt: str,
content_category: ContentCategory,
):
"""Handle selected example prompt"""
prompt = st.session_state.example_prompt

with st.chat_message("assistant"):
response_generator = generate_stream_response(prompt, content_category)
response_generator = generate_stream_response(
session_id, prompt, content_category
)
ai_message = st.write_stream(response_generator)
st.session_state.example_prompt = None

Expand All @@ -45,12 +54,20 @@ def handle_example_prompt(content_category: ContentCategory):
st.error("Failed to generate AI response. Please try again.")


def handle_user_prompt(prompt: str, content_category: ContentCategory):
def handle_user_prompt(
session_id: str,
prompt: str,
content_category: ContentCategory,
):
"""
Handle user input prompt
"""
with st.chat_message("assistant"):
response_generator = generate_stream_response(prompt, content_category)
response_generator = generate_stream_response(
session_id,
prompt,
content_category,
)
ai_message = st.write_stream(response_generator)

if ai_message:
Expand Down Expand Up @@ -110,7 +127,11 @@ def onclick(summary: str):
st.rerun()


async def use_audiocast_request(summary: str, content_category: ContentCategory):
async def use_audiocast_request(
session_id: str,
summary: str,
content_category: ContentCategory,
):
"""
Call audiocast creating workflow
Expand All @@ -121,7 +142,11 @@ async def use_audiocast_request(summary: str, content_category: ContentCategory)
try:
with st.spinner("Generating your audiocast..."):
audiocast_response = await generate_audiocast(
GenerateAudioCastRequest(summary=summary, category=content_category)
GenerateAudioCastRequest(
sessionId=session_id,
summary=summary,
category=content_category,
)
)
print(f"Generate AudioCast Response: {audiocast_response}")

Expand Down
6 changes: 4 additions & 2 deletions src/utils/chat_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import uuid
from typing import Dict, List, Literal

import streamlit as st
from pydantic import BaseModel
from pydantic import BaseModel, Field

ContentCategory = Literal[
"podcast",
Expand Down Expand Up @@ -49,8 +50,9 @@


class SessionChatMessage(BaseModel):
role: Literal["user", "assistant"]
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
content: str
role: Literal["user", "assistant"]


class SessionChatRequest(BaseModel):
Expand Down
Loading

0 comments on commit 82db371

Please sign in to comment.