From 501660dd736139b7ef539a510546e8add05efa9e Mon Sep 17 00:00:00 2001 From: Zhe Su <360307598@qq.com> Date: Thu, 12 Dec 2024 22:25:19 -0500 Subject: [PATCH] prototype for modal serving --- sotopia/ui/fastapi_server.py | 680 +++++++++++++++++---------------- sotopia/ui/modal_api_server.py | 116 ++++++ 2 files changed, 465 insertions(+), 331 deletions(-) create mode 100644 sotopia/ui/modal_api_server.py diff --git a/sotopia/ui/fastapi_server.py b/sotopia/ui/fastapi_server.py index 611878c4..fdc46ce0 100644 --- a/sotopia/ui/fastapi_server.py +++ b/sotopia/ui/fastapi_server.py @@ -51,15 +51,19 @@ logger = logging.getLogger(__name__) -app = FastAPI() +# app = FastAPI() -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) # TODO: Whether allowing CORS for all origins +# app.add_middleware( +# CORSMiddleware, +# allow_origins=["*"], +# allow_credentials=True, +# allow_methods=["*"], +# allow_headers=["*"], +# ) # TODO: Whether allowing CORS for all origins + +active_simulations: Dict[ + str, bool +] = {} # TODO check whether this is the correct way to store the active simulations class RelationshipWrapper(BaseModel): @@ -139,120 +143,114 @@ def validate_models(self) -> Self: return self -@app.get("/scenarios", response_model=list[EnvironmentProfile]) -async def get_scenarios_all() -> list[EnvironmentProfile]: - return EnvironmentProfile.all() - +class SimulationState: + _instance: Optional["SimulationState"] = None + _lock = asyncio.Lock() + _active_simulations: dict[str, bool] = {} -@app.get("/scenarios/{get_by}/{value}", response_model=list[EnvironmentProfile]) -async def get_scenarios( - get_by: Literal["id", "codename"], value: str -) -> list[EnvironmentProfile]: - # Implement logic to fetch scenarios based on the parameters - scenarios: list[EnvironmentProfile] = [] # Replace with actual fetching logic - if get_by == "id": - scenarios.append(EnvironmentProfile.get(pk=value)) - elif get_by == "codename": - json_models = EnvironmentProfile.find( - EnvironmentProfile.codename == value - ).all() - scenarios.extend(cast(list[EnvironmentProfile], json_models)) + def __new__(cls) -> "SimulationState": + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._active_simulations = {} + return cls._instance - if not scenarios: - raise HTTPException( - status_code=404, detail=f"No scenarios found with {get_by}={value}" - ) + async def try_acquire_token(self, token: str) -> tuple[bool, str]: + async with self._lock: + if not token: + return False, "Invalid token" - return scenarios + if self._active_simulations.get(token): + return False, "Token is active already" + self._active_simulations[token] = True + return True, "Token is valid" -@app.get("/agents", response_model=list[AgentProfile]) -async def get_agents_all() -> list[AgentProfile]: - return AgentProfile.all() + async def release_token(self, token: str) -> None: + async with self._lock: + self._active_simulations.pop(token, None) + @asynccontextmanager + async def start_simulation(self, token: str) -> AsyncIterator[bool]: + try: + yield True + finally: + await self.release_token(token) -@app.get("/agents/{get_by}/{value}", response_model=list[AgentProfile]) -async def get_agents( - get_by: Literal["id", "gender", "occupation"], value: str -) -> list[AgentProfile]: - agents_profiles: list[AgentProfile] = [] - if get_by == "id": - agents_profiles.append(AgentProfile.get(pk=value)) - elif get_by == "gender": - json_models = AgentProfile.find(AgentProfile.gender == value).all() - agents_profiles.extend(cast(list[AgentProfile], json_models)) - elif get_by == "occupation": - json_models = AgentProfile.find(AgentProfile.occupation == value).all() - agents_profiles.extend(cast(list[AgentProfile], json_models)) - if not agents_profiles: - raise HTTPException( - status_code=404, detail=f"No agents found with {get_by}={value}" - ) +class SimulationManager: + def __init__(self) -> None: + self.state = SimulationState() - return agents_profiles + async def verify_token(self, token: str) -> dict[str, Any]: + is_valid, msg = await self.state.try_acquire_token(token) + return {"is_valid": is_valid, "msg": msg} + async def create_simulator( + self, env_id: str, agent_ids: list[str] + ) -> WebSocketSotopiaSimulator: + try: + return WebSocketSotopiaSimulator(env_id=env_id, agent_ids=agent_ids) + except Exception as e: + error_msg = f"Failed to create simulator: {e}" + logger.error(error_msg) + raise Exception(error_msg) -@app.get("/relationship/{agent_1_id}/{agent_2_id}", response_model=str) -async def get_relationship(agent_1_id: str, agent_2_id: str) -> str: - relationship_profiles = RelationshipProfile.find( - (RelationshipProfile.agent_1_id == agent_1_id) - & (RelationshipProfile.agent_2_id == agent_2_id) - ).all() - assert ( - len(relationship_profiles) == 1 - ), f"{len(relationship_profiles)} relationship profiles found for agents {agent_1_id} and {agent_2_id}, expected 1" - relationship_profile = relationship_profiles[0] - assert isinstance(relationship_profile, RelationshipProfile) - return f"{str(relationship_profile.relationship)}: {RelationshipType(relationship_profile.relationship).name}" + async def handle_client_message( + self, + websocket: WebSocket, + simulator: WebSocketSotopiaSimulator, + message: dict[str, Any], + timeout: float = 0.1, + ) -> bool: + try: + msg_type = message.get("type") + if msg_type == WSMessageType.FINISH_SIM.value: + return True + # TODO handle other message types + return False + except Exception as e: + msg = f"Error handling client message: {e}" + logger.error(msg) + await self.send_error(websocket, ErrorType.INVALID_MESSAGE, msg) + return False + async def run_simulation( + self, websocket: WebSocket, simulator: WebSocketSotopiaSimulator + ) -> None: + try: + async for message in simulator.arun(): + await self.send_message(websocket, WSMessageType.SERVER_MSG, message) -@app.get("/episodes", response_model=list[EpisodeLog]) -async def get_episodes_all() -> list[EpisodeLog]: - return EpisodeLog.all() + try: + data = await asyncio.wait_for(websocket.receive_json(), timeout=0.1) + if await self.handle_client_message(websocket, simulator, data): + break + except asyncio.TimeoutError: + continue + except Exception as e: + msg = f"Error running simulation: {e}" + logger.error(msg) + await self.send_error(websocket, ErrorType.SIMULATION_ISSUE, msg) + finally: + await self.send_message(websocket, WSMessageType.END_SIM, {}) -@app.get("/episodes/{get_by}/{value}", response_model=list[EpisodeLog]) -async def get_episodes(get_by: Literal["id", "tag"], value: str) -> list[EpisodeLog]: - episodes: list[EpisodeLog] = [] - if get_by == "id": - episodes.append(EpisodeLog.get(pk=value)) - elif get_by == "tag": - json_models = EpisodeLog.find(EpisodeLog.tag == value).all() - episodes.extend(cast(list[EpisodeLog], json_models)) + @staticmethod + async def send_message( + websocket: WebSocket, msg_type: WSMessageType, data: dict[str, Any] + ) -> None: + await websocket.send_json({"type": msg_type.value, "data": data}) - if not episodes: - raise HTTPException( - status_code=404, detail=f"No episodes found with {get_by}={value}" + @staticmethod + async def send_error( + websocket: WebSocket, error_type: ErrorType, details: str = "" + ) -> None: + await websocket.send_json( + { + "type": WSMessageType.ERROR.value, + "data": {"type": error_type.value, "details": details}, + } ) - return episodes - - -@app.post("/scenarios/", response_model=str) -async def create_scenario(scenario: EnvironmentProfileWrapper) -> str: - scenario_profile = EnvironmentProfile(**scenario.model_dump()) - scenario_profile.save() - pk = scenario_profile.pk - assert pk is not None - return pk - - -@app.post("/agents/", response_model=str) -async def create_agent(agent: AgentProfileWrapper) -> str: - agent_profile = AgentProfile(**agent.model_dump()) - agent_profile.save() - pk = agent_profile.pk - assert pk is not None - return pk - - -@app.post("/relationship/", response_model=str) -async def create_relationship(relationship: RelationshipWrapper) -> str: - relationship_profile = RelationshipProfile(**relationship.model_dump()) - relationship_profile.save() - pk = relationship_profile.pk - assert pk is not None - return pk async def run_simulation( @@ -325,263 +323,283 @@ async def run_simulation( ) -@app.post("/simulate/", response_model=str) -def simulate(simulation_request: SimulationRequest) -> Response: - try: - _: EnvironmentProfile = EnvironmentProfile.get(pk=simulation_request.env_id) - except Exception: # TODO Check the exception type - raise HTTPException( - status_code=404, - detail=f"Environment with id={simulation_request.env_id} not found", - ) - try: - __ = AgentProfile.get(pk=simulation_request.agent_ids[0]) - except Exception: # TODO Check the exception type - raise HTTPException( - status_code=404, - detail=f"Agent with id={simulation_request.agent_ids[0]} not found", - ) - try: - ___ = AgentProfile.get(pk=simulation_request.agent_ids[1]) - except Exception: # TODO Check the exception type +async def get_scenarios_all() -> list[EnvironmentProfile]: + return EnvironmentProfile.all() + + +async def get_scenarios( + get_by: Literal["id", "codename"], value: str +) -> list[EnvironmentProfile]: + # Implement logic to fetch scenarios based on the parameters + scenarios: list[EnvironmentProfile] = [] # Replace with actual fetching logic + if get_by == "id": + scenarios.append(EnvironmentProfile.get(pk=value)) + elif get_by == "codename": + json_models = EnvironmentProfile.find( + EnvironmentProfile.codename == value + ).all() + scenarios.extend(cast(list[EnvironmentProfile], json_models)) + + if not scenarios: raise HTTPException( - status_code=404, - detail=f"Agent with id={simulation_request.agent_ids[1]} not found", + status_code=404, detail=f"No scenarios found with {get_by}={value}" ) - episode_pk = EpisodeLog( - environment="", - agents=[], - models=[], - messages=[], - reasoning="", - rewards=[], # Pseudorewards - rewards_prompt="", - ).pk - try: - simulation_status = NonStreamingSimulationStatus( - episode_pk=episode_pk, - status="Started", - ) - simulation_status.save() - queue = rq.Queue("default", connection=get_redis_connection()) - queue.enqueue( - run_simulation, - episode_pk=episode_pk, - simulation_request=simulation_request, - simulation_status=simulation_status, - ) + return scenarios - except Exception as e: - logger.error(f"Error starting simulation: {e}") - simulation_status.status = "Error" - simulation_status.save() - return Response(content=episode_pk, status_code=202) +async def get_agents_all() -> list[AgentProfile]: + return AgentProfile.all() -@app.get("/simulation_status/{episode_pk}", response_model=str) -async def get_simulation_status(episode_pk: str) -> str: - status = NonStreamingSimulationStatus.find( - NonStreamingSimulationStatus.episode_pk == episode_pk - ).all()[0] - assert isinstance(status, NonStreamingSimulationStatus) - return status.status +async def get_agents( + get_by: Literal["id", "gender", "occupation"], value: str +) -> list[AgentProfile]: + agents_profiles: list[AgentProfile] = [] + if get_by == "id": + agents_profiles.append(AgentProfile.get(pk=value)) + elif get_by == "gender": + json_models = AgentProfile.find(AgentProfile.gender == value).all() + agents_profiles.extend(cast(list[AgentProfile], json_models)) + elif get_by == "occupation": + json_models = AgentProfile.find(AgentProfile.occupation == value).all() + agents_profiles.extend(cast(list[AgentProfile], json_models)) -@app.delete("/agents/{agent_id}", response_model=str) -async def delete_agent(agent_id: str) -> str: - try: - agent = AgentProfile.get(pk=agent_id) - except Exception: # TODO Check the exception type + if not agents_profiles: raise HTTPException( - status_code=404, detail=f"Agent with id={agent_id} not found" + status_code=404, detail=f"No agents found with {get_by}={value}" ) - AgentProfile.delete(agent.pk) - assert agent.pk is not None - return agent.pk + return agents_profiles -@app.delete("/scenarios/{scenario_id}", response_model=str) -async def delete_scenario(scenario_id: str) -> str: - try: - scenario = EnvironmentProfile.get(pk=scenario_id) - except Exception: # TODO Check the exception type - raise HTTPException( - status_code=404, detail=f"Scenario with id={scenario_id} not found" - ) - EnvironmentProfile.delete(scenario.pk) - assert scenario.pk is not None - return scenario.pk +async def get_relationship(agent_1_id: str, agent_2_id: str) -> str: + relationship_profiles = RelationshipProfile.find( + (RelationshipProfile.agent_1_id == agent_1_id) + & (RelationshipProfile.agent_2_id == agent_2_id) + ).all() + assert ( + len(relationship_profiles) == 1 + ), f"{len(relationship_profiles)} relationship profiles found for agents {agent_1_id} and {agent_2_id}, expected 1" + relationship_profile = relationship_profiles[0] + assert isinstance(relationship_profile, RelationshipProfile) + return f"{str(relationship_profile.relationship)}: {RelationshipType(relationship_profile.relationship).name}" -@app.delete("/relationship/{relationship_id}", response_model=str) -async def delete_relationship(relationship_id: str) -> str: - RelationshipProfile.delete(relationship_id) - return relationship_id +async def get_episodes_all() -> list[EpisodeLog]: + return EpisodeLog.all() -@app.delete("/episodes/{episode_id}", response_model=str) -async def delete_episode(episode_id: str) -> str: - EpisodeLog.delete(episode_id) - return episode_id +async def get_episodes(get_by: Literal["id", "tag"], value: str) -> list[EpisodeLog]: + episodes: list[EpisodeLog] = [] + if get_by == "id": + episodes.append(EpisodeLog.get(pk=value)) + elif get_by == "tag": + json_models = EpisodeLog.find(EpisodeLog.tag == value).all() + episodes.extend(cast(list[EpisodeLog], json_models)) -active_simulations: Dict[ - str, bool -] = {} # TODO check whether this is the correct way to store the active simulations + if not episodes: + raise HTTPException( + status_code=404, detail=f"No episodes found with {get_by}={value}" + ) + return episodes -@app.get("/models", response_model=list[str]) async def get_models() -> list[str]: # TODO figure out how to get the available models return ["gpt-4o-mini", "gpt-4o", "gpt-3.5-turbo"] -class SimulationState: - _instance: Optional["SimulationState"] = None - _lock = asyncio.Lock() - _active_simulations: dict[str, bool] = {} - - def __new__(cls) -> "SimulationState": - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._active_simulations = {} - return cls._instance - - async def try_acquire_token(self, token: str) -> tuple[bool, str]: - async with self._lock: - if not token: - return False, "Invalid token" - - if self._active_simulations.get(token): - return False, "Token is active already" - - self._active_simulations[token] = True - return True, "Token is valid" - - async def release_token(self, token: str) -> None: - async with self._lock: - self._active_simulations.pop(token, None) - - @asynccontextmanager - async def start_simulation(self, token: str) -> AsyncIterator[bool]: - try: - yield True - finally: - await self.release_token(token) - - -class SimulationManager: - def __init__(self) -> None: - self.state = SimulationState() - - async def verify_token(self, token: str) -> dict[str, Any]: - is_valid, msg = await self.state.try_acquire_token(token) - return {"is_valid": is_valid, "msg": msg} - - async def create_simulator( - self, env_id: str, agent_ids: list[str] - ) -> WebSocketSotopiaSimulator: - try: - return WebSocketSotopiaSimulator(env_id=env_id, agent_ids=agent_ids) - except Exception as e: - error_msg = f"Failed to create simulator: {e}" - logger.error(error_msg) - raise Exception(error_msg) - - async def handle_client_message( - self, - websocket: WebSocket, - simulator: WebSocketSotopiaSimulator, - message: dict[str, Any], - timeout: float = 0.1, - ) -> bool: - try: - msg_type = message.get("type") - if msg_type == WSMessageType.FINISH_SIM.value: - return True - # TODO handle other message types - return False - except Exception as e: - msg = f"Error handling client message: {e}" - logger.error(msg) - await self.send_error(websocket, ErrorType.INVALID_MESSAGE, msg) - return False - - async def run_simulation( - self, websocket: WebSocket, simulator: WebSocketSotopiaSimulator - ) -> None: - try: - async for message in simulator.arun(): - await self.send_message(websocket, WSMessageType.SERVER_MSG, message) - - try: - data = await asyncio.wait_for(websocket.receive_json(), timeout=0.1) - if await self.handle_client_message(websocket, simulator, data): - break - except asyncio.TimeoutError: - continue - - except Exception as e: - msg = f"Error running simulation: {e}" - logger.error(msg) - await self.send_error(websocket, ErrorType.SIMULATION_ISSUE, msg) - finally: - await self.send_message(websocket, WSMessageType.END_SIM, {}) - - @staticmethod - async def send_message( - websocket: WebSocket, msg_type: WSMessageType, data: dict[str, Any] - ) -> None: - await websocket.send_json({"type": msg_type.value, "data": data}) - - @staticmethod - async def send_error( - websocket: WebSocket, error_type: ErrorType, details: str = "" - ) -> None: - await websocket.send_json( - { - "type": WSMessageType.ERROR.value, - "data": {"type": error_type.value, "details": details}, - } +class SotopiaFastAPI(FastAPI): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], ) + self.setup_routes() + def setup_routes(self): + self.get("/scenarios", response_model=list[EnvironmentProfile])( + get_scenarios_all + ) + self.get( + "/scenarios/{get_by}/{value}", response_model=list[EnvironmentProfile] + )(get_scenarios) + self.get("/agents", response_model=list[AgentProfile])(get_agents_all) + self.get("/agents/{get_by}/{value}", response_model=list[AgentProfile])( + get_agents + ) + self.get("/relationship/{agent_1_id}/{agent_2_id}", response_model=str)( + get_relationship + ) + self.get("/episodes", response_model=list[EpisodeLog])(get_episodes_all) + self.get("/episodes/{get_by}/{value}", response_model=list[EpisodeLog])( + get_episodes + ) + self.get("/models", response_model=list[str])(get_models) + + @self.post("/scenarios/", response_model=str) + async def create_scenario(scenario: EnvironmentProfileWrapper) -> str: + scenario_profile = EnvironmentProfile(**scenario.model_dump()) + scenario_profile.save() + pk = scenario_profile.pk + assert pk is not None + return pk + + @self.post("/agents/", response_model=str) + async def create_agent(agent: AgentProfileWrapper) -> str: + agent_profile = AgentProfile(**agent.model_dump()) + agent_profile.save() + pk = agent_profile.pk + assert pk is not None + return pk + + @self.post("/relationship/", response_model=str) + async def create_relationship(relationship: RelationshipWrapper) -> str: + relationship_profile = RelationshipProfile(**relationship.model_dump()) + relationship_profile.save() + pk = relationship_profile.pk + assert pk is not None + return pk + + @self.post("/simulate/", response_model=str) + def simulate(simulation_request: SimulationRequest) -> Response: + try: + _: EnvironmentProfile = EnvironmentProfile.get( + pk=simulation_request.env_id + ) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, + detail=f"Environment with id={simulation_request.env_id} not found", + ) + try: + __ = AgentProfile.get(pk=simulation_request.agent_ids[0]) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, + detail=f"Agent with id={simulation_request.agent_ids[0]} not found", + ) + try: + ___ = AgentProfile.get(pk=simulation_request.agent_ids[1]) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, + detail=f"Agent with id={simulation_request.agent_ids[1]} not found", + ) -@app.websocket("/ws/simulation") -async def websocket_endpoint(websocket: WebSocket, token: str) -> None: - manager = SimulationManager() - - token_status = await manager.verify_token(token) - if not token_status["is_valid"]: - await websocket.close(code=1008, reason=token_status["msg"]) - return - - try: - await websocket.accept() - - while True: - start_msg = await websocket.receive_json() - if start_msg.get("type") != WSMessageType.START_SIM.value: - continue + episode_pk = EpisodeLog( + environment="", + agents=[], + models=[], + messages=[], + reasoning="", + rewards=[], # Pseudorewards + rewards_prompt="", + ).pk + try: + simulation_status = NonStreamingSimulationStatus( + episode_pk=episode_pk, + status="Started", + ) + simulation_status.save() + queue = rq.Queue("default", connection=get_redis_connection()) + queue.enqueue( + run_simulation, + episode_pk=episode_pk, + simulation_request=simulation_request, + simulation_status=simulation_status, + ) - async with manager.state.start_simulation(token): - simulator = await manager.create_simulator( - env_id=start_msg["data"]["env_id"], - agent_ids=start_msg["data"]["agent_ids"], + except Exception as e: + logger.error(f"Error starting simulation: {e}") + simulation_status.status = "Error" + simulation_status.save() + return Response(content=episode_pk, status_code=202) + + @self.get("/simulation_status/{episode_pk}", response_model=str) + async def get_simulation_status(episode_pk: str) -> str: + status = NonStreamingSimulationStatus.find( + NonStreamingSimulationStatus.episode_pk == episode_pk + ).all()[0] + assert isinstance(status, NonStreamingSimulationStatus) + return status.status + + @self.delete("/agents/{agent_id}", response_model=str) + async def delete_agent(agent_id: str) -> str: + try: + agent = AgentProfile.get(pk=agent_id) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, detail=f"Agent with id={agent_id} not found" ) - await manager.run_simulation(websocket, simulator) - - except WebSocketDisconnect: - logger.info(f"Client disconnected: {token}") - except Exception as e: - logger.error(f"Unexpected error: {e}") - await manager.send_error(websocket, ErrorType.SIMULATION_ISSUE, str(e)) - finally: - try: - await websocket.close() - except Exception as e: - logger.error(f"Error closing websocket: {e}") + AgentProfile.delete(agent.pk) + assert agent.pk is not None + return agent.pk + + @self.delete("/scenarios/{scenario_id}", response_model=str) + async def delete_scenario(scenario_id: str) -> str: + try: + scenario = EnvironmentProfile.get(pk=scenario_id) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, detail=f"Scenario with id={scenario_id} not found" + ) + EnvironmentProfile.delete(scenario.pk) + assert scenario.pk is not None + return scenario.pk + + @self.delete("/relationship/{relationship_id}", response_model=str) + async def delete_relationship(relationship_id: str) -> str: + RelationshipProfile.delete(relationship_id) + return relationship_id + + @self.delete("/episodes/{episode_id}", response_model=str) + async def delete_episode(episode_id: str) -> str: + EpisodeLog.delete(episode_id) + return episode_id + + @self.websocket("/ws/simulation") + async def websocket_endpoint(websocket: WebSocket, token: str) -> None: + manager = SimulationManager() + + token_status = await manager.verify_token(token) + if not token_status["is_valid"]: + await websocket.close(code=1008, reason=token_status["msg"]) + return + + try: + await websocket.accept() + + while True: + start_msg = await websocket.receive_json() + if start_msg.get("type") != WSMessageType.START_SIM.value: + continue + + async with manager.state.start_simulation(token): + simulator = await manager.create_simulator( + env_id=start_msg["data"]["env_id"], + agent_ids=start_msg["data"]["agent_ids"], + ) + await manager.run_simulation(websocket, simulator) + + except WebSocketDisconnect: + logger.info(f"Client disconnected: {token}") + except Exception as e: + logger.error(f"Unexpected error: {e}") + await manager.send_error(websocket, ErrorType.SIMULATION_ISSUE, str(e)) + finally: + try: + await websocket.close() + except Exception as e: + logger.error(f"Error closing websocket: {e}") + +app = SotopiaFastAPI() if __name__ == "__main__": uvicorn.run(app, host="127.0.0.1", port=8800) diff --git a/sotopia/ui/modal_api_server.py b/sotopia/ui/modal_api_server.py new file mode 100644 index 00000000..203260cc --- /dev/null +++ b/sotopia/ui/modal_api_server.py @@ -0,0 +1,116 @@ +import modal +import subprocess +import time +import os + +import redis +from sotopia.ui.fastapi_server import SotopiaFastAPI + +# Create persistent volume for Redis data +redis_volume = modal.Volume.from_name("sotopia-api", create_if_missing=True) + + +def initialize_redis_data(): + """Download Redis data if it doesn't exist""" + if not os.path.exists("/vol/redis/dump.rdb"): + os.makedirs("/vol/redis", exist_ok=True) + print("Downloading initial Redis data...") + subprocess.run( + "curl -L https://cmu.box.com/shared/static/xiivc5z8rnmi1zr6vmk1ohxslylvynur --output /vol/redis/dump.rdb", + shell=True, + check=True, + ) + print("Redis data downloaded") + + +# Create image with all necessary dependencies +image = ( + modal.Image.debian_slim(python_version="3.11") + .apt_install( + "git", + "curl", + "gpg", + "lsb-release", + "wget", + "procps", # for ps command + "redis-tools", # for redis-cli + ) + .run_commands( + # Update and install basic dependencies + "apt-get update", + "apt-get install -y curl gpg lsb-release", + # Add Redis Stack repository + "curl -fsSL https://packages.redis.io/gpg | gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg", + "chmod 644 /usr/share/keyrings/redis-archive-keyring.gpg", + 'echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | tee /etc/apt/sources.list.d/redis.list', + "apt-get update", + "apt-get install -y redis-stack-server", + ) + .pip_install( + "pydantic>=2.5.0,<3.0.0", + "aiohttp>=3.9.3,<4.0.0", + "rich>=13.8.1,<14.0.0", + "typer>=0.12.5", + "aiostream>=0.5.2", + "fastapi[all]", + "uvicorn", + "redis>=5.0.0", + "rq", + "lxml>=4.9.3,<6.0.0", + "openai>=1.11.0,<2.0.0", + "langchain>=0.2.5,<0.4.0", + "PettingZoo==1.24.3", + "redis-om>=0.3.0,<0.4.0", + "gin-config>=0.5.0,<0.6.0", + "absl-py>=2.0.0,<3.0.0", + "together>=0.2.4,<1.4.0", + "beartype>=0.14.0,<0.20.0", + "langchain-openai>=0.1.8,<0.2", + "hiredis>=3.0.0", + "aact", + "gin", + ) +) +redis_volume = modal.Volume.from_name("sotopia-api", create_if_missing=True) + +# Create stub for the application +app = modal.App("sotopia-fastapi", image=image, volumes={"/vol/redis": redis_volume}) + + +@app.cls(image=image) +class WebAPI: + def __init__(self): + self.web_app = SotopiaFastAPI() + + @modal.enter() + def setup(self): + # Start Redis server + subprocess.Popen( + ["redis-stack-server", "--dir", "/vol/redis", "--port", "6379"] + ) + + # Wait for Redis to be ready + max_retries = 30 + for _ in range(max_retries): + try: + initialize_redis_data() + # Attempt to create Redis client and ping the server + temp_client = redis.Redis(host="localhost", port=6379, db=0) + temp_client.ping() + self.redis_client = temp_client + print("Successfully connected to Redis") + return + except (redis.exceptions.ConnectionError, redis.exceptions.ResponseError): + print("Waiting for Redis to be ready...") + time.sleep(1) + + raise Exception("Could not connect to Redis after multiple attempts") + + @modal.exit() + def cleanup(self): + if hasattr(self, "redis_client"): + self.redis_client.close() + + @modal.asgi_app() + def serve(self): + return self.web_app