Skip to content

Commit

Permalink
update relationship api
Browse files Browse the repository at this point in the history
  • Loading branch information
XuhuiZhou committed Dec 8, 2024
1 parent 84e45fd commit 5993f09
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 81 deletions.
5 changes: 5 additions & 0 deletions sotopia/database/persistent_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ class RelationshipProfile(JsonModel):
description="0 means stranger, 1 means know_by_name, 2 means acquaintance, 3 means friend, 4 means romantic_relationship, 5 means family_member",
) # this could be improved by limiting str to a relationship Enum
background_story: str | None = Field(default_factory=lambda: None)
tag: str = Field(
index=True,
default_factory=lambda: "",
description="The tag of the relationship, used for searching, could be convenient to document relationship profiles from different works and sources",
)


class EnvironmentList(JsonModel):
Expand Down
38 changes: 10 additions & 28 deletions sotopia/ui/fastapi_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
AgentProfile,
EpisodeLog,
RelationshipProfile,
RelationshipType,
)
from sotopia.envs.parallel import ParallelSotopiaEnv
from sotopia.server import arun_one_episode
Expand Down Expand Up @@ -43,6 +44,7 @@ class RelationshipWrapper(BaseModel):
agent_2_id: str = ""
relationship: Literal[0, 1, 2, 3, 4, 5] = 0
backstory: str = ""
tag: str = ""


class AgentProfileWrapper(BaseModel):
Expand Down Expand Up @@ -89,6 +91,7 @@ class SimulateRequest(BaseModel):
agent_ids: list[str]
models: list[str]
max_turns: int
tag: str


@app.get("/scenarios", response_model=list[EnvironmentProfile])
Expand Down Expand Up @@ -154,7 +157,7 @@ async def get_relationship(agent_1_id: str, agent_2_id: str) -> str:
assert len(relationship_profiles) == 1
relationship_profile = relationship_profiles[0]
assert isinstance(relationship_profile, RelationshipProfile)
return str(relationship_profile.relationship)
return f"{str(relationship_profile.relationship)}: {RelationshipType(relationship_profile.relationship).name}"


@app.get("/episodes", response_model=list[EpisodeLog])
Expand Down Expand Up @@ -226,38 +229,17 @@ async def simulate(simulate_request: SimulateRequest) -> str:
)

episode_pk = await arun_one_episode(
env=env, agent_list=list(agents.values()), only_return_episode_pk=True
env=env,
agent_list=list(agents.values()),
only_return_episode_pk=True,
push_to_db=True,
tag=simulate_request.tag,
)
assert isinstance(episode_pk, str)
EpisodeLog.delete(episode_pk)
return episode_pk


@app.put("/agents/{agent_id}", response_model=str)
async def update_agent(agent_id: str, agent: AgentProfileWrapper) -> str:
try:
old_agent = AgentProfile.get(pk=agent_id)
except Exception: # TODO Check the exception type
raise HTTPException(
status_code=404, detail=f"Agent with id={agent_id} not found"
)
old_agent.update(**agent.model_dump()) # type: ignore
assert old_agent.pk is not None
return old_agent.pk


@app.put("/scenarios/{scenario_id}", response_model=str)
async def update_scenario(scenario_id: str, scenario: EnvironmentProfileWrapper) -> str:
try:
old_scenario = EnvironmentProfile.get(pk=scenario_id)
except Exception: # TODO Check the exception type
raise HTTPException(
status_code=404, detail=f"Scenario with id={scenario_id} not found"
)
old_scenario.update(**scenario.model_dump()) # type: ignore
assert old_scenario.pk is not None
return old_scenario.pk


@app.delete("/agents/{agent_id}", response_model=str)
async def delete_agent(agent_id: str) -> str:
try:
Expand Down
78 changes: 71 additions & 7 deletions tests/ui/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from fastapi.testclient import TestClient
from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog
from sotopia.database import (
EnvironmentProfile,
AgentProfile,
EpisodeLog,
RelationshipProfile,
)
from sotopia.messages import SimpleMessage
from sotopia.ui.fastapi_server import app
import pytest
Expand Down Expand Up @@ -63,21 +68,23 @@ def create_dummy_episode_log() -> None:


