From a5167f64870734f63bf64456ee957203da71b064 Mon Sep 17 00:00:00 2001 From: "jenn.muengtaweepongsa" Date: Wed, 17 Jul 2024 22:15:58 +0700 Subject: [PATCH 1/6] init agent ref --- src/seer/automation/agent/agent.py | 194 +++++++++++--------- tests/automation/agent/test_agent.py | 261 +++++++++++++++++++++++---- 2 files changed, 333 insertions(+), 122 deletions(-) diff --git a/src/seer/automation/agent/agent.py b/src/seer/automation/agent/agent.py index 8e335248a..540fab8d9 100644 --- a/src/seer/automation/agent/agent.py +++ b/src/seer/automation/agent/agent.py @@ -1,8 +1,10 @@ import logging from abc import ABC, abstractmethod -from typing import Any +from typing import Optional from openai._types import NotGiven +from openai.types.chat import ChatCompletionMessageToolCall +from pydantic import BaseModel, Field from seer.automation.agent.client import GptClient, LlmClient from seer.automation.agent.models import Message, ToolCall, Usage @@ -12,139 +14,153 @@ logger = logging.getLogger("autofix") -class LlmAgent(ABC): - name: str - tools: list[FunctionTool] - memory: list[Message] +class AgentConfig(BaseModel): + max_iterations: int = Field( + default=16, description="Maximum number of iterations the agent can perform" + ) + model: str = Field(default="gpt-4-0613", description="The model to be used by the agent") + stop_message: Optional[str] = Field( + default=None, description="Message that signals the agent to stop" + ) + + class Config: + validate_assignment = True - client: LlmClient - iterations: int = 0 - max_iterations: int = 16 +class LlmAgent(ABC): def __init__( self, - tools: list[FunctionTool] | None = None, - memory: list[Message] | None = None, - name="Agent", - stop_message: str | None = None, + config: AgentConfig, + tools: Optional[list[FunctionTool]] = None, + memory: Optional[list[Message]] = None, + name: str = "Agent", ): + self.config = config self.tools = tools or [] self.memory = memory or [] self.usage = Usage() self.name = name - self.stop_message = stop_message + self.iterations = 0 @abstractmethod def run_iteration(self): pass - def run(self, prompt: str): - self.memory.append( - Message( - role="user", - content=prompt, - ) - ) + def should_continue(self) -> bool: + # If this is the first iteration or there are no messages, continue + if self.iterations == 0 or not self.memory: + return True + + # Stop if we've reached the maximum number of iterations + if self.iterations >= self.config.max_iterations: + return False + + last_message = self.memory[-1] + if last_message and last_message.role in ["assistant", "model"]: + if last_message.content: + # Stop if the stop message is found in the content + if self.config.stop_message and self.config.stop_message in last_message.content: + return False + # Stop if there are no tool calls + if not last_message.tool_calls: + return False + # Continue in all other cases + return True + + def run(self, prompt: str): + self.add_user_message(prompt) logger.debug(f"----[{self.name}] Running Agent----") - logger.debug("Previous messages: ") - for message in self.memory: - logger.debug(f"{message.role}: {message.content}") - - while ( - self.iterations == 0 - or ( - not ( - self.memory[-1].role - in ["assistant", "model"] # Will never end on a message not from the assistant - and ( - self.memory[-1].content - and ( - self.stop_message in self.memory[-1].content - ) # If stop message is defined; will end if the assistant response contains the stop message - if self.stop_message - else self.memory[-1].content - is not None # If a stop message is not defined; will end on any non-empty assistant response (OpenAI tool call does not output a message!) - ) - ) - ) - and self.iterations < self.max_iterations # Went above max iterations - ): - # runs until the assistant sends a message with no more tool calls. + + while self.should_continue(): self.run_iteration() - if self.iterations == self.max_iterations: - raise Exception(f"Agent {self.name} reached maximum iterations without finishing.") + if self.iterations == self.config.max_iterations: + raise MaxIterationsReachedException( + f"Agent {self.name} reached maximum iterations without finishing." + ) - return self.memory[-1].content + return self.get_last_message_content() - def call_tool(self, tool_call: ToolCall): - logger.debug( - f"[{tool_call.id}] Calling tool {tool_call.function} with arguments {tool_call.args}" - ) + def add_user_message(self, content: str): + self.memory.append(Message(role="user", content=content)) - tool = next(tool for tool in self.tools if tool.name == tool_call.function) + def get_last_message_content(self) -> str | None: + return self.memory[-1].content if self.memory else None - kwargs = parse_json_with_keys( - tool_call.args, - [param["name"] for param in tool.parameters if isinstance(param["name"], str)], - ) + def call_tool(self, tool_call: ToolCall) -> Message: + logger.debug(f"[{tool_call.id}] Calling tool {tool_call.function}") + + tool = self.get_tool_by_name(tool_call.function) + kwargs = self.parse_tool_arguments(tool, tool_call.args) tool_result = tool.call(**kwargs) - logger.debug(f"Tool {tool_call.function} returned \n{tool_result}") + return Message(role="tool", content=tool_result, tool_call_id=tool_call.id) - return Message( - role="tool", - content=tool_result, - tool_call_id=tool_call.id, + def get_tool_by_name(self, name: str) -> FunctionTool: + return next(tool for tool in self.tools if tool.name == name) + + def parse_tool_arguments(self, tool: FunctionTool, args: str) -> dict: + return parse_json_with_keys( + args, [param["name"] for param in tool.parameters if isinstance(param["name"], str)] ) class GptAgent(LlmAgent): - model: str = "gpt-4o-2024-05-13" - - chat_completion_kwargs: dict[str, Any] = {} - def __init__( self, - tools: list[FunctionTool] | None = None, - memory: list[Message] | None = None, - name="GptAgent", - chat_completion_kwargs=None, - stop_message: str | None = None, + config: AgentConfig, + client: Optional[LlmClient] = None, + tools: Optional[list[FunctionTool]] = None, + memory: Optional[list[Message]] = None, + name: str = "GptAgent", + chat_completion_kwargs: Optional[dict] = None, ): - super().__init__(tools, memory, name=name, stop_message=stop_message) - self.client = GptClient(model=self.model) - + super().__init__(config, tools, memory, name) + self.client = client or GptClient(model=config.model) self.chat_completion_kwargs = chat_completion_kwargs or {} def run_iteration(self): logger.debug(f"----[{self.name}] Running Iteration {self.iterations}----") - message, usage = self.client.completion( + message, usage = self.get_completion() + self.process_message(message) + self.update_usage(usage) + + return self.memory + + def get_completion(self): + return self.client.completion( messages=self.memory, - tools=([tool.to_dict() for tool in self.tools] if len(self.tools) > 0 else NotGiven()), + tools=([tool.to_dict() for tool in self.tools] if self.tools else NotGiven()), **self.chat_completion_kwargs, ) + def process_message(self, message: Message): self.memory.append(message) - logger.debug(f"Message content:\n{message.content}") - logger.debug(f"Message tool calls:\n{message.tool_calls}") - if message.tool_calls: - for tool_call in message.tool_calls: - tool_response = self.call_tool( - ToolCall( - id=tool_call.id, - function=tool_call.function.name, - args=tool_call.function.arguments, - ) - ) - - self.memory.append(tool_response) + converted_tool_calls = self.convert_tool_calls(message.tool_calls) + self.process_tool_calls(converted_tool_calls) self.iterations += 1 + + def convert_tool_calls(self, tool_calls: list[ChatCompletionMessageToolCall]) -> list[ToolCall]: + return [ + ToolCall( + id=tool_call.id, function=tool_call.function.name, args=tool_call.function.arguments + ) + for tool_call in tool_calls + ] + + def process_tool_calls(self, tool_calls: list[ToolCall]): + for tool_call in tool_calls: + tool_response = self.call_tool(tool_call) + self.memory.append(tool_response) + + def update_usage(self, usage: Usage): self.usage += usage - return self.memory + +class MaxIterationsReachedException(Exception): + pass diff --git a/tests/automation/agent/test_agent.py b/tests/automation/agent/test_agent.py index 8f2cdfc36..8945cc149 100644 --- a/tests/automation/agent/test_agent.py +++ b/tests/automation/agent/test_agent.py @@ -1,43 +1,238 @@ -import unittest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock, patch -from seer.automation.agent.agent import GptAgent -from seer.automation.agent.models import ToolCall +import pytest +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, + Function, +) +from pydantic import ValidationError + +from seer.automation.agent.agent import ( + AgentConfig, + GptAgent, + LlmAgent, + MaxIterationsReachedException, +) +from seer.automation.agent.models import Message, ToolCall, Usage from seer.automation.agent.tools import FunctionTool -class TestGptAgentCallToolIntegration(unittest.TestCase): - def setUp(self): - self.agent = GptAgent() - - def test_call_tool(self): - # Setup mock tool and agent - mock_fn = MagicMock(return_value="Tool called successfully") - - self.agent.tools = [ - FunctionTool( - name="mock_tool", - description="tool", - fn=mock_fn, - parameters=[ - {"name": "arg1", "type": "str"}, - {"name": "arg2", "type": "str"}, - {"name": "arg3", "type": "int"}, - ], +class TestAgentConfig: + def test_default_values(self): + config = AgentConfig() + assert config.max_iterations == 16 + assert config.model == "gpt-4-0613" + assert config.stop_message is None + + def test_custom_values(self): + config = AgentConfig(max_iterations=10, model="gpt-3.5-turbo", stop_message="STOP") + assert config.max_iterations == 10 + assert config.model == "gpt-3.5-turbo" + assert config.stop_message == "STOP" + + def test_validation(self): + with pytest.raises(ValidationError): + AgentConfig(max_iterations="not a number") # type: ignore + + +class TestLlmAgent: + @pytest.fixture + def config(self): + return AgentConfig() + + @pytest.fixture + def agent(self, config): + class TestAgent(LlmAgent): + def run_iteration(self): + pass + + return TestAgent(config) + + def test_initialization(self, agent: LlmAgent, config: AgentConfig): + assert agent.config == config + assert agent.tools == [] + assert agent.memory == [] + assert isinstance(agent.usage, Usage) + assert agent.name == "Agent" + assert agent.iterations == 0 + + def test_should_continue(self, agent: LlmAgent): + assert agent.should_continue() # Initial state + + agent.iterations = agent.config.max_iterations + agent.memory = [Message(role="assistant", content="STOP")] + assert agent.should_continue() # Max iterations reached + + agent.iterations = 1 + agent.config.stop_message = "STOP" + assert agent.should_continue() # Stop message found + + def test_add_user_message(self, agent: LlmAgent): + agent.add_user_message("Test message") + assert len(agent.memory) == 1 + assert agent.memory[0].role == "user" + assert agent.memory[0].content == "Test message" + + def test_get_last_message_content(self, agent: LlmAgent): + assert agent.get_last_message_content() is None # Empty memory + agent.memory = [Message(role="user", content="Test")] + assert agent.get_last_message_content() == "Test" + + def test_call_tool(self, agent: LlmAgent): + mock_tool_fn = MagicMock(return_value="Tool result") + mock_tool = FunctionTool( + name="test_tool", + description="Test tool", + fn=mock_tool_fn, + parameters=[], + ) + mock_tool.name = "test_tool" + + agent.tools = [mock_tool] + + tool_call = ToolCall(id="123", function="test_tool", args='{"arg": "value"}') + result = agent.call_tool(tool_call) + + assert isinstance(result, Message) + assert result.role == "tool" + assert result.content == "Tool result" + assert result.tool_call_id == "123" + + def test_get_tool_by_name(self, agent: LlmAgent): + tool = FunctionTool( + name="test_tool", description="Test tool", fn=lambda: None, parameters=[] + ) + agent.tools = [tool] + assert agent.get_tool_by_name("test_tool") == tool + + with pytest.raises(StopIteration): + agent.get_tool_by_name("non_existent_tool") + + def test_parse_tool_arguments(self, agent: LlmAgent): + tool = FunctionTool( + name="test_tool", + description="Test tool", + fn=lambda: None, + parameters=[{"name": "arg1"}, {"name": "arg2"}], + ) + args = '{"arg1": "value1", "arg2": "value2", "arg3": "value3"}' + parsed_args = agent.parse_tool_arguments(tool, args) + assert parsed_args == {"arg1": "value1", "arg2": "value2"} + + +class TestGptAgent: + @pytest.fixture + def config(self): + return AgentConfig() + + @pytest.fixture + def mock_client(self): + return Mock() + + @pytest.fixture + def agent(self, config, mock_client): + return GptAgent(config, client=mock_client) + + def test_initialization(self, agent, config, mock_client): + assert agent.config == config + assert agent.client == mock_client + assert agent.chat_completion_kwargs == {} + + def test_run_iteration(self, agent, mock_client): + mock_message = Message(role="assistant", content="Test response") + mock_usage = Usage(completion_tokens=10, prompt_tokens=20, total_tokens=30) + mock_client.completion.return_value = (mock_message, mock_usage) + + agent.run_iteration() + + assert agent.iterations == 1 + assert len(agent.memory) == 1 + assert agent.memory[0] == mock_message + assert agent.usage == mock_usage + + def test_get_completion(self, agent, mock_client): + mock_message = Message(role="assistant", content="Test response") + mock_usage = Usage(completion_tokens=10, prompt_tokens=20, total_tokens=30) + mock_client.completion.return_value = (mock_message, mock_usage) + + message, usage = agent.get_completion() + + assert message == mock_message + assert usage == mock_usage + + def test_process_message(self, agent): + message = Message(role="assistant", content="Test message") + agent.process_message(message) + assert len(agent.memory) == 1 + assert agent.memory[0] == message + assert agent.iterations == 1 + + def test_convert_tool_calls(self, agent): + tool_calls = [ + ChatCompletionMessageToolCall( + id="1", + function=Function(name="test_tool", arguments='{"arg": "value"}'), + type="function", ) ] + converted = agent.convert_tool_calls(tool_calls) + assert len(converted) == 1 + assert isinstance(converted[0], ToolCall) + assert converted[0].id == "1" + assert converted[0].function == "test_tool" + assert converted[0].args == '{"arg": "value"}' + + def test_process_tool_calls(self, agent): + tool_calls = [ToolCall(id="1", function="test_tool", args='{"arg": "value"}')] + with patch.object(agent, "call_tool") as mock_call_tool: + mock_call_tool.return_value = Message( + role="tool", content="Tool result", tool_call_id="1" + ) + agent.process_tool_calls(tool_calls) + assert len(agent.memory) == 1 + assert agent.memory[0].role == "tool" + assert agent.memory[0].content == "Tool result" - tool_call = ToolCall( - id="1", - function="mock_tool", - args='{"arg1": "value1\\nbar(\'\\n\')", "arg2": "value2", "arg3": 123}', + def test_update_usage(self, agent): + initial_usage = Usage(completion_tokens=10, prompt_tokens=20, total_tokens=30) + agent.usage = initial_usage + new_usage = Usage(completion_tokens=5, prompt_tokens=10, total_tokens=15) + agent.update_usage(new_usage) + expected_usage = Usage(completion_tokens=15, prompt_tokens=30, total_tokens=45) + assert agent.usage == expected_usage + + def test_run(self, agent, mock_client): + mock_message = Message(role="assistant", content="Final response") + mock_usage = Usage(completion_tokens=10, prompt_tokens=20, total_tokens=30) + mock_client.completion.return_value = (mock_message, mock_usage) + + result = agent.run("Test prompt") + + assert result == "Final response" + assert len(agent.memory) > 0 + assert agent.memory[0].role == "user" + assert agent.memory[0].content == "Test prompt" + + def test_run_max_iterations_exception(self, agent, mock_client): + agent.config.max_iterations = 1 + tool = FunctionTool( + name="test_tool", description="Test tool", fn=lambda: None, parameters=[] ) + agent.tools = [tool] - # Call the method - result = self.agent.call_tool(tool_call) + mock_message = Message( + role="assistant", + content="Response", + tool_calls=[ + ChatCompletionMessageToolCall( + id="1", + function=Function(name="test_tool", arguments='{"arg": "value"}'), + type="function", + ) + ], + ) + mock_usage = Usage(completion_tokens=10, prompt_tokens=20, total_tokens=30) + mock_client.completion.return_value = (mock_message, mock_usage) - # Assertions - self.assertEqual(result.content, "Tool called successfully") - self.assertEqual(result.role, "tool") - self.assertEqual(result.tool_call_id, "1") - mock_fn.assert_called_once_with(arg1="value1\nbar('\\n')", arg2="value2", arg3=123) + with pytest.raises(MaxIterationsReachedException): + agent.run("Test prompt") From 0aec9ea2aa84252df51eced5d41477f35f55a6ba Mon Sep 17 00:00:00 2001 From: "jenn.muengtaweepongsa" Date: Wed, 17 Jul 2024 22:37:21 +0700 Subject: [PATCH 2/6] dep inject --- src/seer/automation/agent/agent.py | 9 +++++--- src/seer/automation/agent/client.py | 31 +++++++++++++++++++++------- tests/automation/agent/test_agent.py | 24 +++++++++++---------- 3 files changed, 43 insertions(+), 21 deletions(-) diff --git a/src/seer/automation/agent/agent.py b/src/seer/automation/agent/agent.py index 540fab8d9..0fd2933fe 100644 --- a/src/seer/automation/agent/agent.py +++ b/src/seer/automation/agent/agent.py @@ -6,10 +6,11 @@ from openai.types.chat import ChatCompletionMessageToolCall from pydantic import BaseModel, Field -from seer.automation.agent.client import GptClient, LlmClient +from seer.automation.agent.client import GptClient from seer.automation.agent.models import Message, ToolCall, Usage from seer.automation.agent.tools import FunctionTool from seer.automation.agent.utils import parse_json_with_keys +from seer.dependency_injection import inject, injected logger = logging.getLogger("autofix") @@ -107,17 +108,18 @@ def parse_tool_arguments(self, tool: FunctionTool, args: str) -> dict: class GptAgent(LlmAgent): + @inject def __init__( self, config: AgentConfig, - client: Optional[LlmClient] = None, + client: GptClient = injected, tools: Optional[list[FunctionTool]] = None, memory: Optional[list[Message]] = None, name: str = "GptAgent", chat_completion_kwargs: Optional[dict] = None, ): super().__init__(config, tools, memory, name) - self.client = client or GptClient(model=config.model) + self.client = client self.chat_completion_kwargs = chat_completion_kwargs or {} def run_iteration(self): @@ -132,6 +134,7 @@ def run_iteration(self): def get_completion(self): return self.client.completion( messages=self.memory, + model=self.config.model, tools=([tool.to_dict() for tool in self.tools] if self.tools else NotGiven()), **self.chat_completion_kwargs, ) diff --git a/src/seer/automation/agent/client.py b/src/seer/automation/agent/client.py index bca948e51..32c45f717 100644 --- a/src/seer/automation/agent/client.py +++ b/src/seer/automation/agent/client.py @@ -7,6 +7,7 @@ from openai.types.chat import ChatCompletion from seer.automation.agent.models import Message, Usage +from seer.bootup import module, stub_module T = TypeVar("T") @@ -14,7 +15,7 @@ class LlmClient(ABC): @abstractmethod def completion( - self, messages: list[Message], **chat_completion_kwargs + self, messages: list[Message], model: str, **chat_completion_kwargs ) -> tuple[Message, Usage]: pass @@ -22,21 +23,26 @@ def completion_with_parser( self, messages: list[Message], parser: Callable[[str | None], T], + model: str, **chat_completion_kwargs, ) -> tuple[T, Message, Usage]: - message, usage = self.completion(messages, **chat_completion_kwargs) + message, usage = self.completion(messages, model, **chat_completion_kwargs) return parser(message.content), message, usage +DEFAULT_GPT_MODEL = "gpt-4o-2024-05-13" + + class GptClient(LlmClient): - def __init__(self, model: str = "gpt-4o-2024-05-13"): - self.model = model + def __init__(self): self.openai_client = openai.Client() - def completion(self, messages: list[Message], **chat_completion_kwargs): + def completion( + self, messages: list[Message], model=DEFAULT_GPT_MODEL, **chat_completion_kwargs + ): completion: ChatCompletion = self.openai_client.chat.completions.create( - model=self.model, + model=model, messages=[message.to_openai_message() for message in messages], temperature=0.0, **chat_completion_kwargs, @@ -53,16 +59,22 @@ def completion(self, messages: list[Message], **chat_completion_kwargs): return message, usage def json_completion( - self, messages: list[Message], **chat_completion_kwargs + self, messages: list[Message], model=DEFAULT_GPT_MODEL, **chat_completion_kwargs ) -> tuple[dict[str, Any] | None, Message, Usage]: return self.completion_with_parser( messages, parser=lambda x: json.loads(x) if x else None, + model=model, response_format={"type": "json_object"}, **chat_completion_kwargs, ) +@module.provider +def provide_gpt_client() -> GptClient: + return GptClient() + + GptCompletionHandler = Callable[[list[Message], dict[str, Any]], Optional[tuple[Message, Usage]]] @@ -80,3 +92,8 @@ def completion(self, messages: list[Message], **chat_completion_kwargs): return result self.missed_calls.append((messages, chat_completion_kwargs)) return Message(), Usage() + + +@stub_module.provider +def provide_stub_gpt_client() -> GptClient: + return DummyGptClient() diff --git a/tests/automation/agent/test_agent.py b/tests/automation/agent/test_agent.py index 8945cc149..9ff9ea7e4 100644 --- a/tests/automation/agent/test_agent.py +++ b/tests/automation/agent/test_agent.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, patch import pytest from openai.types.chat.chat_completion_message_tool_call import ( @@ -13,8 +13,10 @@ LlmAgent, MaxIterationsReachedException, ) +from seer.automation.agent.client import GptClient from seer.automation.agent.models import Message, ToolCall, Usage from seer.automation.agent.tools import FunctionTool +from seer.dependency_injection import resolve class TestAgentConfig: @@ -61,11 +63,11 @@ def test_should_continue(self, agent: LlmAgent): agent.iterations = agent.config.max_iterations agent.memory = [Message(role="assistant", content="STOP")] - assert agent.should_continue() # Max iterations reached + assert not agent.should_continue() # Max iterations reached agent.iterations = 1 agent.config.stop_message = "STOP" - assert agent.should_continue() # Stop message found + assert not agent.should_continue() # Stop message found def test_add_user_message(self, agent: LlmAgent): agent.add_user_message("Test message") @@ -126,12 +128,12 @@ def config(self): return AgentConfig() @pytest.fixture - def mock_client(self): - return Mock() + def agent(self, config): + return GptAgent(config) @pytest.fixture - def agent(self, config, mock_client): - return GptAgent(config, client=mock_client) + def mock_client(self): + return resolve(GptClient) def test_initialization(self, agent, config, mock_client): assert agent.config == config @@ -141,7 +143,7 @@ def test_initialization(self, agent, config, mock_client): def test_run_iteration(self, agent, mock_client): mock_message = Message(role="assistant", content="Test response") mock_usage = Usage(completion_tokens=10, prompt_tokens=20, total_tokens=30) - mock_client.completion.return_value = (mock_message, mock_usage) + mock_client.completion = MagicMock(return_value=(mock_message, mock_usage)) agent.run_iteration() @@ -153,7 +155,7 @@ def test_run_iteration(self, agent, mock_client): def test_get_completion(self, agent, mock_client): mock_message = Message(role="assistant", content="Test response") mock_usage = Usage(completion_tokens=10, prompt_tokens=20, total_tokens=30) - mock_client.completion.return_value = (mock_message, mock_usage) + mock_client.completion = MagicMock(return_value=(mock_message, mock_usage)) message, usage = agent.get_completion() @@ -204,7 +206,7 @@ def test_update_usage(self, agent): def test_run(self, agent, mock_client): mock_message = Message(role="assistant", content="Final response") mock_usage = Usage(completion_tokens=10, prompt_tokens=20, total_tokens=30) - mock_client.completion.return_value = (mock_message, mock_usage) + mock_client.completion = MagicMock(return_value=(mock_message, mock_usage)) result = agent.run("Test prompt") @@ -232,7 +234,7 @@ def test_run_max_iterations_exception(self, agent, mock_client): ], ) mock_usage = Usage(completion_tokens=10, prompt_tokens=20, total_tokens=30) - mock_client.completion.return_value = (mock_message, mock_usage) + mock_client.completion = MagicMock(return_value=(mock_message, mock_usage)) with pytest.raises(MaxIterationsReachedException): agent.run("Test prompt") From 41f3b94e778566ff6173dd7a3f13a358063b0e7a Mon Sep 17 00:00:00 2001 From: "jenn.muengtaweepongsa" Date: Wed, 17 Jul 2024 22:45:36 +0700 Subject: [PATCH 3/6] fixes & cleanup --- src/seer/automation/agent/agent.py | 6 +++--- .../automation/autofix/components/executor/component.py | 8 +++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/seer/automation/agent/agent.py b/src/seer/automation/agent/agent.py index 0fd2933fe..6c7a43e74 100644 --- a/src/seer/automation/agent/agent.py +++ b/src/seer/automation/agent/agent.py @@ -6,7 +6,7 @@ from openai.types.chat import ChatCompletionMessageToolCall from pydantic import BaseModel, Field -from seer.automation.agent.client import GptClient +from seer.automation.agent.client import DEFAULT_GPT_MODEL, GptClient from seer.automation.agent.models import Message, ToolCall, Usage from seer.automation.agent.tools import FunctionTool from seer.automation.agent.utils import parse_json_with_keys @@ -19,7 +19,7 @@ class AgentConfig(BaseModel): max_iterations: int = Field( default=16, description="Maximum number of iterations the agent can perform" ) - model: str = Field(default="gpt-4-0613", description="The model to be used by the agent") + model: str = Field(default=DEFAULT_GPT_MODEL, description="The model to be used by the agent") stop_message: Optional[str] = Field( default=None, description="Message that signals the agent to stop" ) @@ -111,7 +111,7 @@ class GptAgent(LlmAgent): @inject def __init__( self, - config: AgentConfig, + config: AgentConfig = AgentConfig(), client: GptClient = injected, tools: Optional[list[FunctionTool]] = None, memory: Optional[list[Message]] = None, diff --git a/src/seer/automation/autofix/components/executor/component.py b/src/seer/automation/autofix/components/executor/component.py index 3c5e9d99a..ea04d3bf4 100644 --- a/src/seer/automation/autofix/components/executor/component.py +++ b/src/seer/automation/autofix/components/executor/component.py @@ -1,7 +1,7 @@ from langfuse.decorators import observe from sentry_sdk.ai.monitoring import ai_track -from seer.automation.agent.agent import GptAgent +from seer.automation.agent.agent import AgentConfig, GptAgent from seer.automation.agent.models import Message from seer.automation.autofix.autofix_context import AutofixContext from seer.automation.autofix.components.executor.models import ExecutorOutput, ExecutorRequest @@ -22,7 +22,10 @@ def invoke(self, request: ExecutorRequest) -> None: code_action_tools = CodeActionTools(self.context) execution_agent = GptAgent( - name="executor", + config=AgentConfig( + stop_message="", + ), + name="Executor", tools=code_action_tools.get_tools(), memory=[ Message( @@ -30,7 +33,6 @@ def invoke(self, request: ExecutorRequest) -> None: content=ExecutionPrompts.format_system_msg(), ), ], - stop_message="", ) execution_agent.run( From f6cafbf40b53d5cea90777571ff9f958b1e55e40 Mon Sep 17 00:00:00 2001 From: "jenn.muengtaweepongsa" Date: Wed, 17 Jul 2024 22:50:25 +0700 Subject: [PATCH 4/6] mypy --- src/seer/automation/agent/client.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/seer/automation/agent/client.py b/src/seer/automation/agent/client.py index 32c45f717..cd22f4c09 100644 --- a/src/seer/automation/agent/client.py +++ b/src/seer/automation/agent/client.py @@ -58,6 +58,17 @@ def completion( return message, usage + def completion_with_parser( + self, + messages: list[Message], + parser: Callable[[str | None], T], + model=DEFAULT_GPT_MODEL, + **chat_completion_kwargs, + ) -> tuple[T, Message, Usage]: + message, usage = self.completion(messages, model, **chat_completion_kwargs) + + return parser(message.content), message, usage + def json_completion( self, messages: list[Message], model=DEFAULT_GPT_MODEL, **chat_completion_kwargs ) -> tuple[dict[str, Any] | None, Message, Usage]: @@ -85,7 +96,7 @@ class DummyGptClient(GptClient): default_factory=list ) - def completion(self, messages: list[Message], **chat_completion_kwargs): + def completion(self, messages: list[Message], model="test-gpt", **chat_completion_kwargs): for handler in self.handlers: result = handler(messages, chat_completion_kwargs) if result: From bf681efd90d14fc34567091c5592f9c0d3e8633a Mon Sep 17 00:00:00 2001 From: "jenn.muengtaweepongsa" Date: Wed, 17 Jul 2024 23:01:26 +0700 Subject: [PATCH 5/6] remove brittle test --- tests/automation/agent/test_agent.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/tests/automation/agent/test_agent.py b/tests/automation/agent/test_agent.py index 9ff9ea7e4..c5a6d42e0 100644 --- a/tests/automation/agent/test_agent.py +++ b/tests/automation/agent/test_agent.py @@ -5,7 +5,6 @@ ChatCompletionMessageToolCall, Function, ) -from pydantic import ValidationError from seer.automation.agent.agent import ( AgentConfig, @@ -19,24 +18,6 @@ from seer.dependency_injection import resolve -class TestAgentConfig: - def test_default_values(self): - config = AgentConfig() - assert config.max_iterations == 16 - assert config.model == "gpt-4-0613" - assert config.stop_message is None - - def test_custom_values(self): - config = AgentConfig(max_iterations=10, model="gpt-3.5-turbo", stop_message="STOP") - assert config.max_iterations == 10 - assert config.model == "gpt-3.5-turbo" - assert config.stop_message == "STOP" - - def test_validation(self): - with pytest.raises(ValidationError): - AgentConfig(max_iterations="not a number") # type: ignore - - class TestLlmAgent: @pytest.fixture def config(self): From d430b564906e83e11d9acb20ff62a336efcc70b7 Mon Sep 17 00:00:00 2001 From: "jenn.muengtaweepongsa" Date: Wed, 17 Jul 2024 23:37:36 +0700 Subject: [PATCH 6/6] remove insignificant initialization tests --- tests/automation/agent/test_agent.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/automation/agent/test_agent.py b/tests/automation/agent/test_agent.py index c5a6d42e0..9fe322cd3 100644 --- a/tests/automation/agent/test_agent.py +++ b/tests/automation/agent/test_agent.py @@ -31,14 +31,6 @@ def run_iteration(self): return TestAgent(config) - def test_initialization(self, agent: LlmAgent, config: AgentConfig): - assert agent.config == config - assert agent.tools == [] - assert agent.memory == [] - assert isinstance(agent.usage, Usage) - assert agent.name == "Agent" - assert agent.iterations == 0 - def test_should_continue(self, agent: LlmAgent): assert agent.should_continue() # Initial state @@ -116,11 +108,6 @@ def agent(self, config): def mock_client(self): return resolve(GptClient) - def test_initialization(self, agent, config, mock_client): - assert agent.config == config - assert agent.client == mock_client - assert agent.chat_completion_kwargs == {} - def test_run_iteration(self, agent, mock_client): mock_message = Message(role="assistant", content="Test response") mock_usage = Usage(completion_tokens=10, prompt_tokens=20, total_tokens=30)