Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving sessions, working on messages #15

Merged
merged 3 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 58 additions & 31 deletions aitino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from uuid import UUID

from autogen import Agent, ConversableAgent
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse, StreamingResponse
from pydantic import BaseModel
Expand Down Expand Up @@ -59,20 +59,62 @@ def improve(


@app.get("/maeve")
async def run_maeve(id: UUID, session_id: UUID | None = None) -> StreamingResponse:
q: Queue[Message | object] = Queue()
job_done = object()
async def run_maeve(
id: UUID, session_id: UUID | None = None, reply: str | None = None
) -> StreamingResponse:
if reply and not session_id:
raise HTTPException(
status_code=400,
detail="If a reply is provided, a session_id must also be provided.",
)
if session_id and not reply:
raise HTTPException(
status_code=400,
detail="If a session_id is provided, a reply must also be provided.",
)

message, composition = db.get_complied(id)

if reply:
message = reply

if not message or not composition:
raise HTTPException(status_code=400, detail=f"Maeve with id {id} not found")

session = db.get_session(session_id) if session_id else None
messages = db.get_messages(session_id) if session_id else None

if session_id and not session:
raise HTTPException(
status_code=400,
detail=f"Session with id {session_id} not found",
)

if session_id and not messages:
raise HTTPException(
status_code=400,
detail=f"Session with id {session_id} found, but has no messages",
)

# Get or create session
session = None
if session_id:
session = db.get_session(session_id)
if session is None:
session = Session()
db.post_session(session)

q: Queue[Message | object] = Queue()
job_done = object()

async def watch_queue() -> AsyncGenerator:
# Watch the queue and yield items (messages) as they arrive
i = 0

# Yield session
yield json.dumps(
StreamReply(id=i, data={"session_id": str(session.id)}).model_dump(),
default=str,
) + "\n"

i += 1

while True:
# Gets and dequeues item
next_item = await q.get()
Expand Down Expand Up @@ -109,30 +151,15 @@ async def on_reply(
return

logger.info(f"on_reply: {recipient.name} {messages[-1]}")
await q.put(
Message(
session_id=session.id,
recipient=recipient.name,
name=messages[-1]["name"],
content=messages[-1]["content"],
role=messages[-1]["role"],
)
)

message, composition = db.get_complied(id)

if not message or not composition:
return StreamingResponse(
json.dumps(
StreamReply(
id=0,
status="error",
data=f"Maeve with id {id} not found",
).model_dump(),
default=str,
),
media_type="application/x-ndjson",
message = Message(
session_id=session.id,
recipient=recipient.name,
name=messages[-1]["name"],
content=messages[-1]["content"],
role=messages[-1]["role"],
)
# db.post_message(message)
await q.put(message)

maeve = Maeve(composition, on_reply)

Expand Down
9 changes: 7 additions & 2 deletions aitino/interfaces/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import json
from uuid import UUID

from dotenv import load_dotenv
Expand Down Expand Up @@ -49,7 +50,9 @@ def post_session(session: Session) -> None:
"""
Post a session to the database.
"""
supabase.table("sessions").upsert(session.model_dump()).execute()
supabase.table("sessions").upsert(
json.loads(json.dumps(session.model_dump(), default=str))
).execute()


def get_messages(session_id: UUID) -> list[Message] | None:
Expand All @@ -74,4 +77,6 @@ def post_message(message: Message) -> None:
"""
Post a message to the database.
"""
supabase.table("sessions").upsert(message.model_dump()).execute()
supabase.table("messages").upsert(
json.loads(json.dumps(message.model_dump(), default=str))
).execute()
1 change: 0 additions & 1 deletion aitino/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ async def call_maeve(url: str) -> AsyncGenerator[str, None]:
async with session.get(url) as response:
while True:
line = await response.content.readline()
await asyncio.sleep(0.1)
if not line:
break
yield line.decode("utf-8")
Expand Down