diff --git a/examples/experimental/websocket/websocket_test_client.py b/examples/experimental/websocket/websocket_test_client.py index c1bd74b67..ab9f30272 100644 --- a/examples/experimental/websocket/websocket_test_client.py +++ b/examples/experimental/websocket/websocket_test_client.py @@ -12,8 +12,8 @@ class WebSocketClient: - def __init__(self, uri: str, token: str, client_id: int): - self.uri = uri + def __init__(self, url: str, token: str, client_id: int): + self.url = url self.token = token self.client_id = client_id self.message_file = Path(f"message_{client_id}.txt") @@ -25,11 +25,11 @@ async def save_message(self, message: str) -> None: async def connect(self) -> None: """Establish and maintain websocket connection""" - uri_with_token = f"{self.uri}?token=test_token_{self.client_id}" + url_with_token = f"{self.url}?token=test_token_{self.client_id}" try: - async with websockets.connect(uri_with_token) as websocket: - print(f"Client {self.client_id}: Connected to {self.uri}") + async with websockets.connect(url_with_token) as websocket: + print(f"Client {self.client_id}: Connected to {self.url}") # Send initial message # Note: You'll need to implement the logic to get agent_ids and env_id @@ -70,16 +70,16 @@ async def connect(self) -> None: async def main() -> None: # Create multiple WebSocket clients num_clients = 0 - uri = "ws://localhost:8800/ws/simulation" + url = "ws://localhost:8800/ws/simulation" # Create and store client instances clients = [ - WebSocketClient(uri=uri, token=f"test_token_{i}", client_id=i) + WebSocketClient(url=url, token=f"test_token_{i}", client_id=i) for i in range(num_clients) ] - clients.append(WebSocketClient(uri=uri, token="test_token_10", client_id=10)) + clients.append(WebSocketClient(url=url, token="test_token_10", client_id=10)) clients.append( - WebSocketClient(uri=uri, token="test_token_10", client_id=10) + WebSocketClient(url=url, token="test_token_10", client_id=10) ) # test duplicate token # Create tasks for each client diff --git a/examples/fast_api_example.py b/examples/fast_api_example.py new file mode 100644 index 000000000..f10b850b6 --- /dev/null +++ b/examples/fast_api_example.py @@ -0,0 +1,122 @@ +# Example curl command to call the simulate endpoint: +import requests +import time + +BASE_URL = "http://localhost:8080" + + +def _create_mock_agent_profile() -> None: + agent1_data = { + "first_name": "John", + "last_name": "Doe", + "occupation": "test_occupation", + "gender": "test_gender", + "pk": "tmppk_agent1", + "tag": "test_tag", + } + response = requests.post( + f"{BASE_URL}/agents/", + headers={"Content-Type": "application/json"}, + json=agent1_data, + ) + assert response.status_code == 200 + + agent2_data = { + "first_name": "Jane", + "last_name": "Doe", + "occupation": "test_occupation", + "gender": "test_gender", + "pk": "tmppk_agent2", + "tag": "test_tag", + } + response = requests.post( + f"{BASE_URL}/agents/", + headers={"Content-Type": "application/json"}, + json=agent2_data, + ) + assert response.status_code == 200 + + +def _create_mock_env_profile() -> None: + env_data = { + "codename": "test_codename", + "scenario": "A", + "agent_goals": [ + "B", + "C", + ], + "pk": "tmppk_env_profile", + "tag": "test_tag", + } + response = requests.post( + f"{BASE_URL}/scenarios/", + headers={"Content-Type": "application/json"}, + json=env_data, + ) + assert response.status_code == 200 + + +_create_mock_agent_profile() +_create_mock_env_profile() + + +data = { + "env_id": "tmppk_env_profile", + "agent_ids": ["tmppk_agent1", "tmppk_agent2"], + "models": ["custom/structured-llama3.2:1b@http://localhost:8000/v1"] * 3, + "max_turns": 10, + "tag": "test_tag", +} +try: + response = requests.post( + f"{BASE_URL}/simulate/", headers={"Content-Type": "application/json"}, json=data + ) + print(response) + assert response.status_code == 202 + assert isinstance(response.content.decode(), str) + episode_pk = response.content.decode() + print(episode_pk) + max_retries = 200 + retry_count = 0 + while retry_count < max_retries: + try: + response = requests.get(f"{BASE_URL}/simulation_status/{episode_pk}") + assert response.status_code == 200 + status = response.content.decode() + print(status) + if status == "Error": + raise Exception("Error running simulation") + elif status == "Completed": + break + # Status is "Started", keep polling + time.sleep(1) + retry_count += 1 + except Exception as e: + print(f"Error checking simulation status: {e}") + time.sleep(1) + retry_count += 1 + else: + raise TimeoutError("Simulation timed out after 10 retries") + +finally: + try: + response = requests.delete(f"{BASE_URL}/agents/tmppk_agent1") + assert response.status_code == 200 + except Exception as e: + print(e) + try: + response = requests.delete(f"{BASE_URL}/agents/tmppk_agent2") + assert response.status_code == 200 + except Exception as e: + print(e) + try: + response = requests.delete(f"{BASE_URL}/scenarios/tmppk_env_profile") + assert response.status_code == 200 + except Exception as e: + print(e) + + try: + response = requests.delete(f"{BASE_URL}/episodes/{episode_pk}") + assert response.status_code == 200 + except Exception as e: + print(e) diff --git a/sotopia/database/__init__.py b/sotopia/database/__init__.py index 07f02c1dd..d9156a989 100644 --- a/sotopia/database/__init__.py +++ b/sotopia/database/__init__.py @@ -2,7 +2,7 @@ from redis_om import JsonModel, Migrator from .annotators import Annotator from .env_agent_combo_storage import EnvAgentComboStorage -from .logs import AnnotationForEpisode, EpisodeLog +from .logs import AnnotationForEpisode, EpisodeLog, NonStreamingSimulationStatus from .persistent_profile import ( AgentProfile, EnvironmentProfile, @@ -44,6 +44,7 @@ "AgentProfile", "EnvironmentProfile", "EpisodeLog", + "NonStreamingSimulationStatus", "EnvAgentComboStorage", "AnnotationForEpisode", "Annotator", @@ -73,6 +74,7 @@ "EvaluationDimensionBuilder", "CustomEvaluationDimension", "CustomEvaluationDimensionList", + "NonStreamingSimulationStatus", ] InheritedJsonModel = TypeVar("InheritedJsonModel", bound="JsonModel") diff --git a/sotopia/database/logs.py b/sotopia/database/logs.py index b3c5ff41e..4a2551aed 100644 --- a/sotopia/database/logs.py +++ b/sotopia/database/logs.py @@ -8,10 +8,15 @@ from pydantic import model_validator from redis_om import JsonModel from redis_om.model.model import Field - +from typing import Literal from sotopia.database.persistent_profile import AgentProfile +class NonStreamingSimulationStatus(JsonModel): + episode_pk: str = Field(index=True) + status: Literal["Started", "Error", "Completed"] + + class EpisodeLog(JsonModel): # Note that we did not validate the following constraints: # 1. The number of turns in messages and rewards should be the same or off by 1 diff --git a/sotopia/database/persistent_profile.py b/sotopia/database/persistent_profile.py index ee99f6601..c2e0e8e86 100644 --- a/sotopia/database/persistent_profile.py +++ b/sotopia/database/persistent_profile.py @@ -94,6 +94,11 @@ class RelationshipProfile(JsonModel): description="0 means stranger, 1 means know_by_name, 2 means acquaintance, 3 means friend, 4 means romantic_relationship, 5 means family_member", ) # this could be improved by limiting str to a relationship Enum background_story: str | None = Field(default_factory=lambda: None) + tag: str = Field( + index=True, + default_factory=lambda: "", + description="The tag of the relationship, used for searching, could be convenient to document relationship profiles from different works and sources", + ) class EnvironmentList(JsonModel): diff --git a/sotopia/database/serialization.py b/sotopia/database/serialization.py index 1fcc8b69e..c38e3c6c3 100644 --- a/sotopia/database/serialization.py +++ b/sotopia/database/serialization.py @@ -84,7 +84,7 @@ def _map_gender_to_adj(gender: str) -> str: "Nonbinary": "nonbinary", } if gender: - return gender_to_adj[gender] + return gender_to_adj.get(gender, "") else: return "" diff --git a/sotopia/envs/parallel.py b/sotopia/envs/parallel.py index 5d27f687b..e0a928d36 100644 --- a/sotopia/envs/parallel.py +++ b/sotopia/envs/parallel.py @@ -51,7 +51,7 @@ def _map_gender_to_adj(gender: str) -> str: "Nonbinary": "nonbinary", } if gender: - return gender_to_adj[gender] + return gender_to_adj.get(gender, "") else: return "" diff --git a/sotopia/server.py b/sotopia/server.py index aec81a0f7..ba88e9a7b 100644 --- a/sotopia/server.py +++ b/sotopia/server.py @@ -15,7 +15,7 @@ ScriptWritingAgent, ) from sotopia.agents.base_agent import BaseAgent -from sotopia.database import EpisodeLog +from sotopia.database import EpisodeLog, NonStreamingSimulationStatus from sotopia.envs import ParallelSotopiaEnv from sotopia.envs.evaluators import ( EvaluationForTwoAgents, @@ -119,12 +119,15 @@ async def arun_one_episode( json_in_script: bool = False, tag: str | None = None, push_to_db: bool = False, + episode_pk: str | None = None, streaming: bool = False, + simulation_status: NonStreamingSimulationStatus | None = None, ) -> Union[ list[tuple[str, str, Message]], AsyncGenerator[list[list[tuple[str, str, Message]]], None], ]: agents = Agents({agent.agent_name: agent for agent in agent_list}) + print(f"Running episode with tag: {tag}------------------") async def generate_messages() -> ( AsyncGenerator[list[list[tuple[str, str, Message]]], None] @@ -188,7 +191,7 @@ async def generate_messages() -> ( for agent_name in env.agents ] ) - + print(f"Messages: {messages}") yield messages rewards.append([rewards_in_turn[agent_name] for agent_name in env.agents]) reasons.append( @@ -228,7 +231,14 @@ async def generate_messages() -> ( if push_to_db: try: - epilog.save() + if episode_pk: + epilog.pk = episode_pk + epilog.save() + else: + epilog.save() + if simulation_status: + simulation_status.status = "Completed" + simulation_status.save() except Exception as e: logging.error(f"Failed to save episode log: {e}") diff --git a/sotopia/ui/README.md b/sotopia/ui/README.md index 156050a4b..4d8bb773a 100644 --- a/sotopia/ui/README.md +++ b/sotopia/ui/README.md @@ -4,6 +4,17 @@ ## FastAPI Server +To run the FastAPI server, you can use the following command: +```bash +uv run rq worker +uv run fastapi run sotopia/ui/fastapi_server.py --workers 4 --port 8080 +``` + +Here is also an example of using the FastAPI server: +```bash +uv run python examples/fast_api_example.py +``` + The API server is a FastAPI application that is used to connect the Sotopia UI to the Sotopia backend. This could also help with other projects that need to connect to the Sotopia backend through HTTP requests. diff --git a/sotopia/ui/fastapi_server.py b/sotopia/ui/fastapi_server.py index 543dafd2b..1d3a9d0eb 100644 --- a/sotopia/ui/fastapi_server.py +++ b/sotopia/ui/fastapi_server.py @@ -1,9 +1,34 @@ -from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect -from typing import Literal, cast, Optional, Any +from typing import Literal, cast, Dict, Self + +from redis_om import get_redis_connection +import rq +from sotopia.database import ( + EnvironmentProfile, + AgentProfile, + EpisodeLog, + RelationshipProfile, + RelationshipType, + NonStreamingSimulationStatus, +) +from sotopia.envs.parallel import ParallelSotopiaEnv +from sotopia.envs.evaluators import ( + RuleBasedTerminatedEvaluator, + ReachGoalLLMEvaluator, + EvaluationForTwoAgents, + SotopiaDimensions, +) +from sotopia.server import arun_one_episode +from sotopia.agents import LLMAgent, Agents +from fastapi import ( + FastAPI, + WebSocket, + HTTPException, + WebSocketDisconnect, +) +from typing import Optional, Any from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel +from pydantic import BaseModel, model_validator, field_validator -from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog from sotopia.ui.websocket_utils import ( WebSocketSotopiaSimulator, WSMessageType, @@ -15,9 +40,11 @@ from contextlib import asynccontextmanager from typing import AsyncIterator import logging +from fastapi.responses import Response logger = logging.getLogger(__name__) + app = FastAPI() app.add_middleware( @@ -29,11 +56,21 @@ ) # TODO: Whether allowing CORS for all origins +class RelationshipWrapper(BaseModel): + pk: str = "" + agent_1_id: str = "" + agent_2_id: str = "" + relationship: Literal[0, 1, 2, 3, 4, 5] = 0 + backstory: str = "" + tag: str = "" + + class AgentProfileWrapper(BaseModel): """ Wrapper for AgentProfile to avoid pydantic v2 issues """ + pk: str = "" first_name: str last_name: str age: int = 0 @@ -57,6 +94,7 @@ class EnvironmentProfileWrapper(BaseModel): Wrapper for EnvironmentProfile to avoid pydantic v2 issues """ + pk: str = "" codename: str source: str = "" scenario: str = "" @@ -68,6 +106,33 @@ class EnvironmentProfileWrapper(BaseModel): tag: str = "" +class SimulationRequest(BaseModel): + env_id: str + agent_ids: list[str] + models: list[str] + max_turns: int + tag: str + + @field_validator("agent_ids") + @classmethod + def validate_agent_ids(cls, v: list[str]) -> list[str]: + if len(v) != 2: + raise ValueError( + "Currently only 2 agents are supported, we are working on supporting more agents" + ) + return v + + @model_validator(mode="after") + def validate_models(self) -> Self: + models = self.models + agent_ids = self.agent_ids + if len(models) != len(agent_ids) + 1: + raise ValueError( + f"models must have exactly {len(agent_ids) + 1} elements, if there are {len(agent_ids)} agents, the first model is the evaluator model" + ) + return self + + @app.get("/scenarios", response_model=list[EnvironmentProfile]) async def get_scenarios_all() -> list[EnvironmentProfile]: return EnvironmentProfile.all() @@ -122,6 +187,18 @@ async def get_agents( return agents_profiles +@app.get("/relationship/{agent_1_id}/{agent_2_id}", response_model=str) +async def get_relationship(agent_1_id: str, agent_2_id: str) -> str: + relationship_profiles = RelationshipProfile.find( + (RelationshipProfile.agent_1_id == agent_1_id) + & (RelationshipProfile.agent_2_id == agent_2_id) + ).all() + assert len(relationship_profiles) == 1 + relationship_profile = relationship_profiles[0] + assert isinstance(relationship_profile, RelationshipProfile) + return f"{str(relationship_profile.relationship)}: {RelationshipType(relationship_profile.relationship).name}" + + @app.get("/episodes", response_model=list[EpisodeLog]) async def get_episodes_all() -> list[EpisodeLog]: return EpisodeLog.all() @@ -143,6 +220,15 @@ async def get_episodes(get_by: Literal["id", "tag"], value: str) -> list[Episode return episodes +@app.post("/scenarios/", response_model=str) +async def create_scenario(scenario: EnvironmentProfileWrapper) -> str: + scenario_profile = EnvironmentProfile(**scenario.model_dump()) + scenario_profile.save() + pk = scenario_profile.pk + assert pk is not None + return pk + + @app.post("/agents/", response_model=str) async def create_agent(agent: AgentProfileWrapper) -> str: agent_profile = AgentProfile(**agent.model_dump()) @@ -152,39 +238,146 @@ async def create_agent(agent: AgentProfileWrapper) -> str: return pk -@app.post("/scenarios/", response_model=str) -async def create_scenario(scenario: EnvironmentProfileWrapper) -> str: - scenario_profile = EnvironmentProfile(**scenario.model_dump()) - scenario_profile.save() - pk = scenario_profile.pk +@app.post("/relationship/", response_model=str) +async def create_relationship(relationship: RelationshipWrapper) -> str: + relationship_profile = RelationshipProfile(**relationship.model_dump()) + relationship_profile.save() + pk = relationship_profile.pk assert pk is not None return pk -@app.put("/agents/{agent_id}", response_model=str) -async def update_agent(agent_id: str, agent: AgentProfileWrapper) -> str: +async def run_simulation( + episode_pk: str, + simulation_request: SimulationRequest, + simulation_status: NonStreamingSimulationStatus, +) -> None: try: - old_agent = AgentProfile.get(pk=agent_id) + env_profile: EnvironmentProfile = EnvironmentProfile.get( + pk=simulation_request.env_id + ) except Exception: # TODO Check the exception type raise HTTPException( - status_code=404, detail=f"Agent with id={agent_id} not found" + status_code=404, + detail=f"Environment with id={simulation_request.env_id} not found", + ) + try: + agent_1_profile = AgentProfile.get(pk=simulation_request.agent_ids[0]) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, + detail=f"Agent with id={simulation_request.agent_ids[0]} not found", + ) + try: + agent_2_profile = AgentProfile.get(pk=simulation_request.agent_ids[1]) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, + detail=f"Agent with id={simulation_request.agent_ids[1]} not found", ) - old_agent.update(**agent.model_dump()) # type: ignore - assert old_agent.pk is not None - return old_agent.pk - -@app.put("/scenarios/{scenario_id}", response_model=str) -async def update_scenario(scenario_id: str, scenario: EnvironmentProfileWrapper) -> str: + env_params: dict[str, Any] = { + "model_name": simulation_request.models[0], + "action_order": "round-robin", + "evaluators": [ + RuleBasedTerminatedEvaluator( + max_turn_number=simulation_request.max_turns, max_stale_turn=2 + ), + ], + "terminal_evaluators": [ + ReachGoalLLMEvaluator( + simulation_request.models[0], + EvaluationForTwoAgents[SotopiaDimensions], + ), + ], + } + env = ParallelSotopiaEnv(env_profile=env_profile, **env_params) + agents = Agents( + { + "agent1": LLMAgent( + "agent1", + model_name=simulation_request.models[1], + agent_profile=agent_1_profile, + ), + "agent2": LLMAgent( + "agent2", + model_name=simulation_request.models[2], + agent_profile=agent_2_profile, + ), + } + ) + + await arun_one_episode( + env=env, + agent_list=list(agents.values()), + push_to_db=True, + tag=simulation_request.tag, + episode_pk=episode_pk, + simulation_status=simulation_status, + ) + + +@app.post("/simulate/", response_model=str) +def simulate(simulation_request: SimulationRequest) -> Response: try: - old_scenario = EnvironmentProfile.get(pk=scenario_id) + _: EnvironmentProfile = EnvironmentProfile.get(pk=simulation_request.env_id) except Exception: # TODO Check the exception type raise HTTPException( - status_code=404, detail=f"Scenario with id={scenario_id} not found" + status_code=404, + detail=f"Environment with id={simulation_request.env_id} not found", + ) + try: + __ = AgentProfile.get(pk=simulation_request.agent_ids[0]) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, + detail=f"Agent with id={simulation_request.agent_ids[0]} not found", + ) + try: + ___ = AgentProfile.get(pk=simulation_request.agent_ids[1]) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, + detail=f"Agent with id={simulation_request.agent_ids[1]} not found", + ) + + episode_pk = EpisodeLog( + environment="", + agents=[], + models=[], + messages=[], + reasoning="", + rewards=[], # Pseudorewards + rewards_prompt="", + ).pk + try: + simulation_status = NonStreamingSimulationStatus( + episode_pk=episode_pk, + status="Started", ) - old_scenario.update(**scenario.model_dump()) # type: ignore - assert old_scenario.pk is not None - return old_scenario.pk + simulation_status.save() + queue = rq.Queue("default", connection=get_redis_connection()) + queue.enqueue( + run_simulation, + episode_pk=episode_pk, + simulation_request=simulation_request, + simulation_status=simulation_status, + ) + + except Exception as e: + logger.error(f"Error starting simulation: {e}") + simulation_status.status = "Error" + simulation_status.save() + return Response(content=episode_pk, status_code=202) + + +@app.get("/simulation_status/{episode_pk}", response_model=str) +async def get_simulation_status(episode_pk: str) -> str: + status = NonStreamingSimulationStatus.find( + NonStreamingSimulationStatus.episode_pk == episode_pk + ).all()[0] + assert isinstance(status, NonStreamingSimulationStatus) + return status.status @app.delete("/agents/{agent_id}", response_model=str) @@ -213,6 +406,23 @@ async def delete_scenario(scenario_id: str) -> str: return scenario.pk +@app.delete("/relationship/{relationship_id}", response_model=str) +async def delete_relationship(relationship_id: str) -> str: + RelationshipProfile.delete(relationship_id) + return relationship_id + + +@app.delete("/episodes/{episode_id}", response_model=str) +async def delete_episode(episode_id: str) -> str: + EpisodeLog.delete(episode_id) + return episode_id + + +active_simulations: Dict[ + str, bool +] = {} # TODO check whether this is the correct way to store the active simulations + + @app.get("/models", response_model=list[str]) async def get_models() -> list[str]: # TODO figure out how to get the available models diff --git a/stubs/redis_om/__init__.pyi b/stubs/redis_om/__init__.pyi index abbae6f43..133b6caff 100644 --- a/stubs/redis_om/__init__.pyi +++ b/stubs/redis_om/__init__.pyi @@ -2,6 +2,7 @@ import abc from typing import Any, Generator, TypeVar from pydantic import BaseModel +import redis from redis_om.model.model import Field from pydantic._internal._model_construction import ModelMetaclass from redis_om.model.model import FindQuery @@ -37,3 +38,5 @@ class EmbeddedJsonModel(JsonModel): ... class Migrator: def run(self) -> None: ... + +def get_redis_connection() -> redis.Redis[bytes]: ... diff --git a/tests/ui/test_fastapi.py b/tests/ui/test_fastapi.py index 9395104f3..b8c7bb7df 100644 --- a/tests/ui/test_fastapi.py +++ b/tests/ui/test_fastapi.py @@ -1,5 +1,10 @@ from fastapi.testclient import TestClient -from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog +from sotopia.database import ( + EnvironmentProfile, + AgentProfile, + EpisodeLog, + RelationshipProfile, +) from sotopia.messages import SimpleMessage from sotopia.ui.fastapi_server import app import pytest @@ -63,7 +68,7 @@ def create_dummy_episode_log() -> None: @pytest.fixture -def create_mock_data() -> Generator[None, None, None]: +def create_mock_data(for_posting: bool = False) -> Generator[None, None, None]: def _create_mock_agent_profile() -> None: AgentProfile( first_name="John", @@ -71,6 +76,7 @@ def _create_mock_agent_profile() -> None: occupation="test_occupation", gender="test_gender", pk="tmppk_agent1", + tag="test_tag", ).save() AgentProfile( first_name="Jane", @@ -78,6 +84,7 @@ def _create_mock_agent_profile() -> None: occupation="test_occupation", gender="test_gender", pk="tmppk_agent2", + tag="test_tag", ).save() def _create_mock_env_profile() -> None: @@ -89,18 +96,55 @@ def _create_mock_env_profile() -> None: "C", ], pk="tmppk_env_profile", + tag="test_tag", ) env_profile.save() + def _create_mock_relationship() -> None: + RelationshipProfile( + pk="tmppk_relationship", + agent_1_id="tmppk_agent1", + agent_2_id="tmppk_agent2", + relationship=1.0, + ).save() + _create_mock_agent_profile() _create_mock_env_profile() - + _create_mock_relationship() yield - AgentProfile.delete("tmppk_agent1") - AgentProfile.delete("tmppk_agent2") - EnvironmentProfile.delete("tmppk_env_profile") - EpisodeLog.delete("tmppk_episode_log") + try: + AgentProfile.delete("tmppk_agent1") + except Exception as e: + print(e) + try: + AgentProfile.delete("tmppk_agent2") + except Exception as e: + print(e) + try: + EnvironmentProfile.delete("tmppk_env_profile") + except Exception as e: + print(e) + try: + RelationshipProfile.delete("tmppk_relationship") + except Exception as e: + print(e) + try: + EpisodeLog.delete("tmppk_episode_log") + except Exception as e: + print(e) + + try: + EpisodeLog.delete("tmppk_episode_log") + except Exception as e: + print(e) + + try: + episodes = EpisodeLog.find(EpisodeLog.tag == "test_tag").all() + for episode in episodes: + EpisodeLog.delete(episode.pk) + except Exception as e: + print(e) def test_get_scenarios_all(create_mock_data: Callable[[], None]) -> None: @@ -169,8 +213,17 @@ def test_get_episodes_by_tag(create_mock_data: Callable[[], None]) -> None: assert response.json()[0]["tag"] == tag +def test_get_relationship(create_mock_data: Callable[[], None]) -> None: + response = client.get("/relationship/tmppk_agent1/tmppk_agent2") + assert response.status_code == 200 + assert isinstance(response.json(), str) + assert response.json() == "1: know_by_name" + + +@pytest.mark.parametrize("create_mock_data", [True], indirect=True) def test_create_agent(create_mock_data: Callable[[], None]) -> None: agent_data = { + "pk": "tmppk_agent1", "first_name": "test_first_name", "last_name": "test_last_name", } @@ -179,13 +232,30 @@ def test_create_agent(create_mock_data: Callable[[], None]) -> None: assert isinstance(response.json(), str) +@pytest.mark.parametrize("create_mock_data", [True], indirect=True) def test_create_scenario(create_mock_data: Callable[[], None]) -> None: scenario_data = { + "pk": "tmppk_env_profile", "codename": "test_codename", "scenario": "test_scenario", "tag": "test", } response = client.post("/scenarios/", json=scenario_data) + EnvironmentProfile.delete("tmppk_env_profile") + assert response.status_code == 200 + assert isinstance(response.json(), str) + + +@pytest.mark.parametrize("create_mock_data", [True], indirect=True) +def test_create_relationship(create_mock_data: Callable[[], None]) -> None: + relationship_data = { + "pk": "tmppk_relationship", + "agent_1_id": "tmppk_agent1", + "agent_2_id": "tmppk_agent2", + "relationship": 1.0, + "tag": "test_tag", + } + response = client.post("/relationship", json=relationship_data) assert response.status_code == 200 assert isinstance(response.json(), str) @@ -200,3 +270,54 @@ def test_delete_scenario(create_mock_data: Callable[[], None]) -> None: response = client.delete("/scenarios/tmppk_env_profile") assert response.status_code == 200 assert isinstance(response.json(), str) + + +def test_delete_relationship(create_mock_data: Callable[[], None]) -> None: + response = client.delete("/relationship/tmppk_relationship") + assert response.status_code == 200 + assert isinstance(response.json(), str) + + +# def test_simulate(create_mock_data: Callable[[], None]) -> None: +# response = client.post( +# "/simulate", +# json={ +# "env_id": "tmppk_env_profile", +# "agent_ids": ["tmppk_agent1", "tmppk_agent2"], +# "models": [ +# # "custom/llama3.2:1b@http://localhost:8000/v1", +# # "custom/llama3.2:1b@http://localhost:8000/v1", +# # "custom/llama3.2:1b@http://localhost:8000/v1" +# "gpt-4o-mini", +# "gpt-4o-mini", +# "gpt-4o-mini", +# ], +# "max_turns": 2, +# "tag": "test_tag", +# }, +# ) +# assert response.status_code == 200 +# assert isinstance(response.json(), str) +# max_retries = 20 +# retry_count = 0 +# while retry_count < max_retries: +# try: +# status = NonStreamingSimulationStatus.find( +# NonStreamingSimulationStatus.episode_pk == response.json() +# ).all()[0] +# assert isinstance(status, NonStreamingSimulationStatus) +# print(status) +# if status.status == "Error": +# raise Exception("Error running simulation") +# elif status.status == "Completed": +# # EpisodeLog.get(response.json()) +# break +# # Status is "Started", keep polling +# time.sleep(1) +# retry_count += 1 +# except Exception as e: +# print(f"Error checking simulation status: {e}") +# time.sleep(1) +# retry_count += 1 +# else: +# raise TimeoutError("Simulation timed out after 10 retries") diff --git a/uv.lock b/uv.lock index 71152b36b..5017e0e00 100644 --- a/uv.lock +++ b/uv.lock @@ -115,51 +115,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f5/76/a57ceff577ae26fe9a6f31ac799bc638ecf26e4acdf04295290b9929b349/aiohttp-3.11.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:9d18a8b44ec8502a7fde91446cd9c9b95ce7c49f1eacc1fb2358b8907d4369fd", size = 1690038 }, { url = "https://files.pythonhosted.org/packages/4b/81/b20e09003b6989a7f23a721692137a6143420a151063c750ab2a04878e3c/aiohttp-3.11.7-cp312-cp312-win32.whl", hash = "sha256:3d1c9c15d3999107cbb9b2d76ca6172e6710a12fda22434ee8bd3f432b7b17e8", size = 409887 }, { url = "https://files.pythonhosted.org/packages/b7/0b/607c98bff1d07bb21e0c39e7711108ef9ff4f2a361a3ec1ce8dce93623a5/aiohttp-3.11.7-cp312-cp312-win_amd64.whl", hash = "sha256:018f1b04883a12e77e7fc161934c0f298865d3a484aea536a6a2ca8d909f0ba0", size = 436462 }, - { url = "https://files.pythonhosted.org/packages/3d/dd/3d40c0e67e79c5c42671e3e268742f1ff96c6573ca43823563d01abd9475/aiohttp-3.10.10-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:be7443669ae9c016b71f402e43208e13ddf00912f47f623ee5994e12fc7d4b3f", size = 586969 }, - { url = "https://files.pythonhosted.org/packages/75/64/8de41b5555e5b43ef6d4ed1261891d33fe45ecc6cb62875bfafb90b9ab93/aiohttp-3.10.10-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7b06b7843929e41a94ea09eb1ce3927865387e3e23ebe108e0d0d09b08d25be9", size = 399367 }, - { url = "https://files.pythonhosted.org/packages/96/36/27bd62ea7ce43906d1443a73691823fc82ffb8fa03276b0e2f7e1037c286/aiohttp-3.10.10-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:333cf6cf8e65f6a1e06e9eb3e643a0c515bb850d470902274239fea02033e9a8", size = 390720 }, - { url = "https://files.pythonhosted.org/packages/e8/4d/d516b050d811ce0dd26325c383013c104ffa8b58bd361b82e52833f68e78/aiohttp-3.10.10-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:274cfa632350225ce3fdeb318c23b4a10ec25c0e2c880eff951a3842cf358ac1", size = 1228820 }, - { url = "https://files.pythonhosted.org/packages/53/94/964d9327a3e336d89aad52260836e4ec87fdfa1207176550fdf384eaffe7/aiohttp-3.10.10-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d9e5e4a85bdb56d224f412d9c98ae4cbd032cc4f3161818f692cd81766eee65a", size = 1264616 }, - { url = "https://files.pythonhosted.org/packages/0c/20/70ce17764b685ca8f5bf4d568881b4e1f1f4ea5e8170f512fdb1a33859d2/aiohttp-3.10.10-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2b606353da03edcc71130b52388d25f9a30a126e04caef1fd637e31683033abd", size = 1298402 }, - { url = "https://files.pythonhosted.org/packages/d1/d1/5248225ccc687f498d06c3bca5af2647a361c3687a85eb3aedcc247ee1aa/aiohttp-3.10.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab5a5a0c7a7991d90446a198689c0535be89bbd6b410a1f9a66688f0880ec026", size = 1222205 }, - { url = "https://files.pythonhosted.org/packages/f2/a3/9296b27cc5d4feadf970a14d0694902a49a985f3fae71b8322a5f77b0baa/aiohttp-3.10.10-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:578a4b875af3e0daaf1ac6fa983d93e0bbfec3ead753b6d6f33d467100cdc67b", size = 1193804 }, - { url = "https://files.pythonhosted.org/packages/d9/07/f3760160feb12ac51a6168a6da251a4a8f2a70733d49e6ceb9b3e6ee2f03/aiohttp-3.10.10-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:8105fd8a890df77b76dd3054cddf01a879fc13e8af576805d667e0fa0224c35d", size = 1193544 }, - { url = "https://files.pythonhosted.org/packages/7e/4c/93a70f9a4ba1c30183a6dd68bfa79cddbf9a674f162f9c62e823a74a5515/aiohttp-3.10.10-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3bcd391d083f636c06a68715e69467963d1f9600f85ef556ea82e9ef25f043f7", size = 1193047 }, - { url = "https://files.pythonhosted.org/packages/ff/a3/36a1e23ff00c7a0cd696c5a28db05db25dc42bfc78c508bd78623ff62a4a/aiohttp-3.10.10-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:fbc6264158392bad9df19537e872d476f7c57adf718944cc1e4495cbabf38e2a", size = 1247201 }, - { url = "https://files.pythonhosted.org/packages/55/ae/95399848557b98bb2c402d640b2276ce3a542b94dba202de5a5a1fe29abe/aiohttp-3.10.10-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:e48d5021a84d341bcaf95c8460b152cfbad770d28e5fe14a768988c461b821bc", size = 1264102 }, - { url = "https://files.pythonhosted.org/packages/38/f5/02e5c72c1b60d7cceb30b982679a26167e84ac029fd35a93dd4da52c50a3/aiohttp-3.10.10-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2609e9ab08474702cc67b7702dbb8a80e392c54613ebe80db7e8dbdb79837c68", size = 1215760 }, - { url = "https://files.pythonhosted.org/packages/30/17/1463840bad10d02d0439068f37ce5af0b383884b0d5838f46fb027e233bf/aiohttp-3.10.10-cp310-cp310-win32.whl", hash = "sha256:84afcdea18eda514c25bc68b9af2a2b1adea7c08899175a51fe7c4fb6d551257", size = 362678 }, - { url = "https://files.pythonhosted.org/packages/dd/01/a0ef707d93e867a43abbffee3a2cdf30559910750b9176b891628c7ad074/aiohttp-3.10.10-cp310-cp310-win_amd64.whl", hash = "sha256:9c72109213eb9d3874f7ac8c0c5fa90e072d678e117d9061c06e30c85b4cf0e6", size = 381097 }, - { url = "https://files.pythonhosted.org/packages/72/31/3c351d17596194e5a38ef169a4da76458952b2497b4b54645b9d483cbbb0/aiohttp-3.10.10-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c30a0eafc89d28e7f959281b58198a9fa5e99405f716c0289b7892ca345fe45f", size = 586501 }, - { url = "https://files.pythonhosted.org/packages/a4/a8/a559d09eb08478cdead6b7ce05b0c4a133ba27fcdfa91e05d2e62867300d/aiohttp-3.10.10-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:258c5dd01afc10015866114e210fb7365f0d02d9d059c3c3415382ab633fcbcb", size = 398993 }, - { url = "https://files.pythonhosted.org/packages/c5/47/7736d4174613feef61d25332c3bd1a4f8ff5591fbd7331988238a7299485/aiohttp-3.10.10-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:15ecd889a709b0080f02721255b3f80bb261c2293d3c748151274dfea93ac871", size = 390647 }, - { url = "https://files.pythonhosted.org/packages/27/21/e9ba192a04b7160f5a8952c98a1de7cf8072ad150fa3abd454ead1ab1d7f/aiohttp-3.10.10-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3935f82f6f4a3820270842e90456ebad3af15810cf65932bd24da4463bc0a4c", size = 1306481 }, - { url = "https://files.pythonhosted.org/packages/cf/50/f364c01c8d0def1dc34747b2470969e216f5a37c7ece00fe558810f37013/aiohttp-3.10.10-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:413251f6fcf552a33c981c4709a6bba37b12710982fec8e558ae944bfb2abd38", size = 1344652 }, - { url = "https://files.pythonhosted.org/packages/1d/c2/74f608e984e9b585649e2e83883facad6fa3fc1d021de87b20cc67e8e5ae/aiohttp-3.10.10-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d1720b4f14c78a3089562b8875b53e36b51c97c51adc53325a69b79b4b48ebcb", size = 1378498 }, - { url = "https://files.pythonhosted.org/packages/9f/a7/05a48c7c0a7a80a5591b1203bf1b64ca2ed6a2050af918d09c05852dc42b/aiohttp-3.10.10-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:679abe5d3858b33c2cf74faec299fda60ea9de62916e8b67e625d65bf069a3b7", size = 1292718 }, - { url = "https://files.pythonhosted.org/packages/7d/78/a925655018747e9790350180330032e27d6e0d7ed30bde545fae42f8c49c/aiohttp-3.10.10-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:79019094f87c9fb44f8d769e41dbb664d6e8fcfd62f665ccce36762deaa0e911", size = 1251776 }, - { url = "https://files.pythonhosted.org/packages/47/9d/85c6b69f702351d1236594745a4fdc042fc43f494c247a98dac17e004026/aiohttp-3.10.10-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:fe2fb38c2ed905a2582948e2de560675e9dfbee94c6d5ccdb1301c6d0a5bf092", size = 1271716 }, - { url = "https://files.pythonhosted.org/packages/7f/a7/55fc805ff9b14af818903882ece08e2235b12b73b867b521b92994c52b14/aiohttp-3.10.10-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:a3f00003de6eba42d6e94fabb4125600d6e484846dbf90ea8e48a800430cc142", size = 1266263 }, - { url = "https://files.pythonhosted.org/packages/1f/ec/d2be2ca7b063e4f91519d550dbc9c1cb43040174a322470deed90b3d3333/aiohttp-3.10.10-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:1bbb122c557a16fafc10354b9d99ebf2f2808a660d78202f10ba9d50786384b9", size = 1321617 }, - { url = "https://files.pythonhosted.org/packages/c9/a3/b29f7920e1cd0a9a68a45dd3eb16140074d2efb1518d2e1f3e140357dc37/aiohttp-3.10.10-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:30ca7c3b94708a9d7ae76ff281b2f47d8eaf2579cd05971b5dc681db8caac6e1", size = 1339227 }, - { url = "https://files.pythonhosted.org/packages/8a/81/34b67235c47e232d807b4bbc42ba9b927c7ce9476872372fddcfd1e41b3d/aiohttp-3.10.10-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:df9270660711670e68803107d55c2b5949c2e0f2e4896da176e1ecfc068b974a", size = 1299068 }, - { url = "https://files.pythonhosted.org/packages/04/1f/26a7fe11b6ad3184f214733428353c89ae9fe3e4f605a657f5245c5e720c/aiohttp-3.10.10-cp311-cp311-win32.whl", hash = "sha256:aafc8ee9b742ce75044ae9a4d3e60e3d918d15a4c2e08a6c3c3e38fa59b92d94", size = 362223 }, - { url = "https://files.pythonhosted.org/packages/10/91/85dcd93f64011434359ce2666bece981f08d31bc49df33261e625b28595d/aiohttp-3.10.10-cp311-cp311-win_amd64.whl", hash = "sha256:362f641f9071e5f3ee6f8e7d37d5ed0d95aae656adf4ef578313ee585b585959", size = 381576 }, - { url = "https://files.pythonhosted.org/packages/ae/99/4c5aefe5ad06a1baf206aed6598c7cdcbc7c044c46801cd0d1ecb758cae3/aiohttp-3.10.10-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9294bbb581f92770e6ed5c19559e1e99255e4ca604a22c5c6397b2f9dd3ee42c", size = 583536 }, - { url = "https://files.pythonhosted.org/packages/a9/36/8b3bc49b49cb6d2da40ee61ff15dbcc44fd345a3e6ab5bb20844df929821/aiohttp-3.10.10-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a8fa23fe62c436ccf23ff930149c047f060c7126eae3ccea005f0483f27b2e28", size = 395693 }, - { url = "https://files.pythonhosted.org/packages/e1/77/0aa8660dcf11fa65d61712dbb458c4989de220a844bd69778dff25f2d50b/aiohttp-3.10.10-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5c6a5b8c7926ba5d8545c7dd22961a107526562da31a7a32fa2456baf040939f", size = 390898 }, - { url = "https://files.pythonhosted.org/packages/38/d2/b833d95deb48c75db85bf6646de0a697e7fb5d87bd27cbade4f9746b48b1/aiohttp-3.10.10-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:007ec22fbc573e5eb2fb7dec4198ef8f6bf2fe4ce20020798b2eb5d0abda6138", size = 1312060 }, - { url = "https://files.pythonhosted.org/packages/aa/5f/29fd5113165a0893de8efedf9b4737e0ba92dfcd791415a528f947d10299/aiohttp-3.10.10-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9627cc1a10c8c409b5822a92d57a77f383b554463d1884008e051c32ab1b3742", size = 1350553 }, - { url = "https://files.pythonhosted.org/packages/ad/cc/f835f74b7d344428469200105236d44606cfa448be1e7c95ca52880d9bac/aiohttp-3.10.10-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:50edbcad60d8f0e3eccc68da67f37268b5144ecc34d59f27a02f9611c1d4eec7", size = 1392646 }, - { url = "https://files.pythonhosted.org/packages/bf/fe/1332409d845ca601893bbf2d76935e0b93d41686e5f333841c7d7a4a770d/aiohttp-3.10.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a45d85cf20b5e0d0aa5a8dca27cce8eddef3292bc29d72dcad1641f4ed50aa16", size = 1306310 }, - { url = "https://files.pythonhosted.org/packages/e4/a1/25a7633a5a513278a9892e333501e2e69c83e50be4b57a62285fb7a008c3/aiohttp-3.10.10-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0b00807e2605f16e1e198f33a53ce3c4523114059b0c09c337209ae55e3823a8", size = 1260255 }, - { url = "https://files.pythonhosted.org/packages/f2/39/30eafe89e0e2a06c25e4762844c8214c0c0cd0fd9ffc3471694a7986f421/aiohttp-3.10.10-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f2d4324a98062be0525d16f768a03e0bbb3b9fe301ceee99611dc9a7953124e6", size = 1271141 }, - { url = "https://files.pythonhosted.org/packages/5b/fc/33125df728b48391ef1fcb512dfb02072158cc10d041414fb79803463020/aiohttp-3.10.10-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:438cd072f75bb6612f2aca29f8bd7cdf6e35e8f160bc312e49fbecab77c99e3a", size = 1280244 }, - { url = "https://files.pythonhosted.org/packages/3b/61/e42bf2c2934b5caa4e2ec0b5e5fd86989adb022b5ee60c2572a9d77cf6fe/aiohttp-3.10.10-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:baa42524a82f75303f714108fea528ccacf0386af429b69fff141ffef1c534f9", size = 1316805 }, - { url = "https://files.pythonhosted.org/packages/18/32/f52a5e2ae9ad3bba10e026a63a7a23abfa37c7d97aeeb9004eaa98df3ce3/aiohttp-3.10.10-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:a7d8d14fe962153fc681f6366bdec33d4356f98a3e3567782aac1b6e0e40109a", size = 1343930 }, - { url = "https://files.pythonhosted.org/packages/05/be/6a403b464dcab3631fe8e27b0f1d906d9e45c5e92aca97ee007e5a895560/aiohttp-3.10.10-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c1277cd707c465cd09572a774559a3cc7c7a28802eb3a2a9472588f062097205", size = 1306186 }, - { url = "https://files.pythonhosted.org/packages/8e/fd/bb50fe781068a736a02bf5c7ad5f3ab53e39f1d1e63110da6d30f7605edc/aiohttp-3.10.10-cp312-cp312-win32.whl", hash = "sha256:59bb3c54aa420521dc4ce3cc2c3fe2ad82adf7b09403fa1f48ae45c0cbde6628", size = 359289 }, - { url = "https://files.pythonhosted.org/packages/70/9e/5add7e240f77ef67c275c82cc1d08afbca57b77593118c1f6e920ae8ad3f/aiohttp-3.10.10-cp312-cp312-win_amd64.whl", hash = "sha256:0e1b370d8007c4ae31ee6db7f9a2fe801a42b146cec80a86766e7ad5c4a259cf", size = 379313 }, ] [[package]] @@ -3211,7 +3166,6 @@ requires-dist = [ { name = "pytest-asyncio", marker = "extra == 'test'" }, { name = "pytest-cov", marker = "extra == 'test'" }, { name = "redis-om", specifier = ">=0.3.0,<0.4.0" }, - { name = "rel", marker = "extra == 'chat'" }, { name = "rich", specifier = ">=13.6.0,<14.0.0" }, { name = "scipy", marker = "extra == 'examples'" }, { name = "together", specifier = ">=0.2.4,<1.4.0" },