Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: FastAPI Implementation of Sotopia Part One (wo websocket) #246

Merged
merged 22 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ groq = ["groq"]
cohere = ["cohere"]
google-generativeai = ["google-generativeai"]
examples = ["transformers", "datasets", "scipy", "torch", "pandas"]
chat = ["fastapi"]
chat = [
"fastapi[standard]",
"websockets>=13.1",
]
test = ["pytest", "pytest-cov", "pytest-asyncio"]

[tool.uv]
Expand Down
10 changes: 10 additions & 0 deletions sotopia/database/persistent_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ class AgentProfile(JsonModel):
secret: str = Field(default_factory=lambda: "")
model_id: str = Field(default_factory=lambda: "")
mbti: str = Field(default_factory=lambda: "")
tag: str = Field(
index=True,
default_factory=lambda: "",
description="The tag of the agent, used for searching, could be convenient to document agent profiles from different works and sources",
)


class EnvironmentProfile(JsonModel):
Expand Down Expand Up @@ -74,6 +79,11 @@ class EnvironmentProfile(JsonModel):
agent_constraint: list[list[str]] | None = Field(
default_factory=lambda: None,
)
tag: str = Field(
index=True,
default_factory=lambda: "",
description="The tag of the environment, used for searching, could be convenient to document environment profiles from different works and sources",
)


class RelationshipProfile(JsonModel):
Expand Down
118 changes: 118 additions & 0 deletions sotopia/ui/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Sotopia UI
> [!CAUTION]
> Work in progress: the API endpoints are being implemented. And will be released in the future major version.

## FastAPI Server

The API server is a FastAPI application that is used to connect the Sotopia UI to the Sotopia backend.
This could also help with other projects that need to connect to the Sotopia backend through HTTP requests.

Here are some initial design of the API server:

### Getting Data from the API Server

#### GET /scenarios

Get all scenarios.

returns:
- scenarios: list[EnvironmentProfile]

#### GET /scenarios/?get_by={id|tag}/{scenario_id|scenario_tag}

Get scenarios by scenario_tag.
parameters:
- get_by: Literal["id", "tag"]
- scenario_id: str or scenario_tag: str
(This scenario tag could be a keyword; so people can search for scenarios by keywords)

returns:
- scenarios: list[EnvironmentProfile]

#### GET /agents

Get all agents.

returns:
- agents: list[AgentProfile]

#### GET /agents/?get_by={id|gender|occupation}/{value}

Get agents by id, gender, or occupation.
parameters:
- get_by: Literal["id", "gender", "occupation"]
- value: str (agent_id, agent_gender, or agent_occupation)

returns:
- agents: list[AgentProfile]


#### GET /episodes/?get_by={id|tag}/{episode_id|episode_tag}

Get episode by episode_tag.
parameters:
- get_by: Literal["id", "tag"]
- episode_id: str or episode_tag: str

returns:
- episodes: list[Episode]


### Sending Data to the API Server

#### POST /agents/

Send agent profile to the API server.
Request Body:
AgentProfile

returns:
- agent_id: str

#### POST /scenarios/

Send scenario profile to the API server.
Request Body:
EnvironmentProfile

returns:
- scenario_id: str

#### DELETE /agents/{agent_id}

Delete agent profile from the API server.

returns:
- agent_id: str

#### DELETE /scenarios/{scenario_id}

Delete scenario profile from the API server.

returns:
- scenario_id: str


### Initiating a new non-streaming simulation episode

#### POST /episodes/

```python
class SimulationEpisodeInitiation(BaseModel):
scenario_id: str
agent_ids: list[str]
episode_tag: str
models: list[str]
```

Send episode profile to the API server.
Request Body:
SimulationEpisodeInitiation

returns:
- episode_id: str (This is the id of the episode that will be used to get the episode data, saved in the redis database)

### Initiating a new interactive streaming simulation episode (this operation will open a websocket connection)

