Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some validation to endpoints #35

Merged
merged 2 commits into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions backend/app/server.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from pathlib import Path
from typing import Optional
from typing import Annotated, Optional

import orjson
from fastapi import FastAPI, Form, Request, UploadFile
from fastapi import Cookie, FastAPI, Form, Request, UploadFile
from fastapi.staticfiles import StaticFiles
from gizmo_agent import agent, ingest_runnable
from langchain.schema.runnable import RunnableConfig
from langserve import add_routes
from typing_extensions import TypedDict

from app.storage import (
get_thread_messages,
Expand All @@ -26,7 +27,8 @@


def attach_user_id_to_config(
config: RunnableConfig, request: Request
config: RunnableConfig,
request: Request,
) -> RunnableConfig:
config["configurable"]["user_id"] = request.cookies["opengpts_user_id"]
return config
Expand All @@ -48,9 +50,9 @@ def ingest_endpoint(files: list[UploadFile], config: str = Form(...)):


@app.get("/assistants/")
def list_assistants_endpoint(req: Request):
def list_assistants_endpoint(opengpts_user_id: Annotated[str, Cookie()]):
"""List all assistants for the current user."""
return list_assistants(req.cookies["opengpts_user_id"])
return list_assistants(opengpts_user_id)


@app.get("/assistants/public/")
Expand All @@ -60,10 +62,20 @@ def list_public_assistants_endpoint(shared_id: Optional[str] = None):
)


class AssistantPayload(TypedDict):
name: str
config: dict
public: bool


@app.put("/assistants/{aid}")
def put_assistant_endpoint(req: Request, aid: str, payload: dict):
def put_assistant_endpoint(
aid: str,
payload: AssistantPayload,
opengpts_user_id: Annotated[str, Cookie()],
):
return put_assistant(
req.cookies["opengpts_user_id"],
opengpts_user_id,
aid,
name=payload["name"],
config=payload["config"],
Expand All @@ -72,19 +84,26 @@ def put_assistant_endpoint(req: Request, aid: str, payload: dict):


@app.get("/threads/")
def list_threads_endpoint(req: Request):
return list_threads(req.cookies["opengpts_user_id"])
def list_threads_endpoint(opengpts_user_id: Annotated[str, Cookie()]):
return list_threads(opengpts_user_id)


@app.get("/threads/{tid}/messages")
def get_thread_messages_endpoint(req: Request, tid: str):
return get_thread_messages(req.cookies["opengpts_user_id"], tid)
def get_thread_messages_endpoint(opengpts_user_id: Annotated[str, Cookie()], tid: str):
return get_thread_messages(opengpts_user_id, tid)


class ThreadPayload(TypedDict):
name: str
assistant_id: str


@app.put("/threads/{tid}")
def put_thread_endpoint(req: Request, tid: str, payload: dict):
def put_thread_endpoint(
opengpts_user_id: Annotated[str, Cookie()], tid: str, payload: ThreadPayload
):
return put_thread(
req.cookies["opengpts_user_id"],
opengpts_user_id,
tid,
assistant_id=payload["assistant_id"],
name=payload["name"],
Expand Down
19 changes: 19 additions & 0 deletions backend/tests/unit_tests/app/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ async def test_list_and_create_assistants(redis_client: RedisType) -> None:
headers=headers,
)
assert response.status_code == 200

assert response.json() == []

# Create an assistant
Expand Down Expand Up @@ -140,3 +141,21 @@ async def test_threads(redis_client: RedisType) -> None:
"thread_id": "1",
}
]

# Test a bad requests
response = await client.put(
"/threads/1",
json={"name": "bobby", "assistant_id": "bobby"},
)
assert response.status_code == 422

response = await client.put(
"/threads/1",
headers={"Cookie": "opengpts_user_id=2"},
)
assert response.status_code == 422

response = await client.get(
"/threads/",
)
assert response.status_code == 422