diff --git a/api/src/main.py b/api/src/main.py
index 0409301..9336d64 100644
--- a/api/src/main.py
+++ b/api/src/main.py
@@ -10,6 +10,7 @@
from .services.storage import StorageManager
from .utils.chat_request import chat_request
from .utils.chat_utils import (
+ ContentCategory,
SessionChatItem,
SessionChatRequest,
)
@@ -25,6 +26,7 @@
)
from .utils.custom_sources.save_copied_source import CopiedPasteSourceRequest, save_copied_source
from .utils.custom_sources.save_uploaded_sources import UploadedFiles
+from .utils.detect_content_category import DetectContentCategoryRequest, detect_content_category
from .utils.generate_audiocast import GenerateAudioCastRequest, GenerateAudiocastException, generate_audiocast
from .utils.generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source
from .utils.get_audiocast import get_audiocast
@@ -206,3 +208,11 @@ def delete_session_endpoint(sessionId: str):
"""
SessionManager._delete_session(sessionId)
return "Deleted"
+
+
+@app.post("/detect-category", response_model=ContentCategory)
+async def detect_category_endpoint(request: DetectContentCategoryRequest):
+ """
+ Detect category of a given content
+ """
+ return await detect_content_category(request.content)
diff --git a/api/src/utils/detect_content_category.py b/api/src/utils/detect_content_category.py
new file mode 100644
index 0000000..56dfe56
--- /dev/null
+++ b/api/src/utils/detect_content_category.py
@@ -0,0 +1,68 @@
+from pydantic import BaseModel
+
+from src.services.gemini_client import get_gemini
+from src.utils.chat_utils import ContentCategory, content_categories
+
+
+class DetectContentCategoryRequest(BaseModel):
+ content: str
+
+
+def detect_category_prompt(content: str) -> str:
+ """
+ System Prompt to detect the category of a given content
+ """
+ return f"""You are an intelligent content type classifier. Your task is to analyze the given content and categorize it into one of the following types:
+ CONTENT: "{content}"
+
+ CATEGORIES: {', '.join(content_categories)}
+
+
+ IMPORTANT NOTE:
+ - You must ONLY output one of these exact category names, with no additional text, explanation, preamble or formatting.
+ - If the content doesn't fit any of the categories, you should output "other".
+
+ Examples:
+ Input: "Welcome to today's episode where we'll be discussing the fascinating world of quantum computing..."
+ Output: podcast
+
+ Input: "And now, dear brothers and sisters, let us reflect on the profound message in today's scripture..."
+ Output: sermon
+
+ Input: "Let's dive into today's lecture on advanced machine learning algorithms..."
+ Output: lecture
+"""
+
+
+def validate_category_output(output: str) -> ContentCategory:
+ """
+ Validate that the AI output is a valid content category.
+ Throws ValueError if the output is not a valid category.
+ """
+ cleaned_output = output.strip().lower()
+ if cleaned_output not in content_categories:
+ raise ValueError(f"Invalid category '{cleaned_output}'. Must be one of: {', '.join(content_categories)}")
+ return cleaned_output
+
+
+async def detect_content_category(content: str) -> ContentCategory:
+ """
+ Detect the category of the given content using Gemini Flash.
+ """
+ client = get_gemini()
+
+ model = client.GenerativeModel(
+ model_name="gemini-1.5-flash-002",
+ system_instruction=detect_category_prompt(content),
+ generation_config=client.GenerationConfig(
+ temperature=0.1,
+ max_output_tokens=30,
+ response_mime_type="text/plain",
+ ),
+ )
+
+ response = model.generate_content(["Now, please categorize the content."])
+ if not response.text:
+ raise Exception("Error obtaining response from Gemini Flash")
+
+ return validate_category_output(response.text)
diff --git a/app/src/lib/components/AutoDetectedCategory.svelte b/app/src/lib/components/AutoDetectedCategory.svelte
new file mode 100644
index 0000000..fdd899a
--- /dev/null
+++ b/app/src/lib/components/AutoDetectedCategory.svelte
@@ -0,0 +1,43 @@
+
+
+
+
+
{String(error)}
{/await} {/if}