We use the websocket connection to send the simulation step-by-step results to the UI.
Please see an example protocol [here](https://claude.site/artifacts/322011f6-597f-4819-8afb-bf8137dfb56a)
15 changes: 15 additions & 0 deletions sotopia/ui/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .fastapi_server import (
get_scenarios_all,
get_scenarios,
get_agents_all,
get_agents,
get_episodes,
)

__all__ = [
"get_scenarios_all",
"get_scenarios",
"get_agents_all",
"get_agents",
"get_episodes",
]
139 changes: 139 additions & 0 deletions sotopia/ui/fastapi_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from fastapi import FastAPI
from typing import Literal, cast, Dict
from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog
from pydantic import BaseModel
import uvicorn

app = FastAPI()


class AgentProfileWrapper(BaseModel):
"""
Wrapper for AgentProfile to avoid pydantic v2 issues
"""

first_name: str
last_name: str
age: int = 0
occupation: str = ""
gender: str = ""
gender_pronoun: str = ""
public_info: str = ""
big_five: str = ""
moral_values: list[str] = []
schwartz_personal_values: list[str] = []
personality_and_values: str = ""
decision_making_style: str = ""
secret: str = ""
model_id: str = ""
mbti: str = ""
tag: str = ""


class EnvironmentProfileWrapper(BaseModel):
"""
Wrapper for EnvironmentProfile to avoid pydantic v2 issues
"""

codename: str
source: str = ""
scenario: str = ""
agent_goals: list[str] = []
relationship: Literal[0, 1, 2, 3, 4, 5] = 0
age_constraint: str | None = None
occupation_constraint: str | None = None
agent_constraint: list[list[str]] | None = None
tag: str = ""


@app.get("/scenarios", response_model=list[EnvironmentProfile])
async def get_scenarios_all() -> list[EnvironmentProfile]:
return EnvironmentProfile.all()


@app.get("/scenarios/{get_by}/{value}", response_model=list[EnvironmentProfile])
async def get_scenarios(
get_by: Literal["id", "codename"], value: str
) -> list[EnvironmentProfile]:
# Implement logic to fetch scenarios based on the parameters
scenarios: list[EnvironmentProfile] = [] # Replace with actual fetching logic
if get_by == "id":
scenarios.append(EnvironmentProfile.get(pk=value))
elif get_by == "codename":
json_models = EnvironmentProfile.find(
EnvironmentProfile.codename == value
).all()
scenarios.extend(cast(list[EnvironmentProfile], json_models))
return scenarios


@app.get("/agents", response_model=list[AgentProfile])
async def get_agents_all() -> list[AgentProfile]:
return AgentProfile.all()


@app.get("/agents/{get_by}/{value}", response_model=list[AgentProfile])
async def get_agents(
get_by: Literal["id", "gender", "occupation"], value: str
) -> list[AgentProfile]:
agents_profiles: list[AgentProfile] = []
if get_by == "id":
agents_profiles.append(AgentProfile.get(pk=value))
elif get_by == "gender":
json_models = AgentProfile.find(AgentProfile.gender == value).all()
agents_profiles.extend(cast(list[AgentProfile], json_models))
elif get_by == "occupation":
json_models = AgentProfile.find(AgentProfile.occupation == value).all()
agents_profiles.extend(cast(list[AgentProfile], json_models))
return agents_profiles


@app.get("/episodes/{get_by}/{value}", response_model=list[EpisodeLog])
async def get_episodes(get_by: Literal["id", "tag"], value: str) -> list[EpisodeLog]:
episodes: list[EpisodeLog] = []
if get_by == "id":
episodes.append(EpisodeLog.get(pk=value))
elif get_by == "tag":
json_models = EpisodeLog.find(EpisodeLog.tag == value).all()
episodes.extend(cast(list[EpisodeLog], json_models))
return episodes


@app.post("/agents/")
async def create_agent(agent: AgentProfileWrapper) -> str:
agent_profile = AgentProfile(**agent.model_dump())
agent_profile.save()
pk = agent_profile.pk
assert pk is not None
return pk


@app.post("/scenarios/", response_model=str)
async def create_scenario(scenario: EnvironmentProfileWrapper) -> str:
print(scenario)
scenario_profile = EnvironmentProfile(**scenario.model_dump())
scenario_profile.save()
pk = scenario_profile.pk
assert pk is not None
return pk


@app.delete("/agents/{agent_id}", response_model=str)
async def delete_agent(agent_id: str) -> str:
AgentProfile.delete(agent_id)
return agent_id


@app.delete("/scenarios/{scenario_id}", response_model=str)
async def delete_scenario(scenario_id: str) -> str:
EnvironmentProfile.delete(scenario_id)
return scenario_id


active_simulations: Dict[
str, bool
] = {} # TODO check whether this is the correct way to store the active simulations


if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=8800)
Loading