Skip to content

Commit

Permalink
Feat/addtional fast apis for non-streaming simulation and managing re…
Browse files Browse the repository at this point in the history
…lationshio (#265)

* temp run

* add relationship api

* fix mypy error

* update relationship api

* simulate episode non-streaming

* modify sim episodes

* add simulation status

* task error

* add background task

* [autofix.ci] apply automated fixes

* back to arun one episode

* upload the code

* use rq to execute background tasks

* temp sol

---------

Co-authored-by: Hao Zhu <prokilchu@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 11, 2024
1 parent 5a9f4b7 commit dea25d3
Show file tree
Hide file tree
Showing 13 changed files with 536 additions and 93 deletions.
18 changes: 9 additions & 9 deletions examples/experimental/websocket/websocket_test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@


class WebSocketClient:
def __init__(self, uri: str, token: str, client_id: int):
self.uri = uri
def __init__(self, url: str, token: str, client_id: int):
self.url = url
self.token = token
self.client_id = client_id
self.message_file = Path(f"message_{client_id}.txt")
Expand All @@ -25,11 +25,11 @@ async def save_message(self, message: str) -> None:

async def connect(self) -> None:
"""Establish and maintain websocket connection"""
uri_with_token = f"{self.uri}?token=test_token_{self.client_id}"
url_with_token = f"{self.url}?token=test_token_{self.client_id}"

try:
async with websockets.connect(uri_with_token) as websocket:
print(f"Client {self.client_id}: Connected to {self.uri}")
async with websockets.connect(url_with_token) as websocket:
print(f"Client {self.client_id}: Connected to {self.url}")

# Send initial message
# Note: You'll need to implement the logic to get agent_ids and env_id
Expand Down Expand Up @@ -70,16 +70,16 @@ async def connect(self) -> None:
async def main() -> None:
# Create multiple WebSocket clients
num_clients = 0
uri = "ws://localhost:8800/ws/simulation"
url = "ws://localhost:8800/ws/simulation"

# Create and store client instances
clients = [
WebSocketClient(uri=uri, token=f"test_token_{i}", client_id=i)
WebSocketClient(url=url, token=f"test_token_{i}", client_id=i)
for i in range(num_clients)
]
clients.append(WebSocketClient(uri=uri, token="test_token_10", client_id=10))
clients.append(WebSocketClient(url=url, token="test_token_10", client_id=10))
clients.append(
WebSocketClient(uri=uri, token="test_token_10", client_id=10)
WebSocketClient(url=url, token="test_token_10", client_id=10)
) # test duplicate token

# Create tasks for each client
Expand Down
122 changes: 122 additions & 0 deletions examples/fast_api_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Example curl command to call the simulate endpoint:
import requests
import time

BASE_URL = "http://localhost:8080"


def _create_mock_agent_profile() -> None:
agent1_data = {
"first_name": "John",
"last_name": "Doe",
"occupation": "test_occupation",
"gender": "test_gender",
"pk": "tmppk_agent1",
"tag": "test_tag",
}
response = requests.post(
f"{BASE_URL}/agents/",
headers={"Content-Type": "application/json"},
json=agent1_data,
)
assert response.status_code == 200

agent2_data = {
"first_name": "Jane",
"last_name": "Doe",
"occupation": "test_occupation",
"gender": "test_gender",
"pk": "tmppk_agent2",
"tag": "test_tag",
}
response = requests.post(
f"{BASE_URL}/agents/",
headers={"Content-Type": "application/json"},
json=agent2_data,
)
assert response.status_code == 200


def _create_mock_env_profile() -> None:
env_data = {
"codename": "test_codename",
"scenario": "A",
"agent_goals": [
"B",
"C",
],
"pk": "tmppk_env_profile",
"tag": "test_tag",
}
response = requests.post(
f"{BASE_URL}/scenarios/",
headers={"Content-Type": "application/json"},
json=env_data,
)
assert response.status_code == 200


_create_mock_agent_profile()
_create_mock_env_profile()


data = {
"env_id": "tmppk_env_profile",
"agent_ids": ["tmppk_agent1", "tmppk_agent2"],
"models": ["custom/structured-llama3.2:1b@http://localhost:8000/v1"] * 3,
"max_turns": 10,
"tag": "test_tag",
}
try:
response = requests.post(
f"{BASE_URL}/simulate/", headers={"Content-Type": "application/json"}, json=data
)
print(response)
assert response.status_code == 202
assert isinstance(response.content.decode(), str)
episode_pk = response.content.decode()
print(episode_pk)
max_retries = 200
retry_count = 0
while retry_count < max_retries:
try:
response = requests.get(f"{BASE_URL}/simulation_status/{episode_pk}")
assert response.status_code == 200
status = response.content.decode()
print(status)
if status == "Error":
raise Exception("Error running simulation")
elif status == "Completed":
break
# Status is "Started", keep polling
time.sleep(1)
retry_count += 1
except Exception as e:
print(f"Error checking simulation status: {e}")
time.sleep(1)
retry_count += 1
else:
raise TimeoutError("Simulation timed out after 10 retries")

finally:
try:
response = requests.delete(f"{BASE_URL}/agents/tmppk_agent1")
assert response.status_code == 200
except Exception as e:
print(e)
try:
response = requests.delete(f"{BASE_URL}/agents/tmppk_agent2")
assert response.status_code == 200
except Exception as e:
print(e)
try:
response = requests.delete(f"{BASE_URL}/scenarios/tmppk_env_profile")
assert response.status_code == 200
except Exception as e:
print(e)

try:
response = requests.delete(f"{BASE_URL}/episodes/{episode_pk}")
assert response.status_code == 200
except Exception as e:
print(e)
4 changes: 3 additions & 1 deletion sotopia/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from redis_om import JsonModel, Migrator
from .annotators import Annotator
from .env_agent_combo_storage import EnvAgentComboStorage
from .logs import AnnotationForEpisode, EpisodeLog
from .logs import AnnotationForEpisode, EpisodeLog, NonStreamingSimulationStatus
from .persistent_profile import (
AgentProfile,
EnvironmentProfile,
Expand Down Expand Up @@ -44,6 +44,7 @@
"AgentProfile",
"EnvironmentProfile",
"EpisodeLog",
"NonStreamingSimulationStatus",
"EnvAgentComboStorage",
"AnnotationForEpisode",
"Annotator",
Expand Down Expand Up @@ -73,6 +74,7 @@
"EvaluationDimensionBuilder",
"CustomEvaluationDimension",
"CustomEvaluationDimensionList",
"NonStreamingSimulationStatus",
]

InheritedJsonModel = TypeVar("InheritedJsonModel", bound="JsonModel")
Expand Down
7 changes: 6 additions & 1 deletion sotopia/database/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@
from pydantic import model_validator
from redis_om import JsonModel
from redis_om.model.model import Field

from typing import Literal
from sotopia.database.persistent_profile import AgentProfile


class NonStreamingSimulationStatus(JsonModel):
episode_pk: str = Field(index=True)
status: Literal["Started", "Error", "Completed"]


class EpisodeLog(JsonModel):
# Note that we did not validate the following constraints:
# 1. The number of turns in messages and rewards should be the same or off by 1
Expand Down
5 changes: 5 additions & 0 deletions sotopia/database/persistent_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ class RelationshipProfile(JsonModel):
description="0 means stranger, 1 means know_by_name, 2 means acquaintance, 3 means friend, 4 means romantic_relationship, 5 means family_member",
) # this could be improved by limiting str to a relationship Enum
background_story: str | None = Field(default_factory=lambda: None)
tag: str = Field(
index=True,
default_factory=lambda: "",
description="The tag of the relationship, used for searching, could be convenient to document relationship profiles from different works and sources",
)


