Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: FastAPI Implementation of Sotopia Part Two (w websocket) #252

Merged
merged 26 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
decacd9
api doc
XuhuiZhou Oct 30, 2024
b62a554
add PUT
XuhuiZhou Nov 5, 2024
1942a14
add an temp example for websocket
XuhuiZhou Nov 7, 2024
a5e0cc2
websocket
XuhuiZhou Nov 8, 2024
1cddc10
update readme
XuhuiZhou Nov 8, 2024
4cd747f
Update README.md
ProKil Nov 9, 2024
2342dc5
update websocket live simulation api doc
bugsz Nov 10, 2024
4977993
[autofix.ci] apply automated fixes
autofix-ci[bot] Nov 10, 2024
215cf8c
update websocket doc
bugsz Nov 11, 2024
ea2f92c
Merge branch 'feature/sotopia-ui-doc' of https://github.com/sotopia-l…
bugsz Nov 11, 2024
ae7f7a8
add api server with websocket as well as a client
bugsz Nov 15, 2024
f6aeecf
fix mypy errors
bugsz Nov 15, 2024
ead08b4
support stopping the chat
bugsz Nov 17, 2024
ed6e437
add 404 to the status code
bugsz Nov 17, 2024
c345711
fix mypy issue
bugsz Nov 17, 2024
5e5a9c0
update the returned message types
bugsz Nov 20, 2024
f8a878d
redesign websocket api
bugsz Nov 27, 2024
031cb92
update websocket, fix mypy error
XuhuiZhou Dec 2, 2024
006d6e9
add example of using websocket
XuhuiZhou Dec 3, 2024
48a48fa
clean code & change to existing functions for simulation
bugsz Dec 4, 2024
fea806d
fix typing mismatch
bugsz Dec 4, 2024
a3df910
Merge branch 'feature/sotopia-ui-fastapi-websocket' of https://github…
bugsz Dec 4, 2024
d50cbd4
Merge branch 'main' of https://github.com/sotopia-lab/sotopia into fe…
bugsz Dec 4, 2024
3e5beb6
update doc & mypy type fix
bugsz Dec 4, 2024
af63410
add type check for run_async_server
bugsz Dec 4, 2024
67eb180
move example
XuhuiZhou Dec 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 115 additions & 86 deletions sotopia/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import itertools
import logging
from typing import Literal, Sequence, Type
from typing import Literal, Sequence, Type, AsyncGenerator, Union

