This repository has been archived by the owner on Sep 19, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f119a72
commit 5ab9eda
Showing
6 changed files
with
333 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from falkordb_gemini_kg.kg import KnowledgeGraph | ||
|
||
|
||
class Agent: | ||
|
||
def __init__(self, id: str, kg: KnowledgeGraph, introduction: str): | ||
self.id = id | ||
self._kg = kg | ||
self._introduction = introduction | ||
|
||
def ask(self, question: str): | ||
return self._kg.ask(question) | ||
|
||
def to_orchestrator(self): | ||
return f""" | ||
--- | ||
Agent ID: {self.id} | ||
Knowledge Graph Name: {self._kg.name} | ||
Introduction: {self._introduction} | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from json import loads | ||
|
||
|
||
class StepBlockType: | ||
PARALLEL = "parallel" | ||
PROMPT_AGENT = "prompt_agent" | ||
SUMMARY = "summary" | ||
|
||
@staticmethod | ||
def from_str(text: str) -> "StepBlockType": | ||
if text == StepBlockType.PARALLEL: | ||
return StepBlockType.PARALLEL | ||
elif text == StepBlockType.PROMPT_AGENT: | ||
return StepBlockType.PROMPT_AGENT | ||
elif text == StepBlockType.SUMMARY: | ||
return StepBlockType.SUMMARY | ||
|
||
|
||
class PromptAgentProperties: | ||
agent_id: str | ||
prompt: str | ||
response: str | None = None | ||
|
||
@staticmethod | ||
def from_json(json: dict) -> "PromptAgentProperties": | ||
return PromptAgentProperties( | ||
json["agent_id"], json["prompt"], json.get("response", None) | ||
) | ||
|
||
def to_json(self) -> dict: | ||
return { | ||
"agent_id": self.agent_id, | ||
"prompt": self.prompt, | ||
"response": self.response, | ||
} | ||
|
||
|
||
class ParallelProperties: | ||
steps: list["PlanStep"] | ||
|
||
@staticmethod | ||
def from_json(json: dict) -> "ParallelProperties": | ||
return ParallelProperties([PlanStep.from_json(step) for step in json["steps"]]) | ||
|
||
def to_json(self) -> dict: | ||
return {"steps": [step.to_json() for step in self.steps]} | ||
|
||
|
||
class PlanStep: | ||
id: str | ||
block: StepBlockType | ||
properties: PromptAgentProperties | ParallelProperties | ||
|
||
@staticmethod | ||
def from_json(json: dict) -> "PlanStep": | ||
block = StepBlockType.from_str(json["block"]) | ||
if block == StepBlockType.PROMPT_AGENT: | ||
properties = PromptAgentProperties.from_json(json["properties"]) | ||
elif block == StepBlockType.PARALLEL: | ||
properties = ParallelProperties.from_json(json["properties"]) | ||
else: | ||
raise ValueError(f"Unknown block type: {block}") | ||
return PlanStep(json["id"], block, properties) | ||
|
||
def to_json(self) -> dict: | ||
return { | ||
"id": self.id, | ||
"block": self.block, | ||
"properties": self.properties.to_json(), | ||
} | ||
|
||
|
||
class ExecutionPlan: | ||
|
||
_steps = [] | ||
|
||
@staticmethod | ||
def from_json(json: str | dict) -> "ExecutionPlan": | ||
if isinstance(json, str): | ||
json = loads(json) | ||
return ExecutionPlan([PlanStep.from_json(step) for step in json["steps"]]) | ||
|
||
def find_step(self, step_id: str) -> PlanStep: | ||
for step in self._steps: | ||
if step.id == step_id: | ||
return step | ||
if step.block == StepBlockType.PARALLEL: | ||
for sub_step in step.properties.steps: | ||
if sub_step.id == step_id: | ||
return sub_step | ||
raise ValueError(f"Step with id {step_id} not found") | ||
|
||
def to_json(self) -> dict: | ||
return {"steps": [step.to_json() for step in self._steps]} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from falkordb_gemini_kg.models import GenerativeModel | ||
from falkordb_gemini_kg.classes.agent import Agent | ||
from falkordb_gemini_kg.classes.orchestrator_runner import OrchestratorRunner | ||
from falkordb_gemini_kg.fixtures.prompts import ( | ||
ORCHESTRATOR_SYSTEM, | ||
ORCHESTRATOR_EXECUTION_PLAN_PROMPT, | ||
) | ||
from falkordb_gemini_kg.helpers import extract_json | ||
from falkordb_gemini_kg.classes.execution_plan import ( | ||
ExecutionPlan, | ||
PlanStep, | ||
StepBlockType, | ||
) | ||
|
||
|
||
class Orchestrator: | ||
|
||
_agents = [] | ||
_chat = None | ||
|
||
def __init__(self, model: GenerativeModel): | ||
self._model = model | ||
|
||
def register_agent(self, agent: Agent): | ||
self._agents.append(agent) | ||
|
||
def ask(self, question: str): | ||
|
||
self._chat = self._model.with_system_instruction( | ||
ORCHESTRATOR_SYSTEM.replace( | ||
"#AGENTS", ",".join([agent.to_orchestrator() for agent in self._agents]) | ||
) | ||
).start_chat({"response_validation": False}) | ||
|
||
plan = self._create_execution_plan(question) | ||
|
||
runner = OrchestratorRunner(self._chat, self._agents, plan) | ||
|
||
return runner | ||
|
||
def _create_execution_plan(self, question: str): | ||
response = self._chat.send_message( | ||
ORCHESTRATOR_EXECUTION_PLAN_PROMPT.replace("#QUESTION", question) | ||
) | ||
|
||
plan = ExecutionPlan.from_json(extract_json(response)) | ||
|
||
return plan |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from falkordb_gemini_kg.classes.agent import Agent | ||
from falkordb_gemini_kg.models import GenerativeModelChatSession | ||
from falkordb_gemini_kg.classes.execution_plan import ( | ||
ExecutionPlan, | ||
PlanStep, | ||
StepBlockType, | ||
) | ||
from concurrent.futures import ThreadPoolExecutor, wait | ||
from falkordb_gemini_kg.fixtures.prompts import ORCHESTRATOR_SUMMARY_PROMPT | ||
|
||
|
||
class OrchestratorRunner: | ||
|
||
def __init__( | ||
self, | ||
chat: GenerativeModelChatSession, | ||
agents: list[Agent], | ||
plan: ExecutionPlan, | ||
config: dict = { | ||
"max_workers": 16, | ||
}, | ||
): | ||
self._chat = chat | ||
self._agents = agents | ||
self._plan = plan | ||
self._config = config | ||
|
||
def _run(self): | ||
for step in self._plan.steps: | ||
self._run_step(step) | ||
|
||
return self._run_summary() | ||
|
||
def _run_summary(self): | ||
return self._chat.send_message( | ||
ORCHESTRATOR_SUMMARY_PROMPT.replace("#EXECUTION_PLAN", self._plan.to_json()) | ||
) | ||
|
||
def _run_step(self, step: PlanStep): | ||
if step.block == StepBlockType.PROMPT_AGENT: | ||
return self._run_prompt_agent(step) | ||
elif step.block == StepBlockType.PARALLEL: | ||
return self._run_parallel(step) | ||
else: | ||
raise ValueError(f"Unknown block type: {step.block}") | ||
|
||
def _run_prompt_agent(self, step: PlanStep): | ||
agent = next( | ||
agent for agent in self._agents if agent.id == step.properties.agent_id | ||
) | ||
response = agent.ask(step.properties.prompt) | ||
step.properties.response = response | ||
|
||
def _run_parallel(self, step: PlanStep): | ||
tasks = [] | ||
with ThreadPoolExecutor( | ||
max_workers=min(self._config["max_workers"], len(step.properties.steps)) | ||
) as executor: | ||
for step in step.properties.steps: | ||
tasks.append(executor.submit(self._run_step, step)) | ||
|
||
wait(tasks) | ||
|
||
return [task.result() for task in tasks] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters