Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt to detect content category #22

Merged
merged 8 commits into from
Dec 5, 2024
10 changes: 10 additions & 0 deletions api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .services.storage import StorageManager
from .utils.chat_request import chat_request
from .utils.chat_utils import (
ContentCategory,
SessionChatItem,
SessionChatRequest,
)
Expand All @@ -25,6 +26,7 @@
)
from .utils.custom_sources.save_copied_source import CopiedPasteSourceRequest, save_copied_source
from .utils.custom_sources.save_uploaded_sources import UploadedFiles
from .utils.detect_content_category import DetectContentCategoryRequest, detect_content_category
from .utils.generate_audiocast import GenerateAudioCastRequest, GenerateAudiocastException, generate_audiocast
from .utils.generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source
from .utils.get_audiocast import get_audiocast
Expand Down Expand Up @@ -206,3 +208,11 @@ def delete_session_endpoint(sessionId: str):
"""
SessionManager._delete_session(sessionId)
return "Deleted"


@app.post("/detect-category", response_model=ContentCategory)
async def detect_category_endpoint(request: DetectContentCategoryRequest):
"""
Detect category of a given content
"""
return await detect_content_category(request.content)
68 changes: 68 additions & 0 deletions api/src/utils/detect_content_category.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from pydantic import BaseModel

from src.services.gemini_client import get_gemini
from src.utils.chat_utils import ContentCategory, content_categories


class DetectContentCategoryRequest(BaseModel):
content: str


def detect_category_prompt(content: str) -> str:
"""
System Prompt to detect the category of a given content
"""
return f"""You are an intelligent content type classifier. Your task is to analyze the given content and categorize it into one of the following types:
CONTENT: "{content}"

CATEGORIES: {', '.join(content_categories)}


IMPORTANT NOTE:
- You must ONLY output one of these exact category names, with no additional text, explanation, preamble or formatting.
- If the content doesn't fit any of the categories, you should output "other".

Examples:
Input: "Welcome to today's episode where we'll be discussing the fascinating world of quantum computing..."
Output: podcast

Input: "And now, dear brothers and sisters, let us reflect on the profound message in today's scripture..."
Output: sermon

