Skip to content

Commit

Permalink
Attempt to detect content category (#22)
Browse files Browse the repository at this point in the history
* fix bug in render_category_selection

* build the ui workflow to automatically detect content category

* allow overloading loading state in chat_item

* write the logic and endpoint to detect content category

* add improved workflow for selecting auto-detected contentcategory

* move ui logic for auto_detected_category to own file

* add createdAt property to chat items and implement error handling for loading state

* allow user to regenerate errored ai responses
  • Loading branch information
nwaughachukwuma authored Dec 5, 2024
1 parent c68ed25 commit 350081a
Show file tree
Hide file tree
Showing 9 changed files with 258 additions and 21 deletions.
10 changes: 10 additions & 0 deletions api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .services.storage import StorageManager
from .utils.chat_request import chat_request
from .utils.chat_utils import (
ContentCategory,
SessionChatItem,
SessionChatRequest,
)
Expand All @@ -25,6 +26,7 @@
)
from .utils.custom_sources.save_copied_source import CopiedPasteSourceRequest, save_copied_source
from .utils.custom_sources.save_uploaded_sources import UploadedFiles
from .utils.detect_content_category import DetectContentCategoryRequest, detect_content_category
from .utils.generate_audiocast import GenerateAudioCastRequest, GenerateAudiocastException, generate_audiocast
from .utils.generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source
from .utils.get_audiocast import get_audiocast
Expand Down Expand Up @@ -206,3 +208,11 @@ def delete_session_endpoint(sessionId: str):
"""
SessionManager._delete_session(sessionId)
return "Deleted"


@app.post("/detect-category", response_model=ContentCategory)
async def detect_category_endpoint(request: DetectContentCategoryRequest):
"""
Detect category of a given content
"""
return await detect_content_category(request.content)
68 changes: 68 additions & 0 deletions api/src/utils/detect_content_category.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from pydantic import BaseModel

from src.services.gemini_client import get_gemini
from src.utils.chat_utils import ContentCategory, content_categories


class DetectContentCategoryRequest(BaseModel):
content: str


def detect_category_prompt(content: str) -> str:
"""
System Prompt to detect the category of a given content
"""
return f"""You are an intelligent content type classifier. Your task is to analyze the given content and categorize it into one of the following types:
CONTENT: "{content}"
CATEGORIES: {', '.join(content_categories)}
IMPORTANT NOTE:
- You must ONLY output one of these exact category names, with no additional text, explanation, preamble or formatting.
- If the content doesn't fit any of the categories, you should output "other".
Examples:
Input: "Welcome to today's episode where we'll be discussing the fascinating world of quantum computing..."
Output: podcast
Input: "And now, dear brothers and sisters, let us reflect on the profound message in today's scripture..."
Output: sermon
Input: "Let's dive into today's lecture on advanced machine learning algorithms..."
Output: lecture
"""


def validate_category_output(output: str) -> ContentCategory:
"""
Validate that the AI output is a valid content category.
Throws ValueError if the output is not a valid category.
"""
cleaned_output = output.strip().lower()
if cleaned_output not in content_categories:
raise ValueError(f"Invalid category '{cleaned_output}'. Must be one of: {', '.join(content_categories)}")
return cleaned_output


async def detect_content_category(content: str) -> ContentCategory:
"""
Detect the category of the given content using Gemini Flash.
"""
client = get_gemini()

model = client.GenerativeModel(
model_name="gemini-1.5-flash-002",
system_instruction=detect_category_prompt(content),
generation_config=client.GenerationConfig(
temperature=0.1,
max_output_tokens=30,
response_mime_type="text/plain",
),
)

response = model.generate_content(["Now, please categorize the content."])
if not response.text:
raise Exception("Error obtaining response from Gemini Flash")

return validate_category_output(response.text)
43 changes: 43 additions & 0 deletions app/src/lib/components/AutoDetectedCategory.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
<script lang="ts" context="module">
import type { ContentCategory } from '@/utils/types';
export const categoryArticleMap: Record<ContentCategory, string> = {
podcast: 'a podcast',
sermon: 'a sermon',
audiodrama: 'an audiodrama',
lecture: 'a lecture',
commentary: 'a commentary',
voicenote: 'a voicenote',
interview: 'an interview',
soundbite: 'a soundbite'
};
</script>

<script lang="ts">
import ChatListItem from './ChatListItem.svelte';
import { Button } from './ui/button';
import { createEventDispatcher } from 'svelte';
import { ArrowRight } from 'lucide-svelte';
export let category: ContentCategory;
const dispatch = createEventDispatcher<{ selectCategory: { value: ContentCategory } }>();
$: categoryWithArticle = `"${categoryArticleMap[category]}"`;
</script>

<ChatListItem
type="assistant"
content="I auto-detected you want {categoryWithArticle}. Press NEXT to continue if correct."
/>

<Button
variant="ghost"
class="text-base px-10 py-6 bg-gray-800 w-fit hover:bg-gray-700"
on:click={() => dispatch('selectCategory', { value: category })}
>
<span> Next </span>
<ArrowRight class="w-4 ml-1 inline" />
</Button>

<ChatListItem type="assistant" content="Else, select your audiocast category" />
37 changes: 34 additions & 3 deletions app/src/lib/components/ChatListItem.svelte
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
<script context="module">
const TWO_MINUTES_MS = 120000;
</script>

<script lang="ts">
import { UserIcon, BotIcon } from 'lucide-svelte';
import { UserIcon, BotIcon, RotateCw } from 'lucide-svelte';
import cs from 'clsx';
import { parse } from 'marked';
import { Button } from './ui/button';
import { createEventDispatcher } from 'svelte';
export let type: 'user' | 'assistant';
export let content: string;
export let loading = false;
export let createdAt: number | undefined = undefined;
const dispatch = createEventDispatcher<{ regenerate: void }>();
$: likelyErrored = loading && (!createdAt || Date.now() - createdAt > TWO_MINUTES_MS);
</script>

<div
Expand All @@ -28,12 +39,32 @@
</div>

<div class="max-w-full justify-center break-words text-gray-200 flex flex-col text-base">
{#if loading}
<span class="animate-pulse">Generating response...</span>
{#if likelyErrored}
<span>
Failed to generate response.
<span class="text-red-300">Likely errored</span>
</span>
{:else if loading}
<slot name="loading">
<span class="animate-pulse">Generating response...</span>
</slot>
{:else}
{#await parse(content) then parsedContent}
{@html parsedContent}
{:catch error}
<p class="text-red-300">{String(error)}</p>
{/await}
{/if}
</div>
</div>

{#if likelyErrored}
<Button
variant="ghost"
class="w-fit bg-gray-800 flex gap-x-2 text-gray-400 items-center hover:bg-gray-700 transition-all px-4 py-0.5"
on:click={() => dispatch('regenerate')}
>
<span>Regenerate</span>
<RotateCw class="inline w-4" />
</Button>
{/if}
2 changes: 1 addition & 1 deletion app/src/lib/components/ExampleCard.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
async function handleClick() {
startSession(category);
addChatItem({ id: uuid(), content, role: 'user', loading: false });
addChatItem({ id: uuid(), content, role: 'user', loading: false, createdAt: Date.now() });
return goto(href, { invalidateAll: true, replaceState: true });
}
</script>
Expand Down
51 changes: 47 additions & 4 deletions app/src/lib/components/RenderCategorySelection.svelte
Original file line number Diff line number Diff line change
@@ -1,13 +1,56 @@
<script lang="ts">
import ChatListItem from './ChatListItem.svelte';
import SelectCategory from './SelectCategory.svelte';
import type { ContentCategory } from '@/utils/types';
import SelectCategory, { categories } from './SelectCategory.svelte';
import { env } from '@env';
import AutoDetectedCategory from './AutoDetectedCategory.svelte';
export let content: string;
let detectingCategory = false;
let detectedCategory: ContentCategory | null = null;
$: getCategorySelection();
async function getCategorySelection() {
if (detectingCategory) return;
detectingCategory = true;
return fetch(`${env.API_BASE_URL}/detect-category`, {
method: 'POST',
body: JSON.stringify({ content }),
headers: { 'Content-Type': 'application/json' }
})
.then<ContentCategory>((res) => {
if (res.ok) return res.json();
throw new Error(res.statusText);
})
.then((res) => categories.includes(res) && (detectedCategory = res))
.catch(() => void 0)
.finally(() => (detectingCategory = false));
}
</script>

<!-- Rename to get_category_selection.svelte -->

<div class="flex flex-col gap-y-3 h-full">
<ChatListItem type="user" {content} />
<div class="flex flex-col gap-y-3 w-full">
<ChatListItem type="user" {content} />

<ChatListItem type="assistant" content="Please select your audiocast category" />
{#if detectingCategory}
<ChatListItem type="assistant" content="Auto-detecting content category..." loading>
<span slot="loading" class="animate-pulse">
Auto-detecting content category...Please wait
</span>
</ChatListItem>
{:else if detectedCategory}
<AutoDetectedCategory category={detectedCategory} on:selectCategory />
{:else}
<ChatListItem type="assistant" content="Please select your audiocast category" />
{/if}
</div>

<SelectCategory on:selectCategory />
{#if !detectingCategory}
<SelectCategory on:selectCategory />
{/if}
</div>
12 changes: 12 additions & 0 deletions app/src/lib/stores/sessionContext.svelte.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export type ChatItem = {
content: string;
role: 'user' | 'assistant';
loading?: boolean;
createdAt?: number;
};

export type Session = {
Expand Down Expand Up @@ -74,6 +75,17 @@ export function setSessionContext(sessionId: string) {
return session;
});
},
removeChatItem: (chatId: string) => {
session$.update((session) => {
if (!session) return session;

const chats = session.chats.filter((i) => i.id !== chatId);
session.chats = chats;
return session;
});

return session$;
},
updateChatContent: (chatId: string, chunk: string) => {
session$.update((session) => {
if (!session) return session;
Expand Down
2 changes: 1 addition & 1 deletion app/src/routes/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
startSession(category);
const content = `${selectContent}\nCategory: ${category} `;
addChatItem({ id: uuid(), content, role: 'user', loading: false });
addChatItem({ id: uuid(), content, role: 'user', loading: false, createdAt: Date.now() });
const href = `/chat/${sessionId}?category=${category}&chat`;
return goto(href, { invalidateAll: true, replaceState: true });
Expand Down
Loading

0 comments on commit 350081a

Please sign in to comment.