@pytest.fixture
def create_mock_data() -> Generator[None, None, None]:
def create_mock_data(for_posting: bool = False) -> Generator[None, None, None]:
def _create_mock_agent_profile() -> None:
AgentProfile(
first_name="John",
last_name="Doe",
occupation="test_occupation",
gender="test_gender",
pk="tmppk_agent1",
tag="test_tag",
).save()
AgentProfile(
first_name="Jane",
last_name="Doe",
occupation="test_occupation",
gender="test_gender",
pk="tmppk_agent2",
tag="test_tag",
).save()

def _create_mock_env_profile() -> None:
Expand All @@ -89,18 +96,43 @@ def _create_mock_env_profile() -> None:
"C",
],
pk="tmppk_env_profile",
tag="test_tag",
)
env_profile.save()

def _create_mock_relationship() -> None:
RelationshipProfile(
pk="tmppk_relationship",
agent_1_id="tmppk_agent1",
agent_2_id="tmppk_agent2",
relationship=1.0,
).save()

_create_mock_agent_profile()
_create_mock_env_profile()

_create_mock_relationship()
yield

AgentProfile.delete("tmppk_agent1")
AgentProfile.delete("tmppk_agent2")
EnvironmentProfile.delete("tmppk_env_profile")
EpisodeLog.delete("tmppk_episode_log")
try:
AgentProfile.delete("tmppk_agent1")
except Exception as e:
print(e)
try:
AgentProfile.delete("tmppk_agent2")
except Exception as e:
print(e)
try:
EnvironmentProfile.delete("tmppk_env_profile")
except Exception as e:
print(e)
try:
RelationshipProfile.delete("tmppk_relationship")
except Exception as e:
print(e)
try:
EpisodeLog.delete("tmppk_episode_log")
except Exception as e:
print(e)


def test_get_scenarios_all(create_mock_data: Callable[[], None]) -> None:
Expand Down Expand Up @@ -169,8 +201,17 @@ def test_get_episodes_by_tag(create_mock_data: Callable[[], None]) -> None:
assert response.json()[0]["tag"] == tag


def test_get_relationship(create_mock_data: Callable[[], None]) -> None:
response = client.get("/relationship/tmppk_agent1/tmppk_agent2")
assert response.status_code == 200
assert isinstance(response.json(), str)
assert response.json() == "1: know_by_name"


@pytest.mark.parametrize("create_mock_data", [True], indirect=True)
def test_create_agent(create_mock_data: Callable[[], None]) -> None:
agent_data = {
"pk": "tmppk_agent1",
"first_name": "test_first_name",
"last_name": "test_last_name",
}
Expand All @@ -179,13 +220,30 @@ def test_create_agent(create_mock_data: Callable[[], None]) -> None:
assert isinstance(response.json(), str)


@pytest.mark.parametrize("create_mock_data", [True], indirect=True)
def test_create_scenario(create_mock_data: Callable[[], None]) -> None:
scenario_data = {
"pk": "tmppk_env_profile",
"codename": "test_codename",
"scenario": "test_scenario",
"tag": "test",
}
response = client.post("/scenarios/", json=scenario_data)
EnvironmentProfile.delete("tmppk_env_profile")
assert response.status_code == 200
assert isinstance(response.json(), str)


@pytest.mark.parametrize("create_mock_data", [True], indirect=True)
def test_create_relationship(create_mock_data: Callable[[], None]) -> None:
relationship_data = {
"pk": "tmppk_relationship",
"agent_1_id": "tmppk_agent1",
"agent_2_id": "tmppk_agent2",
"relationship": 1.0,
"tag": "test_tag",
}
response = client.post("/relationship", json=relationship_data)
assert response.status_code == 200
assert isinstance(response.json(), str)

Expand All @@ -200,3 +258,9 @@ def test_delete_scenario(create_mock_data: Callable[[], None]) -> None:
response = client.delete("/scenarios/tmppk_env_profile")
assert response.status_code == 200
assert isinstance(response.json(), str)


def test_delete_relationship(create_mock_data: Callable[[], None]) -> None:
response = client.delete("/relationship/tmppk_relationship")
assert response.status_code == 200
assert isinstance(response.json(), str)
Loading

0 comments on commit 5993f09

Please sign in to comment.