Skip to content

Commit

Permalink
add custom eval fast api
Browse files Browse the repository at this point in the history
  • Loading branch information
XuhuiZhou committed Dec 13, 2024
1 parent ec5c394 commit 442809a
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 3 deletions.
1 change: 0 additions & 1 deletion examples/use_custom_dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def save_dimensions(dimensions: list[dict[str, Union[str, int]]]) -> None:
def save_dimension_list(
dimensions: list[dict[str, Union[str, int]]], list_name: str
) -> None:
Migrator().run()
dimension_list = CustomEvaluationDimensionList.find(
CustomEvaluationDimensionList.name == list_name
).all()
Expand Down
81 changes: 80 additions & 1 deletion sotopia/ui/fastapi_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
RelationshipProfile,
RelationshipType,
NonStreamingSimulationStatus,
CustomEvaluationDimensionList,
CustomEvaluationDimension,
)
from sotopia.envs.parallel import ParallelSotopiaEnv
from sotopia.envs.evaluators import (
Expand All @@ -33,7 +35,7 @@
)
from typing import Optional, Any
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, model_validator, field_validator
from pydantic import BaseModel, model_validator, field_validator, Field

from sotopia.ui.websocket_utils import (
WebSocketSotopiaSimulator,
Expand Down Expand Up @@ -112,6 +114,16 @@ class EnvironmentProfileWrapper(BaseModel):
tag: str = ""


class CustomEvaluationDimensionsWrapper(BaseModel):
pk: str = ""
name: str = Field(
default="", description="The name of the custom evaluation dimension list"
)
dimensions: list[CustomEvaluationDimension] = Field(
default=[], description="The dimensions of the custom evaluation dimension list"
)


class SimulationRequest(BaseModel):
env_id: str
agent_ids: list[str]
Expand Down Expand Up @@ -228,6 +240,20 @@ async def get_episodes(get_by: Literal["id", "tag"], value: str) -> list[Episode
return episodes


@app.get(
"/evaluation_dimensions/", response_model=dict[str, list[CustomEvaluationDimension]]
)
async def get_evaluation_dimensions() -> dict[str, list[CustomEvaluationDimension]]:
custom_evaluation_dimensions: dict[str, list[CustomEvaluationDimension]] = {}
custom_evaluation_dimension_list = CustomEvaluationDimensionList.all()
for custom_evaluation_dimension_list in custom_evaluation_dimension_list:
custom_evaluation_dimensions[custom_evaluation_dimension_list.name] = [
CustomEvaluationDimension.get(pk=pk)
for pk in custom_evaluation_dimension_list.dimension_pks
]
return custom_evaluation_dimensions


@app.post("/scenarios/", response_model=str)
async def create_scenario(scenario: EnvironmentProfileWrapper) -> str:
scenario_profile = EnvironmentProfile(**scenario.model_dump())
Expand Down Expand Up @@ -255,6 +281,49 @@ async def create_relationship(relationship: RelationshipWrapper) -> str:
return pk


@app.post("/evaluation_dimensions/", response_model=str)
async def create_evaluation_dimensions(
evaluation_dimensions: CustomEvaluationDimensionsWrapper,
) -> str:
dimension_list = CustomEvaluationDimensionList.find(
CustomEvaluationDimensionList.name == evaluation_dimensions.name
).all()

if len(dimension_list) == 0:
all_dimensions_pks = []
for dimension in evaluation_dimensions.dimensions:
find_dimension = CustomEvaluationDimension.find(
CustomEvaluationDimension.name == dimension.name
).all()
if len(find_dimension) == 0:
dimension.save()
all_dimensions_pks.append(dimension.pk)
elif len(find_dimension) == 1:
all_dimensions_pks.append(find_dimension[0].pk)
else:
raise HTTPException(
status_code=409,
detail=f"Evaluation dimension with name={dimension.name} already exists",
)

custom_evaluation_dimension_list = CustomEvaluationDimensionList(
pk=evaluation_dimensions.pk,
name=evaluation_dimensions.name,
dimension_pks=all_dimensions_pks,
)
custom_evaluation_dimension_list.save()
logger.info(f"Created evaluation dimension list {evaluation_dimensions.name}")
else:
raise HTTPException(
status_code=409,
detail=f"Evaluation dimension list with name={evaluation_dimensions.name} already exists",
)

pk = custom_evaluation_dimension_list.pk
assert pk is not None
return pk


async def run_simulation(
episode_pk: str,
simulation_request: SimulationRequest,
Expand Down Expand Up @@ -426,6 +495,16 @@ async def delete_episode(episode_id: str) -> str:
return episode_id


@app.delete("/evaluation_dimensions/{evaluation_dimension_list_pk}", response_model=str)
async def delete_evaluation_dimension_list(evaluation_dimension_list_pk: str) -> str:
for dimension_pk in CustomEvaluationDimensionList.get(
evaluation_dimension_list_pk
).dimension_pks:
CustomEvaluationDimension.delete(dimension_pk)
CustomEvaluationDimensionList.delete(evaluation_dimension_list_pk)
return evaluation_dimension_list_pk


active_simulations: Dict[
str, bool
] = {} # TODO check whether this is the correct way to store the active simulations
Expand Down
65 changes: 64 additions & 1 deletion tests/ui/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
AgentProfile,
EpisodeLog,
RelationshipProfile,
CustomEvaluationDimension,
CustomEvaluationDimensionList,
)
from sotopia.messages import SimpleMessage
from sotopia.ui.fastapi_server import app
Expand Down Expand Up @@ -68,7 +70,9 @@ def create_dummy_episode_log() -> None:


@pytest.fixture
def create_mock_data(for_posting: bool = False) -> Generator[None, None, None]:
def create_mock_data(request: pytest.FixtureRequest) -> Generator[None, None, None]:
for_posting = request.param if hasattr(request, "param") else False

def _create_mock_agent_profile() -> None:
AgentProfile(
first_name="John",
Expand Down Expand Up @@ -108,10 +112,26 @@ def _create_mock_relationship() -> None:
relationship=1.0,
).save()

def _create_mock_evaluation_dimension() -> None:
CustomEvaluationDimension(
pk="tmppk_evaluation_dimension",
name="test_dimension",
description="test_description",
range_high=10,
range_low=-10,
).save()
CustomEvaluationDimensionList(
pk="tmppk_evaluation_dimension_list",
name="test_dimension_list",
dimension_pks=["tmppk_evaluation_dimension"],
).save()

if not for_posting:
_create_mock_agent_profile()
_create_mock_env_profile()
_create_mock_relationship()
_create_mock_evaluation_dimension()
print("created mock data")
yield

try:
Expand Down Expand Up @@ -147,6 +167,15 @@ def _create_mock_relationship() -> None:
except Exception as e:
print(e)

try:
CustomEvaluationDimension.delete("tmppk_evaluation_dimension")
except Exception as e:
print(e)
try:
CustomEvaluationDimensionList.delete("tmppk_evaluation_dimension_list")
except Exception as e:
print(e)


def test_get_scenarios_all(create_mock_data: Callable[[], None]) -> None:
response = client.get("/scenarios")
Expand Down Expand Up @@ -221,6 +250,13 @@ def test_get_relationship(create_mock_data: Callable[[], None]) -> None:
assert response.json() == "1: know_by_name"


def test_get_evaluation_dimensions(create_mock_data: Callable[[], None]) -> None:
response = client.get("/evaluation_dimensions/")
assert response.status_code == 200
assert isinstance(response.json(), dict)
assert response.json()["test_dimension_list"][0]["name"] == "test_dimension"


@pytest.mark.parametrize("create_mock_data", [True], indirect=True)
def test_create_agent(create_mock_data: Callable[[], None]) -> None:
agent_data = {
Expand Down Expand Up @@ -260,6 +296,27 @@ def test_create_relationship(create_mock_data: Callable[[], None]) -> None:
assert isinstance(response.json(), str)


@pytest.mark.parametrize("create_mock_data", [True], indirect=True)
def test_create_evaluation_dimensions(create_mock_data: Callable[[], None]) -> None:
evaluation_dimension_data = {
"pk": "tmppk_evaluation_dimension_list",
"name": "test_dimension_list",
"dimensions": [
{
"pk": "tmppk_evaluation_dimension",
"name": "test_dimension",
"description": "test_description",
"range_high": 10,
"range_low": -10,
}
],
}
response = client.post("/evaluation_dimensions", json=evaluation_dimension_data)
print(response.json())
assert response.status_code == 200
assert isinstance(response.json(), str)


def test_delete_agent(create_mock_data: Callable[[], None]) -> None:
response = client.delete("/agents/tmppk_agent1")
assert response.status_code == 200
Expand All @@ -278,6 +335,12 @@ def test_delete_relationship(create_mock_data: Callable[[], None]) -> None:
assert isinstance(response.json(), str)


def test_delete_evaluation_dimension(create_mock_data: Callable[[], None]) -> None:
response = client.delete("/evaluation_dimensions/tmppk_evaluation_dimension_list")
assert response.status_code == 200
assert isinstance(response.json(), str)


# def test_simulate(create_mock_data: Callable[[], None]) -> None:
# response = client.post(
# "/simulate",
Expand Down

0 comments on commit 442809a

Please sign in to comment.