Skip to content

Commit

Permalink
Custom selector function for SelectorGroupChat (#4026)
Browse files Browse the repository at this point in the history
* Custom selector function for SelectorGroupChat

* Update documentation
  • Loading branch information
ekzhu authored Nov 1, 2024
1 parent 369ffb5 commit 173acc6
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,21 +131,27 @@ class AssistantAgent(BaseChatAgent):
.. code-block:: python
import asyncio
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.task import MaxMessageTermination
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent(name="assistant", model_client=model_client)
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent(name="assistant", model_client=model_client)
await agent.run("What is the capital of France?", termination_condition=MaxMessageTermination(2))
result await agent.run("What is the capital of France?", termination_condition=MaxMessageTermination(2))
print(result)
asyncio.run(main())
The following example demonstrates how to create an assistant agent with
a model client and a tool, and generate a stream of messages for a task.
.. code-block:: python
import asyncio
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.task import MaxMessageTermination
Expand All @@ -155,14 +161,17 @@ async def get_current_time() -> str:
return "The current time is 12:00 PM."
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time])
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time])
stream = agent.run_stream("What is the current time?", termination_condition=MaxMessageTermination(3))
stream = agent.run_stream("What is the current time?", termination_condition=MaxMessageTermination(3))
async for message in stream:
print(message)
async for message in stream:
print(message)
asyncio.run(main())
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,25 @@ class TerminationCondition(ABC):
.. code-block:: python
import asyncio
from autogen_agentchat.teams import MaxTurnsTermination, TextMentionTermination
# Terminate the conversation after 10 turns or if the text "TERMINATE" is mentioned.
cond1 = MaxTurnsTermination(10) | TextMentionTermination("TERMINATE")
# Terminate the conversation after 10 turns and if the text "TERMINATE" is mentioned.
cond2 = MaxTurnsTermination(10) & TextMentionTermination("TERMINATE")
async def main() -> None:
# Terminate the conversation after 10 turns or if the text "TERMINATE" is mentioned.
cond1 = MaxTurnsTermination(10) | TextMentionTermination("TERMINATE")
...
# Terminate the conversation after 10 turns and if the text "TERMINATE" is mentioned.
cond2 = MaxTurnsTermination(10) & TextMentionTermination("TERMINATE")
# Reset the termination condition.
await cond1.reset()
await cond2.reset()
# ...
# Reset the termination condition.
await cond1.reset()
await cond2.reset()
asyncio.run(main())
"""

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,46 +61,55 @@ class RoundRobinGroupChat(BaseGroupChat):
.. code-block:: python
import asyncio
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.task import StopMessageTermination
model_client = OpenAIChatCompletionClient(model="gpt-4o")
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
async def get_weather(location: str) -> str:
return f"The weather in {location} is sunny."
async def get_weather(location: str) -> str:
return f"The weather in {location} is sunny."
assistant = AssistantAgent(
"Assistant",
model_client=model_client,
tools=[get_weather],
)
team = RoundRobinGroupChat([assistant])
stream = team.run_stream("What's the weather in New York?", termination_condition=StopMessageTermination())
async for message in stream:
print(message)
assistant = AssistantAgent(
"Assistant",
model_client=model_client,
tools=[get_weather],
)
team = RoundRobinGroupChat([assistant])
stream = team.run_stream("What's the weather in New York?", termination_condition=StopMessageTermination())
async for message in stream:
print(message)
asyncio.run(main())
A team with multiple participants:
.. code-block:: python
import asyncio
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.task import StopMessageTermination
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent1 = AssistantAgent("Assistant1", model_client=model_client)
agent2 = AssistantAgent("Assistant2", model_client=model_client)
team = RoundRobinGroupChat([agent1, agent2])
stream = team.run_stream("Tell me some jokes.", termination_condition=StopMessageTermination())
async for message in stream:
print(message)
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent1 = AssistantAgent("Assistant1", model_client=model_client)
agent2 = AssistantAgent("Assistant2", model_client=model_client)
team = RoundRobinGroupChat([agent1, agent2])
stream = team.run_stream("Tell me some jokes.", termination_condition=StopMessageTermination())
async for message in stream:
print(message)
asyncio.run(main())
"""

def __init__(self, participants: List[ChatAgent]):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
import re
from typing import Callable, Dict, List
from typing import Callable, Dict, List, Sequence

from autogen_core.components.models import ChatCompletionClient, SystemMessage

from ... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
from ...base import ChatAgent, TerminationCondition
from ...messages import MultiModalMessage, StopMessage, TextMessage
from ...messages import ChatMessage, MultiModalMessage, StopMessage, TextMessage
from .._events import (
GroupChatPublishEvent,
GroupChatSelectSpeakerEvent,
Expand All @@ -20,7 +20,7 @@

class SelectorGroupChatManager(BaseGroupChatManager):
"""A group chat manager that selects the next speaker using a ChatCompletion
model."""
model and a custom selector function."""

def __init__(
self,
Expand All @@ -32,6 +32,7 @@ def __init__(
model_client: ChatCompletionClient,
selector_prompt: str,
allow_repeated_speaker: bool,
selector_func: Callable[[Sequence[ChatMessage]], str | None] | None,
) -> None:
super().__init__(
parent_topic_type,
Expand All @@ -44,12 +45,24 @@ def __init__(
self._selector_prompt = selector_prompt
self._previous_speaker: str | None = None
self._allow_repeated_speaker = allow_repeated_speaker
self._selector_func = selector_func

async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str:
"""Selects the next speaker in a group chat using a ChatCompletion client.
"""Selects the next speaker in a group chat using a ChatCompletion client,
with the selector function as override if it returns a speaker name.
A key assumption is that the agent type is the same as the topic type, which we use as the agent name.
"""

# Use the selector function if provided.
if self._selector_func is not None:
speaker = self._selector_func([msg.agent_message for msg in thread])
if speaker is not None:
# Skip the model based selection.
event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=speaker, source=self.id))
return speaker