Input: "Let's dive into today's lecture on advanced machine learning algorithms..."
Output: lecture
"""


def validate_category_output(output: str) -> ContentCategory:
"""
Validate that the AI output is a valid content category.
Throws ValueError if the output is not a valid category.
"""
cleaned_output = output.strip().lower()
if cleaned_output not in content_categories:
raise ValueError(f"Invalid category '{cleaned_output}'. Must be one of: {', '.join(content_categories)}")
return cleaned_output


async def detect_content_category(content: str) -> ContentCategory:
"""
Detect the category of the given content using Gemini Flash.
"""
client = get_gemini()

model = client.GenerativeModel(
model_name="gemini-1.5-flash-002",
system_instruction=detect_category_prompt(content),
generation_config=client.GenerationConfig(
temperature=0.1,
max_output_tokens=30,
response_mime_type="text/plain",
),
)

response = model.generate_content(["Now, please categorize the content."])
if not response.text:
raise Exception("Error obtaining response from Gemini Flash")

return validate_category_output(response.text)
43 changes: 43 additions & 0 deletions app/src/lib/components/AutoDetectedCategory.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
<script lang="ts" context="module">
import type { ContentCategory } from '@/utils/types';

export const categoryArticleMap: Record<ContentCategory, string> = {
podcast: 'a podcast',
sermon: 'a sermon',
audiodrama: 'an audiodrama',
lecture: 'a lecture',
commentary: 'a commentary',
voicenote: 'a voicenote',
interview: 'an interview',
soundbite: 'a soundbite'
};
</script>

<script lang="ts">
import ChatListItem from './ChatListItem.svelte';
import { Button } from './ui/button';
import { createEventDispatcher } from 'svelte';
import { ArrowRight } from 'lucide-svelte';

export let category: ContentCategory;

const dispatch = createEventDispatcher<{ selectCategory: { value: ContentCategory } }>();

$: categoryWithArticle = `"${categoryArticleMap[category]}"`;
</script>

<ChatListItem
type="assistant"
content="I auto-detected you want {categoryWithArticle}. Press NEXT to continue if correct."
/>

<Button
variant="ghost"
class="text-base px-10 py-6 bg-gray-800 w-fit hover:bg-gray-700"
on:click={() => dispatch('selectCategory', { value: category })}
>
<span> Next </span>
<ArrowRight class="w-4 ml-1 inline" />
</Button>

<ChatListItem type="assistant" content="Else, select your audiocast category" />
37 changes: 34 additions & 3 deletions app/src/lib/components/ChatListItem.svelte
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
<script context="module">
const TWO_MINUTES_MS = 120000;
</script>

<script lang="ts">
import { UserIcon, BotIcon } from 'lucide-svelte';
import { UserIcon, BotIcon, RotateCw } from 'lucide-svelte';
import cs from 'clsx';
import { parse } from 'marked';
import { Button } from './ui/button';
import { createEventDispatcher } from 'svelte';

export let type: 'user' | 'assistant';
export let content: string;
export let loading = false;
export let createdAt: number | undefined = undefined;

const dispatch = createEventDispatcher<{ regenerate: void }>();

$: likelyErrored = loading && (!createdAt || Date.now() - createdAt > TWO_MINUTES_MS);
</script>

<div
Expand All @@ -28,12 +39,32 @@
</div>

<div class="max-w-full justify-center break-words text-gray-200 flex flex-col text-base">
{#if loading}
<span class="animate-pulse">Generating response...</span>
{#if likelyErrored}
<span>
Failed to generate response.
<span class="text-red-300">Likely errored</span>
</span>
{:else if loading}
<slot name="loading">
<span class="animate-pulse">Generating response...</span>
</slot>
{:else}
{#await parse(content) then parsedContent}
{@html parsedContent}
{:catch error}
<p class="text-red-300">{String(error)}</p>
{/await}
{/if}
</div>
</div>

{#if likelyErrored}
<Button
variant="ghost"
class="w-fit bg-gray-800 flex gap-x-2 text-gray-400 items-center hover:bg-gray-700 transition-all px-4 py-0.5"
on:click={() => dispatch('regenerate')}
>
<span>Regenerate</span>
<RotateCw class="inline w-4" />
</Button>
{/if}
2 changes: 1 addition & 1 deletion app/src/lib/components/ExampleCard.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
async function handleClick() {
startSession(category);
addChatItem({ id: uuid(), content, role: 'user', loading: false });
addChatItem({ id: uuid(), content, role: 'user', loading: false, createdAt: Date.now() });
return goto(href, { invalidateAll: true, replaceState: true });
}
</script>
Expand Down
51 changes: 47 additions & 4 deletions app/src/lib/components/RenderCategorySelection.svelte
Original file line number Diff line number Diff line change
@@ -1,13 +1,56 @@
<script lang="ts">
import ChatListItem from './ChatListItem.svelte';
import SelectCategory from './SelectCategory.svelte';
import type { ContentCategory } from '@/utils/types';
import SelectCategory, { categories } from './SelectCategory.svelte';
import { env } from '@env';
import AutoDetectedCategory from './AutoDetectedCategory.svelte';

export let content: string;

let detectingCategory = false;
let detectedCategory: ContentCategory | null = null;

$: getCategorySelection();

async function getCategorySelection() {
if (detectingCategory) return;
detectingCategory = true;

return fetch(`${env.API_BASE_URL}/detect-category`, {
method: 'POST',
body: JSON.stringify({ content }),
headers: { 'Content-Type': 'application/json' }
})
.then<ContentCategory>((res) => {
if (res.ok) return res.json();
throw new Error(res.statusText);
})
.then((res) => categories.includes(res) && (detectedCategory = res))
.catch(() => void 0)
.finally(() => (detectingCategory = false));
}
</script>

<!-- Rename to get_category_selection.svelte -->

<div class="flex flex-col gap-y-3 h-full">
<ChatListItem type="user" {content} />
<div class="flex flex-col gap-y-3 w-full">
<ChatListItem type="user" {content} />

<ChatListItem type="assistant" content="Please select your audiocast category" />
{#if detectingCategory}
<ChatListItem type="assistant" content="Auto-detecting content category..." loading>
<span slot="loading" class="animate-pulse">
Auto-detecting content category...Please wait
</span>
</ChatListItem>
{:else if detectedCategory}
<AutoDetectedCategory category={detectedCategory} on:selectCategory />
{:else}
<ChatListItem type="assistant" content="Please select your audiocast category" />
{/if}
</div>

<SelectCategory on:selectCategory />
{#if !detectingCategory}
<SelectCategory on:selectCategory />
{/if}
</div>
12 changes: 12 additions & 0 deletions app/src/lib/stores/sessionContext.svelte.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export type ChatItem = {
content: string;
role: 'user' | 'assistant';
loading?: boolean;
createdAt?: number;
};

export type Session = {
Expand Down Expand Up @@ -74,6 +75,17 @@ export function setSessionContext(sessionId: string) {
return session;
});
},
removeChatItem: (chatId: string) => {
session$.update((session) => {
if (!session) return session;

const chats = session.chats.filter((i) => i.id !== chatId);
session.chats = chats;
return session;
});

return session$;
},
updateChatContent: (chatId: string, chunk: string) => {
session$.update((session) => {
if (!session) return session;
Expand Down
2 changes: 1 addition & 1 deletion app/src/routes/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
startSession(category);

const content = `${selectContent}\nCategory: ${category} `;
addChatItem({ id: uuid(), content, role: 'user', loading: false });
addChatItem({ id: uuid(), content, role: 'user', loading: false, createdAt: Date.now() });

const href = `/chat/${sessionId}?category=${category}&chat`;
return goto(href, { invalidateAll: true, replaceState: true });
Expand Down
Loading
Loading