Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Supervisor Child of LLM Router #24

Merged
merged 1 commit into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/agent_of_flo_ai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand All @@ -28,7 +28,7 @@
"True"
]
},
"execution_count": 6,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -52,16 +52,16 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<flo_ai.state.flo_session.FloSession at 0x117b21dd0>"
"<flo_ai.state.flo_session.FloSession at 0x112a42a90>"
]
},
"execution_count": 7,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
Expand Down
48 changes: 3 additions & 45 deletions flo_ai/router/flo_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
from flo_ai.state.flo_session import FloSession
from flo_ai.constants.prompt_constants import FLO_FINISH
from flo_ai.helpers.utils import randomize_name
from flo_ai.router.flo_router import FloRouter
from flo_ai.router.flo_llm_router import FloLLMRouter, StateUpdateComponent
from flo_ai.models.flo_team import FloTeam
from flo_ai.models.flo_routed_team import FloRoutedTeam
from langgraph.graph import StateGraph
from flo_ai.state.flo_state import TeamFloAgentState


# TODO, maybe add description about what team members can do
supervisor_system_message = (
"You are a supervisor tasked with managing a conversation between the"
Expand All @@ -22,16 +21,7 @@
" respond with FINISH. "
)

class StateUpdateComponent:
def __init__(self, name: str, session: FloSession) -> None:
self.name = name
self.inner_session = session

def __call__(self, input):
self.inner_session.append(self.name)
return input

class FloSupervisor(FloRouter):
class FloSupervisor(FloLLMRouter):

def __init__(self,
session: FloSession,
Expand All @@ -44,44 +34,12 @@ def __init__(self,
flo_team = flo_team,
executor = executor
)

def build_agent_graph(self):
flo_agent_nodes = [self.build_node(flo_agent) for flo_agent in self.members]
workflow = StateGraph(TeamFloAgentState)
for flo_agent_node in flo_agent_nodes:
workflow.add_node(flo_agent_node.name, flo_agent_node.func)
workflow.add_node(self.router_name, self.executor)
for member in self.member_names:
workflow.add_edge(member, self.router_name)
workflow.add_conditional_edges(self.router_name, self.router_fn)
workflow.set_entry_point(self.router_name)
workflow_graph = workflow.compile()
return FloRoutedTeam(self.flo_team.name, workflow_graph)

def build_team_graph(self):
flo_team_entry_chains = [self.build_node_for_teams(flo_agent) for flo_agent in self.members]
# Define the graph.
super_graph = StateGraph(TeamFloAgentState)
# First add the nodes, which will do the work
for flo_team_chain in flo_team_entry_chains:
super_graph.add_node(flo_team_chain.name, flo_team_chain.func)
super_graph.add_node(self.router_name, self.executor)

for member in self.member_names:
super_graph.add_edge(member, self.router_name)

super_graph.add_conditional_edges(self.router_name, self.router_fn)

super_graph.set_entry_point(self.router_name)
super_graph = super_graph.compile()
return FloRoutedTeam(self.flo_team.name, super_graph)

class Builder:
def __init__(self,
session: FloSession,
name: str,
flo_team: FloTeam,
supervisor_prompt: Union[ChatPromptTemplate, None] = None,
llm: Union[BaseLanguageModel, None] = None) -> None:
# TODO add validation for reporteess
self.name = randomize_name(name)
Expand All @@ -102,7 +60,7 @@ def __init__(self,
" Or should we FINISH if the task is already answered, Select one of: {options}",
),
]
).partial(options=str(self.options), members=", ".join(self.members), member_type=member_type) if supervisor_prompt is None else supervisor_prompt
).partial(options=str(self.options), members=", ".join(self.members), member_type=member_type)

def build(self):
function_def = {
Expand Down
Loading