Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ParallelAgent class #103

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

carr324
Copy link

@carr324 carr324 commented Nov 22, 2024

Issue number: N/A

Summary

Current version of the multi-agent-orchestrator package does not have any readily available agents to run concurrent operations across multiple agents. Default agents either run independently or need to be sequentially linked with a ChainAgent. A new ParallelAgent that allows multiple agents to run simultaneously would allow for more efficient runtimes and more advanced agentic designs.

Changes

New ParallelAgent module has been added by adapting some of the core logic from ChainAgent. Minimal use of external packages (only typing and asyncio as additional imports beyond modules in the src code). Asynchronously and simultaneously runs each individual agent provided to ParallelAgent, then outputting result as ConversationMessage class. The individual agents within the overall ParallelAgent can be of any default agent class. The combined response from all agents within the ParallelAgent is embedded in the ConversationMessage as a dictionary with agent names in keys and text responses as values.

Possible next steps:

  • Finalize python implementation of ParallelAgent
  • Replicate logic to TypeScript
  • Update documentation/website (mirror ChainAgent info on main package website)
  • Anything else suggested by package authors/maintainers

User experience

Users would be able to import and use the ParallelAgent class in a similar manner as ChainAgent.

Checklist

If your change doesn't seem to apply, please leave them unchecked.

  • I have performed a self-review of this change
  • Changes have been tested
  • Changes are documented
Is this a breaking change?

RFC issue number: N/A

Checklist:

  • Migration process documented
  • Implement warnings (if it can live side by side)

Acknowledgment

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Disclaimer: We value your time and bandwidth. As such, any pull requests created on non-triaged issues might not be successful.

Copy link
Contributor

@brnaba-aws brnaba-aws left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial review

class ParallelAgentOptions(AgentOptions):
def __init__(
self,
agents: list[str],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should be a list of Agents

Copy link
Author

@carr324 carr324 Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


async def _get_llm_response(
self,
agent: BedrockLLMAgent,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change from BedrockLLMAgent to Agent

Copy link
Author

@carr324 carr324 Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tasks = []
for agent in self.agents:
tasks.append(
self._get_llm_response(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why adding another method? Can't you just call agent.process_request()?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wanted to include some of the ChainAgent logic in lines 41-65 here in the async function that gets run for each individual agent within the ParallelAgent, which seemed easier/cleaner in a new internal method. If we think it's unnecessary, fine to revise or remove

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see ok. Well I'd suggest to change the method name from _get_llm_response to self.agent_process_request()

The framework is not only about llm.

Copy link
Contributor

@brnaba-aws brnaba-aws left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import asyncio
from dataclasses import dataclass
from typing import Any, AsyncIterable, Dict, List, Optional
from multi_agent_orchestrator.agents import Agent, AgentOptions
from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole
from multi_agent_orchestrator.utils.logger import Logger


@dataclass
class ParallelAgentOptions(AgentOptions):
    """Configuration options for ParallelAgent."""
    agents: List[Agent]
    default_output: Optional[str] = None

    def __post_init__(self) -> None:
        super().__init__()
        if not self.agents:
            raise ValueError("ParallelAgent requires at least 1 agent to initiate!")


class ParallelAgent(Agent):
    """Agent that processes requests in parallel using multiple underlying agents."""

    def __init__(self, options: ParallelAgentOptions) -> None:
        super().__init__(options)
        self.agents = options.agents
        self.default_output = (
            options.default_output or "No output generated from the ParallelAgent."
        )

    async def _get_llm_response(
        self,
        agent: Agent,
        input_text: str,
        user_id: str,
        session_id: str,
        chat_history: List[ConversationMessage],
        additional_params: Optional[Dict[str, str]] = None,
    ) -> ConversationMessage:
        """
        Get response from a single LLM agent.

        Args:
            agent: The agent to process the request
            input_text: The input text to process
            user_id: The user identifier
            session_id: The session identifier
            chat_history: List of previous conversation messages
            additional_params: Optional additional parameters

        Returns:
            ConversationMessage: The processed response
        
        Raises:
            RuntimeError: If there's an error processing the request
        """
        try:
            response = await agent.process_request(
                input_text,
                user_id,
                session_id,
                chat_history,
                additional_params
            )

            if self._is_valid_conversation_message(response):
                return response
            elif isinstance(response, AsyncIterable):
                Logger.warn(f"Agent {agent.name}: Streaming is not supported for ParallelAgents")
                return self._create_default_response()
            else:
                Logger.warn(f"Agent {agent.name}: Invalid response type")
                return self._create_default_response()

        except Exception as error:
            error_msg = f"Error processing request with agent {agent.name}: {str(error)}"
            Logger.error(error_msg)
            raise RuntimeError(error_msg)

    async def process_request(
        self,
        input_text: str,
        user_id: str,
        session_id: str,
        chat_history: List[ConversationMessage],
        additional_params: Optional[Dict[str, str]] = None,
    ) -> ConversationMessage:
        """
        Process requests in parallel using all configured agents.

        Returns:
            ConversationMessage: Combined responses from all agents
        """
        tasks = [
            self._get_llm_response(
                agent,
                input_text,
                user_id,
                session_id,
                chat_history,
                additional_params,
            )
            for agent in self.agents
        ]

        responses = await asyncio.gather(*tasks, return_exceptions=True)
        
        # Filter out errors and create response dictionary
        response_dict = {}
        for agent, response in zip(self.agents, responses):
            if isinstance(response, Exception):
                Logger.error(f"Agent {agent.name} failed: {str(response)}")
                continue
            if response and response.content and "text" in response.content[0]:
                response_dict[agent.name] = response.content[0]["text"]

        if not response_dict:
            return self._create_default_response()

        return ConversationMessage(
            role=ParticipantRole.ASSISTANT.value,
            content=[{"text": str(response_dict)}],
        )

    @staticmethod
    def _is_valid_conversation_message(response: Any) -> bool:
        """Check if response is a valid ConversationMessage with text content."""
        return (
            isinstance(response, ConversationMessage)
            and hasattr(response, "content")
            and isinstance(response.content, list)
            and response.content
            and isinstance(response.content[0], dict)
            and "text" in response.content[0]
        )

    def _create_default_response(self) -> ConversationMessage:
        """Create a default response message."""
        return ConversationMessage(
            role=ParticipantRole.ASSISTANT.value,
            content=[{"text": self.default_output}],
        )

@brnaba-aws
Copy link
Contributor

@carr324 Please review the suggested code if it makes sense (powered by Sonnet 3.5 :) )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants