Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: FastAPI Implementation of Sotopia Part One (wo websocket) #246

Merged
merged 22 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cli_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install uv
uv sync --extra test --extra chat
uv sync --extra test --extra api
- name: Test with pytest
run: |
uv run pytest tests/cli/test_install.py --cov=. --cov-report=xml
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install uv
uv sync --extra test --extra chat
uv sync --extra test --extra api
- name: Type-checking package with mypy
run: |
# Run this mypy instance against our main package.
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
uv run --extra test --extra chat pytest --ignore tests/cli --cov=. --cov-report=xml
uv run --extra test --extra api pytest --ignore tests/cli --cov=. --cov-report=xml
2 changes: 1 addition & 1 deletion .github/workflows/tests_in_docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- name: Docker Compose
run: docker compose -f .devcontainer/docker-compose.yml up -d
- name: Run tests
run: docker compose -f .devcontainer/docker-compose.yml run --rm -u root -v /home/runner/work/sotopia/sotopia:/workspaces/sotopia devcontainer /bin/sh -c "cd /workspaces/sotopia; ls; uv sync --extra test --extra chat; uv run pytest --ignore tests/cli --cov=. --cov-report=xml"
run: docker compose -f .devcontainer/docker-compose.yml run --rm -u root -v /home/runner/work/sotopia/sotopia:/workspaces/sotopia devcontainer /bin/sh -c "cd /workspaces/sotopia; ls; uv sync --extra test --extra api; uv run pytest --ignore tests/cli --cov=. --cov-report=xml"
- name: Upload coverage report to Codecov
uses: codecov/codecov-action@v4.0.1
with:
Expand Down
2 changes: 1 addition & 1 deletion docs/pages/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ export REDIS_OM_URL="redis://localhost:6379"
```
if you are developing Sotopia using uv, you can sync your dependency with
```bash
uv sync --extra examples --extra chat
uv sync --extra examples --extra api
```
</AccordionContent>
</AccordionItem>
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"together>=0.2.4,<1.4.0",
"pydantic>=2.5.0,<3.0.0",
"beartype>=0.14.0,<0.20.0",
"langchain-openai>=0.1.8,<0.3.0",
"langchain-openai>=0.1.8,<0.2",
"hiredis>=3.0.0",
"aact"
]
Expand All @@ -37,7 +37,7 @@ groq = ["groq"]
cohere = ["cohere"]
google-generativeai = ["google-generativeai"]
examples = ["transformers", "datasets", "scipy", "torch", "pandas"]
chat = [
api = [
"fastapi[standard]",
"uvicorn",
]
Expand Down
6 changes: 4 additions & 2 deletions sotopia-chat/chat_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,10 @@ async def _assign_left_or_right_and_run(session_id: str) -> None:
case 2:
if await r.llen("chat_server_combos_double") == 0:
await gather(
_assign_left_or_right_and_run(session_id)
for session_id in session_ids
*[
_assign_left_or_right_and_run(session_id)
for session_id in session_ids
]
)
else:
agent_env_combo_pk: str = (
Expand Down
10 changes: 10 additions & 0 deletions sotopia/database/persistent_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ class AgentProfile(JsonModel):
secret: str = Field(default_factory=lambda: "")
model_id: str = Field(default_factory=lambda: "")
mbti: str = Field(default_factory=lambda: "")
tag: str = Field(
index=True,
default_factory=lambda: "",
description="The tag of the agent, used for searching, could be convenient to document agent profiles from different works and sources",
)


class EnvironmentProfile(JsonModel):
Expand Down Expand Up @@ -74,6 +79,11 @@ class EnvironmentProfile(JsonModel):
agent_constraint: list[list[str]] | None = Field(
default_factory=lambda: None,
)
tag: str = Field(
index=True,
default_factory=lambda: "",
description="The tag of the environment, used for searching, could be convenient to document environment profiles from different works and sources",
)


class RelationshipProfile(JsonModel):
Expand Down
4 changes: 2 additions & 2 deletions sotopia/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import itertools
import logging
from typing import Literal, Sequence, Type, cast
from typing import Literal, Sequence, Type

import gin
import rich
Expand Down Expand Up @@ -310,7 +310,7 @@ def get_agent_class(
else [await i for i in episode_futures]
)

return cast(list[list[tuple[str, str, Message]]], batch_results)
return batch_results


async def arun_one_script(
Expand Down
22 changes: 4 additions & 18 deletions sotopia/ui/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,6 @@ returns:
- agents: list[AgentProfile]


#### GET /episodes

Get all episodes.

returns:
- episodes: list[Episode]

#### GET /episodes/?get_by={id|tag}/{episode_id|episode_tag}

Get episode by episode_tag.
Expand Down Expand Up @@ -85,23 +78,16 @@ EnvironmentProfile
returns:
- scenario_id: str

### Updating Data in the API Server
#### DELETE /agents/{agent_id}

#### PUT /agents/{agent_id}

Update agent profile in the API server.
Request Body:
AgentProfile
Delete agent profile from the API server.

returns:
- agent_id: str

#### DELETE /scenarios/{scenario_id}

#### PUT /scenarios/{scenario_id}

Update scenario profile in the API server.
Request Body:
EnvironmentProfile
Delete scenario profile from the API server.

returns:
- scenario_id: str
Expand Down
15 changes: 15 additions & 0 deletions sotopia/ui/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .fastapi_server import (
get_scenarios_all,
get_scenarios,
get_agents_all,
get_agents,
get_episodes,
)

__all__ = [
"get_scenarios_all",
"get_scenarios",
"get_agents_all",
"get_agents",
"get_episodes",
]
139 changes: 139 additions & 0 deletions sotopia/ui/fastapi_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from fastapi import FastAPI
from typing import Literal, cast, Dict
from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog
from pydantic import BaseModel
import uvicorn

app = FastAPI()


class AgentProfileWrapper(BaseModel):
"""
Wrapper for AgentProfile to avoid pydantic v2 issues
"""

first_name: str
last_name: str
age: int = 0
occupation: str = ""
gender: str = ""
gender_pronoun: str = ""
public_info: str = ""
big_five: str = ""
moral_values: list[str] = []
schwartz_personal_values: list[str] = []
personality_and_values: str = ""
decision_making_style: str = ""
secret: str = ""
model_id: str = ""
mbti: str = ""
tag: str = ""


class EnvironmentProfileWrapper(BaseModel):
"""
Wrapper for EnvironmentProfile to avoid pydantic v2 issues
"""

codename: str
source: str = ""
scenario: str = ""
agent_goals: list[str] = []
relationship: Literal[0, 1, 2, 3, 4, 5] = 0
age_constraint: str | None = None
occupation_constraint: str | None = None
agent_constraint: list[list[str]] | None = None
tag: str = ""


@app.get("/scenarios", response_model=list[EnvironmentProfile])
async def get_scenarios_all() -> list[EnvironmentProfile]:
return EnvironmentProfile.all()


@app.get("/scenarios/{get_by}/{value}", response_model=list[EnvironmentProfile])
async def get_scenarios(
get_by: Literal["id", "codename"], value: str
) -> list[EnvironmentProfile]:
# Implement logic to fetch scenarios based on the parameters
scenarios: list[EnvironmentProfile] = [] # Replace with actual fetching logic
if get_by == "id":
scenarios.append(EnvironmentProfile.get(pk=value))
elif get_by == "codename":
json_models = EnvironmentProfile.find(
EnvironmentProfile.codename == value
).all()
scenarios.extend(cast(list[EnvironmentProfile], json_models))
return scenarios


@app.get("/agents", response_model=list[AgentProfile])
async def get_agents_all() -> list[AgentProfile]:
return AgentProfile.all()


@app.get("/agents/{get_by}/{value}", response_model=list[AgentProfile])
async def get_agents(
get_by: Literal["id", "gender", "occupation"], value: str
) -> list[AgentProfile]:
agents_profiles: list[AgentProfile] = []
if get_by == "id":
agents_profiles.append(AgentProfile.get(pk=value))
elif get_by == "gender":
json_models = AgentProfile.find(AgentProfile.gender == value).all()
agents_profiles.extend(cast(list[AgentProfile], json_models))
elif get_by == "occupation":
json_models = AgentProfile.find(AgentProfile.occupation == value).all()
agents_profiles.extend(cast(list[AgentProfile], json_models))
return agents_profiles


@app.get("/episodes/{get_by}/{value}", response_model=list[EpisodeLog])
async def get_episodes(get_by: Literal["id", "tag"], value: str) -> list[EpisodeLog]:
episodes: list[EpisodeLog] = []
if get_by == "id":
episodes.append(EpisodeLog.get(pk=value))
elif get_by == "tag":
json_models = EpisodeLog.find(EpisodeLog.tag == value).all()
episodes.extend(cast(list[EpisodeLog], json_models))
return episodes


@app.post("/agents/")
async def create_agent(agent: AgentProfileWrapper) -> str:
agent_profile = AgentProfile(**agent.model_dump())
agent_profile.save()
pk = agent_profile.pk
assert pk is not None
return pk


@app.post("/scenarios/", response_model=str)
async def create_scenario(scenario: EnvironmentProfileWrapper) -> str:
print(scenario)
scenario_profile = EnvironmentProfile(**scenario.model_dump())
scenario_profile.save()
pk = scenario_profile.pk
assert pk is not None
return pk


@app.delete("/agents/{agent_id}", response_model=str)
async def delete_agent(agent_id: str) -> str:
AgentProfile.delete(agent_id)
return agent_id


@app.delete("/scenarios/{scenario_id}", response_model=str)
async def delete_scenario(scenario_id: str) -> str:
EnvironmentProfile.delete(scenario_id)
return scenario_id


active_simulations: Dict[
str, bool
] = {} # TODO check whether this is the correct way to store the active simulations


if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=8800)
2 changes: 1 addition & 1 deletion tests/tests.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
docker compose -f .devcontainer/docker-compose.yml run --rm -u root -v $(pwd):/workspaces/sotopia devcontainer /bin/sh -c "export UV_PROJECT_ENVIRONMENT=/workspaces/.venv; cd /workspaces/sotopia; uv run --extra test --extra chat pytest tests/experimental"
docker compose -f .devcontainer/docker-compose.yml run --rm -u root -v $(pwd):/workspaces/sotopia devcontainer /bin/sh -c "export UV_PROJECT_ENVIRONMENT=/workspaces/.venv; cd /workspaces/sotopia; uv run --extra test --extra api pytest tests/experimental"
Loading
Loading