From 693f792248815cccfbdc2541278e284a652d4e53 Mon Sep 17 00:00:00 2001 From: Xuhui Zhou Date: Wed, 4 Dec 2024 23:58:52 -0500 Subject: [PATCH] feat: FastAPI Implementation of Sotopia Part Two (w websocket) (#252) * api doc * add PUT * add an temp example for websocket * websocket * update readme * Update README.md * update websocket live simulation api doc * [autofix.ci] apply automated fixes * update websocket doc * add api server with websocket as well as a client * fix mypy errors * support stopping the chat * add 404 to the status code * fix mypy issue * update the returned message types * redesign websocket api * update websocket, fix mypy error * add example of using websocket * clean code & change to existing functions for simulation * fix typing mismatch * update doc & mypy type fix * add type check for run_async_server * move example --------- Co-authored-by: Hao Zhu Co-authored-by: Zhe Su <360307598@qq.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../websocket/websocket_test_client.py | 97 +++++++ sotopia/server.py | 209 ++++++++------ sotopia/ui/README.md | 85 ++++-- sotopia/ui/fastapi_server.py | 254 +++++++++++++++++- sotopia/ui/websocket_utils.py | 186 +++++++++++++ uv.lock | 46 ++++ 6 files changed, 754 insertions(+), 123 deletions(-) create mode 100644 examples/experimental/websocket/websocket_test_client.py create mode 100644 sotopia/ui/websocket_utils.py diff --git a/examples/experimental/websocket/websocket_test_client.py b/examples/experimental/websocket/websocket_test_client.py new file mode 100644 index 00000000..c1bd74b6 --- /dev/null +++ b/examples/experimental/websocket/websocket_test_client.py @@ -0,0 +1,97 @@ +""" +A test client for the WebSocket server +""" + +import json +from sotopia.database import EnvironmentProfile, AgentProfile + +import asyncio +import websockets +import sys +from pathlib import Path + + +class WebSocketClient: + def __init__(self, uri: str, token: str, client_id: int): + self.uri = uri + self.token = token + self.client_id = client_id + self.message_file = Path(f"message_{client_id}.txt") + + async def save_message(self, message: str) -> None: + """Save received message to a file""" + with open(self.message_file, "a", encoding="utf-8") as f: + f.write(f"{message}\n") + + async def connect(self) -> None: + """Establish and maintain websocket connection""" + uri_with_token = f"{self.uri}?token=test_token_{self.client_id}" + + try: + async with websockets.connect(uri_with_token) as websocket: + print(f"Client {self.client_id}: Connected to {self.uri}") + + # Send initial message + # Note: You'll need to implement the logic to get agent_ids and env_id + # This is just an example structure + agent_ids = [agent.pk for agent in AgentProfile.find().all()[:2]] + env_id = EnvironmentProfile.find().all()[0].pk + start_message = { + "type": "START_SIM", + "data": { + "env_id": env_id, # Replace with actual env_id + "agent_ids": agent_ids, # Replace with actual agent_ids + }, + } + await websocket.send(json.dumps(start_message)) + print(f"Client {self.client_id}: Sent START_SIM message") + + # Receive and process messages + while True: + try: + message = await websocket.recv() + print( + f"\nClient {self.client_id} received message:", + json.dumps(json.loads(message), indent=2), + ) + assert isinstance(message, str) + await self.save_message(message) + except websockets.ConnectionClosed: + print(f"Client {self.client_id}: Connection closed") + break + except Exception as e: + print(f"Client {self.client_id} error:", str(e)) + break + + except Exception as e: + print(f"Client {self.client_id} connection error:", str(e)) + + +async def main() -> None: + # Create multiple WebSocket clients + num_clients = 0 + uri = "ws://localhost:8800/ws/simulation" + + # Create and store client instances + clients = [ + WebSocketClient(uri=uri, token=f"test_token_{i}", client_id=i) + for i in range(num_clients) + ] + clients.append(WebSocketClient(uri=uri, token="test_token_10", client_id=10)) + clients.append( + WebSocketClient(uri=uri, token="test_token_10", client_id=10) + ) # test duplicate token + + # Create tasks for each client + tasks = [asyncio.create_task(client.connect()) for client in clients] + + # Wait for all tasks to complete + await asyncio.gather(*tasks) + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nShutting down clients...") + sys.exit(0) diff --git a/sotopia/server.py b/sotopia/server.py index d285558a..aec81a0f 100644 --- a/sotopia/server.py +++ b/sotopia/server.py @@ -1,7 +1,7 @@ import asyncio import itertools import logging -from typing import Literal, Sequence, Type +from typing import Literal, Sequence, Type, AsyncGenerator, Union import gin import rich @@ -25,7 +25,7 @@ unweighted_aggregate_evaluate, ) from sotopia.generation_utils.generate import LLM_Name, agenerate_script -from sotopia.messages import AgentAction, Message, Observation +from sotopia.messages import AgentAction, Message, Observation, SimpleMessage from sotopia.messages.message_classes import ( ScriptBackground, ScriptEnvironmentResponse, @@ -104,6 +104,12 @@ def run_sync_server( return messages +def flatten_listed_messages( + messages: list[list[tuple[str, str, Message]]], +) -> list[tuple[str, str, Message]]: + return list(itertools.chain.from_iterable(messages)) + + @gin.configurable async def arun_one_episode( env: ParallelSotopiaEnv, @@ -113,102 +119,125 @@ async def arun_one_episode( json_in_script: bool = False, tag: str | None = None, push_to_db: bool = False, -) -> list[tuple[str, str, Message]]: + streaming: bool = False, +) -> Union[ + list[tuple[str, str, Message]], + AsyncGenerator[list[list[tuple[str, str, Message]]], None], +]: agents = Agents({agent.agent_name: agent for agent in agent_list}) - environment_messages = env.reset(agents=agents, omniscient=omniscient) - agents.reset() - - messages: list[list[tuple[str, str, Message]]] = [] - # Main Event Loop - done = False - messages.append( - [ - ("Environment", agent_name, environment_messages[agent_name]) - for agent_name in env.agents - ] - ) - # set goal for agents - for index, agent_name in enumerate(env.agents): - agents[agent_name].goal = env.profile.agent_goals[index] - rewards: list[list[float]] = [] - reasons: list[str] = [] - while not done: - # gather agent messages - agent_messages: dict[str, AgentAction] = dict() - actions = await asyncio.gather( - *[ - agents[agent_name].aact(environment_messages[agent_name]) - for agent_name in env.agents - ] - ) - if script_like: - # manually mask one message - agent_mask = env.action_mask - for idx in range(len(agent_mask)): - print("Current mask: ", agent_mask) - if agent_mask[idx] == 0: - print("Action not taken: ", actions[idx]) - actions[idx] = AgentAction(action_type="none", argument="") - else: - print("Current action taken: ", actions[idx]) - - # actions = cast(list[AgentAction], actions) - for idx, agent_name in enumerate(env.agents): - agent_messages[agent_name] = actions[idx] - - messages[-1].append((agent_name, "Environment", agent_messages[agent_name])) + async def generate_messages() -> ( + AsyncGenerator[list[list[tuple[str, str, Message]]], None] + ): + environment_messages = env.reset(agents=agents, omniscient=omniscient) + agents.reset() + messages: list[list[tuple[str, str, Message]]] = [] - # send agent messages to environment - ( - environment_messages, - rewards_in_turn, - terminated, - ___, - info, - ) = await env.astep(agent_messages) + # Main Event Loop + done = False messages.append( [ ("Environment", agent_name, environment_messages[agent_name]) for agent_name in env.agents ] ) - # print("Environment message: ", environment_messages) - # exit(0) - rewards.append([rewards_in_turn[agent_name] for agent_name in env.agents]) - reasons.append( - " ".join(info[agent_name]["comments"] for agent_name in env.agents) + yield messages + + # set goal for agents + for index, agent_name in enumerate(env.agents): + agents[agent_name].goal = env.profile.agent_goals[index] + rewards: list[list[float]] = [] + reasons: list[str] = [] + while not done: + # gather agent messages + agent_messages: dict[str, AgentAction] = dict() + actions = await asyncio.gather( + *[ + agents[agent_name].aact(environment_messages[agent_name]) + for agent_name in env.agents + ] + ) + if script_like: + # manually mask one message + agent_mask = env.action_mask + for idx in range(len(agent_mask)): + if agent_mask[idx] == 0: + actions[idx] = AgentAction(action_type="none", argument="") + else: + pass + + # actions = cast(list[AgentAction], actions) + for idx, agent_name in enumerate(env.agents): + agent_messages[agent_name] = actions[idx] + + messages[-1].append( + (agent_name, "Environment", agent_messages[agent_name]) + ) + + # send agent messages to environment + ( + environment_messages, + rewards_in_turn, + terminated, + ___, + info, + ) = await env.astep(agent_messages) + messages.append( + [ + ("Environment", agent_name, environment_messages[agent_name]) + for agent_name in env.agents + ] + ) + + yield messages + rewards.append([rewards_in_turn[agent_name] for agent_name in env.agents]) + reasons.append( + " ".join(info[agent_name]["comments"] for agent_name in env.agents) + ) + done = all(terminated.values()) + + epilog = EpisodeLog( + environment=env.profile.pk, + agents=[agent.profile.pk for agent in agent_list], + tag=tag, + models=[env.model_name, agent_list[0].model_name, agent_list[1].model_name], + messages=[ + [(m[0], m[1], m[2].to_natural_language()) for m in messages_in_turn] + for messages_in_turn in messages + ], + reasoning=info[env.agents[0]]["comments"], + rewards=[info[agent_name]["complete_rating"] for agent_name in env.agents], + rewards_prompt=info["rewards_prompt"]["overall_prompt"], ) - done = all(terminated.values()) + rich.print(epilog.rewards_prompt) + agent_profiles, conversation = epilog.render_for_humans() + for agent_profile in agent_profiles: + rich.print(agent_profile) + for message in conversation: + rich.print(message) + + if streaming: + # yield the rewards and reasonings + messages.append( + [("Evaluation", "Rewards", SimpleMessage(message=str(epilog.rewards)))] + ) + messages.append( + [("Evaluation", "Reasoning", SimpleMessage(message=epilog.reasoning))] + ) + yield messages - # TODO: clean up this part - epilog = EpisodeLog( - environment=env.profile.pk, - agents=[agent.profile.pk for agent in agent_list], - tag=tag, - models=[env.model_name, agent_list[0].model_name, agent_list[1].model_name], - messages=[ - [(m[0], m[1], m[2].to_natural_language()) for m in messages_in_turn] - for messages_in_turn in messages - ], - reasoning=info[env.agents[0]]["comments"], - rewards=[info[agent_name]["complete_rating"] for agent_name in env.agents], - rewards_prompt=info["rewards_prompt"]["overall_prompt"], - ) - rich.print(epilog.rewards_prompt) - agent_profiles, conversation = epilog.render_for_humans() - for agent_profile in agent_profiles: - rich.print(agent_profile) - for message in conversation: - rich.print(message) + if push_to_db: + try: + epilog.save() + except Exception as e: + logging.error(f"Failed to save episode log: {e}") - if push_to_db: - try: - epilog.save() - except Exception as e: - logging.error(f"Failed to save episode log: {e}") - # flatten nested list messages - return list(itertools.chain(*messages)) + if streaming: + return generate_messages() + else: + async for last_messages in generate_messages(): + pass + return flatten_listed_messages(last_messages) @gin.configurable @@ -310,7 +339,13 @@ def get_agent_class( else [await i for i in episode_futures] ) - return batch_results + if len(batch_results) > 0: + first_result = batch_results[0] + assert isinstance( + first_result, list + ), f"Unexpected result type: {type(first_result)}" + + return batch_results # type: ignore async def arun_one_script( diff --git a/sotopia/ui/README.md b/sotopia/ui/README.md index ca8b679d..156050a4 100644 --- a/sotopia/ui/README.md +++ b/sotopia/ui/README.md @@ -78,33 +78,31 @@ EnvironmentProfile returns: - scenario_id: str -#### DELETE /agents/{agent_id} +### Updating Data in the API Server -Delete agent profile from the API server. +#### PUT /agents/{agent_id} + +Update agent profile in the API server. +Request Body: +AgentProfile returns: - agent_id: str -#### DELETE /scenarios/{scenario_id} -Delete scenario profile from the API server. +#### PUT /scenarios/{scenario_id} + +Update scenario profile in the API server. +Request Body: +EnvironmentProfile returns: - scenario_id: str - -### Error Code -For RESTful APIs above we have the following error codes: -| **Error Code** | **Description** | -|-----------------|--------------------------------------| -| **404** | A resource is not found | -| **403** | The query is not authorized | -| **500** | Internal running error | - ### Initiating a new non-streaming simulation episode #### POST /episodes/ -[!] Currently not planning to implement + ```python class SimulationEpisodeInitiation(BaseModel): scenario_id: str @@ -147,14 +145,14 @@ returns: | Type | Direction | Description | |-----------|--------|-------------| | SERVER_MSG | Server → Client | Standard message from server (payload: `messageForRendering` [here](https://github.com/sotopia-lab/sotopia-demo/blob/main/socialstream/rendering_utils.py) ) | -| CLIENT_MSG | Client → Server | Standard message from client (payload: Currently not needed) | -| ERROR | Server → Client | Error notification (payload: `{"type": ERROR_TYPE, "description": DESC}`) | +| CLIENT_MSG | Client → Server | Standard message from client (payload: TBD) | +| ERROR | Server → Client | Error notification (payload: TBD) | | START_SIM | Client → Server | Initialize simulation (payload: `SimulationEpisodeInitialization`) | | END_SIM | Client → Server | End simulation (payload: not needed) | | FINISH_SIM | Server → Client | Terminate simulation (payload: not needed) | -**ERROR_TYPE** +**Error Type** | Error Code | Description | |------------|-------------| @@ -167,14 +165,53 @@ returns: | OTHER | Other unspecified errors | -**Conversation Message From the Server** -The server returns messages encapsulated in a structured format which is defined as follows: +**Implementation plan**: Currently only support LLM-LLM simulation based on [this function](https://github.com/sotopia-lab/sotopia/blob/19d39e068c3bca9246fc366e5759414f62284f93/sotopia/server.py#L108). + + +## An example to run simulation with the API + +**Get all scenarios**: +```bash +curl -X GET "http://localhost:8000/scenarios" +``` + +This gonna give you all the scenarios, and you can randomly pick one + + +**Get all agents**: +```bash +curl -X GET "http://localhost:8000/agents" +``` + +This gonna give you all the agents, and you can randomly pick one + +**Connecting to the websocket server**: +We recommend using Python. Here is the simplist way to start a simulation and receive the results in real time: ```python -class MessageForRendering(TypedDict): - role: str # Specifies the origin of the message. Common values include "Background Info", "Environment", "{Agent Names} - type: str # Categorizes the nature of the message. Common types include: "comment", "said", "action" - content: str +import aiohttp +import asyncio +import json + +async def main(): + async with aiohttp.ClientSession() as session: + async with session.ws_connect(f'ws://{API_BASE}/ws/simulation?token={YOUR_TOKEN}') as ws: + start_message = { + "type": "START_SIM", + "data": { + "env_id": "{ENV_ID}", + "agent_ids": ["{AGENT1_PK}", "{AGENT2_PK}"], + }, + } + await ws.send_json(start_message) + + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + print(f"Received: {msg.data}") + elif msg.type == aiohttp.WSMsgType.CLOSED: + break + elif msg.type == aiohttp.WSMsgType.ERROR: + break ``` -**Implementation plan**: Currently only support LLM-LLM simulation based on [this function](https://github.com/sotopia-lab/sotopia/blob/19d39e068c3bca9246fc366e5759414f62284f93/sotopia/server.py#L108). +Please check out an detailed example in `examples/experimental/websocket/websocket_test_client.py` diff --git a/sotopia/ui/fastapi_server.py b/sotopia/ui/fastapi_server.py index ea53f4e5..543dafd2 100644 --- a/sotopia/ui/fastapi_server.py +++ b/sotopia/ui/fastapi_server.py @@ -1,11 +1,33 @@ -from fastapi import FastAPI -from typing import Literal, cast, Dict -from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog +from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect +from typing import Literal, cast, Optional, Any +from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel + +from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog +from sotopia.ui.websocket_utils import ( + WebSocketSotopiaSimulator, + WSMessageType, + ErrorType, +) import uvicorn +import asyncio + +from contextlib import asynccontextmanager +from typing import AsyncIterator +import logging + +logger = logging.getLogger(__name__) app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) # TODO: Whether allowing CORS for all origins + class AgentProfileWrapper(BaseModel): """ @@ -64,6 +86,12 @@ async def get_scenarios( EnvironmentProfile.codename == value ).all() scenarios.extend(cast(list[EnvironmentProfile], json_models)) + + if not scenarios: + raise HTTPException( + status_code=404, detail=f"No scenarios found with {get_by}={value}" + ) + return scenarios @@ -85,9 +113,20 @@ async def get_agents( elif get_by == "occupation": json_models = AgentProfile.find(AgentProfile.occupation == value).all() agents_profiles.extend(cast(list[AgentProfile], json_models)) + + if not agents_profiles: + raise HTTPException( + status_code=404, detail=f"No agents found with {get_by}={value}" + ) + return agents_profiles +@app.get("/episodes", response_model=list[EpisodeLog]) +async def get_episodes_all() -> list[EpisodeLog]: + return EpisodeLog.all() + + @app.get("/episodes/{get_by}/{value}", response_model=list[EpisodeLog]) async def get_episodes(get_by: Literal["id", "tag"], value: str) -> list[EpisodeLog]: episodes: list[EpisodeLog] = [] @@ -96,10 +135,15 @@ async def get_episodes(get_by: Literal["id", "tag"], value: str) -> list[Episode elif get_by == "tag": json_models = EpisodeLog.find(EpisodeLog.tag == value).all() episodes.extend(cast(list[EpisodeLog], json_models)) + + if not episodes: + raise HTTPException( + status_code=404, detail=f"No episodes found with {get_by}={value}" + ) return episodes -@app.post("/agents/") +@app.post("/agents/", response_model=str) async def create_agent(agent: AgentProfileWrapper) -> str: agent_profile = AgentProfile(**agent.model_dump()) agent_profile.save() @@ -110,7 +154,6 @@ async def create_agent(agent: AgentProfileWrapper) -> str: @app.post("/scenarios/", response_model=str) async def create_scenario(scenario: EnvironmentProfileWrapper) -> str: - print(scenario) scenario_profile = EnvironmentProfile(**scenario.model_dump()) scenario_profile.save() pk = scenario_profile.pk @@ -118,21 +161,208 @@ async def create_scenario(scenario: EnvironmentProfileWrapper) -> str: return pk +@app.put("/agents/{agent_id}", response_model=str) +async def update_agent(agent_id: str, agent: AgentProfileWrapper) -> str: + try: + old_agent = AgentProfile.get(pk=agent_id) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, detail=f"Agent with id={agent_id} not found" + ) + old_agent.update(**agent.model_dump()) # type: ignore + assert old_agent.pk is not None + return old_agent.pk + + +@app.put("/scenarios/{scenario_id}", response_model=str) +async def update_scenario(scenario_id: str, scenario: EnvironmentProfileWrapper) -> str: + try: + old_scenario = EnvironmentProfile.get(pk=scenario_id) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, detail=f"Scenario with id={scenario_id} not found" + ) + old_scenario.update(**scenario.model_dump()) # type: ignore + assert old_scenario.pk is not None + return old_scenario.pk + + @app.delete("/agents/{agent_id}", response_model=str) async def delete_agent(agent_id: str) -> str: - AgentProfile.delete(agent_id) - return agent_id + try: + agent = AgentProfile.get(pk=agent_id) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, detail=f"Agent with id={agent_id} not found" + ) + AgentProfile.delete(agent.pk) + assert agent.pk is not None + return agent.pk @app.delete("/scenarios/{scenario_id}", response_model=str) async def delete_scenario(scenario_id: str) -> str: - EnvironmentProfile.delete(scenario_id) - return scenario_id + try: + scenario = EnvironmentProfile.get(pk=scenario_id) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, detail=f"Scenario with id={scenario_id} not found" + ) + EnvironmentProfile.delete(scenario.pk) + assert scenario.pk is not None + return scenario.pk + + +@app.get("/models", response_model=list[str]) +async def get_models() -> list[str]: + # TODO figure out how to get the available models + return ["gpt-4o-mini", "gpt-4o", "gpt-3.5-turbo"] + + +class SimulationState: + _instance: Optional["SimulationState"] = None + _lock = asyncio.Lock() + _active_simulations: dict[str, bool] = {} + + def __new__(cls) -> "SimulationState": + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._active_simulations = {} + return cls._instance + + async def try_acquire_token(self, token: str) -> tuple[bool, str]: + async with self._lock: + if not token: + return False, "Invalid token" + + if self._active_simulations.get(token): + return False, "Token is active already" + + self._active_simulations[token] = True + return True, "Token is valid" + + async def release_token(self, token: str) -> None: + async with self._lock: + self._active_simulations.pop(token, None) + + @asynccontextmanager + async def start_simulation(self, token: str) -> AsyncIterator[bool]: + try: + yield True + finally: + await self.release_token(token) + + +class SimulationManager: + def __init__(self) -> None: + self.state = SimulationState() + + async def verify_token(self, token: str) -> dict[str, Any]: + is_valid, msg = await self.state.try_acquire_token(token) + return {"is_valid": is_valid, "msg": msg} + + async def create_simulator( + self, env_id: str, agent_ids: list[str] + ) -> WebSocketSotopiaSimulator: + try: + return WebSocketSotopiaSimulator(env_id=env_id, agent_ids=agent_ids) + except Exception as e: + error_msg = f"Failed to create simulator: {e}" + logger.error(error_msg) + raise Exception(error_msg) + + async def handle_client_message( + self, + websocket: WebSocket, + simulator: WebSocketSotopiaSimulator, + message: dict[str, Any], + timeout: float = 0.1, + ) -> bool: + try: + msg_type = message.get("type") + if msg_type == WSMessageType.FINISH_SIM.value: + return True + # TODO handle other message types + return False + except Exception as e: + msg = f"Error handling client message: {e}" + logger.error(msg) + await self.send_error(websocket, ErrorType.INVALID_MESSAGE, msg) + return False + + async def run_simulation( + self, websocket: WebSocket, simulator: WebSocketSotopiaSimulator + ) -> None: + try: + async for message in simulator.arun(): + await self.send_message(websocket, WSMessageType.SERVER_MSG, message) + + try: + data = await asyncio.wait_for(websocket.receive_json(), timeout=0.1) + if await self.handle_client_message(websocket, simulator, data): + break + except asyncio.TimeoutError: + continue + + except Exception as e: + msg = f"Error running simulation: {e}" + logger.error(msg) + await self.send_error(websocket, ErrorType.SIMULATION_ISSUE, msg) + finally: + await self.send_message(websocket, WSMessageType.END_SIM, {}) + + @staticmethod + async def send_message( + websocket: WebSocket, msg_type: WSMessageType, data: dict[str, Any] + ) -> None: + await websocket.send_json({"type": msg_type.value, "data": data}) + + @staticmethod + async def send_error( + websocket: WebSocket, error_type: ErrorType, details: str = "" + ) -> None: + await websocket.send_json( + { + "type": WSMessageType.ERROR.value, + "data": {"type": error_type.value, "details": details}, + } + ) + + +@app.websocket("/ws/simulation") +async def websocket_endpoint(websocket: WebSocket, token: str) -> None: + manager = SimulationManager() + + token_status = await manager.verify_token(token) + if not token_status["is_valid"]: + await websocket.close(code=1008, reason=token_status["msg"]) + return + + try: + await websocket.accept() + + while True: + start_msg = await websocket.receive_json() + if start_msg.get("type") != WSMessageType.START_SIM.value: + continue + async with manager.state.start_simulation(token): + simulator = await manager.create_simulator( + env_id=start_msg["data"]["env_id"], + agent_ids=start_msg["data"]["agent_ids"], + ) + await manager.run_simulation(websocket, simulator) -active_simulations: Dict[ - str, bool -] = {} # TODO check whether this is the correct way to store the active simulations + except WebSocketDisconnect: + logger.info(f"Client disconnected: {token}") + except Exception as e: + logger.error(f"Unexpected error: {e}") + await manager.send_error(websocket, ErrorType.SIMULATION_ISSUE, str(e)) + finally: + try: + await websocket.close() + except Exception as e: + logger.error(f"Error closing websocket: {e}") if __name__ == "__main__": diff --git a/sotopia/ui/websocket_utils.py b/sotopia/ui/websocket_utils.py new file mode 100644 index 00000000..5b29da73 --- /dev/null +++ b/sotopia/ui/websocket_utils.py @@ -0,0 +1,186 @@ +from sotopia.envs.evaluators import ( + EvaluationForTwoAgents, + ReachGoalLLMEvaluator, + RuleBasedTerminatedEvaluator, + SotopiaDimensions, +) +from sotopia.agents import Agents, LLMAgent +from sotopia.messages import Observation +from sotopia.envs import ParallelSotopiaEnv +from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog +from sotopia.server import arun_one_episode + +from enum import Enum +from typing import TypedDict, Any, AsyncGenerator +from pydantic import BaseModel + + +class WSMessageType(str, Enum): + SERVER_MSG = "SERVER_MSG" + CLIENT_MSG = "CLIENT_MSG" + ERROR = "ERROR" + START_SIM = "START_SIM" + END_SIM = "END_SIM" + FINISH_SIM = "FINISH_SIM" + + +class ErrorType(str, Enum): + NOT_AUTHORIZED = "NOT_AUTHORIZED" + SIMULATION_ALREADY_STARTED = "SIMULATION_ALREADY_STARTED" + SIMULATION_NOT_STARTED = "SIMULATION_NOT_STARTED" + SIMULATION_ISSUE = "SIMULATION_ISSUE" + INVALID_MESSAGE = "INVALID_MESSAGE" + OTHER = "OTHER" + + +class MessageForRendering(TypedDict): + role: str + type: str + content: str + + +class WSMessage(BaseModel): + type: WSMessageType + data: dict[str, Any] + + model_config = {"arbitrary_types_allowed": True, "protected_namespaces": ()} + + def to_json(self) -> dict[str, Any]: + return { + "type": self.type.value, # TODO check whether we want to use the enum value or the enum itself + "data": self.data, + } + + +def get_env_agents( + env_id: str, + agent_ids: list[str], + agent_models: list[str], + evaluator_model: str, +) -> tuple[ParallelSotopiaEnv, Agents, dict[str, Observation]]: + # environment_profile = EnvironmentProfile.find().all()[0] + # agent_profiles = AgentProfile.find().all()[:2] + assert len(agent_ids) == len( + agent_models + ), f"Provided {len(agent_ids)} agent_ids but {len(agent_models)} agent_models" + + environment_profile: EnvironmentProfile = EnvironmentProfile.get(env_id) + agent_profiles: list[AgentProfile] = [ + AgentProfile.get(agent_id) for agent_id in agent_ids + ] + + agent_list = [ + LLMAgent( + agent_profile=agent_profile, + model_name=agent_models[idx], + ) + for idx, agent_profile in enumerate(agent_profiles) + ] + for idx, goal in enumerate(environment_profile.agent_goals): + agent_list[idx].goal = goal + + agents = Agents({agent.agent_name: agent for agent in agent_list}) + env = ParallelSotopiaEnv( + action_order="round-robin", + model_name="gpt-4o-mini", + evaluators=[ + RuleBasedTerminatedEvaluator(max_turn_number=20, max_stale_turn=2), + ], + terminal_evaluators=[ + ReachGoalLLMEvaluator( + evaluator_model, + EvaluationForTwoAgents[SotopiaDimensions], + ), + ], + env_profile=environment_profile, + ) + + environment_messages = env.reset(agents=agents, omniscient=False) + agents.reset() + + return env, agents, environment_messages + + +def parse_reasoning(reasoning: str, num_agents: int) -> tuple[list[str], str]: + """Parse the reasoning string into a dictionary.""" + sep_token = "SEPSEP" + for i in range(1, num_agents + 1): + reasoning = ( + reasoning.replace(f"Agent {i} comments:\n", sep_token) + .strip(" ") + .strip("\n") + ) + all_chunks = reasoning.split(sep_token) + general_comment = all_chunks[0].strip(" ").strip("\n") + comment_chunks = all_chunks[-num_agents:] + + return comment_chunks, general_comment + + +class WebSocketSotopiaSimulator: + def __init__( + self, + env_id: str, + agent_ids: list[str], + agent_models: list[str] = ["gpt-4o-mini", "gpt-4o-mini"], + evaluator_model: str = "gpt-4o", + ) -> None: + self.env, self.agents, self.environment_messages = get_env_agents( + env_id, agent_ids, agent_models, evaluator_model + ) + self.messages: list[list[tuple[str, str, str]]] = [] + self.messages.append( + [ + ( + "Environment", + agent_name, + self.environment_messages[agent_name].to_natural_language(), + ) + for agent_name in self.env.agents + ] + ) + for index, agent_name in enumerate(self.env.agents): + self.agents[agent_name].goal = self.env.profile.agent_goals[index] + + async def arun(self) -> AsyncGenerator[dict[str, Any], None]: + # Use sotopia to run the simulation + generator = arun_one_episode( + env=self.env, + agent_list=list(self.agents.values()), + push_to_db=False, + streaming=True, + ) + + assert isinstance( + generator, AsyncGenerator + ), "generator should be async generator" + + async for messages in await generator: # type: ignore + reasoning, rewards = "", [0.0, 0.0] + eval_available = False + if messages[-1][0][0] == "Evaluation": + reasoning = messages[-1][0][2].to_natural_language() + rewards = eval(messages[-2][0][2].to_natural_language()) + eval_available = True + + epilog = EpisodeLog( + environment=self.env.profile.pk, + agents=[agent.profile.pk for agent in self.agents.values()], + tag="test", + models=["gpt-4o", "gpt-4o", "gpt-4o-mini"], + messages=[ + [(m[0], m[1], m[2].to_natural_language()) for m in messages_in_turn] + for messages_in_turn in messages + ], + reasoning=reasoning, + rewards=rewards, + rewards_prompt="", + ) + agent_profiles, parsed_messages = epilog.render_for_humans() + if not eval_available: + parsed_messages = parsed_messages[:-2] + + yield { + "type": "messages", + "messages": parsed_messages, + } diff --git a/uv.lock b/uv.lock index 5017e0e0..71152b36 100644 --- a/uv.lock +++ b/uv.lock @@ -115,6 +115,51 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f5/76/a57ceff577ae26fe9a6f31ac799bc638ecf26e4acdf04295290b9929b349/aiohttp-3.11.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:9d18a8b44ec8502a7fde91446cd9c9b95ce7c49f1eacc1fb2358b8907d4369fd", size = 1690038 }, { url = "https://files.pythonhosted.org/packages/4b/81/b20e09003b6989a7f23a721692137a6143420a151063c750ab2a04878e3c/aiohttp-3.11.7-cp312-cp312-win32.whl", hash = "sha256:3d1c9c15d3999107cbb9b2d76ca6172e6710a12fda22434ee8bd3f432b7b17e8", size = 409887 }, { url = "https://files.pythonhosted.org/packages/b7/0b/607c98bff1d07bb21e0c39e7711108ef9ff4f2a361a3ec1ce8dce93623a5/aiohttp-3.11.7-cp312-cp312-win_amd64.whl", hash = "sha256:018f1b04883a12e77e7fc161934c0f298865d3a484aea536a6a2ca8d909f0ba0", size = 436462 }, + { url = "https://files.pythonhosted.org/packages/3d/dd/3d40c0e67e79c5c42671e3e268742f1ff96c6573ca43823563d01abd9475/aiohttp-3.10.10-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:be7443669ae9c016b71f402e43208e13ddf00912f47f623ee5994e12fc7d4b3f", size = 586969 }, + { url = "https://files.pythonhosted.org/packages/75/64/8de41b5555e5b43ef6d4ed1261891d33fe45ecc6cb62875bfafb90b9ab93/aiohttp-3.10.10-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7b06b7843929e41a94ea09eb1ce3927865387e3e23ebe108e0d0d09b08d25be9", size = 399367 }, + { url = "https://files.pythonhosted.org/packages/96/36/27bd62ea7ce43906d1443a73691823fc82ffb8fa03276b0e2f7e1037c286/aiohttp-3.10.10-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:333cf6cf8e65f6a1e06e9eb3e643a0c515bb850d470902274239fea02033e9a8", size = 390720 }, + { url = "https://files.pythonhosted.org/packages/e8/4d/d516b050d811ce0dd26325c383013c104ffa8b58bd361b82e52833f68e78/aiohttp-3.10.10-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:274cfa632350225ce3fdeb318c23b4a10ec25c0e2c880eff951a3842cf358ac1", size = 1228820 }, + { url = "https://files.pythonhosted.org/packages/53/94/964d9327a3e336d89aad52260836e4ec87fdfa1207176550fdf384eaffe7/aiohttp-3.10.10-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d9e5e4a85bdb56d224f412d9c98ae4cbd032cc4f3161818f692cd81766eee65a", size = 1264616 }, + { url = "https://files.pythonhosted.org/packages/0c/20/70ce17764b685ca8f5bf4d568881b4e1f1f4ea5e8170f512fdb1a33859d2/aiohttp-3.10.10-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2b606353da03edcc71130b52388d25f9a30a126e04caef1fd637e31683033abd", size = 1298402 }, + { url = "https://files.pythonhosted.org/packages/d1/d1/5248225ccc687f498d06c3bca5af2647a361c3687a85eb3aedcc247ee1aa/aiohttp-3.10.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab5a5a0c7a7991d90446a198689c0535be89bbd6b410a1f9a66688f0880ec026", size = 1222205 }, + { url = "https://files.pythonhosted.org/packages/f2/a3/9296b27cc5d4feadf970a14d0694902a49a985f3fae71b8322a5f77b0baa/aiohttp-3.10.10-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:578a4b875af3e0daaf1ac6fa983d93e0bbfec3ead753b6d6f33d467100cdc67b", size = 1193804 }, + { url = "https://files.pythonhosted.org/packages/d9/07/f3760160feb12ac51a6168a6da251a4a8f2a70733d49e6ceb9b3e6ee2f03/aiohttp-3.10.10-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:8105fd8a890df77b76dd3054cddf01a879fc13e8af576805d667e0fa0224c35d", size = 1193544 }, + { url = "https://files.pythonhosted.org/packages/7e/4c/93a70f9a4ba1c30183a6dd68bfa79cddbf9a674f162f9c62e823a74a5515/aiohttp-3.10.10-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3bcd391d083f636c06a68715e69467963d1f9600f85ef556ea82e9ef25f043f7", size = 1193047 }, + { url = "https://files.pythonhosted.org/packages/ff/a3/36a1e23ff00c7a0cd696c5a28db05db25dc42bfc78c508bd78623ff62a4a/aiohttp-3.10.10-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:fbc6264158392bad9df19537e872d476f7c57adf718944cc1e4495cbabf38e2a", size = 1247201 }, + { url = "https://files.pythonhosted.org/packages/55/ae/95399848557b98bb2c402d640b2276ce3a542b94dba202de5a5a1fe29abe/aiohttp-3.10.10-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:e48d5021a84d341bcaf95c8460b152cfbad770d28e5fe14a768988c461b821bc", size = 1264102 }, + { url = "https://files.pythonhosted.org/packages/38/f5/02e5c72c1b60d7cceb30b982679a26167e84ac029fd35a93dd4da52c50a3/aiohttp-3.10.10-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2609e9ab08474702cc67b7702dbb8a80e392c54613ebe80db7e8dbdb79837c68", size = 1215760 }, + { url = "https://files.pythonhosted.org/packages/30/17/1463840bad10d02d0439068f37ce5af0b383884b0d5838f46fb027e233bf/aiohttp-3.10.10-cp310-cp310-win32.whl", hash = "sha256:84afcdea18eda514c25bc68b9af2a2b1adea7c08899175a51fe7c4fb6d551257", size = 362678 }, + { url = "https://files.pythonhosted.org/packages/dd/01/a0ef707d93e867a43abbffee3a2cdf30559910750b9176b891628c7ad074/aiohttp-3.10.10-cp310-cp310-win_amd64.whl", hash = "sha256:9c72109213eb9d3874f7ac8c0c5fa90e072d678e117d9061c06e30c85b4cf0e6", size = 381097 }, + { url = "https://files.pythonhosted.org/packages/72/31/3c351d17596194e5a38ef169a4da76458952b2497b4b54645b9d483cbbb0/aiohttp-3.10.10-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c30a0eafc89d28e7f959281b58198a9fa5e99405f716c0289b7892ca345fe45f", size = 586501 }, + { url = "https://files.pythonhosted.org/packages/a4/a8/a559d09eb08478cdead6b7ce05b0c4a133ba27fcdfa91e05d2e62867300d/aiohttp-3.10.10-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:258c5dd01afc10015866114e210fb7365f0d02d9d059c3c3415382ab633fcbcb", size = 398993 }, + { url = "https://files.pythonhosted.org/packages/c5/47/7736d4174613feef61d25332c3bd1a4f8ff5591fbd7331988238a7299485/aiohttp-3.10.10-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:15ecd889a709b0080f02721255b3f80bb261c2293d3c748151274dfea93ac871", size = 390647 }, + { url = "https://files.pythonhosted.org/packages/27/21/e9ba192a04b7160f5a8952c98a1de7cf8072ad150fa3abd454ead1ab1d7f/aiohttp-3.10.10-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3935f82f6f4a3820270842e90456ebad3af15810cf65932bd24da4463bc0a4c", size = 1306481 }, + { url = "https://files.pythonhosted.org/packages/cf/50/f364c01c8d0def1dc34747b2470969e216f5a37c7ece00fe558810f37013/aiohttp-3.10.10-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:413251f6fcf552a33c981c4709a6bba37b12710982fec8e558ae944bfb2abd38", size = 1344652 }, + { url = "https://files.pythonhosted.org/packages/1d/c2/74f608e984e9b585649e2e83883facad6fa3fc1d021de87b20cc67e8e5ae/aiohttp-3.10.10-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d1720b4f14c78a3089562b8875b53e36b51c97c51adc53325a69b79b4b48ebcb", size = 1378498 }, + { url = "https://files.pythonhosted.org/packages/9f/a7/05a48c7c0a7a80a5591b1203bf1b64ca2ed6a2050af918d09c05852dc42b/aiohttp-3.10.10-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:679abe5d3858b33c2cf74faec299fda60ea9de62916e8b67e625d65bf069a3b7", size = 1292718 }, + { url = "https://files.pythonhosted.org/packages/7d/78/a925655018747e9790350180330032e27d6e0d7ed30bde545fae42f8c49c/aiohttp-3.10.10-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:79019094f87c9fb44f8d769e41dbb664d6e8fcfd62f665ccce36762deaa0e911", size = 1251776 }, + { url = "https://files.pythonhosted.org/packages/47/9d/85c6b69f702351d1236594745a4fdc042fc43f494c247a98dac17e004026/aiohttp-3.10.10-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:fe2fb38c2ed905a2582948e2de560675e9dfbee94c6d5ccdb1301c6d0a5bf092", size = 1271716 }, + { url = "https://files.pythonhosted.org/packages/7f/a7/55fc805ff9b14af818903882ece08e2235b12b73b867b521b92994c52b14/aiohttp-3.10.10-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:a3f00003de6eba42d6e94fabb4125600d6e484846dbf90ea8e48a800430cc142", size = 1266263 }, + { url = "https://files.pythonhosted.org/packages/1f/ec/d2be2ca7b063e4f91519d550dbc9c1cb43040174a322470deed90b3d3333/aiohttp-3.10.10-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:1bbb122c557a16fafc10354b9d99ebf2f2808a660d78202f10ba9d50786384b9", size = 1321617 }, + { url = "https://files.pythonhosted.org/packages/c9/a3/b29f7920e1cd0a9a68a45dd3eb16140074d2efb1518d2e1f3e140357dc37/aiohttp-3.10.10-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:30ca7c3b94708a9d7ae76ff281b2f47d8eaf2579cd05971b5dc681db8caac6e1", size = 1339227 }, + { url = "https://files.pythonhosted.org/packages/8a/81/34b67235c47e232d807b4bbc42ba9b927c7ce9476872372fddcfd1e41b3d/aiohttp-3.10.10-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:df9270660711670e68803107d55c2b5949c2e0f2e4896da176e1ecfc068b974a", size = 1299068 }, + { url = "https://files.pythonhosted.org/packages/04/1f/26a7fe11b6ad3184f214733428353c89ae9fe3e4f605a657f5245c5e720c/aiohttp-3.10.10-cp311-cp311-win32.whl", hash = "sha256:aafc8ee9b742ce75044ae9a4d3e60e3d918d15a4c2e08a6c3c3e38fa59b92d94", size = 362223 }, + { url = "https://files.pythonhosted.org/packages/10/91/85dcd93f64011434359ce2666bece981f08d31bc49df33261e625b28595d/aiohttp-3.10.10-cp311-cp311-win_amd64.whl", hash = "sha256:362f641f9071e5f3ee6f8e7d37d5ed0d95aae656adf4ef578313ee585b585959", size = 381576 }, + { url = "https://files.pythonhosted.org/packages/ae/99/4c5aefe5ad06a1baf206aed6598c7cdcbc7c044c46801cd0d1ecb758cae3/aiohttp-3.10.10-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9294bbb581f92770e6ed5c19559e1e99255e4ca604a22c5c6397b2f9dd3ee42c", size = 583536 }, + { url = "https://files.pythonhosted.org/packages/a9/36/8b3bc49b49cb6d2da40ee61ff15dbcc44fd345a3e6ab5bb20844df929821/aiohttp-3.10.10-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a8fa23fe62c436ccf23ff930149c047f060c7126eae3ccea005f0483f27b2e28", size = 395693 }, + { url = "https://files.pythonhosted.org/packages/e1/77/0aa8660dcf11fa65d61712dbb458c4989de220a844bd69778dff25f2d50b/aiohttp-3.10.10-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5c6a5b8c7926ba5d8545c7dd22961a107526562da31a7a32fa2456baf040939f", size = 390898 }, + { url = "https://files.pythonhosted.org/packages/38/d2/b833d95deb48c75db85bf6646de0a697e7fb5d87bd27cbade4f9746b48b1/aiohttp-3.10.10-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:007ec22fbc573e5eb2fb7dec4198ef8f6bf2fe4ce20020798b2eb5d0abda6138", size = 1312060 }, + { url = "https://files.pythonhosted.org/packages/aa/5f/29fd5113165a0893de8efedf9b4737e0ba92dfcd791415a528f947d10299/aiohttp-3.10.10-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9627cc1a10c8c409b5822a92d57a77f383b554463d1884008e051c32ab1b3742", size = 1350553 }, + { url = "https://files.pythonhosted.org/packages/ad/cc/f835f74b7d344428469200105236d44606cfa448be1e7c95ca52880d9bac/aiohttp-3.10.10-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:50edbcad60d8f0e3eccc68da67f37268b5144ecc34d59f27a02f9611c1d4eec7", size = 1392646 }, + { url = "https://files.pythonhosted.org/packages/bf/fe/1332409d845ca601893bbf2d76935e0b93d41686e5f333841c7d7a4a770d/aiohttp-3.10.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a45d85cf20b5e0d0aa5a8dca27cce8eddef3292bc29d72dcad1641f4ed50aa16", size = 1306310 }, + { url = "https://files.pythonhosted.org/packages/e4/a1/25a7633a5a513278a9892e333501e2e69c83e50be4b57a62285fb7a008c3/aiohttp-3.10.10-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0b00807e2605f16e1e198f33a53ce3c4523114059b0c09c337209ae55e3823a8", size = 1260255 }, + { url = "https://files.pythonhosted.org/packages/f2/39/30eafe89e0e2a06c25e4762844c8214c0c0cd0fd9ffc3471694a7986f421/aiohttp-3.10.10-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f2d4324a98062be0525d16f768a03e0bbb3b9fe301ceee99611dc9a7953124e6", size = 1271141 }, + { url = "https://files.pythonhosted.org/packages/5b/fc/33125df728b48391ef1fcb512dfb02072158cc10d041414fb79803463020/aiohttp-3.10.10-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:438cd072f75bb6612f2aca29f8bd7cdf6e35e8f160bc312e49fbecab77c99e3a", size = 1280244 }, + { url = "https://files.pythonhosted.org/packages/3b/61/e42bf2c2934b5caa4e2ec0b5e5fd86989adb022b5ee60c2572a9d77cf6fe/aiohttp-3.10.10-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:baa42524a82f75303f714108fea528ccacf0386af429b69fff141ffef1c534f9", size = 1316805 }, + { url = "https://files.pythonhosted.org/packages/18/32/f52a5e2ae9ad3bba10e026a63a7a23abfa37c7d97aeeb9004eaa98df3ce3/aiohttp-3.10.10-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:a7d8d14fe962153fc681f6366bdec33d4356f98a3e3567782aac1b6e0e40109a", size = 1343930 }, + { url = "https://files.pythonhosted.org/packages/05/be/6a403b464dcab3631fe8e27b0f1d906d9e45c5e92aca97ee007e5a895560/aiohttp-3.10.10-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c1277cd707c465cd09572a774559a3cc7c7a28802eb3a2a9472588f062097205", size = 1306186 }, + { url = "https://files.pythonhosted.org/packages/8e/fd/bb50fe781068a736a02bf5c7ad5f3ab53e39f1d1e63110da6d30f7605edc/aiohttp-3.10.10-cp312-cp312-win32.whl", hash = "sha256:59bb3c54aa420521dc4ce3cc2c3fe2ad82adf7b09403fa1f48ae45c0cbde6628", size = 359289 }, + { url = "https://files.pythonhosted.org/packages/70/9e/5add7e240f77ef67c275c82cc1d08afbca57b77593118c1f6e920ae8ad3f/aiohttp-3.10.10-cp312-cp312-win_amd64.whl", hash = "sha256:0e1b370d8007c4ae31ee6db7f9a2fe801a42b146cec80a86766e7ad5c4a259cf", size = 379313 }, ] [[package]] @@ -3166,6 +3211,7 @@ requires-dist = [ { name = "pytest-asyncio", marker = "extra == 'test'" }, { name = "pytest-cov", marker = "extra == 'test'" }, { name = "redis-om", specifier = ">=0.3.0,<0.4.0" }, + { name = "rel", marker = "extra == 'chat'" }, { name = "rich", specifier = ">=13.6.0,<14.0.0" }, { name = "scipy", marker = "extra == 'examples'" }, { name = "together", specifier = ">=0.2.4,<1.4.0" },