-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into trillville/second-migration
- Loading branch information
Showing
16 changed files
with
576 additions
and
242 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,150 +1,169 @@ | ||
import logging | ||
from abc import ABC, abstractmethod | ||
from typing import Any | ||
from typing import Optional | ||
|
||
from openai._types import NotGiven | ||
from openai.types.chat import ChatCompletionMessageToolCall | ||
from pydantic import BaseModel, Field | ||
|
||
from seer.automation.agent.client import GptClient, LlmClient | ||
from seer.automation.agent.client import DEFAULT_GPT_MODEL, GptClient | ||
from seer.automation.agent.models import Message, ToolCall, Usage | ||
from seer.automation.agent.tools import FunctionTool | ||
from seer.automation.agent.utils import parse_json_with_keys | ||
from seer.dependency_injection import inject, injected | ||
|
||
logger = logging.getLogger("autofix") | ||
|
||
|
||
class LlmAgent(ABC): | ||
name: str | ||
tools: list[FunctionTool] | ||
memory: list[Message] | ||
class AgentConfig(BaseModel): | ||
max_iterations: int = Field( | ||
default=16, description="Maximum number of iterations the agent can perform" | ||
) | ||
model: str = Field(default=DEFAULT_GPT_MODEL, description="The model to be used by the agent") | ||
stop_message: Optional[str] = Field( | ||
default=None, description="Message that signals the agent to stop" | ||
) | ||
|
||
class Config: | ||
validate_assignment = True | ||
|
||
client: LlmClient | ||
iterations: int = 0 | ||
max_iterations: int = 16 | ||
|
||
class LlmAgent(ABC): | ||
def __init__( | ||
self, | ||
tools: list[FunctionTool] | None = None, | ||
memory: list[Message] | None = None, | ||
name="Agent", | ||
stop_message: str | None = None, | ||
config: AgentConfig, | ||
tools: Optional[list[FunctionTool]] = None, | ||
memory: Optional[list[Message]] = None, | ||
name: str = "Agent", | ||
): | ||
self.config = config | ||
self.tools = tools or [] | ||
self.memory = memory or [] | ||
self.usage = Usage() | ||
self.name = name | ||
self.stop_message = stop_message | ||
self.iterations = 0 | ||
|
||
@abstractmethod | ||
def run_iteration(self): | ||
pass | ||
|
||
def run(self, prompt: str): | ||
self.memory.append( | ||
Message( | ||
role="user", | ||
content=prompt, | ||
) | ||
) | ||
def should_continue(self) -> bool: | ||
# If this is the first iteration or there are no messages, continue | ||
if self.iterations == 0 or not self.memory: | ||
return True | ||
|
||
# Stop if we've reached the maximum number of iterations | ||
if self.iterations >= self.config.max_iterations: | ||
return False | ||
|
||
last_message = self.memory[-1] | ||
if last_message and last_message.role in ["assistant", "model"]: | ||
if last_message.content: | ||
# Stop if the stop message is found in the content | ||
if self.config.stop_message and self.config.stop_message in last_message.content: | ||
return False | ||
# Stop if there are no tool calls | ||
if not last_message.tool_calls: | ||
return False | ||
|
||
# Continue in all other cases | ||
return True | ||
|
||
def run(self, prompt: str): | ||
self.add_user_message(prompt) | ||
logger.debug(f"----[{self.name}] Running Agent----") | ||
logger.debug("Previous messages: ") | ||
for message in self.memory: | ||
logger.debug(f"{message.role}: {message.content}") | ||
|
||
while ( | ||
self.iterations == 0 | ||
or ( | ||
not ( | ||
self.memory[-1].role | ||
in ["assistant", "model"] # Will never end on a message not from the assistant | ||
and ( | ||
self.memory[-1].content | ||
and ( | ||
self.stop_message in self.memory[-1].content | ||
) # If stop message is defined; will end if the assistant response contains the stop message | ||
if self.stop_message | ||
else self.memory[-1].content | ||
is not None # If a stop message is not defined; will end on any non-empty assistant response (OpenAI tool call does not output a message!) | ||
) | ||
) | ||
) | ||
and self.iterations < self.max_iterations # Went above max iterations | ||
): | ||
# runs until the assistant sends a message with no more tool calls. | ||
|
||
while self.should_continue(): | ||
self.run_iteration() | ||
|
||
if self.iterations == self.max_iterations: | ||
raise Exception(f"Agent {self.name} reached maximum iterations without finishing.") | ||
if self.iterations == self.config.max_iterations: | ||
raise MaxIterationsReachedException( | ||
f"Agent {self.name} reached maximum iterations without finishing." | ||
) | ||
|
||
return self.memory[-1].content | ||
return self.get_last_message_content() | ||
|
||
def call_tool(self, tool_call: ToolCall): | ||
logger.debug( | ||
f"[{tool_call.id}] Calling tool {tool_call.function} with arguments {tool_call.args}" | ||
) | ||
def add_user_message(self, content: str): | ||
self.memory.append(Message(role="user", content=content)) | ||
|
||
tool = next(tool for tool in self.tools if tool.name == tool_call.function) | ||
def get_last_message_content(self) -> str | None: | ||
return self.memory[-1].content if self.memory else None | ||
|
||
kwargs = parse_json_with_keys( | ||
tool_call.args, | ||
[param["name"] for param in tool.parameters if isinstance(param["name"], str)], | ||
) | ||
def call_tool(self, tool_call: ToolCall) -> Message: | ||
logger.debug(f"[{tool_call.id}] Calling tool {tool_call.function}") | ||
|
||
tool = self.get_tool_by_name(tool_call.function) | ||
kwargs = self.parse_tool_arguments(tool, tool_call.args) | ||
tool_result = tool.call(**kwargs) | ||
|
||
logger.debug(f"Tool {tool_call.function} returned \n{tool_result}") | ||
return Message(role="tool", content=tool_result, tool_call_id=tool_call.id) | ||
|
||
return Message( | ||
role="tool", | ||
content=tool_result, | ||
tool_call_id=tool_call.id, | ||
def get_tool_by_name(self, name: str) -> FunctionTool: | ||
return next(tool for tool in self.tools if tool.name == name) | ||
|
||
def parse_tool_arguments(self, tool: FunctionTool, args: str) -> dict: | ||
return parse_json_with_keys( | ||
args, [param["name"] for param in tool.parameters if isinstance(param["name"], str)] | ||
) | ||
|
||
|
||
class GptAgent(LlmAgent): | ||
model: str = "gpt-4o-2024-05-13" | ||
|
||
chat_completion_kwargs: dict[str, Any] = {} | ||
|
||
@inject | ||
def __init__( | ||
self, | ||
tools: list[FunctionTool] | None = None, | ||
memory: list[Message] | None = None, | ||
name="GptAgent", | ||
chat_completion_kwargs=None, | ||
stop_message: str | None = None, | ||
config: AgentConfig = AgentConfig(), | ||
client: GptClient = injected, | ||
tools: Optional[list[FunctionTool]] = None, | ||
memory: Optional[list[Message]] = None, | ||
name: str = "GptAgent", | ||
chat_completion_kwargs: Optional[dict] = None, | ||
): | ||
super().__init__(tools, memory, name=name, stop_message=stop_message) | ||
self.client = GptClient(model=self.model) | ||
|
||
super().__init__(config, tools, memory, name) | ||
self.client = client | ||
self.chat_completion_kwargs = chat_completion_kwargs or {} | ||
|
||
def run_iteration(self): | ||
logger.debug(f"----[{self.name}] Running Iteration {self.iterations}----") | ||
|
||
message, usage = self.client.completion( | ||
message, usage = self.get_completion() | ||
self.process_message(message) | ||
self.update_usage(usage) | ||
|
||
return self.memory | ||
|
||
def get_completion(self): | ||
return self.client.completion( | ||
messages=self.memory, | ||
tools=([tool.to_dict() for tool in self.tools] if len(self.tools) > 0 else NotGiven()), | ||
model=self.config.model, | ||
tools=([tool.to_dict() for tool in self.tools] if self.tools else NotGiven()), | ||
**self.chat_completion_kwargs, | ||
) | ||
|
||
def process_message(self, message: Message): | ||
self.memory.append(message) | ||
|
||
logger.debug(f"Message content:\n{message.content}") | ||
logger.debug(f"Message tool calls:\n{message.tool_calls}") | ||
|
||
if message.tool_calls: | ||
for tool_call in message.tool_calls: | ||
tool_response = self.call_tool( | ||
ToolCall( | ||
id=tool_call.id, | ||
function=tool_call.function.name, | ||
args=tool_call.function.arguments, | ||
) | ||
) | ||
|
||
self.memory.append(tool_response) | ||
converted_tool_calls = self.convert_tool_calls(message.tool_calls) | ||
self.process_tool_calls(converted_tool_calls) | ||
|
||
self.iterations += 1 | ||
|
||
def convert_tool_calls(self, tool_calls: list[ChatCompletionMessageToolCall]) -> list[ToolCall]: | ||
return [ | ||
ToolCall( | ||
id=tool_call.id, function=tool_call.function.name, args=tool_call.function.arguments | ||
) | ||
for tool_call in tool_calls | ||
] | ||
|
||
def process_tool_calls(self, tool_calls: list[ToolCall]): | ||
for tool_call in tool_calls: | ||
tool_response = self.call_tool(tool_call) | ||
self.memory.append(tool_response) | ||
|
||
def update_usage(self, usage: Usage): | ||
self.usage += usage | ||
|
||
return self.memory | ||
|
||
class MaxIterationsReachedException(Exception): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.