diff --git a/src/seer/automation/agent/agent.py b/src/seer/automation/agent/agent.py index 8e335248a..6c7a43e74 100644 --- a/src/seer/automation/agent/agent.py +++ b/src/seer/automation/agent/agent.py @@ -1,150 +1,169 @@ import logging from abc import ABC, abstractmethod -from typing import Any +from typing import Optional from openai._types import NotGiven +from openai.types.chat import ChatCompletionMessageToolCall +from pydantic import BaseModel, Field -from seer.automation.agent.client import GptClient, LlmClient +from seer.automation.agent.client import DEFAULT_GPT_MODEL, GptClient from seer.automation.agent.models import Message, ToolCall, Usage from seer.automation.agent.tools import FunctionTool from seer.automation.agent.utils import parse_json_with_keys +from seer.dependency_injection import inject, injected logger = logging.getLogger("autofix") -class LlmAgent(ABC): - name: str - tools: list[FunctionTool] - memory: list[Message] +class AgentConfig(BaseModel): + max_iterations: int = Field( + default=16, description="Maximum number of iterations the agent can perform" + ) + model: str = Field(default=DEFAULT_GPT_MODEL, description="The model to be used by the agent") + stop_message: Optional[str] = Field( + default=None, description="Message that signals the agent to stop" + ) + + class Config: + validate_assignment = True - client: LlmClient - iterations: int = 0 - max_iterations: int = 16 +class LlmAgent(ABC): def __init__( self, - tools: list[FunctionTool] | None = None, - memory: list[Message] | None = None, - name="Agent", - stop_message: str | None = None, + config: AgentConfig, + tools: Optional[list[FunctionTool]] = None, + memory: Optional[list[Message]] = None, + name: str = "Agent", ): + self.config = config self.tools = tools or [] self.memory = memory or [] self.usage = Usage() self.name = name - self.stop_message = stop_message + self.iterations = 0 @abstractmethod def run_iteration(self): pass - def run(self, prompt: str): - self.memory.append( - Message( - role="user", - content=prompt, - ) - ) + def should_continue(self) -> bool: + # If this is the first iteration or there are no messages, continue + if self.iterations == 0 or not self.memory: + return True + + # Stop if we've reached the maximum number of iterations + if self.iterations >= self.config.max_iterations: + return False + + last_message = self.memory[-1] + if last_message and last_message.role in ["assistant", "model"]: + if last_message.content: + # Stop if the stop message is found in the content + if self.config.stop_message and self.config.stop_message in last_message.content: + return False + # Stop if there are no tool calls + if not last_message.tool_calls: + return False + # Continue in all other cases + return True + + def run(self, prompt: str): + self.add_user_message(prompt) logger.debug(f"----[{self.name}] Running Agent----") - logger.debug("Previous messages: ") - for message in self.memory: - logger.debug(f"{message.role}: {message.content}") - - while ( - self.iterations == 0 - or ( - not ( - self.memory[-1].role - in ["assistant", "model"] # Will never end on a message not from the assistant - and ( - self.memory[-1].content - and ( - self.stop_message in self.memory[-1].content - ) # If stop message is defined; will end if the assistant response contains the stop message - if self.stop_message - else self.memory[-1].content - is not None # If a stop message is not defined; will end on any non-empty assistant response (OpenAI tool call does not output a message!) - ) - ) - ) - and self.iterations < self.max_iterations # Went above max iterations - ): - # runs until the assistant sends a message with no more tool calls. + + while self.should_continue(): self.run_iteration() - if self.iterations == self.max_iterations: - raise Exception(f"Agent {self.name} reached maximum iterations without finishing.") + if self.iterations == self.config.max_iterations: + raise MaxIterationsReachedException( + f"Agent {self.name} reached maximum iterations without finishing." + ) - return self.memory[-1].content + return self.get_last_message_content() - def call_tool(self, tool_call: ToolCall): - logger.debug( - f"[{tool_call.id}] Calling tool {tool_call.function} with arguments {tool_call.args}" - ) + def add_user_message(self, content: str): + self.memory.append(Message(role="user", content=content)) - tool = next(tool for tool in self.tools if tool.name == tool_call.function) + def get_last_message_content(self) -> str | None: + return self.memory[-1].content if self.memory else None - kwargs = parse_json_with_keys( - tool_call.args, - [param["name"] for param in tool.parameters if isinstance(param["name"], str)], - ) + def call_tool(self, tool_call: ToolCall) -> Message: + logger.debug(f"[{tool_call.id}] Calling tool {tool_call.function}") + + tool = self.get_tool_by_name(tool_call.function) + kwargs = self.parse_tool_arguments(tool, tool_call.args) tool_result = tool.call(**kwargs) - logger.debug(f"Tool {tool_call.function} returned \n{tool_result}") + return Message(role="tool", content=tool_result, tool_call_id=tool_call.id) - return Message( - role="tool", - content=tool_result, - tool_call_id=tool_call.id, + def get_tool_by_name(self, name: str) -> FunctionTool: + return next(tool for tool in self.tools if tool.name == name) + + def parse_tool_arguments(self, tool: FunctionTool, args: str) -> dict: + return parse_json_with_keys( + args, [param["name"] for param in tool.parameters if isinstance(param["name"], str)] ) class GptAgent(LlmAgent): - model: str = "gpt-4o-2024-05-13" - - chat_completion_kwargs: dict[str, Any] = {} - + @inject def __init__( self, - tools: list[FunctionTool] | None = None, - memory: list[Message] | None = None, - name="GptAgent", - chat_completion_kwargs=None, - stop_message: str | None = None, + config: AgentConfig = AgentConfig(), + client: GptClient = injected, + tools: Optional[list[FunctionTool]] = None, + memory: Optional[list[Message]] = None, + name: str = "GptAgent", + chat_completion_kwargs: Optional[dict] = None, ): - super().__init__(tools, memory, name=name, stop_message=stop_message) - self.client = GptClient(model=self.model) - + super().__init__(config, tools, memory, name) + self.client = client self.chat_completion_kwargs = chat_completion_kwargs or {} def run_iteration(self): logger.debug(f"----[{self.name}] Running Iteration {self.iterations}----") - message, usage = self.client.completion( + message, usage = self.get_completion() + self.process_message(message) + self.update_usage(usage) + + return self.memory + + def get_completion(self): + return self.client.completion( messages=self.memory, - tools=([tool.to_dict() for tool in self.tools] if len(self.tools) > 0 else NotGiven()), + model=self.config.model, + tools=([tool.to_dict() for tool in self.tools] if self.tools else NotGiven()), **self.chat_completion_kwargs, ) + def process_message(self, message: Message): self.memory.append(message) - logger.debug(f"Message content:\n{message.content}") - logger.debug(f"Message tool calls:\n{message.tool_calls}") - if message.tool_calls: - for tool_call in message.tool_calls: - tool_response = self.call_tool( - ToolCall( - id=tool_call.id, - function=tool_call.function.name, - args=tool_call.function.arguments, - ) - ) - - self.memory.append(tool_response) + converted_tool_calls = self.convert_tool_calls(message.tool_calls) + self.process_tool_calls(converted_tool_calls) self.iterations += 1 + + def convert_tool_calls(self, tool_calls: list[ChatCompletionMessageToolCall]) -> list[ToolCall]: + return [ + ToolCall( + id=tool_call.id, function=tool_call.function.name, args=tool_call.function.arguments + ) + for tool_call in tool_calls + ] + + def process_tool_calls(self, tool_calls: list[ToolCall]): + for tool_call in tool_calls: + tool_response = self.call_tool(tool_call) + self.memory.append(tool_response) + + def update_usage(self, usage: Usage): self.usage += usage - return self.memory + +class MaxIterationsReachedException(Exception): + pass diff --git a/src/seer/automation/agent/client.py b/src/seer/automation/agent/client.py index bca948e51..cd22f4c09 100644 --- a/src/seer/automation/agent/client.py +++ b/src/seer/automation/agent/client.py @@ -7,6 +7,7 @@ from openai.types.chat import ChatCompletion from seer.automation.agent.models import Message, Usage +from seer.bootup import module, stub_module T = TypeVar("T") @@ -14,7 +15,7 @@ class LlmClient(ABC): @abstractmethod def completion( - self, messages: list[Message], **chat_completion_kwargs + self, messages: list[Message], model: str, **chat_completion_kwargs ) -> tuple[Message, Usage]: pass @@ -22,21 +23,26 @@ def completion_with_parser( self, messages: list[Message], parser: Callable[[str | None], T], + model: str, **chat_completion_kwargs, ) -> tuple[T, Message, Usage]: - message, usage = self.completion(messages, **chat_completion_kwargs) + message, usage = self.completion(messages, model, **chat_completion_kwargs) return parser(message.content), message, usage +DEFAULT_GPT_MODEL = "gpt-4o-2024-05-13" + + class GptClient(LlmClient): - def __init__(self, model: str = "gpt-4o-2024-05-13"): - self.model = model + def __init__(self): self.openai_client = openai.Client() - def completion(self, messages: list[Message], **chat_completion_kwargs): + def completion( + self, messages: list[Message], model=DEFAULT_GPT_MODEL, **chat_completion_kwargs + ): completion: ChatCompletion = self.openai_client.chat.completions.create( - model=self.model, + model=model, messages=[message.to_openai_message() for message in messages], temperature=0.0, **chat_completion_kwargs, @@ -52,17 +58,34 @@ def completion(self, messages: list[Message], **chat_completion_kwargs): return message, usage + def completion_with_parser( + self, + messages: list[Message], + parser: Callable[[str | None], T], + model=DEFAULT_GPT_MODEL, + **chat_completion_kwargs, + ) -> tuple[T, Message, Usage]: + message, usage = self.completion(messages, model, **chat_completion_kwargs) + + return parser(message.content), message, usage + def json_completion( - self, messages: list[Message], **chat_completion_kwargs + self, messages: list[Message], model=DEFAULT_GPT_MODEL, **chat_completion_kwargs ) -> tuple[dict[str, Any] | None, Message, Usage]: return self.completion_with_parser( messages, parser=lambda x: json.loads(x) if x else None, + model=model, response_format={"type": "json_object"}, **chat_completion_kwargs, ) +@module.provider +def provide_gpt_client() -> GptClient: + return GptClient() + + GptCompletionHandler = Callable[[list[Message], dict[str, Any]], Optional[tuple[Message, Usage]]] @@ -73,10 +96,15 @@ class DummyGptClient(GptClient): default_factory=list ) - def completion(self, messages: list[Message], **chat_completion_kwargs): + def completion(self, messages: list[Message], model="test-gpt", **chat_completion_kwargs): for handler in self.handlers: result = handler(messages, chat_completion_kwargs) if result: return result self.missed_calls.append((messages, chat_completion_kwargs)) return Message(), Usage() + + +@stub_module.provider +def provide_stub_gpt_client() -> GptClient: + return DummyGptClient() diff --git a/src/seer/automation/autofix/components/executor/component.py b/src/seer/automation/autofix/components/executor/component.py index 3c5e9d99a..ea04d3bf4 100644 --- a/src/seer/automation/autofix/components/executor/component.py +++ b/src/seer/automation/autofix/components/executor/component.py @@ -1,7 +1,7 @@ from langfuse.decorators import observe from sentry_sdk.ai.monitoring import ai_track -from seer.automation.agent.agent import GptAgent +from seer.automation.agent.agent import AgentConfig, GptAgent from seer.automation.agent.models import Message from seer.automation.autofix.autofix_context import AutofixContext from seer.automation.autofix.components.executor.models import ExecutorOutput, ExecutorRequest @@ -22,7 +22,10 @@ def invoke(self, request: ExecutorRequest) -> None: code_action_tools = CodeActionTools(self.context) execution_agent = GptAgent( - name="executor", + config=AgentConfig( + stop_message="", + ), + name="Executor", tools=code_action_tools.get_tools(), memory=[ Message( @@ -30,7 +33,6 @@ def invoke(self, request: ExecutorRequest) -> None: content=ExecutionPrompts.format_system_msg(), ), ], - stop_message="", ) execution_agent.run( diff --git a/tests/automation/agent/test_agent.py b/tests/automation/agent/test_agent.py index 8f2cdfc36..9fe322cd3 100644 --- a/tests/automation/agent/test_agent.py +++ b/tests/automation/agent/test_agent.py @@ -1,43 +1,208 @@ -import unittest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch -from seer.automation.agent.agent import GptAgent -from seer.automation.agent.models import ToolCall +import pytest +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, + Function, +) + +from seer.automation.agent.agent import ( + AgentConfig, + GptAgent, + LlmAgent, + MaxIterationsReachedException, +) +from seer.automation.agent.client import GptClient +from seer.automation.agent.models import Message, ToolCall, Usage from seer.automation.agent.tools import FunctionTool +from seer.dependency_injection import resolve + + +class TestLlmAgent: + @pytest.fixture + def config(self): + return AgentConfig() + + @pytest.fixture + def agent(self, config): + class TestAgent(LlmAgent): + def run_iteration(self): + pass + + return TestAgent(config) + + def test_should_continue(self, agent: LlmAgent): + assert agent.should_continue() # Initial state + + agent.iterations = agent.config.max_iterations + agent.memory = [Message(role="assistant", content="STOP")] + assert not agent.should_continue() # Max iterations reached + + agent.iterations = 1 + agent.config.stop_message = "STOP" + assert not agent.should_continue() # Stop message found + + def test_add_user_message(self, agent: LlmAgent): + agent.add_user_message("Test message") + assert len(agent.memory) == 1 + assert agent.memory[0].role == "user" + assert agent.memory[0].content == "Test message" + + def test_get_last_message_content(self, agent: LlmAgent): + assert agent.get_last_message_content() is None # Empty memory + agent.memory = [Message(role="user", content="Test")] + assert agent.get_last_message_content() == "Test" + def test_call_tool(self, agent: LlmAgent): + mock_tool_fn = MagicMock(return_value="Tool result") + mock_tool = FunctionTool( + name="test_tool", + description="Test tool", + fn=mock_tool_fn, + parameters=[], + ) + mock_tool.name = "test_tool" + + agent.tools = [mock_tool] + + tool_call = ToolCall(id="123", function="test_tool", args='{"arg": "value"}') + result = agent.call_tool(tool_call) + + assert isinstance(result, Message) + assert result.role == "tool" + assert result.content == "Tool result" + assert result.tool_call_id == "123" + + def test_get_tool_by_name(self, agent: LlmAgent): + tool = FunctionTool( + name="test_tool", description="Test tool", fn=lambda: None, parameters=[] + ) + agent.tools = [tool] + assert agent.get_tool_by_name("test_tool") == tool + + with pytest.raises(StopIteration): + agent.get_tool_by_name("non_existent_tool") + + def test_parse_tool_arguments(self, agent: LlmAgent): + tool = FunctionTool( + name="test_tool", + description="Test tool", + fn=lambda: None, + parameters=[{"name": "arg1"}, {"name": "arg2"}], + ) + args = '{"arg1": "value1", "arg2": "value2", "arg3": "value3"}' + parsed_args = agent.parse_tool_arguments(tool, args) + assert parsed_args == {"arg1": "value1", "arg2": "value2"} -class TestGptAgentCallToolIntegration(unittest.TestCase): - def setUp(self): - self.agent = GptAgent() - - def test_call_tool(self): - # Setup mock tool and agent - mock_fn = MagicMock(return_value="Tool called successfully") - - self.agent.tools = [ - FunctionTool( - name="mock_tool", - description="tool", - fn=mock_fn, - parameters=[ - {"name": "arg1", "type": "str"}, - {"name": "arg2", "type": "str"}, - {"name": "arg3", "type": "int"}, - ], + +class TestGptAgent: + @pytest.fixture + def config(self): + return AgentConfig() + + @pytest.fixture + def agent(self, config): + return GptAgent(config) + + @pytest.fixture + def mock_client(self): + return resolve(GptClient) + + def test_run_iteration(self, agent, mock_client): + mock_message = Message(role="assistant", content="Test response") + mock_usage = Usage(completion_tokens=10, prompt_tokens=20, total_tokens=30) + mock_client.completion = MagicMock(return_value=(mock_message, mock_usage)) + + agent.run_iteration() + + assert agent.iterations == 1 + assert len(agent.memory) == 1 + assert agent.memory[0] == mock_message + assert agent.usage == mock_usage + + def test_get_completion(self, agent, mock_client): + mock_message = Message(role="assistant", content="Test response") + mock_usage = Usage(completion_tokens=10, prompt_tokens=20, total_tokens=30) + mock_client.completion = MagicMock(return_value=(mock_message, mock_usage)) + + message, usage = agent.get_completion() + + assert message == mock_message + assert usage == mock_usage + + def test_process_message(self, agent): + message = Message(role="assistant", content="Test message") + agent.process_message(message) + assert len(agent.memory) == 1 + assert agent.memory[0] == message + assert agent.iterations == 1 + + def test_convert_tool_calls(self, agent): + tool_calls = [ + ChatCompletionMessageToolCall( + id="1", + function=Function(name="test_tool", arguments='{"arg": "value"}'), + type="function", ) ] + converted = agent.convert_tool_calls(tool_calls) + assert len(converted) == 1 + assert isinstance(converted[0], ToolCall) + assert converted[0].id == "1" + assert converted[0].function == "test_tool" + assert converted[0].args == '{"arg": "value"}' + + def test_process_tool_calls(self, agent): + tool_calls = [ToolCall(id="1", function="test_tool", args='{"arg": "value"}')] + with patch.object(agent, "call_tool") as mock_call_tool: + mock_call_tool.return_value = Message( + role="tool", content="Tool result", tool_call_id="1" + ) + agent.process_tool_calls(tool_calls) + assert len(agent.memory) == 1 + assert agent.memory[0].role == "tool" + assert agent.memory[0].content == "Tool result" + + def test_update_usage(self, agent): + initial_usage = Usage(completion_tokens=10, prompt_tokens=20, total_tokens=30) + agent.usage = initial_usage + new_usage = Usage(completion_tokens=5, prompt_tokens=10, total_tokens=15) + agent.update_usage(new_usage) + expected_usage = Usage(completion_tokens=15, prompt_tokens=30, total_tokens=45) + assert agent.usage == expected_usage - tool_call = ToolCall( - id="1", - function="mock_tool", - args='{"arg1": "value1\\nbar(\'\\n\')", "arg2": "value2", "arg3": 123}', + def test_run(self, agent, mock_client): + mock_message = Message(role="assistant", content="Final response") + mock_usage = Usage(completion_tokens=10, prompt_tokens=20, total_tokens=30) + mock_client.completion = MagicMock(return_value=(mock_message, mock_usage)) + + result = agent.run("Test prompt") + + assert result == "Final response" + assert len(agent.memory) > 0 + assert agent.memory[0].role == "user" + assert agent.memory[0].content == "Test prompt" + + def test_run_max_iterations_exception(self, agent, mock_client): + agent.config.max_iterations = 1 + tool = FunctionTool( + name="test_tool", description="Test tool", fn=lambda: None, parameters=[] ) + agent.tools = [tool] - # Call the method - result = self.agent.call_tool(tool_call) + mock_message = Message( + role="assistant", + content="Response", + tool_calls=[ + ChatCompletionMessageToolCall( + id="1", + function=Function(name="test_tool", arguments='{"arg": "value"}'), + type="function", + ) + ], + ) + mock_usage = Usage(completion_tokens=10, prompt_tokens=20, total_tokens=30) + mock_client.completion = MagicMock(return_value=(mock_message, mock_usage)) - # Assertions - self.assertEqual(result.content, "Tool called successfully") - self.assertEqual(result.role, "tool") - self.assertEqual(result.tool_call_id, "1") - mock_fn.assert_called_once_with(arg1="value1\nbar('\\n')", arg2="value2", arg3=123) + with pytest.raises(MaxIterationsReachedException): + agent.run("Test prompt")