# Construct the history of the conversation.
history_messages: List[str] = []
for event in thread:
msg = event.agent_message
Expand Down Expand Up @@ -160,6 +173,10 @@ class SelectorGroupChat(BaseGroupChat):
Must contain '{roles}', '{participants}', and '{history}' to be filled in.
allow_repeated_speaker (bool, optional): Whether to allow the same speaker to be selected
consecutively. Defaults to False.
selector_func (Callable[[Sequence[ChatMessage]], str | None], optional): A custom selector
function that takes the conversation history and returns the name of the next speaker.
If provided, this function will be used to override the model to select the next speaker.
If the function returns None, the model will be used to select the next speaker.
Raises:
ValueError: If the number of participants is less than two or if the selector prompt is invalid.
Expand All @@ -175,51 +192,104 @@ class SelectorGroupChat(BaseGroupChat):
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.task import StopMessageTermination
model_client = OpenAIChatCompletionClient(model="gpt-4o")
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
async def lookup_hotel(location: str) -> str:
return f"Here are some hotels in {location}: hotel1, hotel2, hotel3."
async def lookup_hotel(location: str) -> str:
return f"Here are some hotels in {location}: hotel1, hotel2, hotel3."
async def lookup_flight(origin: str, destination: str) -> str:
return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3."
async def lookup_flight(origin: str, destination: str) -> str:
return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3."
async def book_trip() -> str:
return "Your trip is booked!"
travel_advisor = AssistantAgent(
"Travel_Advisor",
model_client,
tools=[book_trip],
description="Helps with travel planning.",
)
hotel_agent = AssistantAgent(
"Hotel_Agent",
model_client,
tools=[lookup_hotel],
description="Helps with hotel booking.",
)
flight_agent = AssistantAgent(
"Flight_Agent",
model_client,
tools=[lookup_flight],
description="Helps with flight booking.",
)
team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=model_client)
stream = team.run_stream("Book a 3-day trip to new york.", termination_condition=StopMessageTermination())
async for message in stream:
print(message)
async def book_trip() -> str:
return "Your trip is booked!"
import asyncio
travel_advisor = AssistantAgent(
"Travel_Advisor",
model_client,
tools=[book_trip],
description="Helps with travel planning.",
)
hotel_agent = AssistantAgent(
"Hotel_Agent",
model_client,
tools=[lookup_hotel],
description="Helps with hotel booking.",
)
flight_agent = AssistantAgent(
"Flight_Agent",
model_client,
tools=[lookup_flight],
description="Helps with flight booking.",
)
team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=model_client)
stream = team.run_stream("Book a 3-day trip to new york.", termination_condition=StopMessageTermination())
async for message in stream:
print(message)
asyncio.run(main())
A team with a custom selector function:
.. code-block:: python
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.task import TextMentionTermination
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
def check_caculation(x: int, y: int, answer: int) -> str:
if x + y == answer:
return "Correct!"
else:
return "Incorrect!"
agent1 = AssistantAgent(
"Agent1",
model_client,
description="For calculation",
system_message="Calculate the sum of two numbers",
)
agent2 = AssistantAgent(
"Agent2",
model_client,
tools=[check_caculation],
description="For checking calculation",
system_message="Check the answer and respond with 'Correct!' or 'Incorrect!'",
)
def selector_func(messages):
if len(messages) == 1 or messages[-1].content == "Incorrect!":
return "Agent1"
if messages[-1].source == "Agent1":
return "Agent2"
return None
team = SelectorGroupChat([agent1, agent2], model_client=model_client, selector_func=selector_func)
stream = team.run_stream("What is 1 + 1?", termination_condition=TextMentionTermination("Correct!"))
async for message in stream:
print(message)
import asyncio
asyncio.run(main())
"""

def __init__(
self,
participants: List[ChatAgent],
model_client: ChatCompletionClient,
*,
termination_condition: TerminationCondition | None = None,
selector_prompt: str = """You are in a role play game. The following roles are available:
{roles}.
Read the following conversation. Then select the next role from {participants} to play. Only return the role.
Expand All @@ -229,6 +299,7 @@ def __init__(
Read the above conversation. Then select the next role from {participants} to play. Only return the role.
""",
allow_repeated_speaker: bool = False,
selector_func: Callable[[Sequence[ChatMessage]], str | None] | None = None,
):
super().__init__(participants, group_chat_manager_class=SelectorGroupChatManager)
# Validate the participants.
Expand All @@ -244,6 +315,7 @@ def __init__(
self._selector_prompt = selector_prompt
self._model_client = model_client
self._allow_repeated_speaker = allow_repeated_speaker
self._selector_func = selector_func

def _create_group_chat_manager_factory(
self,
Expand All @@ -262,4 +334,5 @@ def _create_group_chat_manager_factory(
self._model_client,
self._selector_prompt,
self._allow_repeated_speaker,
self._selector_func,
)
Loading

0 comments on commit 173acc6

Please sign in to comment.