diff --git a/api/src/env_var.py b/api/src/env_var.py index 7754af5..ca36d49 100644 --- a/api/src/env_var.py +++ b/api/src/env_var.py @@ -15,3 +15,5 @@ CSE_ID = environ["CSE_ID"] CSE_API_KEY = environ["GOOGLE_API_KEY"] + +PROD_ENV = environ.get("ENV", "dev") == "prod" diff --git a/api/src/main.py b/api/src/main.py index 9336d64..a2a4187 100644 --- a/api/src/main.py +++ b/api/src/main.py @@ -1,4 +1,4 @@ -import asyncio +from io import BytesIO from time import time from typing import Any, Callable, Generator @@ -7,7 +7,7 @@ from fastapi.responses import JSONResponse, StreamingResponse from fastapi_utilities import add_timer_middleware -from .services.storage import StorageManager +from .services.storage import BLOB_BASE_URI, StorageManager, UploadItemParams from .utils.chat_request import chat_request from .utils.chat_utils import ( ContentCategory, @@ -24,14 +24,18 @@ GetCustomSourcesRequest, generate_custom_source, ) +from .utils.custom_sources.manage_attachments import ManageAttachments +from .utils.custom_sources.read_content import ReadContent from .utils.custom_sources.save_copied_source import CopiedPasteSourceRequest, save_copied_source from .utils.custom_sources.save_uploaded_sources import UploadedFiles +from .utils.decorators.retry_decorator import RetryConfig, retry from .utils.detect_content_category import DetectContentCategoryRequest, detect_content_category from .utils.generate_audiocast import GenerateAudioCastRequest, GenerateAudiocastException, generate_audiocast from .utils.generate_audiocast_source import GenerateAudiocastSource, generate_audiocast_source from .utils.get_audiocast import get_audiocast from .utils.get_session_title import GetSessionTitleModel, get_session_title from .utils.session_manager import SessionManager, SessionModel +from .utils.summarize_custom_sources import SummarizeCustomSourcesRequest, summarize_custom_sources app = FastAPI(title="Audiora", version="1.0.0") @@ -69,16 +73,21 @@ def root(): @app.post("/chat/{session_id}", response_model=Generator[str, Any, None]) -def chat_endpoint( +async def chat_endpoint( session_id: str, request: SessionChatRequest, background_tasks: BackgroundTasks, ): """Chat endpoint""" category = request.contentCategory + attachments = request.attachments + db = SessionManager(session_id, category) db._add_chat(request.chatItem) + attachment_manager = ManageAttachments(session_id) + sources_summary = await attachment_manager.get_attachments_summary(db, attachments) + def on_finish(text: str): background_tasks.add_task(db._update, {"status": "collating"}) background_tasks.add_task( @@ -86,9 +95,13 @@ def on_finish(text: str): SessionChatItem(role="assistant", content=text), ) + if attachments: + background_tasks.add_task(attachment_manager.store_attachments, attachments) + response = chat_request( content_category=category, previous_messages=db._get_chats(), + reference_material=sources_summary, on_finish=on_finish, ) @@ -128,27 +141,22 @@ async def get_signed_url_endpoint(blobname: str): """ Get signed URL for generated audiocast """ - retry_count = 0 - max_retries = 3 - errors: list[str] = [] - while retry_count < max_retries: - try: - url = StorageManager().get_signed_url(blobname=blobname) - return JSONResponse( - content=url, - headers={ - "Content-Type": "application/json", - "Cache-Control": "public, max-age=86390, immutable", - }, - ) - except Exception as e: - errors.append(str(e)) + @retry(RetryConfig(max_retries=3, delay=5, backoff=1.5)) + def handler() -> str | None: + return StorageManager().get_signed_url(blobname=blobname) - await asyncio.sleep(5) - retry_count += 1 + url = handler() + if not url: + raise HTTPException(status_code=500, detail="Failed to get signed URL") - raise HTTPException(status_code=500, detail="".join(errors)) + return JSONResponse( + content=url, + headers={ + "Content-Type": "application/json", + "Cache-Control": "public, max-age=86390, immutable", + }, + ) @app.post("/get-session-title", response_model=str) @@ -216,3 +224,42 @@ async def detect_category_endpoint(request: DetectContentCategoryRequest): Detect category of a given content """ return await detect_content_category(request.content) + + +@app.post("/store-file-upload", response_model=str) +async def store_file_upload(file: UploadFile, filename: str = Form(...), preserve: bool = Form(False)): + """ + Store file uploaded from the frontend + """ + print(f"Storing file: {filename}. Preserve: {preserve}") + + storage_manager = StorageManager() + file_exists = storage_manager.check_blob_exists(filename) + if file_exists: + return storage_manager.get_gcs_url(filename) + + file_content = await ReadContent()._read_file(file, preserve=preserve) + content_type = ( + file.content_type or "application/octet-stream" + if preserve or isinstance(file_content, BytesIO) + else "text/plain" + ) + + result = storage_manager.upload_to_gcs( + item=file_content, + blobname=f"{BLOB_BASE_URI}/{filename}", + params=UploadItemParams( + cache_control="public, max-age=31536000", + content_type=content_type, + ), + ) + + return result + + +@app.post("/summarize-custom-sources", response_model=str) +async def summarize_custom_sources_endpoint(request: SummarizeCustomSourcesRequest): + """ + Summarize custom sources from specified source URLs + """ + return await summarize_custom_sources(request.sourceURLs) diff --git a/api/src/services/storage.py b/api/src/services/storage.py index a8b2097..93accbb 100644 --- a/api/src/services/storage.py +++ b/api/src/services/storage.py @@ -31,6 +31,8 @@ class UploadItemParams: class StorageManager: + bucket_name = BUCKET_NAME + def check_blob_exists(self, filename: str, root_path=BLOB_BASE_URI): """check if a file exists in the bucket""" blobname = f"{root_path}/{filename}" @@ -123,3 +125,16 @@ def get_signed_url(self, blobname, expiration=datetime.timedelta(days=1)): expiration=expiration, method="GET", ) + + def get_gcs_url(self, filename: str): + """get full path to a file in the bucket""" + blobname = f"{BLOB_BASE_URI}/{filename}" + return f"gs://{BUCKET_NAME}/{blobname}" + + def get_blob(self, blobname: str): + """get a blob object""" + return bucket.blob(blobname) + + def get_blobname_from_url(self, url: str): + """get blobname from a URL""" + return url.replace(f"gs://{self.bucket_name}/", "") diff --git a/api/src/utils/chat_request.py b/api/src/utils/chat_request.py index 1774602..895f632 100644 --- a/api/src/utils/chat_request.py +++ b/api/src/utils/chat_request.py @@ -4,19 +4,22 @@ from src.utils.chat_utils import ContentCategory, SessionChatItem -def get_system_message(content_category: ContentCategory): +def get_system_message(content_category: ContentCategory, reference_material: str | None = None): return f""" 1. You're a super-intelligent AI. Your task is to understand what audiocast a user wants to listen to. 2. You will steer the conversation providing eliciting questions until you have enough context. - 3. Keep the conversation exchange short, say 3-5 back and forth i.e., questions and answers. + 3. If the user provides a reference material, steer the conversation based on it until you have enough context to understand what audiocast the user wants. + 4. Keep the conversation exchange short, say 3-5 back and forth i.e., questions and answers. 4. As soon as you have enough context and the user's request is clear terminate the conversation by saying "Ok, thanks for clarifying! You want to listen to [Best case summary of user request so far]. Please click the button below to start generating the audiocast." 6. If the user's request remains unclear after 5 responses for clarity, terminate the conversation by saying "Your request is not very specific but from what I understand, you want to listen to [Best case summary of user request so far]. Please click the button below to start generating the audiocast." + {"REFERENCE MATERIAL: " + reference_material if reference_material else ""} GENERAL IDEA AND WORKFLOW: 1. A user comes to you with a request for an audiocast of type {content_category}. - 2. You need to ask the user questions (elicitation) to understand what kind of audiocast they want to listen to. - 3. Once you have enough context, within 3-5 exchanges, you should terminate the conversation. + 2. The request can include a reference material: a high-level description of the audiocast they want. + 3. You will ask the user questions (elicitation) to understand what kind of audiocast they want to listen to. + 4. Once you have enough context, within 3-5 exchanges, you should terminate the conversation. IMPORTANT NOTES: 1. Your task is to understand the user's request only by eliciting questions. @@ -28,12 +31,13 @@ def get_system_message(content_category: ContentCategory): def chat_request( content_category: ContentCategory, previous_messages: List[SessionChatItem], + reference_material: Optional[str] = None, on_finish: Optional[Callable[[str], Any]] = None, ): response_stream = get_openai().chat.completions.create( model="gpt-4o", messages=[ - {"role": "system", "content": get_system_message(content_category)}, + {"role": "system", "content": get_system_message(content_category, reference_material)}, *[ {"role": "user", "content": msg.content} if msg.role == "user" diff --git a/api/src/utils/chat_utils.py b/api/src/utils/chat_utils.py index b3bf592..83f2e6c 100644 --- a/api/src/utils/chat_utils.py +++ b/api/src/utils/chat_utils.py @@ -1,5 +1,5 @@ import uuid -from typing import Dict, List, Literal +from typing import Dict, List, Literal, Optional from pydantic import BaseModel, Field @@ -57,3 +57,4 @@ class SessionChatItem(BaseModel): class SessionChatRequest(BaseModel): contentCategory: ContentCategory chatItem: SessionChatItem + attachments: Optional[List[str]] = None diff --git a/api/src/utils/custom_sources/base_utils.py b/api/src/utils/custom_sources/base_utils.py index 6291f6e..ed180e9 100644 --- a/api/src/utils/custom_sources/base_utils.py +++ b/api/src/utils/custom_sources/base_utils.py @@ -1,6 +1,7 @@ from typing import Literal, Optional, TypedDict, cast from google.cloud.firestore_v1 import DocumentReference +from google.cloud.firestore_v1.base_query import FieldFilter from pydantic import BaseModel from src.services.firestore_sdk import ( @@ -103,3 +104,19 @@ def _get_custom_sources(self) -> list[CustomSourceModelDict]: def _delete_custom_source(self, source_id: str): return self._get_doc_ref(source_id).delete() + + def _get_custom_source_by_url(self, url: str): + self._check_document() + try: + session_ref = self._get_collection(self.collection).document(self.doc_id) + + query = session_ref.collection(self.sub_collection).where(filter=FieldFilter("url", "==", url)) + + docs = query.get() + + for doc in docs: + if doc.exists: + return cast(CustomSourceModel, self._safe_to_dict(doc.to_dict())) + except Exception as e: + print(f"Error getting custom sources for Session: {self.doc_id}", e) + return None diff --git a/api/src/utils/custom_sources/extract_url_content.py b/api/src/utils/custom_sources/extract_url_content.py index d51f4b4..685e4f8 100644 --- a/api/src/utils/custom_sources/extract_url_content.py +++ b/api/src/utils/custom_sources/extract_url_content.py @@ -5,7 +5,8 @@ from bs4 import BeautifulSoup, Tag from pydantic import BaseModel -from src.utils.decorators import process_time +from src.services.storage import StorageManager +from src.utils.decorators.base import process_time from .base_utils import SourceContent from .read_content import ReadContent @@ -42,14 +43,24 @@ def _extract_html(self, content: bytes) -> tuple[str, dict]: return self._clean_text(text_content), metadata + def _resolve_gcs_url(self, url) -> str: + if url.startswith("gs://"): + storage_manager = StorageManager() + blobame = storage_manager.get_blobname_from_url(url) + return storage_manager.get_signed_url(blobame) + + return url + @process_time() def _extract(self, url: str) -> SourceContent: - parsed_url = urlparse(url) + resolved_url = self._resolve_gcs_url(url) + + parsed_url = urlparse(resolved_url) if not parsed_url.scheme or not parsed_url.netloc: raise ValueError("Invalid URL provided") try: - response = httpx.get(url) + response = httpx.get(resolved_url) response.raise_for_status() content_type = response.headers.get("content-type", "").lower() diff --git a/api/src/utils/custom_sources/generate_url_source.py b/api/src/utils/custom_sources/generate_url_source.py index 8a61860..b493520 100644 --- a/api/src/utils/custom_sources/generate_url_source.py +++ b/api/src/utils/custom_sources/generate_url_source.py @@ -20,7 +20,10 @@ class DeleteCustomSourcesRequest(BaseModel): sourceId: str -def generate_custom_source(request: GenerateCustomSourceRequest, background_tasks: BackgroundTasks): +def generate_custom_source( + request: GenerateCustomSourceRequest, + background_tasks: BackgroundTasks | None = None, +): extractor = ExtractURLContent() content = extractor._extract(request.url) @@ -33,5 +36,9 @@ def save_to_firestore(): manager = CustomSourceManager(request.sessionId) manager._set_custom_source(custom_source) - background_tasks.add_task(save_to_firestore) + if background_tasks: + background_tasks.add_task(save_to_firestore) + else: + save_to_firestore() + return content.model_dump() diff --git a/api/src/utils/custom_sources/manage_attachments.py b/api/src/utils/custom_sources/manage_attachments.py new file mode 100644 index 0000000..b7a55e1 --- /dev/null +++ b/api/src/utils/custom_sources/manage_attachments.py @@ -0,0 +1,47 @@ +import asyncio + +from src.utils.decorators.base import use_cache_manager +from src.utils.make_seed import get_hash +from src.utils.session_manager import SessionManager +from src.utils.summarize_custom_sources import summarize_custom_sources + +from .base_utils import CustomSourceManager +from .generate_url_source import GenerateCustomSourceRequest, generate_custom_source + + +class ManageAttachments: + def __init__(self, session_id: str): + self.session_id = session_id + + async def get_attachments_summary(self, db: SessionManager, attachments: list[str] | None): + """ + Manage custom sources uploaded by the user + """ + sources_summary: str | None = None + if attachments: + attachments.sort(key=lambda x: x.lower()) + + @use_cache_manager(get_hash(attachments)) + async def handler(): + summary = await summarize_custom_sources(attachments) + db._update_source(summary) + return summary + + sources_summary = await handler() + + return sources_summary + + async def store_attachments(self, attachments: list[str]): + """ + Store attachments as custom sources of type links + """ + cs_manager = CustomSourceManager(self.session_id) + + async def _handler(url: str): + custom_source = cs_manager._get_custom_source_by_url(url) + if not custom_source: + request = GenerateCustomSourceRequest(url=url, sessionId=self.session_id) + return generate_custom_source(request) + + await asyncio.gather(*[_handler(url) for url in attachments], return_exceptions=True) + return True diff --git a/api/src/utils/custom_sources/read_content.py b/api/src/utils/custom_sources/read_content.py index 295fb9d..8065bc8 100644 --- a/api/src/utils/custom_sources/read_content.py +++ b/api/src/utils/custom_sources/read_content.py @@ -1,5 +1,6 @@ from io import BytesIO +from fastapi import UploadFile from pypdf import PdfReader @@ -19,3 +20,18 @@ def _read_pdf(self, content: bytes) -> tuple[str, PdfReader]: def _read_txt(self, content: bytes) -> str: return content.decode() + + async def _read_file(self, file: UploadFile, preserve: bool): + file_bytes = await file.read() + + if preserve: + return BytesIO(file_bytes) + + if file.content_type == "application/pdf": + text_content, _ = self._read_pdf(file_bytes) + elif file.content_type == "text/plain": + text_content = self._read_txt(file_bytes) + else: + return BytesIO(file_bytes) + + return text_content diff --git a/api/src/utils/decorators/__init__.py b/api/src/utils/decorators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/src/utils/decorators.py b/api/src/utils/decorators/base.py similarity index 92% rename from api/src/utils/decorators.py rename to api/src/utils/decorators/base.py index 3d4c05f..9d70a92 100644 --- a/api/src/utils/decorators.py +++ b/api/src/utils/decorators/base.py @@ -38,7 +38,7 @@ def wrapper(*args, **kwargs): return decorator -def use_cache_manager(cache_key: str): +def use_cache_manager(cache_key: str, expiry=86400): """decorator to use cache manager""" def decorator(func): @@ -54,7 +54,7 @@ async def wrapper(*args, **kwargs): if result and cache: redis = cache.get("redis") - await redis.set(cache_key, result, ex=cache.get("expiry")) + await redis.set(cache_key, result, ex=expiry) return result diff --git a/api/src/utils/decorators/retry_decorator.py b/api/src/utils/decorators/retry_decorator.py new file mode 100644 index 0000000..25b0016 --- /dev/null +++ b/api/src/utils/decorators/retry_decorator.py @@ -0,0 +1,62 @@ +import asyncio +from dataclasses import dataclass +from functools import wraps +from time import sleep +from typing import Any, Optional + + +@dataclass +class RetryConfig: + max_retries: int = 3 + delay: float = 1.0 + backoff: Optional[float] = None + + +def retry(retry_config: RetryConfig | None, default_return: Any = None): + """ + Retry logic for async functions with exponential backoff. + """ + config = retry_config or RetryConfig() + + def decorator(func): + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args, **kwargs): + async def _async_retry(): + delay = config.delay + for attempt in range(config.max_retries): + try: + return await func(*args, **kwargs) + except Exception as e: + print(f"Retry attempt {attempt + 1}/{config.max_retries} failed: {e}") + await asyncio.sleep(delay) + if config.backoff: + delay *= config.backoff + + return default_return + + return await _async_retry() + + return async_wrapper + + @wraps(func) + def wrapper(*args, **kwargs): + def _sync_retry(): + delay = config.delay + for attempt in range(config.max_retries): + try: + return func(*args, **kwargs) + except Exception as e: + print(f"Retry attempt {attempt + 1}/{config.max_retries} failed: {e}") + sleep(delay) + if config.backoff: + delay *= config.backoff + + return default_return + + return _sync_retry() + + return wrapper + + return decorator diff --git a/api/src/utils/generate_audiocast.py b/api/src/utils/generate_audiocast.py index 561b7d0..0e44454 100644 --- a/api/src/utils/generate_audiocast.py +++ b/api/src/utils/generate_audiocast.py @@ -3,6 +3,7 @@ from fastapi import BackgroundTasks, HTTPException +from src.env_var import PROD_ENV from src.services.storage import StorageManager from .audio_manager import AudioManager, AudioManagerConfig @@ -122,7 +123,9 @@ def update_session_info(info: str): # Generate audio update_session_info("Generating audio...") - audio_manager = AudioManager(custom_config=AudioManagerConfig(tts_provider="elevenlabs")) + + tts_provider = "elevenlabs" if PROD_ENV else "openai" + audio_manager = AudioManager(custom_config=AudioManagerConfig(tts_provider=tts_provider)) audio_path = await audio_manager.generate_speech(audio_script) background_tasks.add_task( diff --git a/api/src/utils/generate_audiocast_source.py b/api/src/utils/generate_audiocast_source.py index 3372fbf..9ce9481 100644 --- a/api/src/utils/generate_audiocast_source.py +++ b/api/src/utils/generate_audiocast_source.py @@ -2,7 +2,7 @@ from src.utils.audiocast_request import GenerateSourceContent from src.utils.chat_utils import ContentCategory -from src.utils.decorators import use_cache_manager +from src.utils.decorators.base import use_cache_manager from src.utils.make_seed import get_hash from src.utils.session_manager import SessionManager @@ -28,6 +28,7 @@ async def generate_audiocast_source(request: GenerateAudiocastSource): async def _handler(): db = SessionManager(session_id, category) db._update_info("Generating source content...") + generator = GenerateSourceContent(category, preference_summary) source_content = await generator._run() db._update_source(source_content) diff --git a/api/src/utils/generate_speech_utils.py b/api/src/utils/generate_speech_utils.py index 429f036..ab29afe 100644 --- a/api/src/utils/generate_speech_utils.py +++ b/api/src/utils/generate_speech_utils.py @@ -4,7 +4,7 @@ from src.services.elevenlabs_client import get_elevenlabs_client from src.services.openai_client import get_openai -from src.utils.decorators import process_time +from src.utils.decorators.base import process_time TTSProvider = Literal["openai", "elevenlabs"] diff --git a/api/src/utils/get_audiocast.py b/api/src/utils/get_audiocast.py index e69e2cd..c5e6321 100644 --- a/api/src/utils/get_audiocast.py +++ b/api/src/utils/get_audiocast.py @@ -2,7 +2,7 @@ from src.services.storage import StorageManager -from .decorators import process_time +from .decorators.base import process_time from .session_manager import SessionManager diff --git a/api/src/utils/summarize_custom_sources.py b/api/src/utils/summarize_custom_sources.py new file mode 100644 index 0000000..015a6b1 --- /dev/null +++ b/api/src/utils/summarize_custom_sources.py @@ -0,0 +1,88 @@ +import asyncio + +from pydantic import BaseModel + +from src.services.gemini_client import get_gemini +from src.services.storage import StorageManager + +from .custom_sources.read_content import ReadContent +from .decorators.base import process_time + + +class SummarizeCustomSourcesRequest(BaseModel): + sourceURLs: list[str] + + +def summarize_custom_sources_prompt(combined_content: str) -> str: + """ + Summarize a list of custom sources using Gemini Flash. + """ + return f"""Please provide a comprehensive 3-paragraph summary of the following content. + Each paragraph should serve a specific purpose: + + Paragraph 1: Introduce the main topics and key themes discussed across all sources. + Paragraph 2: Dive into the most significant details, findings, or arguments presented. + Paragraph 3: Conclude with the implications, connections between ideas, or final insights. + + Maintain high fidelity to the source material and ensure all critical information is preserved. + + Content to summarize: {combined_content} + """ + + +async def get_source_content(source_url: str) -> str: + """ + Get the content of a source URL. + """ + storage_manager = StorageManager() + content_reader = ReadContent() + + blob_name = source_url.replace(f"gs://{storage_manager.bucket_name}/", "") + blob = storage_manager.get_blob(blob_name) + content_byte = blob.download_as_bytes() + + if blob.content_type == "application/pdf": + text_content, _ = content_reader._read_pdf(content_byte) + elif blob.content_type == "text/plain": + text_content = content_reader._read_txt(content_byte) + else: + raise ValueError(f"Unsupported content type: {blob.content_type}") + + return text_content + + +async def get_sources_str(source_urls: list[str]) -> str: + """ + Get the content of a list of source URLs. + """ + tasks = [get_source_content(source_url) for source_url in source_urls] + sources = await asyncio.gather(*tasks, return_exceptions=True) + + valid_sources = [source for source in sources if isinstance(source, str)] + return "\n\n".join(valid_sources) + + +@process_time() +async def summarize_custom_sources(source_urls: list[str]) -> str: + """ + Summarize the contents of list of custom sources using Gemini Flash. + """ + content = await get_sources_str(source_urls) + + client = get_gemini() + + model = client.GenerativeModel( + model_name="gemini-1.5-flash-002", + system_instruction=summarize_custom_sources_prompt(content), + generation_config=client.GenerationConfig( + temperature=0.1, + max_output_tokens=2048, + response_mime_type="text/plain", + ), + ) + + response = model.generate_content(["Now, provide the summary"]) + if not response.text: + raise Exception("Error obtaining response from Gemini Flash") + + return response.text diff --git a/api/tests/test_additional_context.py b/api/tests/test_additional_context.py index 3fdc108..2bf2bc4 100644 --- a/api/tests/test_additional_context.py +++ b/api/tests/test_additional_context.py @@ -5,7 +5,7 @@ from src.utils.audiocast_request import GenerateSourceContent from src.utils.audiocast_source_context import SourceContext from src.utils.chat_utils import ContentCategory -from src.utils.decorators import process_time +from src.utils.decorators.base import process_time async def test_additional_context(preference_summary: str): diff --git a/app/src/lib/components/AutoDetectedCategory.svelte b/app/src/lib/components/AutoDetectedCategory.svelte index 0f06f66..0ef1782 100644 --- a/app/src/lib/components/AutoDetectedCategory.svelte +++ b/app/src/lib/components/AutoDetectedCategory.svelte @@ -33,7 +33,7 @@ + {/each} diff --git a/app/src/lib/components/ChatBoxAttachment.svelte b/app/src/lib/components/ChatBoxAttachment.svelte new file mode 100644 index 0000000..47b308f --- /dev/null +++ b/app/src/lib/components/ChatBoxAttachment.svelte @@ -0,0 +1,91 @@ + + + + + + + + handleFileSelect(e.currentTarget.files)} +/> diff --git a/app/src/lib/components/ChatBoxAttachmentPreview.svelte b/app/src/lib/components/ChatBoxAttachmentPreview.svelte new file mode 100644 index 0000000..c8cac22 --- /dev/null +++ b/app/src/lib/components/ChatBoxAttachmentPreview.svelte @@ -0,0 +1,61 @@ + + +
+ {#each validItems as { file, id, loading } (id)} +
+
+ {#if loading} + + {:else} + + {/if} +
+
+

{file.name}

+

+ {formatFileSize(file.size)} • {parseFileType(file)} +

+
+ +
+ {/each} +
diff --git a/app/src/lib/components/ChatContainer.svelte b/app/src/lib/components/ChatContainer.svelte index 72cdd67..6538f1e 100644 --- a/app/src/lib/components/ChatContainer.svelte +++ b/app/src/lib/components/ChatContainer.svelte @@ -6,7 +6,7 @@ export let searchTerm = ''; export let disableTextInput = false; - const { sessionCompleted$, fetchingSource$, audioSource$, session$ } = getSessionContext(); + const { sessionCompleted$, fetchingSource$, session$ } = getSessionContext(); let navLoading = false; @@ -26,7 +26,7 @@ - {#if !hasFinalResponse && !$sessionCompleted$ && !$fetchingSource$ && !$audioSource$} + {#if !hasFinalResponse && !$sessionCompleted$ && !$fetchingSource$}
{ - $audioSource$ = res; - toast.success('AI-generated source material generated successfully'); - }) + .then(() => toast.success('AI-generated source material generated successfully')) .catch((error) => toast.error(error.message)) .finally(() => ($fetchingSource$ = false)); } diff --git a/app/src/lib/stores/attachmentsContext.svelte.ts b/app/src/lib/stores/attachmentsContext.svelte.ts new file mode 100644 index 0000000..90ff6c8 --- /dev/null +++ b/app/src/lib/stores/attachmentsContext.svelte.ts @@ -0,0 +1,42 @@ +import { getContext, setContext } from 'svelte'; +import { derived, writable } from 'svelte/store'; + +const CONTEXT_KEY = {}; + +export type UploadedItem = { + id: string; + file: File; + loading?: boolean; + errored?: boolean; + gcsUrl?: string; +}; + +export const setAttachmentsContext = (sessionId: string) => { + const uploadedItems$ = writable([]); + + const sessionUploadItems$ = derived(uploadedItems$, (items) => + items.filter((i) => i.file instanceof File && i.id.startsWith(sessionId)) + ); + + return setContext(CONTEXT_KEY, { + uploadedItems$, + sessionUploadItems$, + addUploadItem(item: UploadedItem) { + uploadedItems$.update((files) => [...files, item]); + }, + updateUploadItem(itemId: string, update: Partial) { + uploadedItems$.update((files) => { + return files.map((f) => (f.id === itemId ? { ...f, ...update } : f)); + }); + }, + removeUploadItem(itemId: string) { + uploadedItems$.update((files) => { + return files.filter((f) => f.id !== itemId); + }); + } + }); +}; + +export type AttachmentsContext = ReturnType; + +export const getAttachmentsContext = () => getContext(CONTEXT_KEY); diff --git a/app/src/lib/stores/sessionContext.svelte.ts b/app/src/lib/stores/sessionContext.svelte.ts index 518b36b..b61cfa4 100644 --- a/app/src/lib/stores/sessionContext.svelte.ts +++ b/app/src/lib/stores/sessionContext.svelte.ts @@ -3,6 +3,7 @@ import { page } from '$app/stores'; import { getCustomSources$ } from '@/db/db.customSources'; import { getSession$ } from '@/db/db.session'; import type { ContentCategory } from '@/utils/types'; +import { of, distinctUntilChanged, startWith, switchMap, tap } from 'rxjs'; import { setContext, getContext } from 'svelte'; import { persisted } from 'svelte-persisted-store'; import { derived, get, writable } from 'svelte/store'; @@ -34,9 +35,17 @@ export function setSessionContext(sessionId: string) { const sessionCompleted$ = derived(session$, ($session) => !!$session?.completed); const fetchingSource$ = writable(false); - const audioSource$ = persisted(`AUDIOCAST_SOURCE_${sessionId}`, ''); - const refreshSidebar$ = derived(page, ({ url }) => browser && url.searchParams.has('chat')); + const sessionModel$ = getSession$(sessionId); + + const localAudioSource = persisted(`AUDIOCAST_SOURCE_${sessionId}`, ''); + + const audioSource$ = sessionModel$.pipe( + switchMap((session) => of(session?.metadata?.source || '')), + distinctUntilChanged(), + tap((v) => localAudioSource.set(v)), + startWith(get(localAudioSource)) + ); return setContext(CONTEXT_KEY, { session$, @@ -45,7 +54,7 @@ export function setSessionContext(sessionId: string) { fetchingSource$, audioSource$, customSources$: getCustomSources$(sessionId), - sessionModel$: getSession$(sessionId), + sessionModel$, refreshSidebar$, startSession: (category: ContentCategory) => { session$.set({ diff --git a/app/src/lib/utils/pluralize.ts b/app/src/lib/utils/pluralize.ts new file mode 100644 index 0000000..da02f94 --- /dev/null +++ b/app/src/lib/utils/pluralize.ts @@ -0,0 +1,3 @@ +export function pluralize(count: number, singular: string, plural: string) { + return count === 1 ? singular : plural; +} diff --git a/app/src/routes/+layout.svelte b/app/src/routes/+layout.svelte index 1acefb8..9053350 100644 --- a/app/src/routes/+layout.svelte +++ b/app/src/routes/+layout.svelte @@ -14,6 +14,7 @@ import { onMount } from 'svelte'; import { getAnalytics, logEvent } from 'firebase/analytics'; import cs from 'clsx'; + import { setAttachmentsContext } from '@/stores/attachmentsContext.svelte'; export let data; @@ -25,6 +26,7 @@ $: ({ session$ } = setSessionContext(sessionId)); $: sessionTitle = $session$?.title; + $: setAttachmentsContext(sessionId); onMount(() => { logEvent(getAnalytics(), 'page_view', { diff --git a/app/src/routes/chat/[sessionId=sessionId]/+page.svelte b/app/src/routes/chat/[sessionId=sessionId]/+page.svelte index 5671b49..25e08b2 100644 --- a/app/src/routes/chat/[sessionId=sessionId]/+page.svelte +++ b/app/src/routes/chat/[sessionId=sessionId]/+page.svelte @@ -10,12 +10,16 @@ import { debounce } from 'throttle-debounce'; import AudiocastPageHeader from '@/components/AudiocastPageHeader.svelte'; import { getSummary, isfinalResponse } from '@/utils/session.utils'; + import { getAttachmentsContext } from '@/stores/attachmentsContext.svelte.js'; + import { pluralize } from '@/utils/pluralize'; export let data; const { session$, addChatItem, updateChatContent, sessionId$, removeChatItem } = getSessionContext(); + const { sessionUploadItems$ } = getAttachmentsContext(); + let searchTerm = ''; let loading = false; let mounted = false; @@ -70,7 +74,11 @@ return fetch(`${env.API_BASE_URL}/chat/${sessionId}`, { method: 'POST', - body: JSON.stringify({ chatItem: uItem, contentCategory: category }), + body: JSON.stringify({ + chatItem: uItem, + contentCategory: category, + attachments: $sessionUploadItems$.map((i) => i.gcsUrl).filter(Boolean) + }), headers: { 'Content-Type': 'application/json' } }) .then((res) => handleStreamingResponse(res, aItem.id)) @@ -114,15 +122,25 @@

- {#each sessionChats as item (item.id)} + {#each sessionChats as item, ix (item.id)} {@const finalResponse = isfinalResponse(item)} + {@const firstRequest = ix === 1} + > + + {#if firstRequest && $sessionUploadItems$.length} + {@const count = $sessionUploadItems$.length} + Evaluating {pluralize(count, 'attachment', 'attachments')}... + {:else} + Generating response... + {/if} + + {#if finalResponse}