From e56d6c88971d4d32b27cdbd49e9e2a466cd2c9b9 Mon Sep 17 00:00:00 2001 From: Chukwuma Nwaugha Date: Mon, 28 Oct 2024 19:24:58 +0000 Subject: [PATCH] refine and simplify --- app.py | 9 +++++- src/utils/chat_thread.py | 59 ++++++++++++++------------------------ src/utils/session_state.py | 21 +++++++++----- 3 files changed, 43 insertions(+), 46 deletions(-) diff --git a/app.py b/app.py index 7336c0f..ffc504a 100644 --- a/app.py +++ b/app.py @@ -6,6 +6,7 @@ evaluate_final_response, handle_example_prompt, handle_user_prompt, + use_audiocast_request, ) from src.utils.chat_utils import display_example_cards from src.utils.render_audiocast import render_audiocast @@ -29,7 +30,13 @@ async def main(): # Main chat interface st.title("🎧 AudioCaster") - if st.session_state.generating_audiocast: + if st.session_state.user_specification: + st.info("Generating audiocast from your specifications") + + summary = st.session_state.user_specification + content_category = st.session_state.content_category + await use_audiocast_request(summary, content_category) + if st.session_state.current_audiocast: render_audiocast() else: diff --git a/src/utils/chat_thread.py b/src/utils/chat_thread.py index 61d7014..91f8b02 100644 --- a/src/utils/chat_thread.py +++ b/src/utils/chat_thread.py @@ -1,4 +1,3 @@ -import asyncio import re import streamlit as st @@ -89,10 +88,15 @@ async def evaluate_final_response(ai_message: str, content_category: ContentCate col1, col2 = st.columns(2) with col1: + + def onclick(v: str): + st.session_state.user_specification = v + if st.button( "Generate Audiocast", use_container_width=True, - on_click=UseAudiocast(summary, content_category).run, + on_click=onclick, + args=(summary,), ): st.rerun() with col2: @@ -100,41 +104,20 @@ async def evaluate_final_response(ai_message: str, content_category: ContentCate st.rerun() -class UseAudiocast: - summary: str - content_category: ContentCategory - - def __init__(self, summary: str, content_category: ContentCategory): - self.summary = summary - self.content_category = content_category - - def run(self): - """ - Run command to start generating audiocast - """ - st.session_state.generating_audiocast = True - - async def handler(): - await self.__use_audiocast_request(self.summary, self.content_category) - - asyncio.run(handler()) - - async def __use_audiocast_request( - self, summary: str, content_category: ContentCategory - ): - """ - Call audiocast creating workflow +async def use_audiocast_request(summary: str, content_category: ContentCategory): + """ + Call audiocast creating workflow - Args: - summary (str): user request summary or user specification - content_category (ContentCategory): content category - """ - with st.spinner("Generating your audiocast..."): - audiocast_response = await generate_audiocast( - GenerateAudioCastRequest( - summary=summary, - category=content_category, - ) + Args: + summary (str): user request summary or user specification + content_category (ContentCategory): content category + """ + with st.spinner("Generating your audiocast..."): + audiocast_response = await generate_audiocast( + GenerateAudioCastRequest( + summary=summary, + category=content_category, ) - print(f"Generate AudioCast Response: {audiocast_response}") - st.session_state.current_audiocast = audiocast_response + ) + print(f"Generate AudioCast Response: {audiocast_response}") + st.session_state.current_audiocast = audiocast_response diff --git a/src/utils/session_state.py b/src/utils/session_state.py index 9e5a971..1386e1e 100644 --- a/src/utils/session_state.py +++ b/src/utils/session_state.py @@ -17,18 +17,21 @@ def init_session_state(): """Initialize session state""" if "chat_session_id" not in st.session_state: st.session_state.chat_session_id = str(uuid.uuid4()) + if "messages" not in st.session_state: st.session_state.messages = cast(List[ChatMessage], []) - if "current_audiocast" not in st.session_state: - st.session_state.current_audiocast = None if "example_prompt" not in st.session_state: st.session_state.example_prompt = None if "prompt" not in st.session_state: st.session_state.prompt = None + if "content_category" not in st.session_state: st.session_state.content_category = cast(ContentCategory | None, None) - if "generating_audiocast" not in st.session_state: - st.session_state.generating_audiocast = False + + if "user_specification" not in st.session_state: + st.session_state.user_specification = None + if "current_audiocast" not in st.session_state: + st.session_state.current_audiocast = None def reset_session(): @@ -37,11 +40,15 @@ def reset_session(): #### Client must call st.rerun() """ - st.session_state.messages = [] st.session_state.chat_session_id = str(uuid.uuid4()) - st.session_state.current_audiocast = None + st.session_state.messages = [] st.session_state.example_prompt = None st.session_state.prompt = None - st.session_state.generating_audiocast = False + + st.session_state.content_category = None + + st.session_state.user_specification = None + st.session_state.current_audiocast = None + st.cache_data.clear()