Skip to content

Commit

Permalink
simulate episode non-streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
XuhuiZhou committed Dec 8, 2024
1 parent 5993f09 commit cc77ab5
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 10 deletions.
2 changes: 1 addition & 1 deletion sotopia/database/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _map_gender_to_adj(gender: str) -> str:
"Nonbinary": "nonbinary",
}
if gender:
return gender_to_adj[gender]
return gender_to_adj.get(gender, "")
else:
return ""

Expand Down
2 changes: 1 addition & 1 deletion sotopia/envs/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _map_gender_to_adj(gender: str) -> str:
"Nonbinary": "nonbinary",
}
if gender:
return gender_to_adj[gender]
return gender_to_adj.get(gender, "")
else:
return ""

Expand Down
4 changes: 2 additions & 2 deletions sotopia/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ async def generate_messages() -> (
elif only_return_episode_pk:
async for last_messages in generate_messages():
pass
assert isinstance(last_messages[-1][-1], SimpleMessage)
return last_messages[-1][-1].message
assert isinstance(last_messages[-1][-1][-1], SimpleMessage)
return last_messages[-1][-1][-1].message
else:
async for last_messages in generate_messages():
pass
Expand Down
31 changes: 25 additions & 6 deletions sotopia/ui/fastapi_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
RelationshipType,
)
from sotopia.envs.parallel import ParallelSotopiaEnv
from sotopia.envs.evaluators import (
RuleBasedTerminatedEvaluator,
ReachGoalLLMEvaluator,
EvaluationForTwoAgents,
SotopiaDimensions,
)
from sotopia.server import arun_one_episode
from sotopia.agents import LLMAgent, Agents
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
Expand Down Expand Up @@ -210,24 +216,38 @@ async def create_relationship(relationship: RelationshipWrapper) -> str:

@app.post("/simulate/", response_model=str)
async def simulate(simulate_request: SimulateRequest) -> str:
assert len(simulate_request.models) == 3, "There should be three models"
env_profile: EnvironmentProfile = EnvironmentProfile.get(pk=simulate_request.env_id)
env = ParallelSotopiaEnv(env_profile=env_profile)

env_params: dict[str, Any] = {
"model_name": simulate_request.models[0],
"action_order": "round-robin",
"evaluators": [
RuleBasedTerminatedEvaluator(
max_turn_number=simulate_request.max_turns, max_stale_turn=2
),
],
"terminal_evaluators": [
ReachGoalLLMEvaluator(
simulate_request.models[0],
EvaluationForTwoAgents[SotopiaDimensions],
),
],
}
env = ParallelSotopiaEnv(env_profile=env_profile, **env_params)
agents = Agents(
{
"agent1": LLMAgent(
"agent1",
model_name=simulate_request.models[0],
model_name=simulate_request.models[1],
agent_profile=AgentProfile.get(pk=simulate_request.agent_ids[0]),
),
"agent2": LLMAgent(
"agent2",
model_name=simulate_request.models[1],
model_name=simulate_request.models[2],
agent_profile=AgentProfile.get(pk=simulate_request.agent_ids[1]),
),
}
)

episode_pk = await arun_one_episode(
env=env,
agent_list=list(agents.values()),
Expand All @@ -236,7 +256,6 @@ async def simulate(simulate_request: SimulateRequest) -> str:
tag=simulate_request.tag,
)
assert isinstance(episode_pk, str)
EpisodeLog.delete(episode_pk)
return episode_pk


Expand Down
36 changes: 36 additions & 0 deletions tests/ui/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,18 @@ def _create_mock_relationship() -> None:
except Exception as e:
print(e)

try:
EpisodeLog.delete("tmppk_episode_log")
except Exception as e:
print(e)

try:
episodes = EpisodeLog.find(EpisodeLog.tag == "test_tag").all()
for episode in episodes:
EpisodeLog.delete(episode.pk)
except Exception as e:
print(e)


def test_get_scenarios_all(create_mock_data: Callable[[], None]) -> None:
response = client.get("/scenarios")
Expand Down Expand Up @@ -264,3 +276,27 @@ def test_delete_relationship(create_mock_data: Callable[[], None]) -> None:
response = client.delete("/relationship/tmppk_relationship")
assert response.status_code == 200
assert isinstance(response.json(), str)


def test_simulate(create_mock_data: Callable[[], None]) -> None:
response = client.post(
"/simulate",
json={
"env_id": "tmppk_env_profile",
"agent_ids": ["tmppk_agent1", "tmppk_agent2"],
"models": [
# "custom/llama3.2:1b@http://localhost:8000/v1",
# "custom/llama3.2:1b@http://localhost:8000/v1",
# "custom/llama3.2:1b@http://localhost:8000/v1"
"gpt-4o-mini",
"gpt-4o-mini",
"gpt-4o-mini",
],
"max_turns": 2,
"tag": "test_tag",
},
)
assert response.status_code == 200
assert isinstance(response.json(), str)
episode = EpisodeLog.get(response.json())
print(episode)

0 comments on commit cc77ab5

Please sign in to comment.