class EnvironmentList(JsonModel):
Expand Down
2 changes: 1 addition & 1 deletion sotopia/database/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _map_gender_to_adj(gender: str) -> str:
"Nonbinary": "nonbinary",
}
if gender:
return gender_to_adj[gender]
return gender_to_adj.get(gender, "")
else:
return ""

Expand Down
2 changes: 1 addition & 1 deletion sotopia/envs/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _map_gender_to_adj(gender: str) -> str:
"Nonbinary": "nonbinary",
}
if gender:
return gender_to_adj[gender]
return gender_to_adj.get(gender, "")
else:
return ""

Expand Down
16 changes: 13 additions & 3 deletions sotopia/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
ScriptWritingAgent,
)
from sotopia.agents.base_agent import BaseAgent
from sotopia.database import EpisodeLog
from sotopia.database import EpisodeLog, NonStreamingSimulationStatus
from sotopia.envs import ParallelSotopiaEnv
from sotopia.envs.evaluators import (
EvaluationForTwoAgents,
Expand Down Expand Up @@ -119,12 +119,15 @@ async def arun_one_episode(
json_in_script: bool = False,
tag: str | None = None,
push_to_db: bool = False,
episode_pk: str | None = None,
streaming: bool = False,
simulation_status: NonStreamingSimulationStatus | None = None,
) -> Union[
list[tuple[str, str, Message]],
AsyncGenerator[list[list[tuple[str, str, Message]]], None],
]:
agents = Agents({agent.agent_name: agent for agent in agent_list})
print(f"Running episode with tag: {tag}------------------")

async def generate_messages() -> (
AsyncGenerator[list[list[tuple[str, str, Message]]], None]
Expand Down Expand Up @@ -188,7 +191,7 @@ async def generate_messages() -> (
for agent_name in env.agents
]
)

print(f"Messages: {messages}")
yield messages
rewards.append([rewards_in_turn[agent_name] for agent_name in env.agents])
reasons.append(
Expand Down Expand Up @@ -228,7 +231,14 @@ async def generate_messages() -> (

if push_to_db:
try:
epilog.save()
if episode_pk:
epilog.pk = episode_pk
epilog.save()
else:
epilog.save()
if simulation_status:
simulation_status.status = "Completed"
simulation_status.save()
except Exception as e:
logging.error(f"Failed to save episode log: {e}")

Expand Down
11 changes: 11 additions & 0 deletions sotopia/ui/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@
## FastAPI Server

To run the FastAPI server, you can use the following command:
```bash
uv run rq worker
uv run fastapi run sotopia/ui/fastapi_server.py --workers 4 --port 8080
```

Here is also an example of using the FastAPI server:
```bash
uv run python examples/fast_api_example.py
```

The API server is a FastAPI application that is used to connect the Sotopia UI to the Sotopia backend.
This could also help with other projects that need to connect to the Sotopia backend through HTTP requests.

Expand Down
Loading

0 comments on commit dea25d3

Please sign in to comment.