Skip to content

Commit

Permalink
code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
brnaba-aws committed Oct 9, 2024
1 parent 1c7df42 commit 80f6162
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 125 deletions.
2 changes: 1 addition & 1 deletion examples/chat-chainlit-app/agents.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from orchestrator import BedrockLLMAgent, BedrockLLMAgentOptions, AgentCallbacks
from multi_agent_orchestrator.agents import BedrockLLMAgent, BedrockLLMAgentOptions, AgentCallbacks
from ollamaAgent import OllamaAgent, OllamaAgentOptions
import asyncio

Expand Down
105 changes: 23 additions & 82 deletions examples/chat-chainlit-app/app.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import uuid
import chainlit as cl
from orchestrator import MultiAgentOrchestrator, OrchestratorConfig, BedrockClassifier, BedrockClassifierOptions
from agents import create_tech_agent, create_travel_agent, create_health_agent
from multi_agent_orchestrator.classifiers import ClassifierResult
from multi_agent_orchestrator.orchestrator import MultiAgentOrchestrator, OrchestratorConfig
from multi_agent_orchestrator.classifiers import BedrockClassifier, BedrockClassifierOptions
from multi_agent_orchestrator.types import ConversationMessage
from multi_agent_orchestrator.types import ParticipantRole
from multi_agent_orchestrator.agents import AgentResponse


# Initialize the orchestrator
Expand All @@ -18,15 +18,17 @@
))

orchestrator = MultiAgentOrchestrator(options=OrchestratorConfig(
LOG_AGENT_CHAT=True,
LOG_CLASSIFIER_CHAT=True,
LOG_CLASSIFIER_RAW_OUTPUT=True,
LOG_CLASSIFIER_OUTPUT=True,
LOG_EXECUTION_TIMES=True,
MAX_RETRIES=3,
USE_DEFAULT_AGENT_IF_NONE_IDENTIFIED=False,
MAX_MESSAGE_PAIRS_PER_AGENT=10
), classifier=custom_bedrock_classifier)
LOG_AGENT_CHAT=True,
LOG_CLASSIFIER_CHAT=True,
LOG_CLASSIFIER_RAW_OUTPUT=True,
LOG_CLASSIFIER_OUTPUT=True,
LOG_EXECUTION_TIMES=True,
MAX_RETRIES=3,
USE_DEFAULT_AGENT_IF_NONE_IDENTIFIED=False,
MAX_MESSAGE_PAIRS_PER_AGENT=10
),
classifier=custom_bedrock_classifier
)

# Add agents to the orchestrator
orchestrator.add_agent(create_tech_agent())
Expand All @@ -39,90 +41,29 @@ async def start():
cl.user_session.set("session_id", str(uuid.uuid4()))
cl.user_session.set("chat_history", [])

@cl.step(type="tool")
async def classify(user_query):
user_id = cl.user_session.get("user_id")
session_id = cl.user_session.get("session_id")

chat_history = await orchestrator.storage.fetch_all_chats(user_id, session_id) or []

# Perform classification
classifier_result:ClassifierResult = await orchestrator.classifier.classify(user_query, chat_history)

cl.user_session.set("chat_history", chat_history)

# Prepare the output message
output = "**Classifying Intent** \n"
# output += "=======================\n"
output += f"> Text: {user_query}\n"
if classifier_result.selected_agent:
output += f"> Selected Agent: {classifier_result.selected_agent.name}\n"
else:
output += "> Selected Agent: No agent found\n"

output += f"> Confidence: {classifier_result.confidence:.2f}\n"

return output, classifier_result

@cl.on_message
async def main(message: cl.Message):
user_id = cl.user_session.get("user_id")
session_id = cl.user_session.get("session_id")

msg = cl.Message(content="")
output, classifier_result = await classify(message.content)
await cl.Message(content=output).send()

await msg.send() # Send the message immediately to start streaming
cl.user_session.set("current_msg", msg)

error=False

if not classifier_result.selected_agent:
if orchestrator.config.USE_DEFAULT_AGENT_IF_NONE_IDENTIFIED:
classifier_result = orchestrator.get_fallback_result()
else:
error = True
await msg.stream_token(orchestrator.config.NO_SELECTED_AGENT_MESSAGE)
await msg.update()

if not error:
agent_response = await orchestrator.dispatch_to_agent({
"user_input": message.content,
"user_id": user_id,
"session_id": session_id,
"classifier_result": classifier_result,
"additional_params": {}
})

#Save user question
await orchestrator.save_message(
ConversationMessage(
role=ParticipantRole.USER.value,
content=[{'text':message.content}]
),
user_id,
session_id,
classifier_result.selected_agent
)

#Save agent response
await orchestrator.save_message(
agent_response,
user_id,
session_id,
classifier_result.selected_agent
)
response:AgentResponse = await orchestrator.route_request(message.content, user_id, session_id, {})


# Handle non-streaming responses
if classifier_result.selected_agent.streaming is False:
# Handle regular response
if isinstance(agent_response, str):
await msg.stream_token(agent_response)
elif isinstance(agent_response, ConversationMessage):
await msg.stream_token(agent_response.content[0].get('text'))
await msg.update()
# Handle non-streaming responses
if isinstance(response, AgentResponse) and response.streaming is False:
# Handle regular response
if isinstance(response.output, str):
await msg.stream_token(response.output)
elif isinstance(response.output, ConversationMessage):
await msg.stream_token(response.output.content[0].get('text'))
await msg.update()


if __name__ == "__main__":
Expand Down
42 changes: 0 additions & 42 deletions examples/chat-chainlit-app/orchestrator.py

This file was deleted.

0 comments on commit 80f6162

Please sign in to comment.