-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: FastAPI Implementation of Sotopia Part One (wo websocket) (#246)
* api doc * add PUT * add an temp example for websocket * websocket * update readme * Update README.md * fastapi server wo websocket * update project toml * add websocket api and a sample client sample * add initial test * post update * finalize the doc * fix the create * finish test * add files * change chata to api * fix mypy error * fix mypy bug * fix mypy * create mock agents * downgrade langchain openai upperbound to fix CI --------- Co-authored-by: Hao Zhu <prokilchu@gmail.com> Co-authored-by: Zhe Su <360307598@qq.com>
- Loading branch information
1 parent
07ee2c4
commit d55ec34
Showing
15 changed files
with
1,342 additions
and
1,778 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
uv run --extra test --extra chat pytest --ignore tests/cli --cov=. --cov-report=xml | ||
uv run --extra test --extra api pytest --ignore tests/cli --cov=. --cov-report=xml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from .fastapi_server import ( | ||
get_scenarios_all, | ||
get_scenarios, | ||
get_agents_all, | ||
get_agents, | ||
get_episodes, | ||
) | ||
|
||
__all__ = [ | ||
"get_scenarios_all", | ||
"get_scenarios", | ||
"get_agents_all", | ||
"get_agents", | ||
"get_episodes", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
from fastapi import FastAPI | ||
from typing import Literal, cast, Dict | ||
from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog | ||
from pydantic import BaseModel | ||
import uvicorn | ||
|
||
app = FastAPI() | ||
|
||
|
||
class AgentProfileWrapper(BaseModel): | ||
""" | ||
Wrapper for AgentProfile to avoid pydantic v2 issues | ||
""" | ||
|
||
first_name: str | ||
last_name: str | ||
age: int = 0 | ||
occupation: str = "" | ||
gender: str = "" | ||
gender_pronoun: str = "" | ||
public_info: str = "" | ||
big_five: str = "" | ||
moral_values: list[str] = [] | ||
schwartz_personal_values: list[str] = [] | ||
personality_and_values: str = "" | ||
decision_making_style: str = "" | ||
secret: str = "" | ||
model_id: str = "" | ||
mbti: str = "" | ||
tag: str = "" | ||
|
||
|
||
class EnvironmentProfileWrapper(BaseModel): | ||
""" | ||
Wrapper for EnvironmentProfile to avoid pydantic v2 issues | ||
""" | ||
|
||
codename: str | ||
source: str = "" | ||
scenario: str = "" | ||
agent_goals: list[str] = [] | ||
relationship: Literal[0, 1, 2, 3, 4, 5] = 0 | ||
age_constraint: str | None = None | ||
occupation_constraint: str | None = None | ||
agent_constraint: list[list[str]] | None = None | ||
tag: str = "" | ||
|
||
|
||
@app.get("/scenarios", response_model=list[EnvironmentProfile]) | ||
async def get_scenarios_all() -> list[EnvironmentProfile]: | ||
return EnvironmentProfile.all() | ||
|
||
|
||
@app.get("/scenarios/{get_by}/{value}", response_model=list[EnvironmentProfile]) | ||
async def get_scenarios( | ||
get_by: Literal["id", "codename"], value: str | ||
) -> list[EnvironmentProfile]: | ||
# Implement logic to fetch scenarios based on the parameters | ||
scenarios: list[EnvironmentProfile] = [] # Replace with actual fetching logic | ||
if get_by == "id": | ||
scenarios.append(EnvironmentProfile.get(pk=value)) | ||
elif get_by == "codename": | ||
json_models = EnvironmentProfile.find( | ||
EnvironmentProfile.codename == value | ||
).all() | ||
scenarios.extend(cast(list[EnvironmentProfile], json_models)) | ||
return scenarios | ||
|
||
|
||
@app.get("/agents", response_model=list[AgentProfile]) | ||
async def get_agents_all() -> list[AgentProfile]: | ||
return AgentProfile.all() | ||
|
||
|
||
@app.get("/agents/{get_by}/{value}", response_model=list[AgentProfile]) | ||
async def get_agents( | ||
get_by: Literal["id", "gender", "occupation"], value: str | ||
) -> list[AgentProfile]: | ||
agents_profiles: list[AgentProfile] = [] | ||
if get_by == "id": | ||
agents_profiles.append(AgentProfile.get(pk=value)) | ||
elif get_by == "gender": | ||
json_models = AgentProfile.find(AgentProfile.gender == value).all() | ||
agents_profiles.extend(cast(list[AgentProfile], json_models)) | ||
elif get_by == "occupation": | ||
json_models = AgentProfile.find(AgentProfile.occupation == value).all() | ||
agents_profiles.extend(cast(list[AgentProfile], json_models)) | ||
return agents_profiles | ||
|
||
|
||
@app.get("/episodes/{get_by}/{value}", response_model=list[EpisodeLog]) | ||
async def get_episodes(get_by: Literal["id", "tag"], value: str) -> list[EpisodeLog]: | ||
episodes: list[EpisodeLog] = [] | ||
if get_by == "id": | ||
episodes.append(EpisodeLog.get(pk=value)) | ||
elif get_by == "tag": | ||
json_models = EpisodeLog.find(EpisodeLog.tag == value).all() | ||
episodes.extend(cast(list[EpisodeLog], json_models)) | ||
return episodes | ||
|
||
|
||
@app.post("/agents/") | ||
async def create_agent(agent: AgentProfileWrapper) -> str: | ||
agent_profile = AgentProfile(**agent.model_dump()) | ||
agent_profile.save() | ||
pk = agent_profile.pk | ||
assert pk is not None | ||
return pk | ||
|
||
|
||
@app.post("/scenarios/", response_model=str) | ||
async def create_scenario(scenario: EnvironmentProfileWrapper) -> str: | ||
print(scenario) | ||
scenario_profile = EnvironmentProfile(**scenario.model_dump()) | ||
scenario_profile.save() | ||
pk = scenario_profile.pk | ||
assert pk is not None | ||
return pk | ||
|
||
|
||
@app.delete("/agents/{agent_id}", response_model=str) | ||
async def delete_agent(agent_id: str) -> str: | ||
AgentProfile.delete(agent_id) | ||
return agent_id | ||
|
||
|
||
@app.delete("/scenarios/{scenario_id}", response_model=str) | ||
async def delete_scenario(scenario_id: str) -> str: | ||
EnvironmentProfile.delete(scenario_id) | ||
return scenario_id | ||
|
||
|
||
active_simulations: Dict[ | ||
str, bool | ||
] = {} # TODO check whether this is the correct way to store the active simulations | ||
|
||
|
||
if __name__ == "__main__": | ||
uvicorn.run(app, host="127.0.0.1", port=8800) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
docker compose -f .devcontainer/docker-compose.yml run --rm -u root -v $(pwd):/workspaces/sotopia devcontainer /bin/sh -c "export UV_PROJECT_ENVIRONMENT=/workspaces/.venv; cd /workspaces/sotopia; uv run --extra test --extra chat pytest tests/experimental" | ||
docker compose -f .devcontainer/docker-compose.yml run --rm -u root -v $(pwd):/workspaces/sotopia devcontainer /bin/sh -c "export UV_PROJECT_ENVIRONMENT=/workspaces/.venv; cd /workspaces/sotopia; uv run --extra test --extra api pytest tests/experimental" |
Oops, something went wrong.