Skip to content

Commit

Permalink
Add tests for endpoints involving storage on server (#32)
Browse files Browse the repository at this point in the history
Add tests for endpoints involving storage
  • Loading branch information
eyurtsev authored Nov 12, 2023
1 parent da3df17 commit d11ff09
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 1 deletion.
7 changes: 6 additions & 1 deletion backend/app/server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from typing import Optional

import orjson
Expand All @@ -20,6 +21,9 @@

FEATURED_PUBLIC_ASSISTANTS = []

# Get root of app, used to point to directory containing static files
ROOT = Path(__file__).parent.parent


def attach_user_id_to_config(
config: RunnableConfig, request: Request
Expand All @@ -45,6 +49,7 @@ def ingest_endpoint(files: list[UploadFile], config: str = Form(...)):

@app.get("/assistants/")
def list_assistants_endpoint(req: Request):
"""List all assistants for the current user."""
return list_assistants(req.cookies["opengpts_user_id"])


Expand Down Expand Up @@ -86,7 +91,7 @@ def put_thread_endpoint(req: Request, tid: str, payload: dict):
)


app.mount("", StaticFiles(directory="ui", html=True), name="ui")
app.mount("", StaticFiles(directory=str(ROOT / "ui"), html=True), name="ui")

if __name__ == "__main__":
import uvicorn
Expand Down
Empty file.
142 changes: 142 additions & 0 deletions backend/tests/unit_tests/app/test_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Test the server and client together."""

import os
from contextlib import asynccontextmanager
from typing import Optional, Sequence

import pytest
from httpx import AsyncClient
from langchain.utilities.redis import get_client as _get_redis_client
from redis.client import Redis as RedisType
from typing_extensions import AsyncGenerator


@asynccontextmanager
async def get_client() -> AsyncGenerator[AsyncClient, None]:
"""Get the app."""
from app.server import app

async with AsyncClient(app=app, base_url="http://test") as ac:
yield ac


@pytest.fixture(scope="function")
def redis_client() -> RedisType:
"""Get a redis client -- and clear it before the test!"""
redis_url = os.environ.get("REDIS_URL")
if "localhost" not in redis_url:
raise ValueError(
"This test is only intended to be run against a local redis instance"
)

if not redis_url.endswith("/3"):
raise ValueError(
"This test is only intended to be run against a local redis instance. "
"For testing purposes this is expected to be database #3 (arbitrary)."
)

client = _get_redis_client(redis_url)
client.flushdb()
try:
yield client
finally:
client.close()


def _project(d: dict, *, exclude_keys: Optional[Sequence[str]]) -> dict:
"""Return a dict with only the keys specified."""
_exclude = set(exclude_keys) if exclude_keys else set()
return {k: v for k, v in d.items() if k not in _exclude}


@pytest.mark.asyncio
async def test_list_and_create_assistants(redis_client: RedisType) -> None:
"""Test list and create assistants."""
headers = {"Cookie": "opengpts_user_id=1"}
assert sorted(redis_client.keys()) == []
async with get_client() as client:
response = await client.get(
"/assistants/",
headers=headers,
)
assert response.status_code == 200
assert response.json() == []

# Create an assistant
response = await client.put(
"/assistants/bobby",
json={"name": "bobby", "config": {}, "public": False},
headers=headers,
)
assert response.status_code == 200
assert _project(response.json(), exclude_keys=["updated_at"]) == {
"assistant_id": "bobby",
"config": {},
"name": "bobby",
"public": False,
"user_id": "1",
}
assert sorted(redis_client.keys()) == [
b"opengpts:1:assistant:bobby",
b"opengpts:1:assistants",
]

response = await client.get("/assistants/", headers=headers)
assert [_project(d, exclude_keys=["updated_at"]) for d in response.json()] == [
{
"assistant_id": "bobby",
"config": {},
"name": "bobby",
"public": False,
}
]

response = await client.put(
"/assistants/bobby",
json={"name": "bobby", "config": {}, "public": False},
headers=headers,
)

assert _project(response.json(), exclude_keys=["updated_at"]) == {
"assistant_id": "bobby",
"config": {},
"name": "bobby",
"public": False,
"user_id": "1",
}

# Check not visible to other users
headers = {"Cookie": "opengpts_user_id=2 flushdb"}
response = await client.get("/assistants/", headers=headers)
assert response.status_code == 200
assert response.json() == []


@pytest.mark.asyncio
async def test_threads(redis_client: RedisType) -> None:
"""Test put thread."""
async with get_client() as client:
response = await client.put(
"/threads/1",
json={"name": "bobby", "assistant_id": "bobby"},
headers={"Cookie": "opengpts_user_id=1"},
)
assert response.status_code == 200

response = await client.get(
"/threads/1/messages", headers={"Cookie": "opengpts_user_id=1"}
)
assert response.status_code == 200
assert response.json() == {"messages": []}

response = await client.get(
"/threads/", headers={"Cookie": "opengpts_user_id=1"}
)
assert response.status_code == 200
assert [_project(d, exclude_keys=["updated_at"]) for d in response.json()] == [
{
"assistant_id": "bobby",
"name": "bobby",
"thread_id": "1",
}
]
6 changes: 6 additions & 0 deletions backend/tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import os

# Temporary handling of environment variables for testing
os.environ["REDIS_URL"] = "redis://localhost:6379/3"
os.environ["OPENAI_API_KEY"] = "test"
os.environ["YDC_API_KEY"] = "test"

0 comments on commit d11ff09

Please sign in to comment.