From 442809a4204706504056e1f9ef7eee21c239e03b Mon Sep 17 00:00:00 2001 From: XuhuiZhou Date: Fri, 13 Dec 2024 18:14:40 +0000 Subject: [PATCH] add custom eval fast api --- examples/use_custom_dimensions.py | 1 - sotopia/ui/fastapi_server.py | 81 ++++++++++++++++++++++++++++++- tests/ui/test_fastapi.py | 65 ++++++++++++++++++++++++- 3 files changed, 144 insertions(+), 3 deletions(-) diff --git a/examples/use_custom_dimensions.py b/examples/use_custom_dimensions.py index 0bbdfda8..ff712224 100644 --- a/examples/use_custom_dimensions.py +++ b/examples/use_custom_dimensions.py @@ -45,7 +45,6 @@ def save_dimensions(dimensions: list[dict[str, Union[str, int]]]) -> None: def save_dimension_list( dimensions: list[dict[str, Union[str, int]]], list_name: str ) -> None: - Migrator().run() dimension_list = CustomEvaluationDimensionList.find( CustomEvaluationDimensionList.name == list_name ).all() diff --git a/sotopia/ui/fastapi_server.py b/sotopia/ui/fastapi_server.py index 611878c4..9bd0d4f3 100644 --- a/sotopia/ui/fastapi_server.py +++ b/sotopia/ui/fastapi_server.py @@ -15,6 +15,8 @@ RelationshipProfile, RelationshipType, NonStreamingSimulationStatus, + CustomEvaluationDimensionList, + CustomEvaluationDimension, ) from sotopia.envs.parallel import ParallelSotopiaEnv from sotopia.envs.evaluators import ( @@ -33,7 +35,7 @@ ) from typing import Optional, Any from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel, model_validator, field_validator +from pydantic import BaseModel, model_validator, field_validator, Field from sotopia.ui.websocket_utils import ( WebSocketSotopiaSimulator, @@ -112,6 +114,16 @@ class EnvironmentProfileWrapper(BaseModel): tag: str = "" +class CustomEvaluationDimensionsWrapper(BaseModel): + pk: str = "" + name: str = Field( + default="", description="The name of the custom evaluation dimension list" + ) + dimensions: list[CustomEvaluationDimension] = Field( + default=[], description="The dimensions of the custom evaluation dimension list" + ) + + class SimulationRequest(BaseModel): env_id: str agent_ids: list[str] @@ -228,6 +240,20 @@ async def get_episodes(get_by: Literal["id", "tag"], value: str) -> list[Episode return episodes +@app.get( + "/evaluation_dimensions/", response_model=dict[str, list[CustomEvaluationDimension]] +) +async def get_evaluation_dimensions() -> dict[str, list[CustomEvaluationDimension]]: + custom_evaluation_dimensions: dict[str, list[CustomEvaluationDimension]] = {} + custom_evaluation_dimension_list = CustomEvaluationDimensionList.all() + for custom_evaluation_dimension_list in custom_evaluation_dimension_list: + custom_evaluation_dimensions[custom_evaluation_dimension_list.name] = [ + CustomEvaluationDimension.get(pk=pk) + for pk in custom_evaluation_dimension_list.dimension_pks + ] + return custom_evaluation_dimensions + + @app.post("/scenarios/", response_model=str) async def create_scenario(scenario: EnvironmentProfileWrapper) -> str: scenario_profile = EnvironmentProfile(**scenario.model_dump()) @@ -255,6 +281,49 @@ async def create_relationship(relationship: RelationshipWrapper) -> str: return pk +@app.post("/evaluation_dimensions/", response_model=str) +async def create_evaluation_dimensions( + evaluation_dimensions: CustomEvaluationDimensionsWrapper, +) -> str: + dimension_list = CustomEvaluationDimensionList.find( + CustomEvaluationDimensionList.name == evaluation_dimensions.name + ).all() + + if len(dimension_list) == 0: + all_dimensions_pks = [] + for dimension in evaluation_dimensions.dimensions: + find_dimension = CustomEvaluationDimension.find( + CustomEvaluationDimension.name == dimension.name + ).all() + if len(find_dimension) == 0: + dimension.save() + all_dimensions_pks.append(dimension.pk) + elif len(find_dimension) == 1: + all_dimensions_pks.append(find_dimension[0].pk) + else: + raise HTTPException( + status_code=409, + detail=f"Evaluation dimension with name={dimension.name} already exists", + ) + + custom_evaluation_dimension_list = CustomEvaluationDimensionList( + pk=evaluation_dimensions.pk, + name=evaluation_dimensions.name, + dimension_pks=all_dimensions_pks, + ) + custom_evaluation_dimension_list.save() + logger.info(f"Created evaluation dimension list {evaluation_dimensions.name}") + else: + raise HTTPException( + status_code=409, + detail=f"Evaluation dimension list with name={evaluation_dimensions.name} already exists", + ) + + pk = custom_evaluation_dimension_list.pk + assert pk is not None + return pk + + async def run_simulation( episode_pk: str, simulation_request: SimulationRequest, @@ -426,6 +495,16 @@ async def delete_episode(episode_id: str) -> str: return episode_id +@app.delete("/evaluation_dimensions/{evaluation_dimension_list_pk}", response_model=str) +async def delete_evaluation_dimension_list(evaluation_dimension_list_pk: str) -> str: + for dimension_pk in CustomEvaluationDimensionList.get( + evaluation_dimension_list_pk + ).dimension_pks: + CustomEvaluationDimension.delete(dimension_pk) + CustomEvaluationDimensionList.delete(evaluation_dimension_list_pk) + return evaluation_dimension_list_pk + + active_simulations: Dict[ str, bool ] = {} # TODO check whether this is the correct way to store the active simulations diff --git a/tests/ui/test_fastapi.py b/tests/ui/test_fastapi.py index ac00eacf..0ce87278 100644 --- a/tests/ui/test_fastapi.py +++ b/tests/ui/test_fastapi.py @@ -4,6 +4,8 @@ AgentProfile, EpisodeLog, RelationshipProfile, + CustomEvaluationDimension, + CustomEvaluationDimensionList, ) from sotopia.messages import SimpleMessage from sotopia.ui.fastapi_server import app @@ -68,7 +70,9 @@ def create_dummy_episode_log() -> None: @pytest.fixture -def create_mock_data(for_posting: bool = False) -> Generator[None, None, None]: +def create_mock_data(request: pytest.FixtureRequest) -> Generator[None, None, None]: + for_posting = request.param if hasattr(request, "param") else False + def _create_mock_agent_profile() -> None: AgentProfile( first_name="John", @@ -108,10 +112,26 @@ def _create_mock_relationship() -> None: relationship=1.0, ).save() + def _create_mock_evaluation_dimension() -> None: + CustomEvaluationDimension( + pk="tmppk_evaluation_dimension", + name="test_dimension", + description="test_description", + range_high=10, + range_low=-10, + ).save() + CustomEvaluationDimensionList( + pk="tmppk_evaluation_dimension_list", + name="test_dimension_list", + dimension_pks=["tmppk_evaluation_dimension"], + ).save() + if not for_posting: _create_mock_agent_profile() _create_mock_env_profile() _create_mock_relationship() + _create_mock_evaluation_dimension() + print("created mock data") yield try: @@ -147,6 +167,15 @@ def _create_mock_relationship() -> None: except Exception as e: print(e) + try: + CustomEvaluationDimension.delete("tmppk_evaluation_dimension") + except Exception as e: + print(e) + try: + CustomEvaluationDimensionList.delete("tmppk_evaluation_dimension_list") + except Exception as e: + print(e) + def test_get_scenarios_all(create_mock_data: Callable[[], None]) -> None: response = client.get("/scenarios") @@ -221,6 +250,13 @@ def test_get_relationship(create_mock_data: Callable[[], None]) -> None: assert response.json() == "1: know_by_name" +def test_get_evaluation_dimensions(create_mock_data: Callable[[], None]) -> None: + response = client.get("/evaluation_dimensions/") + assert response.status_code == 200 + assert isinstance(response.json(), dict) + assert response.json()["test_dimension_list"][0]["name"] == "test_dimension" + + @pytest.mark.parametrize("create_mock_data", [True], indirect=True) def test_create_agent(create_mock_data: Callable[[], None]) -> None: agent_data = { @@ -260,6 +296,27 @@ def test_create_relationship(create_mock_data: Callable[[], None]) -> None: assert isinstance(response.json(), str) +@pytest.mark.parametrize("create_mock_data", [True], indirect=True) +def test_create_evaluation_dimensions(create_mock_data: Callable[[], None]) -> None: + evaluation_dimension_data = { + "pk": "tmppk_evaluation_dimension_list", + "name": "test_dimension_list", + "dimensions": [ + { + "pk": "tmppk_evaluation_dimension", + "name": "test_dimension", + "description": "test_description", + "range_high": 10, + "range_low": -10, + } + ], + } + response = client.post("/evaluation_dimensions", json=evaluation_dimension_data) + print(response.json()) + assert response.status_code == 200 + assert isinstance(response.json(), str) + + def test_delete_agent(create_mock_data: Callable[[], None]) -> None: response = client.delete("/agents/tmppk_agent1") assert response.status_code == 200 @@ -278,6 +335,12 @@ def test_delete_relationship(create_mock_data: Callable[[], None]) -> None: assert isinstance(response.json(), str) +def test_delete_evaluation_dimension(create_mock_data: Callable[[], None]) -> None: + response = client.delete("/evaluation_dimensions/tmppk_evaluation_dimension_list") + assert response.status_code == 200 + assert isinstance(response.json(), str) + + # def test_simulate(create_mock_data: Callable[[], None]) -> None: # response = client.post( # "/simulate",