From 350081a964fe2407f23d8279ac317523684812c4 Mon Sep 17 00:00:00 2001 From: Chukwuma Nwaugha Date: Thu, 5 Dec 2024 11:24:37 +0000 Subject: [PATCH] Attempt to detect content category (#22) * fix bug in render_category_selection * build the ui workflow to automatically detect content category * allow overloading loading state in chat_item * write the logic and endpoint to detect content category * add improved workflow for selecting auto-detected contentcategory * move ui logic for auto_detected_category to own file * add createdAt property to chat items and implement error handling for loading state * allow user to regenerate errored ai responses --- api/src/main.py | 10 +++ api/src/utils/detect_content_category.py | 68 +++++++++++++++++++ .../components/AutoDetectedCategory.svelte | 43 ++++++++++++ app/src/lib/components/ChatListItem.svelte | 37 +++++++++- app/src/lib/components/ExampleCard.svelte | 2 +- .../components/RenderCategorySelection.svelte | 51 ++++++++++++-- app/src/lib/stores/sessionContext.svelte.ts | 12 ++++ app/src/routes/+page.svelte | 2 +- .../chat/[sessionId=sessionId]/+page.svelte | 54 +++++++++++---- 9 files changed, 258 insertions(+), 21 deletions(-) create mode 100644 api/src/utils/detect_content_category.py create mode 100644 app/src/lib/components/AutoDetectedCategory.svelte diff --git a/api/src/main.py b/api/src/main.py index 0409301..9336d64 100644 --- a/api/src/main.py +++ b/api/src/main.py @@ -10,6 +10,7 @@ from .services.storage import StorageManager from .utils.chat_request import chat_request from .utils.chat_utils import ( + ContentCategory, SessionChatItem, SessionChatRequest, ) @@ -25,6 +26,7 @@ ) from .utils.custom_sources.save_copied_source import CopiedPasteSourceRequest, save_copied_source from .utils.custom_sources.save_uploaded_sources import UploadedFiles +from .utils.detect_content_category import DetectContentCategoryRequest, detect_content_category from .utils.generate_audiocast import GenerateAudioCastRequest, GenerateAudiocastException, generate_audiocast from .utils.generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source from .utils.get_audiocast import get_audiocast @@ -206,3 +208,11 @@ def delete_session_endpoint(sessionId: str): """ SessionManager._delete_session(sessionId) return "Deleted" + + +@app.post("/detect-category", response_model=ContentCategory) +async def detect_category_endpoint(request: DetectContentCategoryRequest): + """ + Detect category of a given content + """ + return await detect_content_category(request.content) diff --git a/api/src/utils/detect_content_category.py b/api/src/utils/detect_content_category.py new file mode 100644 index 0000000..56dfe56 --- /dev/null +++ b/api/src/utils/detect_content_category.py @@ -0,0 +1,68 @@ +from pydantic import BaseModel + +from src.services.gemini_client import get_gemini +from src.utils.chat_utils import ContentCategory, content_categories + + +class DetectContentCategoryRequest(BaseModel): + content: str + + +def detect_category_prompt(content: str) -> str: + """ + System Prompt to detect the category of a given content + """ + return f"""You are an intelligent content type classifier. Your task is to analyze the given content and categorize it into one of the following types: + CONTENT: "{content}" + + CATEGORIES: {', '.join(content_categories)} + + + IMPORTANT NOTE: + - You must ONLY output one of these exact category names, with no additional text, explanation, preamble or formatting. + - If the content doesn't fit any of the categories, you should output "other". + + Examples: + Input: "Welcome to today's episode where we'll be discussing the fascinating world of quantum computing..." + Output: podcast + + Input: "And now, dear brothers and sisters, let us reflect on the profound message in today's scripture..." + Output: sermon + + Input: "Let's dive into today's lecture on advanced machine learning algorithms..." + Output: lecture +""" + + +def validate_category_output(output: str) -> ContentCategory: + """ + Validate that the AI output is a valid content category. + Throws ValueError if the output is not a valid category. + """ + cleaned_output = output.strip().lower() + if cleaned_output not in content_categories: + raise ValueError(f"Invalid category '{cleaned_output}'. Must be one of: {', '.join(content_categories)}") + return cleaned_output + + +async def detect_content_category(content: str) -> ContentCategory: + """ + Detect the category of the given content using Gemini Flash. + """ + client = get_gemini() + + model = client.GenerativeModel( + model_name="gemini-1.5-flash-002", + system_instruction=detect_category_prompt(content), + generation_config=client.GenerationConfig( + temperature=0.1, + max_output_tokens=30, + response_mime_type="text/plain", + ), + ) + + response = model.generate_content(["Now, please categorize the content."]) + if not response.text: + raise Exception("Error obtaining response from Gemini Flash") + + return validate_category_output(response.text) diff --git a/app/src/lib/components/AutoDetectedCategory.svelte b/app/src/lib/components/AutoDetectedCategory.svelte new file mode 100644 index 0000000..fdd899a --- /dev/null +++ b/app/src/lib/components/AutoDetectedCategory.svelte @@ -0,0 +1,43 @@ + + + + + + + + + diff --git a/app/src/lib/components/ChatListItem.svelte b/app/src/lib/components/ChatListItem.svelte index 5e416f8..bc5a3e6 100644 --- a/app/src/lib/components/ChatListItem.svelte +++ b/app/src/lib/components/ChatListItem.svelte @@ -1,11 +1,22 @@ + +
- {#if loading} - Generating response... + {#if likelyErrored} + + Failed to generate response. + Likely errored + + {:else if loading} + + Generating response... + {:else} {#await parse(content) then parsedContent} {@html parsedContent} + {:catch error} +

{String(error)}

{/await} {/if}
+ +{#if likelyErrored} + +{/if} diff --git a/app/src/lib/components/ExampleCard.svelte b/app/src/lib/components/ExampleCard.svelte index cbd7581..74063b1 100644 --- a/app/src/lib/components/ExampleCard.svelte +++ b/app/src/lib/components/ExampleCard.svelte @@ -12,7 +12,7 @@ async function handleClick() { startSession(category); - addChatItem({ id: uuid(), content, role: 'user', loading: false }); + addChatItem({ id: uuid(), content, role: 'user', loading: false, createdAt: Date.now() }); return goto(href, { invalidateAll: true, replaceState: true }); } diff --git a/app/src/lib/components/RenderCategorySelection.svelte b/app/src/lib/components/RenderCategorySelection.svelte index b76ed28..632c768 100644 --- a/app/src/lib/components/RenderCategorySelection.svelte +++ b/app/src/lib/components/RenderCategorySelection.svelte @@ -1,13 +1,56 @@ + +
- +
+ - + {#if detectingCategory} + + + Auto-detecting content category...Please wait + + + {:else if detectedCategory} + + {:else} + + {/if} +
- + {#if !detectingCategory} + + {/if}
diff --git a/app/src/lib/stores/sessionContext.svelte.ts b/app/src/lib/stores/sessionContext.svelte.ts index 3f37bed..518b36b 100644 --- a/app/src/lib/stores/sessionContext.svelte.ts +++ b/app/src/lib/stores/sessionContext.svelte.ts @@ -15,6 +15,7 @@ export type ChatItem = { content: string; role: 'user' | 'assistant'; loading?: boolean; + createdAt?: number; }; export type Session = { @@ -74,6 +75,17 @@ export function setSessionContext(sessionId: string) { return session; }); }, + removeChatItem: (chatId: string) => { + session$.update((session) => { + if (!session) return session; + + const chats = session.chats.filter((i) => i.id !== chatId); + session.chats = chats; + return session; + }); + + return session$; + }, updateChatContent: (chatId: string, chunk: string) => { session$.update((session) => { if (!session) return session; diff --git a/app/src/routes/+page.svelte b/app/src/routes/+page.svelte index 201dd39..a872fc4 100644 --- a/app/src/routes/+page.svelte +++ b/app/src/routes/+page.svelte @@ -25,7 +25,7 @@ startSession(category); const content = `${selectContent}\nCategory: ${category} `; - addChatItem({ id: uuid(), content, role: 'user', loading: false }); + addChatItem({ id: uuid(), content, role: 'user', loading: false, createdAt: Date.now() }); const href = `/chat/${sessionId}?category=${category}&chat`; return goto(href, { invalidateAll: true, replaceState: true }); diff --git a/app/src/routes/chat/[sessionId=sessionId]/+page.svelte b/app/src/routes/chat/[sessionId=sessionId]/+page.svelte index 77fbf3b..5671b49 100644 --- a/app/src/routes/chat/[sessionId=sessionId]/+page.svelte +++ b/app/src/routes/chat/[sessionId=sessionId]/+page.svelte @@ -1,5 +1,5 @@ @@ -92,7 +116,13 @@
{#each sessionChats as item (item.id)} {@const finalResponse = isfinalResponse(item)} - + {#if finalResponse}