Skip to content

Commit

Permalink
feat: FastAPI Implementation of Sotopia Part One (wo websocket) (#246)
Browse files Browse the repository at this point in the history
* api doc

* add PUT

* add an temp example for websocket

* websocket

* update readme

* Update README.md

* fastapi server wo websocket

* update project toml

* add websocket api and a sample client sample

* add initial test

* post update

* finalize the doc

* fix the create

* finish test

* add files

* change chata to api

* fix mypy error

* fix mypy bug

* fix mypy

* create mock agents

* downgrade langchain openai upperbound to fix CI

---------

Co-authored-by: Hao Zhu <prokilchu@gmail.com>
Co-authored-by: Zhe Su <360307598@qq.com>
  • Loading branch information
3 people authored Nov 26, 2024
1 parent 07ee2c4 commit d55ec34
Show file tree
Hide file tree
Showing 15 changed files with 1,342 additions and 1,778 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cli_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install uv
uv sync --extra test --extra chat
uv sync --extra test --extra api
- name: Test with pytest
run: |
uv run pytest tests/cli/test_install.py --cov=. --cov-report=xml
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install uv
uv sync --extra test --extra chat
uv sync --extra test --extra api
- name: Type-checking package with mypy
run: |
# Run this mypy instance against our main package.
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
uv run --extra test --extra chat pytest --ignore tests/cli --cov=. --cov-report=xml
uv run --extra test --extra api pytest --ignore tests/cli --cov=. --cov-report=xml
2 changes: 1 addition & 1 deletion .github/workflows/tests_in_docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- name: Docker Compose
run: docker compose -f .devcontainer/docker-compose.yml up -d
- name: Run tests
run: docker compose -f .devcontainer/docker-compose.yml run --rm -u root -v /home/runner/work/sotopia/sotopia:/workspaces/sotopia devcontainer /bin/sh -c "cd /workspaces/sotopia; ls; uv sync --extra test --extra chat; uv run pytest --ignore tests/cli --cov=. --cov-report=xml"
run: docker compose -f .devcontainer/docker-compose.yml run --rm -u root -v /home/runner/work/sotopia/sotopia:/workspaces/sotopia devcontainer /bin/sh -c "cd /workspaces/sotopia; ls; uv sync --extra test --extra api; uv run pytest --ignore tests/cli --cov=. --cov-report=xml"
- name: Upload coverage report to Codecov
uses: codecov/codecov-action@v4.0.1
with:
Expand Down
2 changes: 1 addition & 1 deletion docs/pages/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ export REDIS_OM_URL="redis://localhost:6379"
```
if you are developing Sotopia using uv, you can sync your dependency with
```bash
uv sync --extra examples --extra chat
uv sync --extra examples --extra api
```
</AccordionContent>
</AccordionItem>
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"together>=0.2.4,<1.4.0",
"pydantic>=2.5.0,<3.0.0",
"beartype>=0.14.0,<0.20.0",
"langchain-openai>=0.1.8,<0.3.0",
"langchain-openai>=0.1.8,<0.2",
"hiredis>=3.0.0",
"aact"
]
Expand All @@ -37,7 +37,7 @@ groq = ["groq"]
cohere = ["cohere"]
google-generativeai = ["google-generativeai"]
examples = ["transformers", "datasets", "scipy", "torch", "pandas"]
chat = [
api = [
"fastapi[standard]",
"uvicorn",
]
Expand Down
6 changes: 4 additions & 2 deletions sotopia-chat/chat_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,10 @@ async def _assign_left_or_right_and_run(session_id: str) -> None:
case 2:
if await r.llen("chat_server_combos_double") == 0:
await gather(
_assign_left_or_right_and_run(session_id)
for session_id in session_ids
*[
_assign_left_or_right_and_run(session_id)
for session_id in session_ids
]
)
else:
agent_env_combo_pk: str = (
Expand Down
10 changes: 10 additions & 0 deletions sotopia/database/persistent_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ class AgentProfile(JsonModel):
secret: str = Field(default_factory=lambda: "")
model_id: str = Field(default_factory=lambda: "")
mbti: str = Field(default_factory=lambda: "")
tag: str = Field(
index=True,
default_factory=lambda: "",
description="The tag of the agent, used for searching, could be convenient to document agent profiles from different works and sources",
)


class EnvironmentProfile(JsonModel):
Expand Down Expand Up @@ -74,6 +79,11 @@ class EnvironmentProfile(JsonModel):
agent_constraint: list[list[str]] | None = Field(
default_factory=lambda: None,
)
tag: str = Field(
index=True,
default_factory=lambda: "",
description="The tag of the environment, used for searching, could be convenient to document environment profiles from different works and sources",
)


class RelationshipProfile(JsonModel):
Expand Down
4 changes: 2 additions & 2 deletions sotopia/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import itertools
import logging
from typing import Literal, Sequence, Type, cast
from typing import Literal, Sequence, Type

import gin
import rich
Expand Down Expand Up @@ -310,7 +310,7 @@ def get_agent_class(
else [await i for i in episode_futures]
)

return cast(list[list[tuple[str, str, Message]]], batch_results)
return batch_results


async def arun_one_script(
Expand Down
22 changes: 4 additions & 18 deletions sotopia/ui/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,6 @@ returns:
- agents: list[AgentProfile]


#### GET /episodes

Get all episodes.

returns:
- episodes: list[Episode]

#### GET /episodes/?get_by={id|tag}/{episode_id|episode_tag}

Get episode by episode_tag.
Expand Down Expand Up @@ -85,23 +78,16 @@ EnvironmentProfile
returns:
- scenario_id: str

### Updating Data in the API Server
#### DELETE /agents/{agent_id}

#### PUT /agents/{agent_id}

Update agent profile in the API server.
Request Body:
AgentProfile
Delete agent profile from the API server.

returns:
- agent_id: str

#### DELETE /scenarios/{scenario_id}

#### PUT /scenarios/{scenario_id}

Update scenario profile in the API server.
Request Body:
EnvironmentProfile
Delete scenario profile from the API server.

returns:
- scenario_id: str
Expand Down
15 changes: 15 additions & 0 deletions sotopia/ui/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .fastapi_server import (
get_scenarios_all,
get_scenarios,
get_agents_all,
get_agents,
get_episodes,
)

__all__ = [
"get_scenarios_all",
"get_scenarios",
"get_agents_all",
"get_agents",
"get_episodes",
]
139 changes: 139 additions & 0 deletions sotopia/ui/fastapi_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from fastapi import FastAPI
from typing import Literal, cast, Dict
from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog
from pydantic import BaseModel
import uvicorn

app = FastAPI()


class AgentProfileWrapper(BaseModel):
"""
Wrapper for AgentProfile to avoid pydantic v2 issues
"""

first_name: str
last_name: str
age: int = 0
occupation: str = ""
gender: str = ""
gender_pronoun: str = ""
public_info: str = ""
big_five: str = ""
moral_values: list[str] = []
schwartz_personal_values: list[str] = []
personality_and_values: str = ""
decision_making_style: str = ""
secret: str = ""
model_id: str = ""
mbti: str = ""
tag: str = ""


class EnvironmentProfileWrapper(BaseModel):
"""
Wrapper for EnvironmentProfile to avoid pydantic v2 issues
"""

codename: str
source: str = ""
scenario: str = ""
agent_goals: list[str] = []
relationship: Literal[0, 1, 2, 3, 4, 5] = 0
age_constraint: str | None = None
occupation_constraint: str | None = None
agent_constraint: list[list[str]] | None = None
tag: str = ""


@app.get("/scenarios", response_model=list[EnvironmentProfile])
async def get_scenarios_all() -> list[EnvironmentProfile]:
return EnvironmentProfile.all()


@app.get("/scenarios/{get_by}/{value}", response_model=list[EnvironmentProfile])
async def get_scenarios(
get_by: Literal["id", "codename"], value: str
) -> list[EnvironmentProfile]:
# Implement logic to fetch scenarios based on the parameters
scenarios: list[EnvironmentProfile] = [] # Replace with actual fetching logic
if get_by == "id":
scenarios.append(EnvironmentProfile.get(pk=value))
elif get_by == "codename":
json_models = EnvironmentProfile.find(
EnvironmentProfile.codename == value
).all()
scenarios.extend(cast(list[EnvironmentProfile], json_models))
return scenarios


@app.get("/agents", response_model=list[AgentProfile])
async def get_agents_all() -> list[AgentProfile]:
return AgentProfile.all()


@app.get("/agents/{get_by}/{value}", response_model=list[AgentProfile])
async def get_agents(
get_by: Literal["id", "gender", "occupation"], value: str
) -> list[AgentProfile]:
agents_profiles: list[AgentProfile] = []
if get_by == "id":
agents_profiles.append(AgentProfile.get(pk=value))
elif get_by == "gender":
json_models = AgentProfile.find(AgentProfile.gender == value).all()
agents_profiles.extend(cast(list[AgentProfile], json_models))
elif get_by == "occupation":
json_models = AgentProfile.find(AgentProfile.occupation == value).all()
agents_profiles.extend(cast(list[AgentProfile], json_models))
return agents_profiles


@app.get("/episodes/{get_by}/{value}", response_model=list[EpisodeLog])
async def get_episodes(get_by: Literal["id", "tag"], value: str) -> list[EpisodeLog]:
episodes: list[EpisodeLog] = []
if get_by == "id":
episodes.append(EpisodeLog.get(pk=value))
elif get_by == "tag":
json_models = EpisodeLog.find(EpisodeLog.tag == value).all()
episodes.extend(cast(list[EpisodeLog], json_models))
return episodes


@app.post("/agents/")
async def create_agent(agent: AgentProfileWrapper) -> str:
agent_profile = AgentProfile(**agent.model_dump())
agent_profile.save()
pk = agent_profile.pk
assert pk is not None
return pk


@app.post("/scenarios/", response_model=str)
async def create_scenario(scenario: EnvironmentProfileWrapper) -> str:
print(scenario)
scenario_profile = EnvironmentProfile(**scenario.model_dump())
scenario_profile.save()
pk = scenario_profile.pk
assert pk is not None
return pk


@app.delete("/agents/{agent_id}", response_model=str)
async def delete_agent(agent_id: str) -> str:
AgentProfile.delete(agent_id)
return agent_id


@app.delete("/scenarios/{scenario_id}", response_model=str)
async def delete_scenario(scenario_id: str) -> str:
EnvironmentProfile.delete(scenario_id)
return scenario_id


active_simulations: Dict[
str, bool
] = {} # TODO check whether this is the correct way to store the active simulations


if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=8800)
2 changes: 1 addition & 1 deletion tests/tests.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
docker compose -f .devcontainer/docker-compose.yml run --rm -u root -v $(pwd):/workspaces/sotopia devcontainer /bin/sh -c "export UV_PROJECT_ENVIRONMENT=/workspaces/.venv; cd /workspaces/sotopia; uv run --extra test --extra chat pytest tests/experimental"
docker compose -f .devcontainer/docker-compose.yml run --rm -u root -v $(pwd):/workspaces/sotopia devcontainer /bin/sh -c "export UV_PROJECT_ENVIRONMENT=/workspaces/.venv; cd /workspaces/sotopia; uv run --extra test --extra api pytest tests/experimental"
Loading

0 comments on commit d55ec34

Please sign in to comment.