Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simple unit tests for server #32

Merged
merged 5 commits into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion backend/app/server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from typing import Optional

import orjson
Expand All @@ -20,6 +21,9 @@

FEATURED_PUBLIC_ASSISTANTS = []

# Get root of app, used to point to directory containing static files
ROOT = Path(__file__).parent.parent


def attach_user_id_to_config(
config: RunnableConfig, request: Request
Expand All @@ -44,6 +48,7 @@ def ingest_endpoint(files: list[UploadFile], config: str = Form(...)):

@app.get("/assistants/")
def list_assistants_endpoint(req: Request):
"""List all assistants for the current user."""
return list_assistants(req.cookies["opengpts_user_id"])


Expand Down Expand Up @@ -85,7 +90,7 @@ def put_thread_endpoint(req: Request, tid: str, payload: dict):
)


app.mount("", StaticFiles(directory="ui", html=True), name="ui")
app.mount("", StaticFiles(directory=str(ROOT / "ui"), html=True), name="ui")

if __name__ == "__main__":
import uvicorn
Expand Down
Empty file.
142 changes: 142 additions & 0 deletions backend/tests/unit_tests/app/test_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Test the server and client together."""

import os
from contextlib import asynccontextmanager
from typing import Optional, Sequence

import pytest
from httpx import AsyncClient
from langchain.utilities.redis import get_client as _get_redis_client
from redis.client import Redis as RedisType
from typing_extensions import AsyncGenerator


@asynccontextmanager
async def get_client() -> AsyncGenerator[AsyncClient, None]:
"""Get the app."""
from app.server import app

async with AsyncClient(app=app, base_url="http://test") as ac:
yield ac


@pytest.fixture(scope="function")
def redis_client() -> RedisType:
"""Get a redis client -- and clear it before the test!"""
redis_url = os.environ.get("REDIS_URL")
if "localhost" not in redis_url:
raise ValueError(
"This test is only intended to be run against a local redis instance"
)

if not redis_url.endswith("/3"):
raise ValueError(
"This test is only intended to be run against a local redis instance. "
"For testing purposes this is expected to be database #3 (arbitrary)."
)

client = _get_redis_client(redis_url)
client.flushdb()
try:
yield client
finally:
client.close()


def _project(d: dict, *, exclude_keys: Optional[Sequence[str]]) -> dict:
"""Return a dict with only the keys specified."""
_exclude = set(exclude_keys) if exclude_keys else set()
return {k: v for k, v in d.items() if k not in _exclude}


@pytest.mark.asyncio
async def test_list_and_create_assistants(redis_client: RedisType) -> None:
"""Test list and create assistants."""
headers = {"Cookie": "opengpts_user_id=1"}
assert sorted(redis_client.keys()) == []
async with get_client() as client:
response = await client.get(
"/assistants/",
headers=headers,
)
assert response.status_code == 200
assert response.json() == []

# Create an assistant
response = await client.put(
"/assistants/bobby",
json={"name": "bobby", "config": {}, "public": False},
headers=headers,
)
assert response.status_code == 200
assert _project(response.json(), exclude_keys=["updated_at"]) == {
"assistant_id": "bobby",
"config": {},
"name": "bobby",
"public": False,
"user_id": "1",
}
assert sorted(redis_client.keys()) == [
b"opengpts:1:assistant:bobby",
b"opengpts:1:assistants",
]

response = await client.get("/assistants/", headers=headers)
assert [_project(d, exclude_keys=["updated_at"]) for d in response.json()] == [
{
"assistant_id": "bobby",
"config": {},
"name": "bobby",
"public": False,
}
]

response = await client.put(
"/assistants/bobby",
json={"name": "bobby", "config": {}, "public": False},
headers=headers,
)

assert _project(response.json(), exclude_keys=["updated_at"]) == {
"assistant_id": "bobby",
"config": {},
"name": "bobby",
"public": False,
"user_id": "1",
}

# Check not visible to other users
headers = {"Cookie": "opengpts_user_id=2 flushdb"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a header to flush db?

response = await client.get("/assistants/", headers=headers)
assert response.status_code == 200
assert response.json() == []


@pytest.mark.asyncio
async def test_threads(redis_client: RedisType) -> None:
"""Test put thread."""
async with get_client() as client:
response = await client.put(
"/threads/1",
json={"name": "bobby", "assistant_id": "bobby"},
headers={"Cookie": "opengpts_user_id=1"},
)
assert response.status_code == 200

response = await client.get(
"/threads/1/messages", headers={"Cookie": "opengpts_user_id=1"}
)
assert response.status_code == 200
assert response.json() == {"messages": []}

response = await client.get(
"/threads/", headers={"Cookie": "opengpts_user_id=1"}
)
assert response.status_code == 200
assert [_project(d, exclude_keys=["updated_at"]) for d in response.json()] == [
{
"assistant_id": "bobby",
"name": "bobby",
"thread_id": "1",
}
]
6 changes: 6 additions & 0 deletions backend/tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import os

# Temporary handling of environment variables for testing
os.environ["REDIS_URL"] = "redis://localhost:6379/3"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty hacky, but we can update afterwards to pick up settings properly

os.environ["OPENAI_API_KEY"] = "test"
os.environ["YDC_API_KEY"] = "test"