import gin
import rich
Expand All @@ -25,7 +25,7 @@
unweighted_aggregate_evaluate,
)
from sotopia.generation_utils.generate import LLM_Name, agenerate_script
from sotopia.messages import AgentAction, Message, Observation
from sotopia.messages import AgentAction, Message, Observation, SimpleMessage
from sotopia.messages.message_classes import (
ScriptBackground,
ScriptEnvironmentResponse,
Expand Down Expand Up @@ -104,6 +104,12 @@ def run_sync_server(
return messages


def flatten_listed_messages(
messages: list[list[tuple[str, str, Message]]],
) -> list[tuple[str, str, Message]]:
return list(itertools.chain.from_iterable(messages))


@gin.configurable
async def arun_one_episode(
env: ParallelSotopiaEnv,
Expand All @@ -113,102 +119,125 @@ async def arun_one_episode(
json_in_script: bool = False,
tag: str | None = None,
push_to_db: bool = False,
) -> list[tuple[str, str, Message]]:
streaming: bool = False,
) -> Union[
list[tuple[str, str, Message]],
AsyncGenerator[list[list[tuple[str, str, Message]]], None],
]:
agents = Agents({agent.agent_name: agent for agent in agent_list})
environment_messages = env.reset(agents=agents, omniscient=omniscient)
agents.reset()

messages: list[list[tuple[str, str, Message]]] = []

# Main Event Loop
done = False
messages.append(
[
("Environment", agent_name, environment_messages[agent_name])
for agent_name in env.agents
]
)
# set goal for agents
for index, agent_name in enumerate(env.agents):
agents[agent_name].goal = env.profile.agent_goals[index]
rewards: list[list[float]] = []
reasons: list[str] = []
while not done:
# gather agent messages
agent_messages: dict[str, AgentAction] = dict()
actions = await asyncio.gather(
*[
agents[agent_name].aact(environment_messages[agent_name])
for agent_name in env.agents
]
)
if script_like:
# manually mask one message
agent_mask = env.action_mask
for idx in range(len(agent_mask)):
print("Current mask: ", agent_mask)
if agent_mask[idx] == 0:
print("Action not taken: ", actions[idx])
actions[idx] = AgentAction(action_type="none", argument="")
else:
print("Current action taken: ", actions[idx])

# actions = cast(list[AgentAction], actions)
for idx, agent_name in enumerate(env.agents):
agent_messages[agent_name] = actions[idx]

messages[-1].append((agent_name, "Environment", agent_messages[agent_name]))
async def generate_messages() -> (
AsyncGenerator[list[list[tuple[str, str, Message]]], None]
):
environment_messages = env.reset(agents=agents, omniscient=omniscient)
agents.reset()
messages: list[list[tuple[str, str, Message]]] = []

# send agent messages to environment
(
environment_messages,
rewards_in_turn,
terminated,
___,
info,
) = await env.astep(agent_messages)
# Main Event Loop
done = False
messages.append(
[
("Environment", agent_name, environment_messages[agent_name])
for agent_name in env.agents
]
)
# print("Environment message: ", environment_messages)
# exit(0)
rewards.append([rewards_in_turn[agent_name] for agent_name in env.agents])
reasons.append(
" ".join(info[agent_name]["comments"] for agent_name in env.agents)
yield messages

# set goal for agents
for index, agent_name in enumerate(env.agents):
agents[agent_name].goal = env.profile.agent_goals[index]
rewards: list[list[float]] = []
reasons: list[str] = []
while not done:
# gather agent messages
agent_messages: dict[str, AgentAction] = dict()
actions = await asyncio.gather(
*[
agents[agent_name].aact(environment_messages[agent_name])
for agent_name in env.agents
]
)
if script_like:
# manually mask one message
agent_mask = env.action_mask
for idx in range(len(agent_mask)):
if agent_mask[idx] == 0:
actions[idx] = AgentAction(action_type="none", argument="")
else:
pass

# actions = cast(list[AgentAction], actions)
for idx, agent_name in enumerate(env.agents):
agent_messages[agent_name] = actions[idx]

messages[-1].append(
(agent_name, "Environment", agent_messages[agent_name])
)

# send agent messages to environment
(
environment_messages,
rewards_in_turn,
terminated,
___,
info,
) = await env.astep(agent_messages)
messages.append(
[
("Environment", agent_name, environment_messages[agent_name])
for agent_name in env.agents
]
)

yield messages
rewards.append([rewards_in_turn[agent_name] for agent_name in env.agents])
reasons.append(
" ".join(info[agent_name]["comments"] for agent_name in env.agents)
)
done = all(terminated.values())

epilog = EpisodeLog(
environment=env.profile.pk,
agents=[agent.profile.pk for agent in agent_list],
tag=tag,
models=[env.model_name, agent_list[0].model_name, agent_list[1].model_name],
messages=[
[(m[0], m[1], m[2].to_natural_language()) for m in messages_in_turn]
for messages_in_turn in messages
],
reasoning=info[env.agents[0]]["comments"],
rewards=[info[agent_name]["complete_rating"] for agent_name in env.agents],
rewards_prompt=info["rewards_prompt"]["overall_prompt"],
)
done = all(terminated.values())
rich.print(epilog.rewards_prompt)
agent_profiles, conversation = epilog.render_for_humans()
for agent_profile in agent_profiles:
rich.print(agent_profile)
for message in conversation:
rich.print(message)

if streaming:
# yield the rewards and reasonings
messages.append(
[("Evaluation", "Rewards", SimpleMessage(message=str(epilog.rewards)))]
)
messages.append(
[("Evaluation", "Reasoning", SimpleMessage(message=epilog.reasoning))]
)
yield messages

# TODO: clean up this part
epilog = EpisodeLog(
environment=env.profile.pk,
agents=[agent.profile.pk for agent in agent_list],
tag=tag,
models=[env.model_name, agent_list[0].model_name, agent_list[1].model_name],
messages=[
[(m[0], m[1], m[2].to_natural_language()) for m in messages_in_turn]
for messages_in_turn in messages
],
reasoning=info[env.agents[0]]["comments"],
rewards=[info[agent_name]["complete_rating"] for agent_name in env.agents],
rewards_prompt=info["rewards_prompt"]["overall_prompt"],
)
rich.print(epilog.rewards_prompt)
agent_profiles, conversation = epilog.render_for_humans()
for agent_profile in agent_profiles:
rich.print(agent_profile)
for message in conversation:
rich.print(message)
if push_to_db:
try:
epilog.save()
except Exception as e:
logging.error(f"Failed to save episode log: {e}")

if push_to_db:
try:
epilog.save()
except Exception as e:
logging.error(f"Failed to save episode log: {e}")
# flatten nested list messages
return list(itertools.chain(*messages))
if streaming:
return generate_messages()
else:
async for last_messages in generate_messages():
pass
return flatten_listed_messages(last_messages)


@gin.configurable
Expand Down
68 changes: 43 additions & 25 deletions sotopia/ui/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ returns:
- agents: list[AgentProfile]


#### GET /episodes
bugsz marked this conversation as resolved.
Show resolved Hide resolved

Get all episodes.

returns:
- episodes: list[Episode]

#### GET /episodes/?get_by={id|tag}/{episode_id|episode_tag}

Get episode by episode_tag.
Expand Down Expand Up @@ -78,33 +85,31 @@ EnvironmentProfile
returns:
- scenario_id: str

#### DELETE /agents/{agent_id}
### Updating Data in the API Server

Delete agent profile from the API server.
#### PUT /agents/{agent_id}

Update agent profile in the API server.
Request Body:
AgentProfile

returns:
- agent_id: str

#### DELETE /scenarios/{scenario_id}

Delete scenario profile from the API server.
#### PUT /scenarios/{scenario_id}

Update scenario profile in the API server.
Request Body:
EnvironmentProfile

returns:
- scenario_id: str


### Error Code
For RESTful APIs above we have the following error codes:
| **Error Code** | **Description** |
|-----------------|--------------------------------------|
| **404** | A resource is not found |
| **403** | The query is not authorized |
| **500** | Internal running error |

### Initiating a new non-streaming simulation episode

#### POST /episodes/
[!] Currently not planning to implement

```python
class SimulationEpisodeInitiation(BaseModel):
scenario_id: str
Expand Down Expand Up @@ -147,14 +152,14 @@ returns:
| Type | Direction | Description |
|-----------|--------|-------------|
| SERVER_MSG | Server β†’ Client | Standard message from server (payload: `messageForRendering` [here](https://github.com/sotopia-lab/sotopia-demo/blob/main/socialstream/rendering_utils.py) ) |
| CLIENT_MSG | Client β†’ Server | Standard message from client (payload: Currently not needed) |
| ERROR | Server β†’ Client | Error notification (payload: `{"type": ERROR_TYPE, "description": DESC}`) |
| CLIENT_MSG | Client β†’ Server | Standard message from client (payload: TBD) |
| ERROR | Server β†’ Client | Error notification (payload: TBD) |
| START_SIM | Client β†’ Server | Initialize simulation (payload: `SimulationEpisodeInitialization`) |
| END_SIM | Client β†’ Server | End simulation (payload: not needed) |
| FINISH_SIM | Server β†’ Client | Terminate simulation (payload: not needed) |


**ERROR_TYPE**
**Error Type**

| Error Code | Description |
|------------|-------------|
Expand All @@ -167,14 +172,27 @@ returns:
| OTHER | Other unspecified errors |


**Conversation Message From the Server**
The server returns messages encapsulated in a structured format which is defined as follows:

```python
class MessageForRendering(TypedDict):
role: str # Specifies the origin of the message. Common values include "Background Info", "Environment", "{Agent Names}
type: str # Categorizes the nature of the message. Common types include: "comment", "said", "action"
content: str
**Implementation plan**: Currently only support LLM-LLM simulation based on [this function](https://github.com/sotopia-lab/sotopia/blob/19d39e068c3bca9246fc366e5759414f62284f93/sotopia/server.py#L108).


## An example to run simulation with the API

**Get all scenarios**:
```bash
curl -X GET "http://localhost:8000/scenarios"
```

**Implementation plan**: Currently only support LLM-LLM simulation based on [this function](https://github.com/sotopia-lab/sotopia/blob/19d39e068c3bca9246fc366e5759414f62284f93/sotopia/server.py#L108).
Randomly select a scenario, e.g., `01HZRGTG1K4YQ2CBS9SNH28R9S`
XuhuiZhou marked this conversation as resolved.
Show resolved Hide resolved


**Get all agents**:
```bash
curl -X GET "http://localhost:8000/agents"
```

Randomly select two agents, e.g., `01H5TNE5PE9RQGH86YM6MSWZMW` and `01H5TNE5PT06B3QPXJ65HHACV7`

**Connecting to the websocket server**:

@bugsz: Adding ur example here?
Loading
Loading