diff --git a/docs/pages/concepts/evaluation_dimension.md b/docs/pages/concepts/evaluation_dimension.md new file mode 100644 index 000000000..f86b7a894 --- /dev/null +++ b/docs/pages/concepts/evaluation_dimension.md @@ -0,0 +1,116 @@ +## Overview + +Evaluation dimensions are used to evaluate the quality of social interactions. +In original Sotopia paper, there are 7 dimensions to evaluate the quality of social interactions, where we named them as `sotopia` evaluation dimensions: +- believability +- relationship +- knowledge +- secret +- social rules +- financial and material benefits +- goal + +The `SotopiaDimensions` can be used directly without initializing the database. It provides a set of predefined evaluation dimensions that are ready to use for evaluating social interactions. For example, + +```python +from sotopia.envs.parallel import ParallelSotopiaEnv +from sotopia.envs.evaluators import EvaluationForTwoAgents, ReachGoalLLMEvaluator, RuleBasedTerminatedEvaluator, SotopiaDimensions + +env = ParallelSotopiaEnv( + env_profile=env_profile, + model_name=model_names["env"], + action_order="round-robin", + evaluators=[ + RuleBasedTerminatedEvaluator(max_turn_number=20, max_stale_turn=2), + ], + terminal_evaluators=[ + ReachGoalLLMEvaluator( + model_names["env"], + EvaluationForTwoAgents[SotopiaDimensions], # type: ignore + # TODO check how to do type annotation + ), + ], + ) +``` + + +However we observe under many use cases people may want to evaluate with customized evaluation metrics, so we provide a way to build custom evaluation dimensions. +For a quick reference, you can directly check out the `examples/use_custom_dimensions.py`. + +### CustomEvaluationDimension +The [`CustomEvaluationDimension`](/python_API/database/evaluation_dimensions) is a class that can be used to create a custom evaluation dimension. +There are four parameters: +- name: the name of the dimension +- description: the description of the dimension +- range_low: the minimum score of the dimension (should be an integer) +- range_high: the maximum score of the dimension (should be an integer) + +### CustomEvaluationDimensionList +The [`CustomEvaluationDimensionList`](/python_API/database/evaluation_dimensions) is a class that can be used to create a custom evaluation dimension list based on the existing dimensions. It helps one to group multiple dimensions together for a specific use case. +There are two parameters: +- name: the name of the dimension list +- dimension_pks: the primary keys of the dimensions in the dimension list + +### EvaluationDimensionBuilder +The [`EvaluationDimensionBuilder`](/python_API/database/evaluation_dimensions) is a class that can be used to generate a custom evaluation dimension model based on the existing dimensions. + + +## Usage +### Initialize the database +The default evaluation metric is still `SotopiaDimensions` in `sotopia.env.evaluators`.There is no `CustomEvaluationDimension` in the database by default. To initialize the database, please refer to `examples/use_custom_dimensions.py`. + + +### Use the custom evaluation dimensions +After you initialize your customized evaluation dimensions, you can choose to use any one of these methods provided below: + +#### Method 1: Choose dimensions by names +```python +evaluation_dimensions = ( + EvaluationDimensionBuilder.select_existing_dimension_model_by_name( + ["transactivity", "verbal_equity"] + ) +) +``` + +#### Method 2: Directly choose the grouped evaluation dimension list +```python +evaluation_dimensions = ( + EvaluationDimensionBuilder.select_existing_dimension_model_by_list_name( + "sotopia" + ) +) +``` + +#### Method 3: Build a custom evaluation dimension model temporarily +We provide multiple ways to build a custom evaluation dimension model with `EvaluationDimensionBuilder`, specifically: +- `generate_dimension_model`: build an evaluation dimension from existing dimension primary keys. +- `generate_dimension_model_from_dict`: build an evaluation dimension from a dictionary that specifies the parameters of the `CustomEvaluationDimension`. For example +```json +[ + { + "name": "believability", + "description": "The believability of the interaction", + "range_low": 0, + "range_high": 10 + }, + ... +] +``` +- `select_existing_dimension_model_by_name`: build an evaluation dimension from existing dimension names. For example `['believability', 'goal']` +- `select_existing_dimension_model_by_list_name`: build an evaluation dimension from existing `CustomEvaluationDimensionList` list names. For example, directly use `sotopia`. + + +After you get the evaluation dimension model, you can pass it as a parameter for the `Evaluator`, for example, +```python +evaluation_dimensions = ( + EvaluationDimensionBuilder.select_existing_dimension_model_by_list_name( + "sotopia" + ) +) +terminal_evaluators=[ + ReachGoalLLMEvaluator( + model_names["env"], + EvaluationForTwoAgents[evaluation_dimensions], # type: ignore + ), +], +``` diff --git a/docs/pages/python_API/database/evaluation_dimensions.md b/docs/pages/python_API/database/evaluation_dimensions.md new file mode 100644 index 000000000..4a826a555 --- /dev/null +++ b/docs/pages/python_API/database/evaluation_dimensions.md @@ -0,0 +1,54 @@ +# `evaluation_dimensions.py` + +This module provides classes and utilities for defining and managing custom evaluation dimensions within the Sotopia environment. It includes classes for individual dimensions, lists of dimensions, and a builder for creating dimension models. + +## Classes + +### `CustomEvaluationDimension` + +Represents a custom evaluation dimension with specific attributes such as name, description, and score range. + +#### Attributes +- `name`: `str`. The name of the dimension. +- `description`: `str`. A brief description of the dimension. +- `range_low`: `int`. The minimum score for the dimension. +- `range_high`: `int`. The maximum score for the dimension. + +### `CustomEvaluationDimensionList` + +Groups multiple custom evaluation dimensions together. + +#### Attributes +- `name`: `str`. The name of the dimension list. +- `dimension_pks`: `list[str]`. A list of primary keys for the dimensions included in the list. + +### `EvaluationDimensionBuilder` + +Provides utility methods to create and manage evaluation dimension models. + +#### Methods +- `create_range_validator(low: int, high: int)`: Creates a validator for score ranges. + + **Arguments:** + - `low`: `int`. The minimum score allowed. + - `high`: `int`. The maximum score allowed. + +- `build_dimension_model(dimension_ids: list[str])`: Builds a dimension model from primary keys. + + **Arguments:** + - `dimension_ids`: `list[str]`. A list of dimension primary keys. + +- `build_dimension_model_from_dict(dimensions: list[dict[str, Union[str, int]]])`: Builds a dimension model from a dictionary. + + **Arguments:** + - `dimensions`: `list[dict[str, Union[str, int]]]`. A list of dictionaries specifying dimension attributes. + +- `select_existing_dimension_model_by_name(dimension_names: list[str])`: Selects a dimension model by dimension names. + + **Arguments:** + - `dimension_names`: `list[str]`. A list of dimension names. + +- `select_existing_dimension_model_by_list_name(list_name: str)`: Selects a dimension model by list name. + + **Arguments:** + - `list_name`: `str`. The name of the dimension list. diff --git a/examples/experiment_eval.py b/examples/experiment_eval.py index ee0df3f11..82fe4bbd0 100644 --- a/examples/experiment_eval.py +++ b/examples/experiment_eval.py @@ -17,6 +17,7 @@ EnvAgentComboStorage, EnvironmentProfile, EpisodeLog, + EvaluationDimensionBuilder, ) from sotopia.envs.evaluators import ( EvaluationForTwoAgents, @@ -34,6 +35,7 @@ ) from sotopia.server import run_async_server from sotopia_conf.gin_utils import parse_gin_flags, run +# from sotopia.database import EvaluationDimensionBuilder _DEFAULT_GIN_SEARCH_PATHS = [ os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -109,6 +111,18 @@ def _iterate_env_agent_combo_not_in_db( tag: str | None = None, ) -> Generator[EnvAgentCombo[Observation, AgentAction], None, None]: """We iterate over each environment and return the **first** env-agent combo that is not in the database.""" + # loading evaluation metric + try: + evaluation_dimensions = EvaluationDimensionBuilder.select_existing_dimension_model_by_list_name( + "sotopia" + ) # Initialize your customized dimension, please refer to `examples/use_custom_dimensions.py` + except Exception as e: + print( + "No customized evaluation dimensions found, using default SotopiaDimensions", + e, + ) + evaluation_dimensions = SotopiaDimensions + if not env_ids: env_ids = list(EnvironmentProfile.all_pks()) for env_id in env_ids: @@ -152,7 +166,8 @@ def _iterate_env_agent_combo_not_in_db( terminal_evaluators=[ ReachGoalLLMEvaluator( model_names["env"], - EvaluationForTwoAgents[SotopiaDimensions], + EvaluationForTwoAgents[evaluation_dimensions], # type: ignore + # TODO check how to do type annotation ), ], ) diff --git a/examples/experimental/nodes/initial_message_node.py b/examples/experimental/nodes/initial_message_node.py index 9cb7f63ca..9ff4c3bdf 100644 --- a/examples/experimental/nodes/initial_message_node.py +++ b/examples/experimental/nodes/initial_message_node.py @@ -18,6 +18,7 @@ def __init__( input_tick_channel: str, output_channels: list[str], env_scenario: str, + node_name: str, redis_url: str = "redis://localhost:6379/0", ): super().__init__( @@ -26,6 +27,7 @@ def __init__( (output_channel, Text) for output_channel in output_channels ], redis_url=redis_url, + node_name=node_name, ) self.env_scenario = env_scenario self.output_channels = output_channels diff --git a/examples/experimental/sotopia_original_replica/llm_agent_sotopia.py b/examples/experimental/sotopia_original_replica/llm_agent_sotopia.py new file mode 100644 index 000000000..abe959294 --- /dev/null +++ b/examples/experimental/sotopia_original_replica/llm_agent_sotopia.py @@ -0,0 +1,113 @@ +import logging +import sys +from rich.logging import RichHandler + +from aact import NodeFactory + +from sotopia.experimental.agents.base_agent import BaseAgent +from sotopia.experimental.agents.datamodels import Observation, AgentAction + +from sotopia.generation_utils import agenerate +from sotopia.generation_utils.generate import StrOutputParser + +# Check Python version +if sys.version_info >= (3, 11): + pass +else: + pass + +# Configure logging +FORMAT = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" +logging.basicConfig( + level=logging.WARNING, + format=FORMAT, + datefmt="[%X]", + handlers=[RichHandler()], +) + + +@NodeFactory.register("llm_agent") +class LLMAgent(BaseAgent[Observation, AgentAction]): + def __init__( + self, + input_channels: list[str], + output_channel: str, + query_interval: int, + agent_name: str, + node_name: str, + goal: str, + model_name: str, + redis_url: str, + ): + super().__init__( + [(input_channel, Observation) for input_channel in input_channels], + [(output_channel, AgentAction)], + redis_url, + node_name, + ) + self.output_channel = output_channel + self.query_interval = query_interval + self.count_ticks = 0 + self.message_history: list[Observation] = [] + self.name = agent_name + self.model_name = model_name + self.goal = goal + + def _format_message_history(self, message_history: list[Observation]) -> str: + ## TODO: akhatua Fix the mapping of action to be gramatically correct + return "\n".join(message.to_natural_language() for message in message_history) + + async def aact(self, obs: Observation) -> AgentAction: + if obs.turn_number == -1: + return AgentAction( + agent_name=self.name, + output_channel=self.output_channel, + action_type="none", + argument=self.model_name, + ) + + self.message_history.append(obs) + + if len(obs.available_actions) == 1 and "none" in obs.available_actions: + return AgentAction( + agent_name=self.name, + output_channel=self.output_channel, + action_type="none", + argument="", + ) + elif len(obs.available_actions) == 1 and "leave" in obs.available_actions: + self.shutdown_event.set() + return AgentAction( + agent_name=self.name, + output_channel=self.output_channel, + action_type="leave", + argument="", + ) + else: + history = self._format_message_history(self.message_history) + action: str = await agenerate( + model_name=self.model_name, + template="Imagine that you are a friend of the other persons. Here is the " + "conversation between you and them.\n" + "You are {agent_name} in the conversation.\n" + "{message_history}\n" + "and you plan to {goal}.\n" + "You can choose to interrupt the other person " + "by saying something or not to interrupt by outputting notiong. What would you say? " + "Please only output a sentence or not outputting anything." + "{format_instructions}", + input_values={ + "message_history": history, + "goal": self.goal, + "agent_name": self.name, + }, + temperature=0.7, + output_parser=StrOutputParser(), + ) + + return AgentAction( + agent_name=self.name, + output_channel=self.output_channel, + action_type="speak", + argument=action, + ) diff --git a/examples/experimental/sotopia_original_replica/origin.svg b/examples/experimental/sotopia_original_replica/origin.svg new file mode 100644 index 000000000..78717b14a --- /dev/null +++ b/examples/experimental/sotopia_original_replica/origin.svg @@ -0,0 +1 @@ +

examples/experimental/sotopia_original_replica/origin.toml

Jane:moderator

Jack:moderator

moderator:Jane

moderator:Jack

Jane:Jack

Jack:Jane

Agent:Runtime

'Jane'

