Skip to content

Commit

Permalink
chat working again
Browse files Browse the repository at this point in the history
  • Loading branch information
eksno committed Feb 29, 2024
1 parent 293161c commit 6913bd0
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 21 deletions.
13 changes: 6 additions & 7 deletions aitino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse, StreamingResponse
from openai import OpenAI
from pydantic import BaseModel

from .improver import PromptType, improve_prompt
from .interfaces import db
Expand Down Expand Up @@ -65,14 +63,12 @@ def compile(id: UUID) -> dict[str, str | Composition]:
def improve(
word_limit: int, prompt: str, prompt_type: PromptType, temperature: float
) -> str:
return improve_prompt(
word_limit, prompt, prompt_type, OpenAI().chat.completions, temperature
)
return improve_prompt(word_limit, prompt, prompt_type, temperature)


@app.get("/maeve")
async def run_maeve(
id: UUID, session_id: UUID | None = None, reply: str | None = None
id: UUID, profile_id: UUID, session_id: UUID | None = None, reply: str | None = None
) -> StreamingResponse:
if reply and not session_id:
raise HTTPException(
Expand Down Expand Up @@ -109,7 +105,10 @@ async def run_maeve(
)

if session is None:
session = Session()
session = Session(
maeve_id=id,
profile_id=profile_id,
)
db.post_session(session)

q: Queue[Message | object] = Queue()
Expand Down
40 changes: 27 additions & 13 deletions aitino/improver.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,50 @@
import os
from pathlib import Path
from typing import Literal, Protocol
from dotenv import load_dotenv

from openai import OpenAI

load_dotenv()
client = OpenAI()

PromptType = Literal["generic", "system", "user"]

class InvalidPromptTypeError(BaseException):
...

class InvalidPromptTypeError(BaseException): ...


class ContentProtocol(Protocol):
content: str


class MessageProtocol(Protocol):
message: ContentProtocol


class ResponseProtocol(Protocol):
choices: list[MessageProtocol]



class CompletionsProtocol(Protocol):
def create(self, messages, model, temperature, frequency_penalty, presence_penalty) -> ResponseProtocol:
...
def create(
self, messages, model, temperature, frequency_penalty, presence_penalty
) -> ResponseProtocol: ...


def improve_prompt(
word_limit: int, prompt: str, prompt_type: PromptType, client: CompletionsProtocol, temperature: float = 0.0
) -> str | None:
if (word_limit <= 0):
word_limit: int,
prompt: str,
prompt_type: PromptType,
temperature: float = 0.0,
) -> str:
if word_limit <= 0:
raise ValueError("Word limit must be greater than 0")

if (prompt_type not in ["generic", "system", "user"]):
if prompt_type not in ["generic", "system", "user"]:
raise InvalidPromptTypeError(f"Invalid prompt type: {prompt_type}")
if (temperature < -2.0 or temperature > 2.0):

if temperature < -2.0 or temperature > 2.0:
raise ValueError("Temperature must be in between -2 and 2")

with open(
Expand All @@ -43,7 +57,7 @@ def improve_prompt(
f"\n4. Limit the amount of words in this prompt to {word_limit} words"
)

result = client.create(
result = client.chat.completions.create(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
Expand All @@ -55,4 +69,4 @@ def improve_prompt(
)

content = result.choices[0].message.content
return content if content else None
return content if content else "error"
2 changes: 1 addition & 1 deletion aitino/interfaces/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_complied(maeve_id: UUID) -> tuple[str, Composition] | tuple[None, None]:
Get the complied message and composition for a given Maeve ID.
"""
logger.debug(f"Getting complied message and composition for {maeve_id}")

Check failure

Code scanning / CodeQL

Log Injection High

This log entry depends on a
user-provided value
.
This log entry depends on a
user-provided value
.
response = supabase.table("maeve_nodes").select("*").eq("id", maeve_id).execute()
response = supabase.table("maeves").select("*").eq("id", maeve_id).execute()

if len(response.data) == 0:
return None, None
Expand Down
2 changes: 2 additions & 0 deletions aitino/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@

class Session(BaseModel):
id: UUID = Field(default_factory=lambda: uuid4())
maeve_id: UUID
profile_id: UUID
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=UTC))

0 comments on commit 6913bd0

Please sign in to comment.