'moderator'

'Jack'

'chat_print'

diff --git a/examples/experimental/sotopia_original_replica/origin.toml b/examples/experimental/sotopia_original_replica/origin.toml new file mode 100644 index 000000000..7bf225273 --- /dev/null +++ b/examples/experimental/sotopia_original_replica/origin.toml @@ -0,0 +1,52 @@ +redis_url = "redis://localhost:6379/0" +extra_modules = ["examples.experimental.sotopia_original_replica.llm_agent_sotopia", "examples.experimental.nodes.chat_print_node", "sotopia.experimental.agents.moderator"] + + +[[nodes]] +node_name = "moderator" +node_class = "moderator" + +[nodes.node_args] +output_channels = ["moderator:Jane", "moderator:Jack"] +input_channels = ["Jane:moderator", "Jack:moderator"] +agent_backgrounds = {"Jane" = "", "Jack" = ""} +agent_mapping = {"moderator:Jane" = "Jane", "moderator:Jack" = "Jack"} +scenario = "Two friends are sitting in a cafe and catching up with each other's lives." +max_turns = 2 +push_to_db = false + +[[nodes]] +node_name = "Jack" +node_class = "llm_agent" + +[nodes.node_args] +query_interval = 5 +input_channels = ["moderator:Jack"] +output_channel = "Jack:moderator" +goal = "Your goal is to borrow 5000 dollars from Jane." +model_name = "gpt-4o-mini" +agent_name = "Jack" + + +[[nodes]] +node_name = "Jane" +node_class = "llm_agent" + +[nodes.node_args] +query_interval = 7 +output_channel = "Jane:moderator" +input_channels = ["moderator:Jane"] +goal = "Your goal is to help Jack however, you are in a finicial crisis yourself and can only afford to give him 500 dollars." +model_name = "gpt-4o-mini" +agent_name = "Jane" + +[[nodes]] +node_name = "chat_print" +node_class = "chat_print" + +[nodes.node_args.print_channel_types] +"Jane:moderator" = "agent_action" +"Jack:moderator" = "agent_action" + +[nodes.node_args] +env_agents = ["Jack", "Jane"] diff --git a/examples/experimental/sotopia_original_replica/readme.md b/examples/experimental/sotopia_original_replica/readme.md new file mode 100644 index 000000000..cb3931dc7 --- /dev/null +++ b/examples/experimental/sotopia_original_replica/readme.md @@ -0,0 +1,13 @@ +To run this example, please use aact to launch. + +```bash +aact run-dataflow examples/experimental/sotopia_original_replica/origin.toml +``` + +To view the flow of the information, please run: + +```bash +aact draw-dataflow examples/experimental/sotopia_original_replica/origin.toml --svg-path examples/experimental/sotopia_original_replica/origin.svg +``` + +![Alt text](./origin.svg) diff --git a/examples/experimental/websocket/websocket_test_client.py b/examples/experimental/websocket/websocket_test_client.py new file mode 100644 index 000000000..ab9f30272 --- /dev/null +++ b/examples/experimental/websocket/websocket_test_client.py @@ -0,0 +1,97 @@ +""" +A test client for the WebSocket server +""" + +import json +from sotopia.database import EnvironmentProfile, AgentProfile + +import asyncio +import websockets +import sys +from pathlib import Path + + +class WebSocketClient: + def __init__(self, url: str, token: str, client_id: int): + self.url = url + self.token = token + self.client_id = client_id + self.message_file = Path(f"message_{client_id}.txt") + + async def save_message(self, message: str) -> None: + """Save received message to a file""" + with open(self.message_file, "a", encoding="utf-8") as f: + f.write(f"{message}\n") + + async def connect(self) -> None: + """Establish and maintain websocket connection""" + url_with_token = f"{self.url}?token=test_token_{self.client_id}" + + try: + async with websockets.connect(url_with_token) as websocket: + print(f"Client {self.client_id}: Connected to {self.url}") + + # Send initial message + # Note: You'll need to implement the logic to get agent_ids and env_id + # This is just an example structure + agent_ids = [agent.pk for agent in AgentProfile.find().all()[:2]] + env_id = EnvironmentProfile.find().all()[0].pk + start_message = { + "type": "START_SIM", + "data": { + "env_id": env_id, # Replace with actual env_id + "agent_ids": agent_ids, # Replace with actual agent_ids + }, + } + await websocket.send(json.dumps(start_message)) + print(f"Client {self.client_id}: Sent START_SIM message") + + # Receive and process messages + while True: + try: + message = await websocket.recv() + print( + f"\nClient {self.client_id} received message:", + json.dumps(json.loads(message), indent=2), + ) + assert isinstance(message, str) + await self.save_message(message) + except websockets.ConnectionClosed: + print(f"Client {self.client_id}: Connection closed") + break + except Exception as e: + print(f"Client {self.client_id} error:", str(e)) + break + + except Exception as e: + print(f"Client {self.client_id} connection error:", str(e)) + + +async def main() -> None: + # Create multiple WebSocket clients + num_clients = 0 + url = "ws://localhost:8800/ws/simulation" + + # Create and store client instances + clients = [ + WebSocketClient(url=url, token=f"test_token_{i}", client_id=i) + for i in range(num_clients) + ] + clients.append(WebSocketClient(url=url, token="test_token_10", client_id=10)) + clients.append( + WebSocketClient(url=url, token="test_token_10", client_id=10) + ) # test duplicate token + + # Create tasks for each client + tasks = [asyncio.create_task(client.connect()) for client in clients] + + # Wait for all tasks to complete + await asyncio.gather(*tasks) + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nShutting down clients...") + sys.exit(0) diff --git a/examples/fast_api_example.py b/examples/fast_api_example.py new file mode 100644 index 000000000..f10b850b6 --- /dev/null +++ b/examples/fast_api_example.py @@ -0,0 +1,122 @@ +# Example curl command to call the simulate endpoint: +import requests +import time + +BASE_URL = "http://localhost:8080" + + +def _create_mock_agent_profile() -> None: + agent1_data = { + "first_name": "John", + "last_name": "Doe", + "occupation": "test_occupation", + "gender": "test_gender", + "pk": "tmppk_agent1", + "tag": "test_tag", + } + response = requests.post( + f"{BASE_URL}/agents/", + headers={"Content-Type": "application/json"}, + json=agent1_data, + ) + assert response.status_code == 200 + + agent2_data = { + "first_name": "Jane", + "last_name": "Doe", + "occupation": "test_occupation", + "gender": "test_gender", + "pk": "tmppk_agent2", + "tag": "test_tag", + } + response = requests.post( + f"{BASE_URL}/agents/", + headers={"Content-Type": "application/json"}, + json=agent2_data, + ) + assert response.status_code == 200 + + +def _create_mock_env_profile() -> None: + env_data = { + "codename": "test_codename", + "scenario": "A", + "agent_goals": [ + "B", + "C", + ], + "pk": "tmppk_env_profile", + "tag": "test_tag", + } + response = requests.post( + f"{BASE_URL}/scenarios/", + headers={"Content-Type": "application/json"}, + json=env_data, + ) + assert response.status_code == 200 + + +_create_mock_agent_profile() +_create_mock_env_profile() + + +data = { + "env_id": "tmppk_env_profile", + "agent_ids": ["tmppk_agent1", "tmppk_agent2"], + "models": ["custom/structured-llama3.2:1b@http://localhost:8000/v1"] * 3, + "max_turns": 10, + "tag": "test_tag", +} +try: + response = requests.post( + f"{BASE_URL}/simulate/", headers={"Content-Type": "application/json"}, json=data + ) + print(response) + assert response.status_code == 202 + assert isinstance(response.content.decode(), str) + episode_pk = response.content.decode() + print(episode_pk) + max_retries = 200 + retry_count = 0 + while retry_count < max_retries: + try: + response = requests.get(f"{BASE_URL}/simulation_status/{episode_pk}") + assert response.status_code == 200 + status = response.content.decode() + print(status) + if status == "Error": + raise Exception("Error running simulation") + elif status == "Completed": + break + # Status is "Started", keep polling + time.sleep(1) + retry_count += 1 + except Exception as e: + print(f"Error checking simulation status: {e}") + time.sleep(1) + retry_count += 1 + else: + raise TimeoutError("Simulation timed out after 10 retries") + +finally: + try: + response = requests.delete(f"{BASE_URL}/agents/tmppk_agent1") + assert response.status_code == 200 + except Exception as e: + print(e) + try: + response = requests.delete(f"{BASE_URL}/agents/tmppk_agent2") + assert response.status_code == 200 + except Exception as e: + print(e) + try: + response = requests.delete(f"{BASE_URL}/scenarios/tmppk_env_profile") + assert response.status_code == 200 + except Exception as e: + print(e) + + try: + response = requests.delete(f"{BASE_URL}/episodes/{episode_pk}") + assert response.status_code == 200 + except Exception as e: + print(e) diff --git a/examples/use_custom_dimensions.py b/examples/use_custom_dimensions.py new file mode 100644 index 000000000..ff7122242 --- /dev/null +++ b/examples/use_custom_dimensions.py @@ -0,0 +1,233 @@ +from pydantic import BaseModel +from sotopia.database import ( + EvaluationDimensionBuilder, + CustomEvaluationDimensionList, + CustomEvaluationDimension, +) +from typing import Type, Union +from redis_om import Migrator +from sotopia.envs.evaluators import ( + ReachGoalLLMEvaluator, + EvaluationForTwoAgents, + RuleBasedTerminatedEvaluator, +) +from sotopia.server import arun_one_episode +from typing import Optional, cast +from sotopia.envs import ParallelSotopiaEnv +from sotopia.agents import LLMAgent +from sotopia.database import AgentProfile, EnvironmentProfile +import asyncio + + +def save_dimensions(dimensions: list[dict[str, Union[str, int]]]) -> None: + Migrator().run() + for dimension in dimensions: + if ( + len( + CustomEvaluationDimension.find( + CustomEvaluationDimension.name == dimension["name"] + ).all() + ) + == 0 + ): + print("No existing dimension found, creating a new one") + CustomEvaluationDimension(**dimension).save() + print("Saved {}".format(dimension["name"])) + else: + print( + CustomEvaluationDimension.find( + CustomEvaluationDimension.name == dimension["name"] + ).all()[0], + "already exists", + ) + + +def save_dimension_list( + dimensions: list[dict[str, Union[str, int]]], list_name: str +) -> None: + dimension_list = CustomEvaluationDimensionList.find( + CustomEvaluationDimensionList.name == list_name + ).all() + + if len(dimension_list) == 0: + all_dimensions_pks = [] + for dimension in dimensions: + find_dimension = CustomEvaluationDimension.find( + CustomEvaluationDimension.name == dimension["name"] + ).all() + assert ( + len(find_dimension) == 1 + ), f"Expected 1 dimension for {dimension['name']}, but found {len(find_dimension)}" + all_dimensions_pks.append(find_dimension[0].pk) + CustomEvaluationDimensionList( + name=list_name, dimension_pks=all_dimensions_pks + ).save() + print("Saved {}".format(list_name)) + else: + print(dimension_list[0], "already exists") + + +def build_sotopia_dimensions() -> Type[BaseModel]: + """ + Build the default Sotopia dimensions, the same as + `sotopia.env.evaluators.SotopiaDimensions` + """ + sotopia_dimensions: list[dict[str, Union[str, int]]] = [ + { + "name": "believability", + "description": "Reasoning requirement: 1. Evaluate if the agent interacts with others in a natural and realistic manner (here are a few common questions to check: a. whether the agent is confusing with its own identity? b. whether the agent repeats others' words/actions without any reason? c. whether the agent is being overly polite considering the context?). Start the analysis with tag 2. Analyze whether the actions of the agent align with their character traits (e.g., personality, values, and etc.). Start the analysis with tag . Output your reasoning process to the 'reasoning' field. Output an integer score ranging from 0 and 10 in the 'score' field. A higher score indicates that the agent is more believable.", + "range_high": 10, + "range_low": 0, + }, + { + "name": "relationship", + "description": "Please first analyze what relationship the participant has with the other agent(s) before the interaction. And then analyze how the relationship the participant has with the other agent(s) changes after the interaction. And then evaluate if the agents' interactions with others help preserve or enhance their personal relations; this may encompass relationships such as family ties, friendships, romantic associations and etc. Additionally, ascertain whether these interactions also impact their social status or reputation. In the 'reasoning' field, provide a comprehensive account of the logic or thought process that led you to your conclusion. Further, provide an integer score ranging from -5 to 5 in the 'score' field. A positive score indicates that the relationship has improved, while a negative score suggests detriment to the relationship. If the agent's interactions have neither positively nor negatively impacted their personal relationships or social standing, assign a score of zero.", + "range_high": 5, + "range_low": -5, + }, + { + "name": "knowledge", + "description": "Please first assess what information the agent has gained through the interaction. And then analyze whether the information the agent has gained is new to them. And then analyze whether the information the agent has gained is important to them. In the 'reasoning' field, provide a comprehensive account of the logic or thought process that led you to your conclusion. Further, provide an integer score ranging from 0 and 10 in the 'score' field. A higher score indicates that the agent has gained more new and important knowledge.", + "range_high": 10, + "range_low": 0, + }, + { + "name": "secret", + "description": "First please analyze what secret or secretive intention the participant wants to keep. And then analyze whether they keep it successfully. Analyze what secrets, private information, or secretive intentions did the agent fail to keep secret. Please output the reasoning in the reasoning field. Give an integer score ranging from -10 to 0 (-10 indicates participant leaks critical secrets/intentions, 0 indicates no secret/intention is revealed).", + "range_high": 0, + "range_low": -10, + }, + { + "name": "social_rules", + "description": "Please evaluate whether the agent violates any moral rules or laws in their interactions with others. In the 'reasoning' field, provide a comprehensive account of the logic or thought process that led you to your conclusion. Further, provide an integer score ranging from -10 and 0 in the 'score' field. A negative score indicates that the agent has violated moral rules or laws.", + "range_high": 0, + "range_low": -10, + }, + { + "name": "financial_and_material_benefits", + "description": "Please evaluate whether the agent's interactions with others contribute towards financial and material benefits. Analyze what the agent would gain/lose after the interactions. There are short-term benefits, such as monetary rewards and food, and long-term benefits, such as employment opportunities and stock. In the 'reasoning' field, provide a comprehensive account of the logic or thought process that led you to your conclusion. Further, provide an integer score ranging from -5 and 5 in the 'score' field. Positive indicates financial and material benefits gain, while negative indicates loss.", + "range_high": 5, + "range_low": -5, + }, + { + "name": "goal", + "description": "Please first reiterate agent's social goals. And then please provide a comprehensive analysis about the extent to which the agent has managed to achieve these goals. In the 'reasoning' field, provide a comprehensive account of the logic or thought process that led you to your conclusion. Further, provide an integer score ranging from 0 and 10 in the 'score' field. 0 represents minimal goals achievement, 10 represents complete goal achievement, and a higher score indicates that the agent is making progress towards their social goals.", + "range_high": 10, + "range_low": 0, + }, + ] + + dimensions = EvaluationDimensionBuilder.build_dimension_model_from_dict( + dimensions=sotopia_dimensions + ) + save_dimensions(sotopia_dimensions) + save_dimension_list(sotopia_dimensions, "sotopia") + + return dimensions + + +def build_custom_dimensions( + custom_dimensions: list[dict[str, Union[str, int]]], list_name: Optional[str] = None +) -> Type[BaseModel]: + """ + Build a custom evaluation dimension model, + : param custom_dimensions: a list of dictionaries that specify the parameters of the `CustomEvaluationDimension`. + : param list_name: the name of the list to save the custom dimensions to. If None, no list will be saved. + """ + dimensions = EvaluationDimensionBuilder.build_dimension_model_from_dict( + dimensions=custom_dimensions + ) + + save_dimensions(custom_dimensions) + if list_name is not None: + save_dimension_list(custom_dimensions, list_name=list_name) + + return dimensions + + +def run_simple_sample_with_custom_samples( + custom_dimensions: list[dict[str, Union[str, int]]], +) -> None: + custom_dimensions_type = build_custom_dimensions( + custom_dimensions, list_name="custom" + ) + evaluator = RuleBasedTerminatedEvaluator(max_turn_number=10, max_stale_turn=2) + terminal_evaluator = ReachGoalLLMEvaluator( + model_name="gpt-4o-mini", + response_format_class=EvaluationForTwoAgents[custom_dimensions_type], # type: ignore + ) + + all_agents: list[AgentProfile] = cast( + list[AgentProfile], + AgentProfile.find().page(0, 2), # type: ignore + ) + all_envs: list[EnvironmentProfile] = cast( + list[EnvironmentProfile], + EnvironmentProfile.find().page(0, 1), # type: ignore + ) + environment: ParallelSotopiaEnv = ParallelSotopiaEnv( + env_profile=all_envs[0], + model_name="gpt-4o-mini", + action_order="round-robin", + evaluators=[evaluator], + terminal_evaluators=[terminal_evaluator], + ) + agents: list[LLMAgent] = [ + LLMAgent(agent_profile=agent_profile, model_name="gpt-4o-mini") + for agent_profile in all_agents[:2] + ] + + res = asyncio.run( + arun_one_episode( + env=environment, + agent_list=agents, + omniscient=False, + script_like=False, + tag=None, + push_to_db=False, + ) + ) + + print(res) + + +if __name__ == "__main__": + """ + A sample dimension: + custom_dimensions: list[dict[str, Union[str, int]]] = [ + { + "name": "transactivity", + "description": "Analyze the provided social interaction episode between the given pair/team, focusing on identifying instances of transactive exchanges. Evaluate the level of transactivity by considering the following aspects: elaboration, building upon ideas, questioning, argumentation. Analyze whether these transactive patterns persist consistently across the entire interaction or if there are notable variations throughout the exchange. In the 'reasoning' field, provide a comprehensive account of the logic and thought process that led to your conclusion. Consider how the observed instances of transactivity contribute to or detract from the overall quality and depth of the interaction. In the 'score' field, provide an integer score ranging from 0 to 10, where a higher score indicates a higher level of transactivity.", + "range_high": 10, + "range_low": 0, + }, + { + "name": "verbal_equity", + "description": "Analyze the script and measure the level of verbal equity reflected in the interaction between the agents. And then analyze the extent to which the interaction shows a balanced distribution of speaking opportunities among team members. In the 'reasoning' field, provide a comprehensive account of the logic or thought process that led you to your conclusion. Further, provide an integer score ranging from 0 and 10 in the 'score' field. A higher score indicates a higher level of verbal equity.", + "range_high": 10, + "range_low": 0, + }, + ] + """ + + custom_dimensions: list[dict[str, Union[str, int]]] = [ + { + "name": "transactivity", + "description": "Analyze the provided social interaction episode between the given pair/team, focusing on identifying instances of transactive exchanges. Evaluate the level of transactivity by considering the following aspects: elaboration, building upon ideas, questioning, argumentation. Analyze whether these transactive patterns persist consistently across the entire interaction or if there are notable variations throughout the exchange. In the 'reasoning' field, provide a comprehensive account of the logic and thought process that led to your conclusion. Consider how the observed instances of transactivity contribute to or detract from the overall quality and depth of the interaction. In the 'score' field, provide an integer score ranging from 0 to 10, where a higher score indicates a higher level of transactivity.", + "range_high": 10, + "range_low": 0, + }, + { + "name": "verbal_equity", + "description": "Analyze the script and measure the level of verbal equity reflected in the interaction between the agents. And then analyze the extent to which the interaction shows a balanced distribution of speaking opportunities among team members. In the 'reasoning' field, provide a comprehensive account of the logic or thought process that led you to your conclusion. Further, provide an integer score ranging from 0 and 10 in the 'score' field. A higher score indicates a higher level of verbal equity.", + "range_high": 10, + "range_low": 0, + }, + ] + + # Only build evaluation dimensions + build_sotopia_dimensions() + build_custom_dimensions(custom_dimensions=custom_dimensions, list_name="custom") + + # Build and use evaluation dimensions + run_simple_sample_with_custom_samples(custom_dimensions=custom_dimensions) diff --git a/pyproject.toml b/pyproject.toml index 57af6cc3a..b9edcc942 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,9 @@ plugins = [ module = "transformers.*" ignore_missing_imports = true +[tool.uv.sources] +aact = { git = "https://github.com/ProKil/aact" , branch = "feature/node-manager" } + [tool.pytest.ini_options] testpaths = ["tests"] python_files = "test_*.py" diff --git a/sotopia/database/__init__.py b/sotopia/database/__init__.py index bd737855d..d9156a989 100644 --- a/sotopia/database/__init__.py +++ b/sotopia/database/__init__.py @@ -2,7 +2,7 @@ from redis_om import JsonModel, Migrator from .annotators import Annotator from .env_agent_combo_storage import EnvAgentComboStorage -from .logs import AnnotationForEpisode, EpisodeLog +from .logs import AnnotationForEpisode, EpisodeLog, NonStreamingSimulationStatus from .persistent_profile import ( AgentProfile, EnvironmentProfile, @@ -30,6 +30,11 @@ from .session_transaction import MessageTransaction, SessionTransaction from .waiting_room import MatchingInWaitingRoom from .aggregate_annotations import map_human_annotations_to_episode_logs +from .evaluation_dimensions import ( + EvaluationDimensionBuilder, + CustomEvaluationDimension, + CustomEvaluationDimensionList, +) from logging import Logger @@ -39,6 +44,7 @@ "AgentProfile", "EnvironmentProfile", "EpisodeLog", + "NonStreamingSimulationStatus", "EnvAgentComboStorage", "AnnotationForEpisode", "Annotator", @@ -65,6 +71,10 @@ "jsonl_to_relationshipprofiles", "jsonl_to_envagnetcombostorage", "get_rewards_from_episode", + "EvaluationDimensionBuilder", + "CustomEvaluationDimension", + "CustomEvaluationDimensionList", + "NonStreamingSimulationStatus", ] InheritedJsonModel = TypeVar("InheritedJsonModel", bound="JsonModel") diff --git a/sotopia/database/evaluation_dimensions.py b/sotopia/database/evaluation_dimensions.py new file mode 100644 index 000000000..4b2a2c2a1 --- /dev/null +++ b/sotopia/database/evaluation_dimensions.py @@ -0,0 +1,147 @@ +from redis_om import JsonModel +from redis_om.model.model import Field +from pydantic import BaseModel, create_model +from typing import Type, Callable, Tuple, Annotated, Union, cast, Any + + +class CustomEvaluationDimension(JsonModel): + name: str = Field(index=True) + description: str = Field(index=True) + range_high: int = Field(index=True) + range_low: int = Field(index=True) + + +class CustomEvaluationDimensionList(JsonModel): + name: str = Field(index=True) + dimension_pks: list[str] = Field(default_factory=lambda: [], index=True) + + +class EvaluationDimensionBuilder: + """ + EvaluationDimensionBuilder is a utility class for creating and managing evaluation dimensions. + It provides methods to build evaluation dimension models from various inputs such as primary keys, dictionaries, and names. + """ + + @staticmethod + def create_range_validator( + low: int, high: int + ) -> Callable[[Tuple[str, int]], Tuple[str, int]]: + def validator(x: Tuple[str, int]) -> Tuple[str, int]: + if not isinstance(x, tuple) or len(x) != 2: + raise ValueError("Must be a tuple of (str, int)") + if not isinstance(x[1], int) or not low <= x[1] <= high: + raise ValueError(f"Score must be between {low} and {high}") + return x + + return validator + + @staticmethod + def build_dimension_model(dimension_ids: list[str]) -> Type[BaseModel]: + """ + Build an evaluation dimension from existing dimension primary keys. + The returned model is a pydantic model that can be used to evaluate the conversation. + """ + fields: dict[str, Any] = {} + + for dimension_id in dimension_ids: + dimension = CustomEvaluationDimension.get(dimension_id) + range_validator = EvaluationDimensionBuilder.create_range_validator( + dimension.range_low, dimension.range_high + ) + field_type = Annotated[Tuple[str, int], range_validator] + + fields[dimension.name] = ( + field_type, + Field(..., description=dimension.description), + ) + + model: Type[BaseModel] = create_model( + "CustomEvaluationDimensionModel", + __base__=BaseModel, + **fields, + ) + return model + + @staticmethod + def build_dimension_model_from_dict( + dimensions: list[dict[str, Union[str, int]]], + ) -> Type[BaseModel]: + """ + Build an evaluation dimension from a dictionary that specifies the parameters of the `CustomEvaluationDimension`. + The returned model is a pydantic model that can be used to evaluate the conversation. + """ + fields: dict[str, Any] = {} + for dimension_dict in dimensions: + dimension = CustomEvaluationDimension(**dimension_dict) + range_validator = EvaluationDimensionBuilder.create_range_validator( + dimension.range_low, dimension.range_high + ) + field_type = Annotated[Tuple[str, int], range_validator] + + fields[dimension.name] = ( + field_type, + Field(..., description=dimension.description), + ) + + dimension_model = create_model( + "CustomEvaluationDimensionModel", + __base__=BaseModel, + **fields, + ) + return dimension_model + + @staticmethod + def select_existing_dimension_model_by_name( + dimension_names: list[str], + ) -> Type[BaseModel]: + """ + Build an evaluation dimension from existing dimension names. For example `['believability', 'goal']` + The returned model is a pydantic model that can be used to evaluate the conversation. + """ + fields: dict[str, Any] = {} + for dimension_name in dimension_names: + dimensions = CustomEvaluationDimension.find( + CustomEvaluationDimension.name == dimension_name + ).all() + assert ( + len(dimensions) == 1 + ), f"Expected 1 dimension for {dimension_name}, but found {len(dimensions)}" + dimension = cast(CustomEvaluationDimension, dimensions[0]) + range_validator = EvaluationDimensionBuilder.create_range_validator( + dimension.range_low, dimension.range_high + ) + field_type = Annotated[Tuple[str, int], range_validator] + + fields[dimension.name] = ( + field_type, + Field(..., description=dimension.description), + ) + + model: Type[BaseModel] = create_model( + "CustomEvaluationDimensionModel", + __base__=BaseModel, + **fields, + ) + return model + + @staticmethod + def select_existing_dimension_model_by_list_name( + list_name: str, + ) -> Type[BaseModel]: + """ + Build an evaluation dimension from existing `CustomEvaluationDimensionList` list names. For example, directly use `sotopia` + The returned model is a pydantic model that can be used to evaluate the conversation. + """ + # if list_name == "sotopia": + # return SotopiaDimensions # TODO see if we could make this work in `experiment_eval.py`. Right now there is a circular import + + dimensions = CustomEvaluationDimensionList.find( + CustomEvaluationDimensionList.name == list_name + ).all() + assert ( + len(dimensions) == 1 + ), f"Expected 1 dimension list for {list_name}, but found {len(dimensions)}" + dimension_list = cast(CustomEvaluationDimensionList, dimensions[0]) + dimension_ids = dimension_list.dimension_pks + model = EvaluationDimensionBuilder.build_dimension_model(dimension_ids) + return model diff --git a/sotopia/database/logs.py b/sotopia/database/logs.py index b3c5ff41e..4a2551aed 100644 --- a/sotopia/database/logs.py +++ b/sotopia/database/logs.py @@ -8,10 +8,15 @@ from pydantic import model_validator from redis_om import JsonModel from redis_om.model.model import Field - +from typing import Literal from sotopia.database.persistent_profile import AgentProfile +class NonStreamingSimulationStatus(JsonModel): + episode_pk: str = Field(index=True) + status: Literal["Started", "Error", "Completed"] + + class EpisodeLog(JsonModel): # Note that we did not validate the following constraints: # 1. The number of turns in messages and rewards should be the same or off by 1 diff --git a/sotopia/database/persistent_profile.py b/sotopia/database/persistent_profile.py index ee99f6601..c2e0e8e86 100644 --- a/sotopia/database/persistent_profile.py +++ b/sotopia/database/persistent_profile.py @@ -94,6 +94,11 @@ class RelationshipProfile(JsonModel): description="0 means stranger, 1 means know_by_name, 2 means acquaintance, 3 means friend, 4 means romantic_relationship, 5 means family_member", ) # this could be improved by limiting str to a relationship Enum background_story: str | None = Field(default_factory=lambda: None) + tag: str = Field( + index=True, + default_factory=lambda: "", + description="The tag of the relationship, used for searching, could be convenient to document relationship profiles from different works and sources", + ) class EnvironmentList(JsonModel): diff --git a/sotopia/database/serialization.py b/sotopia/database/serialization.py index 1fcc8b69e..c38e3c6c3 100644 --- a/sotopia/database/serialization.py +++ b/sotopia/database/serialization.py @@ -84,7 +84,7 @@ def _map_gender_to_adj(gender: str) -> str: "Nonbinary": "nonbinary", } if gender: - return gender_to_adj[gender] + return gender_to_adj.get(gender, "") else: return "" diff --git a/sotopia/envs/parallel.py b/sotopia/envs/parallel.py index 5d27f687b..e0a928d36 100644 --- a/sotopia/envs/parallel.py +++ b/sotopia/envs/parallel.py @@ -51,7 +51,7 @@ def _map_gender_to_adj(gender: str) -> str: "Nonbinary": "nonbinary", } if gender: - return gender_to_adj[gender] + return gender_to_adj.get(gender, "") else: return "" diff --git a/sotopia/experimental/agents/base_agent.py b/sotopia/experimental/agents/base_agent.py index a7bbafae6..6d9466bbc 100644 --- a/sotopia/experimental/agents/base_agent.py +++ b/sotopia/experimental/agents/base_agent.py @@ -22,11 +22,13 @@ def __init__( input_channel_types: list[tuple[str, type[T_agent_observation]]], output_channel_types: list[tuple[str, type[T_agent_action]]], redis_url: str = "redis://localhost:6379/0", + node_name: str = "base_agent", ): super().__init__( input_channel_types=input_channel_types, output_channel_types=output_channel_types, redis_url=redis_url, + node_name=node_name, ) self.observation_queue: asyncio.Queue[T_agent_observation] = asyncio.Queue() diff --git a/sotopia/experimental/agents/datamodels.py b/sotopia/experimental/agents/datamodels.py new file mode 100644 index 000000000..a243a52a3 --- /dev/null +++ b/sotopia/experimental/agents/datamodels.py @@ -0,0 +1,42 @@ +from sotopia.messages import ActionType +from aact.messages import DataModel, DataModelFactory +from pydantic import Field + + +@DataModelFactory.register("observation") +class Observation(DataModel): + agent_name: str = Field(description="the name of the agent") + last_turn: str = Field(description="the last turn of the conversation") + turn_number: int = Field(description="the turn number of the conversation") + available_actions: list[ActionType] = Field(description="the available actions") + + def to_natural_language(self) -> str: + if self.turn_number == 0: + return f"\n{self.last_turn}\nConversation Starts:\n" + else: + return f"Turn #{self.turn_number-1}: {self.last_turn}\n" + + +@DataModelFactory.register("agent_action") +class AgentAction(DataModel): + agent_name: str = Field(description="the name of the agent") + output_channel: str = Field(description="the name of the output channel") + action_type: ActionType = Field( + description="whether to speak at this turn or choose to not do anything" + ) + argument: str = Field( + description="the utterance if choose to speak, the expression or gesture if choose non-verbal communication, or the physical action if choose action" + ) + + def to_natural_language(self) -> str: + match self.action_type: + case "none": + return "did nothing" + case "speak": + return f'said: "{self.argument}"' + case "non-verbal communication": + return f"[{self.action_type}] {self.argument}" + case "action": + return f"[{self.action_type}] {self.argument}" + case "leave": + return "left the conversation" diff --git a/sotopia/experimental/agents/moderator.py b/sotopia/experimental/agents/moderator.py new file mode 100644 index 000000000..ce57fb38b --- /dev/null +++ b/sotopia/experimental/agents/moderator.py @@ -0,0 +1,270 @@ +import asyncio +import sys + + +if sys.version_info < (3, 11): + from typing_extensions import Self +else: + from typing import Self + + +from aact import Message, NodeFactory, Node +from aact.messages import DataModel, DataModelFactory + +from typing import Literal, Any, AsyncIterator +from pydantic import Field + +from sotopia.database import EpisodeLog +from .datamodels import AgentAction, Observation +from sotopia.messages import ActionType + + +@DataModelFactory.register("observations") +class Observations(DataModel): + observations_map: dict[str, Observation] = Field( + description="the observations of the agents" + ) + + +@NodeFactory.register("moderator") +class Moderator(Node[AgentAction, Observation]): + def __init__( + self, + input_channels: list[str], + output_channels: list[str], + scenario: str, + agent_mapping: dict[str, str], + node_name: str, + agent_backgrounds: dict[str, str], + redis_url: str = "redis://localhost:6379/0", + action_order: Literal["simultaneous", "round-robin", "random"] = "round-robin", + available_actions: list[ActionType] = [ + "none", + "speak", + "non-verbal communication", + "action", + "leave", + ], + max_turns: int = 20, + push_to_db: bool = False, + ): + super().__init__( + input_channel_types=[ + (input_channel, AgentAction) for input_channel in input_channels + ], + output_channel_types=[ + (output_channel, Observation) for output_channel in output_channels + ], + redis_url=redis_url, + node_name=node_name, + ) + self.observation_queue: asyncio.Queue[AgentAction] = asyncio.Queue() + self.task_scheduler: asyncio.Task[None] | None = None + self.shutdown_event: asyncio.Event = asyncio.Event() + self.agent_mapping: dict[str, str] = agent_mapping + self.action_order: Literal["simultaneous", "round-robin", "random"] = ( + action_order + ) + self.available_actions: list[ActionType] = available_actions + self.turn_number: int = 0 + self.max_turns: int = max_turns + self.current_agent_index: int = 0 + self.scenario: str = scenario + self.agents: list[str] = list(agent_mapping.values()) + self.agent_models: dict[str, str] = {} + self.agents_awake: dict[str, bool] = {name: False for name in self.agents} + self.all_agents_awake: asyncio.Event = asyncio.Event() + self.message_history: list[list[tuple[str, str, str]]] = [ + [("Environment", "Environment", self.scenario)] + ] + self.push_to_db = push_to_db + self.agent_backgrounds = agent_backgrounds + + if self.action_order == "round-robin": + pass + else: + raise NotImplementedError( + "the selected action order is currently not implemented" + ) + + async def __aenter__(self) -> Self: + print(self.scenario) + asyncio.create_task(self.booting()) + self.task_scheduler = asyncio.create_task(self._task_scheduler()) + return await super().__aenter__() + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.shutdown_event.set() + if self.task_scheduler is not None: + self.task_scheduler.cancel() + return await super().__aexit__(exc_type, exc_value, traceback) + + async def send(self, observations: Observations) -> None: + for output_channel, output_channel_type in self.output_channel_types.items(): + if output_channel in observations.observations_map: + await self.r.publish( + output_channel, + Message[output_channel_type]( # type:ignore[valid-type] + data=observations.observations_map[output_channel] + ).model_dump_json(), + ) + + async def event_handler( + self, channel: str, message: Message[AgentAction] + ) -> AsyncIterator[tuple[str, Message[Observation]]]: + if channel in self.input_channel_types: + await self.observation_queue.put(message.data) + else: + raise ValueError(f"Invalid channel: {channel}") + yield "", self.output_type() + + async def _task_scheduler(self) -> None: + await self.all_agents_awake.wait() + while not self.shutdown_event.is_set(): + observation = await self.observation_queue.get() + action_or_none = await self.aact(observation) + if action_or_none is not None: + await self.send(action_or_none) + self.observation_queue.task_done() + + async def booting(self) -> None: + """ + 1. send checking message to agents for every 0.1 seconds, until all agents are awake + - this message has turn_number of -1 for identification, agents should not record this into actual message_history + - if the agent booted succesfully, he is expected to return its model name for record. + 2. (under round-robin action order)after all agents are awake, send agent[0] a message to allow the agent to start speaking + """ + while not self.all_agents_awake.is_set(): + await self.send( + Observations( + observations_map={ + output_channel: Observation( + agent_name="moderator", + last_turn=self.scenario, + turn_number=-1, + available_actions=["none"], + ) + for output_channel, agent_name in self.agent_mapping.items() + } + ) + ) + await asyncio.sleep(0.1) + while not self.observation_queue.empty(): + agent_action = await self.observation_queue.get() + self.agents_awake[agent_action.agent_name] = True + self.agent_models[agent_action.agent_name] = agent_action.argument + if False not in self.agents_awake.values(): + self.all_agents_awake.set() + + if self.action_order == "round-robin": + await self.send( + Observations( + observations_map={ + output_channel: Observation( + agent_name="moderator", + last_turn=self.agent_backgrounds[agent_name], + turn_number=0, + available_actions=self.available_actions + if agent_name == self.agents[0] + else ["none"], + ) + for output_channel, agent_name in self.agent_mapping.items() + } + ) + ) + self.current_agent_index += 1 + + async def wrap_up_and_stop(self) -> None: + if self.push_to_db: + await self.save() + await asyncio.sleep(0.5) + print("stopping all agents") + await self.r.publish( + f"shutdown:{self.node_name}", + "shutdown", + ) + + async def save(self) -> EpisodeLog: + """ + save the EpisodeLog to redis, without evaluating + TODO: specify what to be added inside tag + TODO: update the code so that EpisodeLog.render_for_humans() can work + -currently it cannot work because no AgentProfile has been uploaded to redis + -such a process should be done back in the agents' end + -also the current agentslist is consist of names, but not uuid's of agents + """ + epilog = EpisodeLog( + environment=self.scenario, + agents=self.agents, + tag=None, + models=list(self.agent_models.values()), + messages=self.message_history, + reasoning="", + rewards=[0] * len(self.agents), + rewards_prompt="", + ) + epilog.save() + # print(epilog.render_for_humans()) + return epilog + + async def aact(self, agent_action: AgentAction) -> Observations | None: + if agent_action.action_type == "leave": + self.agents_awake[agent_action.agent_name] = False + if True not in self.agents_awake.values(): + await self.wrap_up_and_stop() + return None + if agent_action.action_type == "none": + return None + + if len(self.message_history) == 1: + self.message_history[0].append( + ( + agent_action.agent_name, + "Environment", + agent_action.to_natural_language(), + ) + ) + else: + self.message_history.append( + [ + ( + agent_action.agent_name, + "Environment", + agent_action.to_natural_language(), + ) + ] + ) + + if self.turn_number < self.max_turns: + self.turn_number += 1 + else: + return Observations( + observations_map={ + output_channel: Observation( + agent_name="moderator", + last_turn=self.scenario, + turn_number=self.turn_number + 1, + available_actions=["leave"], + ) + for output_channel, agent_name in self.agent_mapping.items() + } + ) + + observations_map: dict[str, Observation] = {} + for output_channel, output_channel_type in self.output_channel_types.items(): + agent_name = self.agent_mapping[output_channel] + available_actions: list[ActionType] = ["none"] + if self.action_order == "round-robin": + if agent_name == self.agents[self.current_agent_index]: + available_actions = self.available_actions + + observation = Observation( + agent_name=agent_name, + last_turn=agent_action.to_natural_language(), + turn_number=self.turn_number, + available_actions=available_actions, + ) + observations_map[output_channel] = observation + self.current_agent_index = (self.current_agent_index + 1) % len(self.agents) + + return Observations(observations_map=observations_map) diff --git a/sotopia/server.py b/sotopia/server.py index d285558a5..ba88e9a7b 100644 --- a/sotopia/server.py +++ b/sotopia/server.py @@ -1,7 +1,7 @@ import asyncio import itertools import logging -from typing import Literal, Sequence, Type +from typing import Literal, Sequence, Type, AsyncGenerator, Union import gin import rich @@ -15,7 +15,7 @@ ScriptWritingAgent, ) from sotopia.agents.base_agent import BaseAgent -from sotopia.database import EpisodeLog +from sotopia.database import EpisodeLog, NonStreamingSimulationStatus from sotopia.envs import ParallelSotopiaEnv from sotopia.envs.evaluators import ( EvaluationForTwoAgents, @@ -25,7 +25,7 @@ unweighted_aggregate_evaluate, ) from sotopia.generation_utils.generate import LLM_Name, agenerate_script -from sotopia.messages import AgentAction, Message, Observation +from sotopia.messages import AgentAction, Message, Observation, SimpleMessage from sotopia.messages.message_classes import ( ScriptBackground, ScriptEnvironmentResponse, @@ -104,6 +104,12 @@ def run_sync_server( return messages +def flatten_listed_messages( + messages: list[list[tuple[str, str, Message]]], +) -> list[tuple[str, str, Message]]: + return list(itertools.chain.from_iterable(messages)) + + @gin.configurable async def arun_one_episode( env: ParallelSotopiaEnv, @@ -113,102 +119,135 @@ async def arun_one_episode( json_in_script: bool = False, tag: str | None = None, push_to_db: bool = False, -) -> list[tuple[str, str, Message]]: + episode_pk: str | None = None, + streaming: bool = False, + simulation_status: NonStreamingSimulationStatus | None = None, +) -> Union[ + list[tuple[str, str, Message]], + AsyncGenerator[list[list[tuple[str, str, Message]]], None], +]: agents = Agents({agent.agent_name: agent for agent in agent_list}) - environment_messages = env.reset(agents=agents, omniscient=omniscient) - agents.reset() - - messages: list[list[tuple[str, str, Message]]] = [] + print(f"Running episode with tag: {tag}------------------") - # Main Event Loop - done = False - messages.append( - [ - ("Environment", agent_name, environment_messages[agent_name]) - for agent_name in env.agents - ] - ) - # set goal for agents - for index, agent_name in enumerate(env.agents): - agents[agent_name].goal = env.profile.agent_goals[index] - rewards: list[list[float]] = [] - reasons: list[str] = [] - while not done: - # gather agent messages - agent_messages: dict[str, AgentAction] = dict() - actions = await asyncio.gather( - *[ - agents[agent_name].aact(environment_messages[agent_name]) - for agent_name in env.agents - ] - ) - if script_like: - # manually mask one message - agent_mask = env.action_mask - for idx in range(len(agent_mask)): - print("Current mask: ", agent_mask) - if agent_mask[idx] == 0: - print("Action not taken: ", actions[idx]) - actions[idx] = AgentAction(action_type="none", argument="") - else: - print("Current action taken: ", actions[idx]) + async def generate_messages() -> ( + AsyncGenerator[list[list[tuple[str, str, Message]]], None] + ): + environment_messages = env.reset(agents=agents, omniscient=omniscient) + agents.reset() + messages: list[list[tuple[str, str, Message]]] = [] - # actions = cast(list[AgentAction], actions) - for idx, agent_name in enumerate(env.agents): - agent_messages[agent_name] = actions[idx] - - messages[-1].append((agent_name, "Environment", agent_messages[agent_name])) - - # send agent messages to environment - ( - environment_messages, - rewards_in_turn, - terminated, - ___, - info, - ) = await env.astep(agent_messages) + # Main Event Loop + done = False messages.append( [ ("Environment", agent_name, environment_messages[agent_name]) for agent_name in env.agents ] ) - # print("Environment message: ", environment_messages) - # exit(0) - rewards.append([rewards_in_turn[agent_name] for agent_name in env.agents]) - reasons.append( - " ".join(info[agent_name]["comments"] for agent_name in env.agents) - ) - done = all(terminated.values()) + yield messages + + # set goal for agents + for index, agent_name in enumerate(env.agents): + agents[agent_name].goal = env.profile.agent_goals[index] + rewards: list[list[float]] = [] + reasons: list[str] = [] + while not done: + # gather agent messages + agent_messages: dict[str, AgentAction] = dict() + actions = await asyncio.gather( + *[ + agents[agent_name].aact(environment_messages[agent_name]) + for agent_name in env.agents + ] + ) + if script_like: + # manually mask one message + agent_mask = env.action_mask + for idx in range(len(agent_mask)): + if agent_mask[idx] == 0: + actions[idx] = AgentAction(action_type="none", argument="") + else: + pass + + # actions = cast(list[AgentAction], actions) + for idx, agent_name in enumerate(env.agents): + agent_messages[agent_name] = actions[idx] + + messages[-1].append( + (agent_name, "Environment", agent_messages[agent_name]) + ) - # TODO: clean up this part - epilog = EpisodeLog( - environment=env.profile.pk, - agents=[agent.profile.pk for agent in agent_list], - tag=tag, - models=[env.model_name, agent_list[0].model_name, agent_list[1].model_name], - messages=[ - [(m[0], m[1], m[2].to_natural_language()) for m in messages_in_turn] - for messages_in_turn in messages - ], - reasoning=info[env.agents[0]]["comments"], - rewards=[info[agent_name]["complete_rating"] for agent_name in env.agents], - rewards_prompt=info["rewards_prompt"]["overall_prompt"], - ) - rich.print(epilog.rewards_prompt) - agent_profiles, conversation = epilog.render_for_humans() - for agent_profile in agent_profiles: - rich.print(agent_profile) - for message in conversation: - rich.print(message) + # send agent messages to environment + ( + environment_messages, + rewards_in_turn, + terminated, + ___, + info, + ) = await env.astep(agent_messages) + messages.append( + [ + ("Environment", agent_name, environment_messages[agent_name]) + for agent_name in env.agents + ] + ) + print(f"Messages: {messages}") + yield messages + rewards.append([rewards_in_turn[agent_name] for agent_name in env.agents]) + reasons.append( + " ".join(info[agent_name]["comments"] for agent_name in env.agents) + ) + done = all(terminated.values()) - if push_to_db: - try: - epilog.save() - except Exception as e: - logging.error(f"Failed to save episode log: {e}") - # flatten nested list messages - return list(itertools.chain(*messages)) + epilog = EpisodeLog( + environment=env.profile.pk, + agents=[agent.profile.pk for agent in agent_list], + tag=tag, + models=[env.model_name, agent_list[0].model_name, agent_list[1].model_name], + messages=[ + [(m[0], m[1], m[2].to_natural_language()) for m in messages_in_turn] + for messages_in_turn in messages + ], + reasoning=info[env.agents[0]]["comments"], + rewards=[info[agent_name]["complete_rating"] for agent_name in env.agents], + rewards_prompt=info["rewards_prompt"]["overall_prompt"], + ) + rich.print(epilog.rewards_prompt) + agent_profiles, conversation = epilog.render_for_humans() + for agent_profile in agent_profiles: + rich.print(agent_profile) + for message in conversation: + rich.print(message) + + if streaming: + # yield the rewards and reasonings + messages.append( + [("Evaluation", "Rewards", SimpleMessage(message=str(epilog.rewards)))] + ) + messages.append( + [("Evaluation", "Reasoning", SimpleMessage(message=epilog.reasoning))] + ) + yield messages + + if push_to_db: + try: + if episode_pk: + epilog.pk = episode_pk + epilog.save() + else: + epilog.save() + if simulation_status: + simulation_status.status = "Completed" + simulation_status.save() + except Exception as e: + logging.error(f"Failed to save episode log: {e}") + + if streaming: + return generate_messages() + else: + async for last_messages in generate_messages(): + pass + return flatten_listed_messages(last_messages) @gin.configurable @@ -310,7 +349,13 @@ def get_agent_class( else [await i for i in episode_futures] ) - return batch_results + if len(batch_results) > 0: + first_result = batch_results[0] + assert isinstance( + first_result, list + ), f"Unexpected result type: {type(first_result)}" + + return batch_results # type: ignore async def arun_one_script( diff --git a/sotopia/ui/README.md b/sotopia/ui/README.md index ca8b679d2..4d8bb773a 100644 --- a/sotopia/ui/README.md +++ b/sotopia/ui/README.md @@ -4,6 +4,17 @@ ## FastAPI Server +To run the FastAPI server, you can use the following command: +```bash +uv run rq worker +uv run fastapi run sotopia/ui/fastapi_server.py --workers 4 --port 8080 +``` + +Here is also an example of using the FastAPI server: +```bash +uv run python examples/fast_api_example.py +``` + The API server is a FastAPI application that is used to connect the Sotopia UI to the Sotopia backend. This could also help with other projects that need to connect to the Sotopia backend through HTTP requests. @@ -78,33 +89,31 @@ EnvironmentProfile returns: - scenario_id: str -#### DELETE /agents/{agent_id} +### Updating Data in the API Server -Delete agent profile from the API server. +#### PUT /agents/{agent_id} + +Update agent profile in the API server. +Request Body: +AgentProfile returns: - agent_id: str -#### DELETE /scenarios/{scenario_id} -Delete scenario profile from the API server. +#### PUT /scenarios/{scenario_id} + +Update scenario profile in the API server. +Request Body: +EnvironmentProfile returns: - scenario_id: str - -### Error Code -For RESTful APIs above we have the following error codes: -| **Error Code** | **Description** | -|-----------------|--------------------------------------| -| **404** | A resource is not found | -| **403** | The query is not authorized | -| **500** | Internal running error | - ### Initiating a new non-streaming simulation episode #### POST /episodes/ -[!] Currently not planning to implement + ```python class SimulationEpisodeInitiation(BaseModel): scenario_id: str @@ -147,14 +156,14 @@ returns: | Type | Direction | Description | |-----------|--------|-------------| | SERVER_MSG | Server → Client | Standard message from server (payload: `messageForRendering` [here](https://github.com/sotopia-lab/sotopia-demo/blob/main/socialstream/rendering_utils.py) ) | -| CLIENT_MSG | Client → Server | Standard message from client (payload: Currently not needed) | -| ERROR | Server → Client | Error notification (payload: `{"type": ERROR_TYPE, "description": DESC}`) | +| CLIENT_MSG | Client → Server | Standard message from client (payload: TBD) | +| ERROR | Server → Client | Error notification (payload: TBD) | | START_SIM | Client → Server | Initialize simulation (payload: `SimulationEpisodeInitialization`) | | END_SIM | Client → Server | End simulation (payload: not needed) | | FINISH_SIM | Server → Client | Terminate simulation (payload: not needed) | -**ERROR_TYPE** +**Error Type** | Error Code | Description | |------------|-------------| @@ -167,14 +176,53 @@ returns: | OTHER | Other unspecified errors | -**Conversation Message From the Server** -The server returns messages encapsulated in a structured format which is defined as follows: +**Implementation plan**: Currently only support LLM-LLM simulation based on [this function](https://github.com/sotopia-lab/sotopia/blob/19d39e068c3bca9246fc366e5759414f62284f93/sotopia/server.py#L108). + + +## An example to run simulation with the API + +**Get all scenarios**: +```bash +curl -X GET "http://localhost:8000/scenarios" +``` + +This gonna give you all the scenarios, and you can randomly pick one + + +**Get all agents**: +```bash +curl -X GET "http://localhost:8000/agents" +``` + +This gonna give you all the agents, and you can randomly pick one + +**Connecting to the websocket server**: +We recommend using Python. Here is the simplist way to start a simulation and receive the results in real time: ```python -class MessageForRendering(TypedDict): - role: str # Specifies the origin of the message. Common values include "Background Info", "Environment", "{Agent Names} - type: str # Categorizes the nature of the message. Common types include: "comment", "said", "action" - content: str +import aiohttp +import asyncio +import json + +async def main(): + async with aiohttp.ClientSession() as session: + async with session.ws_connect(f'ws://{API_BASE}/ws/simulation?token={YOUR_TOKEN}') as ws: + start_message = { + "type": "START_SIM", + "data": { + "env_id": "{ENV_ID}", + "agent_ids": ["{AGENT1_PK}", "{AGENT2_PK}"], + }, + } + await ws.send_json(start_message) + + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + print(f"Received: {msg.data}") + elif msg.type == aiohttp.WSMsgType.CLOSED: + break + elif msg.type == aiohttp.WSMsgType.ERROR: + break ``` -**Implementation plan**: Currently only support LLM-LLM simulation based on [this function](https://github.com/sotopia-lab/sotopia/blob/19d39e068c3bca9246fc366e5759414f62284f93/sotopia/server.py#L108). +Please check out an detailed example in `examples/experimental/websocket/websocket_test_client.py` diff --git a/sotopia/ui/fastapi_server.py b/sotopia/ui/fastapi_server.py index ea53f4e56..2eacef028 100644 --- a/sotopia/ui/fastapi_server.py +++ b/sotopia/ui/fastapi_server.py @@ -1,17 +1,84 @@ -from fastapi import FastAPI from typing import Literal, cast, Dict -from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog -from pydantic import BaseModel +import sys + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +from redis_om import get_redis_connection +import rq +from sotopia.database import ( + EnvironmentProfile, + AgentProfile, + EpisodeLog, + RelationshipProfile, + RelationshipType, + NonStreamingSimulationStatus, + CustomEvaluationDimensionList, + CustomEvaluationDimension, +) +from sotopia.envs.parallel import ParallelSotopiaEnv +from sotopia.envs.evaluators import ( + RuleBasedTerminatedEvaluator, + ReachGoalLLMEvaluator, + EvaluationForTwoAgents, + SotopiaDimensions, +) +from sotopia.server import arun_one_episode +from sotopia.agents import LLMAgent, Agents +from fastapi import ( + FastAPI, + WebSocket, + HTTPException, + WebSocketDisconnect, +) +from typing import Optional, Any +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, model_validator, field_validator, Field + +from sotopia.ui.websocket_utils import ( + WebSocketSotopiaSimulator, + WSMessageType, + ErrorType, +) import uvicorn +import asyncio + +from contextlib import asynccontextmanager +from typing import AsyncIterator +import logging +from fastapi.responses import Response + +logger = logging.getLogger(__name__) + app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) # TODO: Whether allowing CORS for all origins + + +class RelationshipWrapper(BaseModel): + pk: str = "" + agent_1_id: str = "" + agent_2_id: str = "" + relationship: Literal[0, 1, 2, 3, 4, 5] = 0 + backstory: str = "" + tag: str = "" + class AgentProfileWrapper(BaseModel): """ Wrapper for AgentProfile to avoid pydantic v2 issues """ + pk: str = "" first_name: str last_name: str age: int = 0 @@ -35,6 +102,7 @@ class EnvironmentProfileWrapper(BaseModel): Wrapper for EnvironmentProfile to avoid pydantic v2 issues """ + pk: str = "" codename: str source: str = "" scenario: str = "" @@ -46,6 +114,43 @@ class EnvironmentProfileWrapper(BaseModel): tag: str = "" +class CustomEvaluationDimensionsWrapper(BaseModel): + pk: str = "" + name: str = Field( + default="", description="The name of the custom evaluation dimension list" + ) + dimensions: list[CustomEvaluationDimension] = Field( + default=[], description="The dimensions of the custom evaluation dimension list" + ) + + +class SimulationRequest(BaseModel): + env_id: str + agent_ids: list[str] + models: list[str] + max_turns: int + tag: str + + @field_validator("agent_ids") + @classmethod + def validate_agent_ids(cls, v: list[str]) -> list[str]: + if len(v) != 2: + raise ValueError( + "Currently only 2 agents are supported, we are working on supporting more agents" + ) + return v + + @model_validator(mode="after") + def validate_models(self) -> Self: + models = self.models + agent_ids = self.agent_ids + if len(models) != len(agent_ids) + 1: + raise ValueError( + f"models must have exactly {len(agent_ids) + 1} elements, if there are {len(agent_ids)} agents, the first model is the evaluator model" + ) + return self + + @app.get("/scenarios", response_model=list[EnvironmentProfile]) async def get_scenarios_all() -> list[EnvironmentProfile]: return EnvironmentProfile.all() @@ -64,6 +169,12 @@ async def get_scenarios( EnvironmentProfile.codename == value ).all() scenarios.extend(cast(list[EnvironmentProfile], json_models)) + + if not scenarios: + raise HTTPException( + status_code=404, detail=f"No scenarios found with {get_by}={value}" + ) + return scenarios @@ -85,9 +196,34 @@ async def get_agents( elif get_by == "occupation": json_models = AgentProfile.find(AgentProfile.occupation == value).all() agents_profiles.extend(cast(list[AgentProfile], json_models)) + + if not agents_profiles: + raise HTTPException( + status_code=404, detail=f"No agents found with {get_by}={value}" + ) + return agents_profiles +@app.get("/relationship/{agent_1_id}/{agent_2_id}", response_model=str) +async def get_relationship(agent_1_id: str, agent_2_id: str) -> str: + relationship_profiles = RelationshipProfile.find( + (RelationshipProfile.agent_1_id == agent_1_id) + & (RelationshipProfile.agent_2_id == agent_2_id) + ).all() + assert ( + len(relationship_profiles) == 1 + ), f"{len(relationship_profiles)} relationship profiles found for agents {agent_1_id} and {agent_2_id}, expected 1" + relationship_profile = relationship_profiles[0] + assert isinstance(relationship_profile, RelationshipProfile) + return f"{str(relationship_profile.relationship)}: {RelationshipType(relationship_profile.relationship).name}" + + +@app.get("/episodes", response_model=list[EpisodeLog]) +async def get_episodes_all() -> list[EpisodeLog]: + return EpisodeLog.all() + + @app.get("/episodes/{get_by}/{value}", response_model=list[EpisodeLog]) async def get_episodes(get_by: Literal["id", "tag"], value: str) -> list[EpisodeLog]: episodes: list[EpisodeLog] = [] @@ -96,10 +232,41 @@ async def get_episodes(get_by: Literal["id", "tag"], value: str) -> list[Episode elif get_by == "tag": json_models = EpisodeLog.find(EpisodeLog.tag == value).all() episodes.extend(cast(list[EpisodeLog], json_models)) + + if not episodes: + raise HTTPException( + status_code=404, detail=f"No episodes found with {get_by}={value}" + ) return episodes -@app.post("/agents/") +@app.get( + "/evaluation_dimensions/", response_model=dict[str, list[CustomEvaluationDimension]] +) +async def get_evaluation_dimensions() -> dict[str, list[CustomEvaluationDimension]]: + custom_evaluation_dimensions: dict[str, list[CustomEvaluationDimension]] = {} + all_custom_evaluation_dimension_list = CustomEvaluationDimensionList.all() + for custom_evaluation_dimension_list in all_custom_evaluation_dimension_list: + assert isinstance( + custom_evaluation_dimension_list, CustomEvaluationDimensionList + ) + custom_evaluation_dimensions[custom_evaluation_dimension_list.name] = [ + CustomEvaluationDimension.get(pk=pk) + for pk in custom_evaluation_dimension_list.dimension_pks + ] + return custom_evaluation_dimensions + + +@app.post("/scenarios/", response_model=str) +async def create_scenario(scenario: EnvironmentProfileWrapper) -> str: + scenario_profile = EnvironmentProfile(**scenario.model_dump()) + scenario_profile.save() + pk = scenario_profile.pk + assert pk is not None + return pk + + +@app.post("/agents/", response_model=str) async def create_agent(agent: AgentProfileWrapper) -> str: agent_profile = AgentProfile(**agent.model_dump()) agent_profile.save() @@ -108,26 +275,237 @@ async def create_agent(agent: AgentProfileWrapper) -> str: return pk -@app.post("/scenarios/", response_model=str) -async def create_scenario(scenario: EnvironmentProfileWrapper) -> str: - print(scenario) - scenario_profile = EnvironmentProfile(**scenario.model_dump()) - scenario_profile.save() - pk = scenario_profile.pk +@app.post("/relationship/", response_model=str) +async def create_relationship(relationship: RelationshipWrapper) -> str: + relationship_profile = RelationshipProfile(**relationship.model_dump()) + relationship_profile.save() + pk = relationship_profile.pk + assert pk is not None + return pk + + +@app.post("/evaluation_dimensions/", response_model=str) +async def create_evaluation_dimensions( + evaluation_dimensions: CustomEvaluationDimensionsWrapper, +) -> str: + dimension_list = CustomEvaluationDimensionList.find( + CustomEvaluationDimensionList.name == evaluation_dimensions.name + ).all() + + if len(dimension_list) == 0: + all_dimensions_pks = [] + for dimension in evaluation_dimensions.dimensions: + find_dimension = CustomEvaluationDimension.find( + CustomEvaluationDimension.name == dimension.name + ).all() + if len(find_dimension) == 0: + dimension.save() + all_dimensions_pks.append(dimension.pk) + elif len(find_dimension) == 1: + all_dimensions_pks.append(find_dimension[0].pk) + else: + raise HTTPException( + status_code=409, + detail=f"Evaluation dimension with name={dimension.name} already exists", + ) + + custom_evaluation_dimension_list = CustomEvaluationDimensionList( + pk=evaluation_dimensions.pk, + name=evaluation_dimensions.name, + dimension_pks=all_dimensions_pks, + ) + custom_evaluation_dimension_list.save() + logger.info(f"Created evaluation dimension list {evaluation_dimensions.name}") + else: + raise HTTPException( + status_code=409, + detail=f"Evaluation dimension list with name={evaluation_dimensions.name} already exists", + ) + + pk = custom_evaluation_dimension_list.pk assert pk is not None return pk +async def run_simulation( + episode_pk: str, + simulation_request: SimulationRequest, + simulation_status: NonStreamingSimulationStatus, +) -> None: + try: + env_profile: EnvironmentProfile = EnvironmentProfile.get( + pk=simulation_request.env_id + ) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, + detail=f"Environment with id={simulation_request.env_id} not found", + ) + try: + agent_1_profile = AgentProfile.get(pk=simulation_request.agent_ids[0]) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, + detail=f"Agent with id={simulation_request.agent_ids[0]} not found", + ) + try: + agent_2_profile = AgentProfile.get(pk=simulation_request.agent_ids[1]) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, + detail=f"Agent with id={simulation_request.agent_ids[1]} not found", + ) + + env_params: dict[str, Any] = { + "model_name": simulation_request.models[0], + "action_order": "round-robin", + "evaluators": [ + RuleBasedTerminatedEvaluator( + max_turn_number=simulation_request.max_turns, max_stale_turn=2 + ), + ], + "terminal_evaluators": [ + ReachGoalLLMEvaluator( + simulation_request.models[0], + EvaluationForTwoAgents[SotopiaDimensions], + ), + ], + } + env = ParallelSotopiaEnv(env_profile=env_profile, **env_params) + agents = Agents( + { + "agent1": LLMAgent( + "agent1", + model_name=simulation_request.models[1], + agent_profile=agent_1_profile, + ), + "agent2": LLMAgent( + "agent2", + model_name=simulation_request.models[2], + agent_profile=agent_2_profile, + ), + } + ) + + await arun_one_episode( + env=env, + agent_list=list(agents.values()), + push_to_db=True, + tag=simulation_request.tag, + episode_pk=episode_pk, + simulation_status=simulation_status, + ) + + +@app.post("/simulate/", response_model=str) +def simulate(simulation_request: SimulationRequest) -> Response: + try: + _: EnvironmentProfile = EnvironmentProfile.get(pk=simulation_request.env_id) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, + detail=f"Environment with id={simulation_request.env_id} not found", + ) + try: + __ = AgentProfile.get(pk=simulation_request.agent_ids[0]) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, + detail=f"Agent with id={simulation_request.agent_ids[0]} not found", + ) + try: + ___ = AgentProfile.get(pk=simulation_request.agent_ids[1]) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, + detail=f"Agent with id={simulation_request.agent_ids[1]} not found", + ) + + episode_pk = EpisodeLog( + environment="", + agents=[], + models=[], + messages=[], + reasoning="", + rewards=[], # Pseudorewards + rewards_prompt="", + ).pk + try: + simulation_status = NonStreamingSimulationStatus( + episode_pk=episode_pk, + status="Started", + ) + simulation_status.save() + queue = rq.Queue("default", connection=get_redis_connection()) + queue.enqueue( + run_simulation, + episode_pk=episode_pk, + simulation_request=simulation_request, + simulation_status=simulation_status, + ) + + except Exception as e: + logger.error(f"Error starting simulation: {e}") + simulation_status.status = "Error" + simulation_status.save() + return Response(content=episode_pk, status_code=202) + + +@app.get("/simulation_status/{episode_pk}", response_model=str) +async def get_simulation_status(episode_pk: str) -> str: + status = NonStreamingSimulationStatus.find( + NonStreamingSimulationStatus.episode_pk == episode_pk + ).all()[0] + assert isinstance(status, NonStreamingSimulationStatus) + return status.status + + @app.delete("/agents/{agent_id}", response_model=str) async def delete_agent(agent_id: str) -> str: - AgentProfile.delete(agent_id) - return agent_id + try: + agent = AgentProfile.get(pk=agent_id) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, detail=f"Agent with id={agent_id} not found" + ) + AgentProfile.delete(agent.pk) + assert agent.pk is not None + return agent.pk @app.delete("/scenarios/{scenario_id}", response_model=str) async def delete_scenario(scenario_id: str) -> str: - EnvironmentProfile.delete(scenario_id) - return scenario_id + try: + scenario = EnvironmentProfile.get(pk=scenario_id) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, detail=f"Scenario with id={scenario_id} not found" + ) + EnvironmentProfile.delete(scenario.pk) + assert scenario.pk is not None + return scenario.pk + + +@app.delete("/relationship/{relationship_id}", response_model=str) +async def delete_relationship(relationship_id: str) -> str: + RelationshipProfile.delete(relationship_id) + return relationship_id + + +@app.delete("/episodes/{episode_id}", response_model=str) +async def delete_episode(episode_id: str) -> str: + EpisodeLog.delete(episode_id) + return episode_id + + +@app.delete("/evaluation_dimensions/{evaluation_dimension_list_pk}", response_model=str) +async def delete_evaluation_dimension_list(evaluation_dimension_list_pk: str) -> str: + for dimension_pk in CustomEvaluationDimensionList.get( + evaluation_dimension_list_pk + ).dimension_pks: + CustomEvaluationDimension.delete(dimension_pk) + CustomEvaluationDimensionList.delete(evaluation_dimension_list_pk) + return evaluation_dimension_list_pk active_simulations: Dict[ @@ -135,5 +513,157 @@ async def delete_scenario(scenario_id: str) -> str: ] = {} # TODO check whether this is the correct way to store the active simulations +@app.get("/models", response_model=list[str]) +async def get_models() -> list[str]: + # TODO figure out how to get the available models + return ["gpt-4o-mini", "gpt-4o", "gpt-3.5-turbo"] + + +class SimulationState: + _instance: Optional["SimulationState"] = None + _lock = asyncio.Lock() + _active_simulations: dict[str, bool] = {} + + def __new__(cls) -> "SimulationState": + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._active_simulations = {} + return cls._instance + + async def try_acquire_token(self, token: str) -> tuple[bool, str]: + async with self._lock: + if not token: + return False, "Invalid token" + + if self._active_simulations.get(token): + return False, "Token is active already" + + self._active_simulations[token] = True + return True, "Token is valid" + + async def release_token(self, token: str) -> None: + async with self._lock: + self._active_simulations.pop(token, None) + + @asynccontextmanager + async def start_simulation(self, token: str) -> AsyncIterator[bool]: + try: + yield True + finally: + await self.release_token(token) + + +class SimulationManager: + def __init__(self) -> None: + self.state = SimulationState() + + async def verify_token(self, token: str) -> dict[str, Any]: + is_valid, msg = await self.state.try_acquire_token(token) + return {"is_valid": is_valid, "msg": msg} + + async def create_simulator( + self, env_id: str, agent_ids: list[str] + ) -> WebSocketSotopiaSimulator: + try: + return WebSocketSotopiaSimulator(env_id=env_id, agent_ids=agent_ids) + except Exception as e: + error_msg = f"Failed to create simulator: {e}" + logger.error(error_msg) + raise Exception(error_msg) + + async def handle_client_message( + self, + websocket: WebSocket, + simulator: WebSocketSotopiaSimulator, + message: dict[str, Any], + timeout: float = 0.1, + ) -> bool: + try: + msg_type = message.get("type") + if msg_type == WSMessageType.FINISH_SIM.value: + return True + # TODO handle other message types + return False + except Exception as e: + msg = f"Error handling client message: {e}" + logger.error(msg) + await self.send_error(websocket, ErrorType.INVALID_MESSAGE, msg) + return False + + async def run_simulation( + self, websocket: WebSocket, simulator: WebSocketSotopiaSimulator + ) -> None: + try: + async for message in simulator.arun(): + await self.send_message(websocket, WSMessageType.SERVER_MSG, message) + + try: + data = await asyncio.wait_for(websocket.receive_json(), timeout=0.1) + if await self.handle_client_message(websocket, simulator, data): + break + except asyncio.TimeoutError: + continue + + except Exception as e: + msg = f"Error running simulation: {e}" + logger.error(msg) + await self.send_error(websocket, ErrorType.SIMULATION_ISSUE, msg) + finally: + await self.send_message(websocket, WSMessageType.END_SIM, {}) + + @staticmethod + async def send_message( + websocket: WebSocket, msg_type: WSMessageType, data: dict[str, Any] + ) -> None: + await websocket.send_json({"type": msg_type.value, "data": data}) + + @staticmethod + async def send_error( + websocket: WebSocket, error_type: ErrorType, details: str = "" + ) -> None: + await websocket.send_json( + { + "type": WSMessageType.ERROR.value, + "data": {"type": error_type.value, "details": details}, + } + ) + + +@app.websocket("/ws/simulation") +async def websocket_endpoint(websocket: WebSocket, token: str) -> None: + manager = SimulationManager() + + token_status = await manager.verify_token(token) + if not token_status["is_valid"]: + await websocket.close(code=1008, reason=token_status["msg"]) + return + + try: + await websocket.accept() + + while True: + start_msg = await websocket.receive_json() + if start_msg.get("type") != WSMessageType.START_SIM.value: + continue + + async with manager.state.start_simulation(token): + simulator = await manager.create_simulator( + env_id=start_msg["data"]["env_id"], + agent_ids=start_msg["data"]["agent_ids"], + ) + await manager.run_simulation(websocket, simulator) + + except WebSocketDisconnect: + logger.info(f"Client disconnected: {token}") + except Exception as e: + logger.error(f"Unexpected error: {e}") + await manager.send_error(websocket, ErrorType.SIMULATION_ISSUE, str(e)) + finally: + try: + await websocket.close() + except Exception as e: + logger.error(f"Error closing websocket: {e}") + + if __name__ == "__main__": uvicorn.run(app, host="127.0.0.1", port=8800) diff --git a/sotopia/ui/websocket_utils.py b/sotopia/ui/websocket_utils.py new file mode 100644 index 000000000..5b29da732 --- /dev/null +++ b/sotopia/ui/websocket_utils.py @@ -0,0 +1,186 @@ +from sotopia.envs.evaluators import ( + EvaluationForTwoAgents, + ReachGoalLLMEvaluator, + RuleBasedTerminatedEvaluator, + SotopiaDimensions, +) +from sotopia.agents import Agents, LLMAgent +from sotopia.messages import Observation +from sotopia.envs import ParallelSotopiaEnv +from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog +from sotopia.server import arun_one_episode + +from enum import Enum +from typing import TypedDict, Any, AsyncGenerator +from pydantic import BaseModel + + +class WSMessageType(str, Enum): + SERVER_MSG = "SERVER_MSG" + CLIENT_MSG = "CLIENT_MSG" + ERROR = "ERROR" + START_SIM = "START_SIM" + END_SIM = "END_SIM" + FINISH_SIM = "FINISH_SIM" + + +class ErrorType(str, Enum): + NOT_AUTHORIZED = "NOT_AUTHORIZED" + SIMULATION_ALREADY_STARTED = "SIMULATION_ALREADY_STARTED" + SIMULATION_NOT_STARTED = "SIMULATION_NOT_STARTED" + SIMULATION_ISSUE = "SIMULATION_ISSUE" + INVALID_MESSAGE = "INVALID_MESSAGE" + OTHER = "OTHER" + + +class MessageForRendering(TypedDict): + role: str + type: str + content: str + + +class WSMessage(BaseModel): + type: WSMessageType + data: dict[str, Any] + + model_config = {"arbitrary_types_allowed": True, "protected_namespaces": ()} + + def to_json(self) -> dict[str, Any]: + return { + "type": self.type.value, # TODO check whether we want to use the enum value or the enum itself + "data": self.data, + } + + +def get_env_agents( + env_id: str, + agent_ids: list[str], + agent_models: list[str], + evaluator_model: str, +) -> tuple[ParallelSotopiaEnv, Agents, dict[str, Observation]]: + # environment_profile = EnvironmentProfile.find().all()[0] + # agent_profiles = AgentProfile.find().all()[:2] + assert len(agent_ids) == len( + agent_models + ), f"Provided {len(agent_ids)} agent_ids but {len(agent_models)} agent_models" + + environment_profile: EnvironmentProfile = EnvironmentProfile.get(env_id) + agent_profiles: list[AgentProfile] = [ + AgentProfile.get(agent_id) for agent_id in agent_ids + ] + + agent_list = [ + LLMAgent( + agent_profile=agent_profile, + model_name=agent_models[idx], + ) + for idx, agent_profile in enumerate(agent_profiles) + ] + for idx, goal in enumerate(environment_profile.agent_goals): + agent_list[idx].goal = goal + + agents = Agents({agent.agent_name: agent for agent in agent_list}) + env = ParallelSotopiaEnv( + action_order="round-robin", + model_name="gpt-4o-mini", + evaluators=[ + RuleBasedTerminatedEvaluator(max_turn_number=20, max_stale_turn=2), + ], + terminal_evaluators=[ + ReachGoalLLMEvaluator( + evaluator_model, + EvaluationForTwoAgents[SotopiaDimensions], + ), + ], + env_profile=environment_profile, + ) + + environment_messages = env.reset(agents=agents, omniscient=False) + agents.reset() + + return env, agents, environment_messages + + +def parse_reasoning(reasoning: str, num_agents: int) -> tuple[list[str], str]: + """Parse the reasoning string into a dictionary.""" + sep_token = "SEPSEP" + for i in range(1, num_agents + 1): + reasoning = ( + reasoning.replace(f"Agent {i} comments:\n", sep_token) + .strip(" ") + .strip("\n") + ) + all_chunks = reasoning.split(sep_token) + general_comment = all_chunks[0].strip(" ").strip("\n") + comment_chunks = all_chunks[-num_agents:] + + return comment_chunks, general_comment + + +class WebSocketSotopiaSimulator: + def __init__( + self, + env_id: str, + agent_ids: list[str], + agent_models: list[str] = ["gpt-4o-mini", "gpt-4o-mini"], + evaluator_model: str = "gpt-4o", + ) -> None: + self.env, self.agents, self.environment_messages = get_env_agents( + env_id, agent_ids, agent_models, evaluator_model + ) + self.messages: list[list[tuple[str, str, str]]] = [] + self.messages.append( + [ + ( + "Environment", + agent_name, + self.environment_messages[agent_name].to_natural_language(), + ) + for agent_name in self.env.agents + ] + ) + for index, agent_name in enumerate(self.env.agents): + self.agents[agent_name].goal = self.env.profile.agent_goals[index] + + async def arun(self) -> AsyncGenerator[dict[str, Any], None]: + # Use sotopia to run the simulation + generator = arun_one_episode( + env=self.env, + agent_list=list(self.agents.values()), + push_to_db=False, + streaming=True, + ) + + assert isinstance( + generator, AsyncGenerator + ), "generator should be async generator" + + async for messages in await generator: # type: ignore + reasoning, rewards = "", [0.0, 0.0] + eval_available = False + if messages[-1][0][0] == "Evaluation": + reasoning = messages[-1][0][2].to_natural_language() + rewards = eval(messages[-2][0][2].to_natural_language()) + eval_available = True + + epilog = EpisodeLog( + environment=self.env.profile.pk, + agents=[agent.profile.pk for agent in self.agents.values()], + tag="test", + models=["gpt-4o", "gpt-4o", "gpt-4o-mini"], + messages=[ + [(m[0], m[1], m[2].to_natural_language()) for m in messages_in_turn] + for messages_in_turn in messages + ], + reasoning=reasoning, + rewards=rewards, + rewards_prompt="", + ) + agent_profiles, parsed_messages = epilog.render_for_humans() + if not eval_available: + parsed_messages = parsed_messages[:-2] + + yield { + "type": "messages", + "messages": parsed_messages, + } diff --git a/stubs/redis_om/__init__.pyi b/stubs/redis_om/__init__.pyi index abbae6f43..133b6caff 100644 --- a/stubs/redis_om/__init__.pyi +++ b/stubs/redis_om/__init__.pyi @@ -2,6 +2,7 @@ import abc from typing import Any, Generator, TypeVar from pydantic import BaseModel +import redis from redis_om.model.model import Field from pydantic._internal._model_construction import ModelMetaclass from redis_om.model.model import FindQuery @@ -37,3 +38,5 @@ class EmbeddedJsonModel(JsonModel): ... class Migrator: def run(self) -> None: ... + +def get_redis_connection() -> redis.Redis[bytes]: ... diff --git a/tests/database/test_database.py b/tests/database/test_database.py index 5279ecc9c..142e2bd13 100644 --- a/tests/database/test_database.py +++ b/tests/database/test_database.py @@ -8,6 +8,7 @@ AgentProfile, EnvironmentProfile, EpisodeLog, + CustomEvaluationDimension, ) from sotopia.envs.parallel import ParallelSotopiaEnv from sotopia.messages import SimpleMessage @@ -42,6 +43,25 @@ def test_create_agent_profile() -> None: AgentProfile.delete(pk) +def test_create_custom_dimension() -> None: + custom_dimension = CustomEvaluationDimension( + name="verbosity_custom", + description="The verbosity of the conversation", + range_low=0, + range_high=10, + ) + custom_dimension.save() + pk = custom_dimension.pk + dimension = CustomEvaluationDimension.get(pk) + assert ( + dimension.name == custom_dimension.name + and dimension.description == custom_dimension.description + and dimension.range_low == custom_dimension.range_low + and dimension.range_high == custom_dimension.range_high + ) + CustomEvaluationDimension.delete(pk) + + @pytest.fixture def _test_create_episode_log_setup_and_tear_down() -> Generator[None, None, None]: AgentProfile(first_name="John", last_name="Doe", pk="tmppk_agent1").save() diff --git a/tests/experimental/test_agent.py b/tests/experimental/test_agent.py index 020c2131b..834c4286c 100644 --- a/tests/experimental/test_agent.py +++ b/tests/experimental/test_agent.py @@ -19,11 +19,13 @@ async def aact(self, observation: Tick) -> Tick: @pytest.mark.asyncio async def test_base_agent() -> None: async with ReturnPlusOneAgent( + node_name="test_base_agent", input_channel_types=[("input", Tick)], output_channel_types=[("output", Tick)], redis_url="redis://localhost:6379/0", ) as agent1: async with ReturnPlusOneAgent( + node_name="test_base_agent_2", input_channel_types=[("output", Tick)], output_channel_types=[("final", Tick)], redis_url="redis://localhost:6379/0", diff --git a/tests/sampler/test_sampler.py b/tests/sampler/test_sampler.py index 3b039f177..53a406918 100644 --- a/tests/sampler/test_sampler.py +++ b/tests/sampler/test_sampler.py @@ -33,16 +33,24 @@ def _test_create_episode_log_setup_and_tear_down() -> Generator[None, None, None age_constraint="[(18, 70), (18, 70)]", ).save() RelationshipProfile( - agent_1_id="tmppk_agent1", agent_2_id="tmppk_agent2", relationship=2 + agent_1_id="tmppk_agent1", + agent_2_id="tmppk_agent2", + relationship=2, + pk="tmppk_relationship1", ).save() RelationshipProfile( - agent_1_id="tmppk_agent1", agent_2_id="tmppk_agent3", relationship=2 + agent_1_id="tmppk_agent1", + agent_2_id="tmppk_agent3", + relationship=2, + pk="tmppk_relationship2", ).save() yield AgentProfile.delete("tmppk_agent1") AgentProfile.delete("tmppk_agent2") AgentProfile.delete("tmppk_agent3") EnvironmentProfile.delete("tmppk_environment") + RelationshipProfile.delete("tmppk_relationship1") + RelationshipProfile.delete("tmppk_relationship2") def _generate_name() -> str: diff --git a/tests/ui/test_fastapi.py b/tests/ui/test_fastapi.py index 9395104f3..0ce87278f 100644 --- a/tests/ui/test_fastapi.py +++ b/tests/ui/test_fastapi.py @@ -1,5 +1,12 @@ from fastapi.testclient import TestClient -from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog +from sotopia.database import ( + EnvironmentProfile, + AgentProfile, + EpisodeLog, + RelationshipProfile, + CustomEvaluationDimension, + CustomEvaluationDimensionList, +) from sotopia.messages import SimpleMessage from sotopia.ui.fastapi_server import app import pytest @@ -63,7 +70,9 @@ def create_dummy_episode_log() -> None: @pytest.fixture -def create_mock_data() -> Generator[None, None, None]: +def create_mock_data(request: pytest.FixtureRequest) -> Generator[None, None, None]: + for_posting = request.param if hasattr(request, "param") else False + def _create_mock_agent_profile() -> None: AgentProfile( first_name="John", @@ -71,6 +80,7 @@ def _create_mock_agent_profile() -> None: occupation="test_occupation", gender="test_gender", pk="tmppk_agent1", + tag="test_tag", ).save() AgentProfile( first_name="Jane", @@ -78,6 +88,7 @@ def _create_mock_agent_profile() -> None: occupation="test_occupation", gender="test_gender", pk="tmppk_agent2", + tag="test_tag", ).save() def _create_mock_env_profile() -> None: @@ -89,18 +100,81 @@ def _create_mock_env_profile() -> None: "C", ], pk="tmppk_env_profile", + tag="test_tag", ) env_profile.save() - _create_mock_agent_profile() - _create_mock_env_profile() + def _create_mock_relationship() -> None: + RelationshipProfile( + pk="tmppk_relationship", + agent_1_id="tmppk_agent1", + agent_2_id="tmppk_agent2", + relationship=1.0, + ).save() + + def _create_mock_evaluation_dimension() -> None: + CustomEvaluationDimension( + pk="tmppk_evaluation_dimension", + name="test_dimension", + description="test_description", + range_high=10, + range_low=-10, + ).save() + CustomEvaluationDimensionList( + pk="tmppk_evaluation_dimension_list", + name="test_dimension_list", + dimension_pks=["tmppk_evaluation_dimension"], + ).save() + if not for_posting: + _create_mock_agent_profile() + _create_mock_env_profile() + _create_mock_relationship() + _create_mock_evaluation_dimension() + print("created mock data") yield - AgentProfile.delete("tmppk_agent1") - AgentProfile.delete("tmppk_agent2") - EnvironmentProfile.delete("tmppk_env_profile") - EpisodeLog.delete("tmppk_episode_log") + try: + AgentProfile.delete("tmppk_agent1") + except Exception as e: + print(e) + try: + AgentProfile.delete("tmppk_agent2") + except Exception as e: + print(e) + try: + EnvironmentProfile.delete("tmppk_env_profile") + except Exception as e: + print(e) + try: + RelationshipProfile.delete("tmppk_relationship") + except Exception as e: + print(e) + try: + EpisodeLog.delete("tmppk_episode_log") + except Exception as e: + print(e) + + try: + EpisodeLog.delete("tmppk_episode_log") + except Exception as e: + print(e) + + try: + episodes = EpisodeLog.find(EpisodeLog.tag == "test_tag").all() + for episode in episodes: + EpisodeLog.delete(episode.pk) + except Exception as e: + print(e) + + try: + CustomEvaluationDimension.delete("tmppk_evaluation_dimension") + except Exception as e: + print(e) + try: + CustomEvaluationDimensionList.delete("tmppk_evaluation_dimension_list") + except Exception as e: + print(e) def test_get_scenarios_all(create_mock_data: Callable[[], None]) -> None: @@ -169,8 +243,24 @@ def test_get_episodes_by_tag(create_mock_data: Callable[[], None]) -> None: assert response.json()[0]["tag"] == tag +def test_get_relationship(create_mock_data: Callable[[], None]) -> None: + response = client.get("/relationship/tmppk_agent1/tmppk_agent2") + assert response.status_code == 200 + assert isinstance(response.json(), str) + assert response.json() == "1: know_by_name" + + +def test_get_evaluation_dimensions(create_mock_data: Callable[[], None]) -> None: + response = client.get("/evaluation_dimensions/") + assert response.status_code == 200 + assert isinstance(response.json(), dict) + assert response.json()["test_dimension_list"][0]["name"] == "test_dimension" + + +@pytest.mark.parametrize("create_mock_data", [True], indirect=True) def test_create_agent(create_mock_data: Callable[[], None]) -> None: agent_data = { + "pk": "tmppk_agent1", "first_name": "test_first_name", "last_name": "test_last_name", } @@ -179,8 +269,10 @@ def test_create_agent(create_mock_data: Callable[[], None]) -> None: assert isinstance(response.json(), str) +@pytest.mark.parametrize("create_mock_data", [True], indirect=True) def test_create_scenario(create_mock_data: Callable[[], None]) -> None: scenario_data = { + "pk": "tmppk_env_profile", "codename": "test_codename", "scenario": "test_scenario", "tag": "test", @@ -190,6 +282,41 @@ def test_create_scenario(create_mock_data: Callable[[], None]) -> None: assert isinstance(response.json(), str) +@pytest.mark.parametrize("create_mock_data", [True], indirect=True) +def test_create_relationship(create_mock_data: Callable[[], None]) -> None: + relationship_data = { + "pk": "tmppk_relationship", + "agent_1_id": "tmppk_agent1", + "agent_2_id": "tmppk_agent2", + "relationship": 1.0, + "tag": "test_tag", + } + response = client.post("/relationship", json=relationship_data) + assert response.status_code == 200 + assert isinstance(response.json(), str) + + +@pytest.mark.parametrize("create_mock_data", [True], indirect=True) +def test_create_evaluation_dimensions(create_mock_data: Callable[[], None]) -> None: + evaluation_dimension_data = { + "pk": "tmppk_evaluation_dimension_list", + "name": "test_dimension_list", + "dimensions": [ + { + "pk": "tmppk_evaluation_dimension", + "name": "test_dimension", + "description": "test_description", + "range_high": 10, + "range_low": -10, + } + ], + } + response = client.post("/evaluation_dimensions", json=evaluation_dimension_data) + print(response.json()) + assert response.status_code == 200 + assert isinstance(response.json(), str) + + def test_delete_agent(create_mock_data: Callable[[], None]) -> None: response = client.delete("/agents/tmppk_agent1") assert response.status_code == 200 @@ -200,3 +327,60 @@ def test_delete_scenario(create_mock_data: Callable[[], None]) -> None: response = client.delete("/scenarios/tmppk_env_profile") assert response.status_code == 200 assert isinstance(response.json(), str) + + +def test_delete_relationship(create_mock_data: Callable[[], None]) -> None: + response = client.delete("/relationship/tmppk_relationship") + assert response.status_code == 200 + assert isinstance(response.json(), str) + + +def test_delete_evaluation_dimension(create_mock_data: Callable[[], None]) -> None: + response = client.delete("/evaluation_dimensions/tmppk_evaluation_dimension_list") + assert response.status_code == 200 + assert isinstance(response.json(), str) + + +# def test_simulate(create_mock_data: Callable[[], None]) -> None: +# response = client.post( +# "/simulate", +# json={ +# "env_id": "tmppk_env_profile", +# "agent_ids": ["tmppk_agent1", "tmppk_agent2"], +# "models": [ +# # "custom/llama3.2:1b@http://localhost:8000/v1", +# # "custom/llama3.2:1b@http://localhost:8000/v1", +# # "custom/llama3.2:1b@http://localhost:8000/v1" +# "gpt-4o-mini", +# "gpt-4o-mini", +# "gpt-4o-mini", +# ], +# "max_turns": 2, +# "tag": "test_tag", +# }, +# ) +# assert response.status_code == 200 +# assert isinstance(response.json(), str) +# max_retries = 20 +# retry_count = 0 +# while retry_count < max_retries: +# try: +# status = NonStreamingSimulationStatus.find( +# NonStreamingSimulationStatus.episode_pk == response.json() +# ).all()[0] +# assert isinstance(status, NonStreamingSimulationStatus) +# print(status) +# if status.status == "Error": +# raise Exception("Error running simulation") +# elif status.status == "Completed": +# # EpisodeLog.get(response.json()) +# break +# # Status is "Started", keep polling +# time.sleep(1) +# retry_count += 1 +# except Exception as e: +# print(f"Error checking simulation status: {e}") +# time.sleep(1) +# retry_count += 1 +# else: +# raise TimeoutError("Simulation timed out after 10 retries") diff --git a/uv.lock b/uv.lock index 5017e0e00..217200dd2 100644 --- a/uv.lock +++ b/uv.lock @@ -10,9 +10,10 @@ resolution-markers = [ [[package]] name = "aact" version = "0.0.10" -source = { registry = "https://pypi.org/simple" } +source = { git = "https://github.com/ProKil/aact?branch=feature%2Fnode-manager#56cd2a2aad8a0e806e4f3a170e848cb1e1ad0720" } dependencies = [ { name = "aiofiles" }, + { name = "aiohttp" }, { name = "aiostream" }, { name = "numpy" }, { name = "pydantic" }, @@ -22,10 +23,6 @@ dependencies = [ { name = "tomlkit", marker = "python_full_version < '3.11'" }, { name = "typer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6e/9f/2b32aca3e2fe614df4e04a074870b6b27ef037af62f639b0e4d0b33abb31/aact-0.0.10.tar.gz", hash = "sha256:0cde5360d27bab002a43e9895c4006bfa541f6c2db798412f4aad1fdb685632e", size = 113329 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/31/18/32beed32416f8c9618ed4fc42e33eef94d7c181caf59c6909b3841047006/aact-0.0.10-py3-none-any.whl", hash = "sha256:2c1959666270acc681aafc1452aa089cb26a24a0871b01faa7761fa300b2fc9a", size = 29102 }, -] [[package]] name = "absl-py" @@ -3144,7 +3141,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "aact" }, + { name = "aact", git = "https://github.com/ProKil/aact?branch=feature%2Fnode-manager" }, { name = "absl-py", specifier = ">=2.0.0,<3.0.0" }, { name = "anthropic", marker = "extra == 'anthropic'" }, { name = "beartype", specifier = ">=0.14.0,<0.20.0" },