Skip to content

Commit

Permalink
ref(agents): Better interface for LLM agents and clients + test cover…
Browse files Browse the repository at this point in the history
…age (#918)

A major refactor of our agents and clients to better strengthen key
structural pathways in Autofix & increase test coverage

+ Tested locally e2e with a root cause run
  • Loading branch information
jennmueng committed Jul 17, 2024
1 parent cc9ad01 commit b5d10ae
Show file tree
Hide file tree
Showing 4 changed files with 348 additions and 134 deletions.
199 changes: 109 additions & 90 deletions src/seer/automation/agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,150 +1,169 @@
import logging
from abc import ABC, abstractmethod
from typing import Any
from typing import Optional

from openai._types import NotGiven
from openai.types.chat import ChatCompletionMessageToolCall
from pydantic import BaseModel, Field

from seer.automation.agent.client import GptClient, LlmClient
from seer.automation.agent.client import DEFAULT_GPT_MODEL, GptClient
from seer.automation.agent.models import Message, ToolCall, Usage
from seer.automation.agent.tools import FunctionTool
from seer.automation.agent.utils import parse_json_with_keys
from seer.dependency_injection import inject, injected

logger = logging.getLogger("autofix")


class LlmAgent(ABC):
name: str
tools: list[FunctionTool]
memory: list[Message]
class AgentConfig(BaseModel):
max_iterations: int = Field(
default=16, description="Maximum number of iterations the agent can perform"
)
model: str = Field(default=DEFAULT_GPT_MODEL, description="The model to be used by the agent")
stop_message: Optional[str] = Field(
default=None, description="Message that signals the agent to stop"
)

class Config:
validate_assignment = True

client: LlmClient
iterations: int = 0
max_iterations: int = 16

class LlmAgent(ABC):
def __init__(
self,
tools: list[FunctionTool] | None = None,
memory: list[Message] | None = None,
name="Agent",
stop_message: str | None = None,
config: AgentConfig,
tools: Optional[list[FunctionTool]] = None,
memory: Optional[list[Message]] = None,
name: str = "Agent",
):
self.config = config
self.tools = tools or []
self.memory = memory or []
self.usage = Usage()
self.name = name
self.stop_message = stop_message
self.iterations = 0

@abstractmethod
def run_iteration(self):
pass

def run(self, prompt: str):
self.memory.append(
Message(
role="user",
content=prompt,
)
)
def should_continue(self) -> bool:
# If this is the first iteration or there are no messages, continue
if self.iterations == 0 or not self.memory:
return True

# Stop if we've reached the maximum number of iterations
if self.iterations >= self.config.max_iterations:
return False

last_message = self.memory[-1]
if last_message and last_message.role in ["assistant", "model"]:
if last_message.content:
# Stop if the stop message is found in the content
if self.config.stop_message and self.config.stop_message in last_message.content:
return False
# Stop if there are no tool calls
if not last_message.tool_calls:
return False

# Continue in all other cases
return True

def run(self, prompt: str):
self.add_user_message(prompt)
logger.debug(f"----[{self.name}] Running Agent----")
logger.debug("Previous messages: ")
for message in self.memory:
logger.debug(f"{message.role}: {message.content}")

while (
self.iterations == 0
or (
not (
self.memory[-1].role
in ["assistant", "model"] # Will never end on a message not from the assistant
and (
self.memory[-1].content
and (
self.stop_message in self.memory[-1].content
) # If stop message is defined; will end if the assistant response contains the stop message
if self.stop_message
else self.memory[-1].content
is not None # If a stop message is not defined; will end on any non-empty assistant response (OpenAI tool call does not output a message!)
)
)
)
and self.iterations < self.max_iterations # Went above max iterations
):
# runs until the assistant sends a message with no more tool calls.

while self.should_continue():
self.run_iteration()

if self.iterations == self.max_iterations:
raise Exception(f"Agent {self.name} reached maximum iterations without finishing.")
if self.iterations == self.config.max_iterations:
raise MaxIterationsReachedException(
f"Agent {self.name} reached maximum iterations without finishing."
)

return self.memory[-1].content
return self.get_last_message_content()

def call_tool(self, tool_call: ToolCall):
logger.debug(
f"[{tool_call.id}] Calling tool {tool_call.function} with arguments {tool_call.args}"
)
def add_user_message(self, content: str):
self.memory.append(Message(role="user", content=content))

tool = next(tool for tool in self.tools if tool.name == tool_call.function)
def get_last_message_content(self) -> str | None:
return self.memory[-1].content if self.memory else None

kwargs = parse_json_with_keys(
tool_call.args,
[param["name"] for param in tool.parameters if isinstance(param["name"], str)],
)
def call_tool(self, tool_call: ToolCall) -> Message:
logger.debug(f"[{tool_call.id}] Calling tool {tool_call.function}")

tool = self.get_tool_by_name(tool_call.function)
kwargs = self.parse_tool_arguments(tool, tool_call.args)
tool_result = tool.call(**kwargs)

logger.debug(f"Tool {tool_call.function} returned \n{tool_result}")
return Message(role="tool", content=tool_result, tool_call_id=tool_call.id)

return Message(
role="tool",
content=tool_result,
tool_call_id=tool_call.id,
def get_tool_by_name(self, name: str) -> FunctionTool:
return next(tool for tool in self.tools if tool.name == name)

def parse_tool_arguments(self, tool: FunctionTool, args: str) -> dict:
return parse_json_with_keys(
args, [param["name"] for param in tool.parameters if isinstance(param["name"], str)]
)


class GptAgent(LlmAgent):
model: str = "gpt-4o-2024-05-13"

chat_completion_kwargs: dict[str, Any] = {}

@inject
def __init__(
self,
tools: list[FunctionTool] | None = None,
memory: list[Message] | None = None,
name="GptAgent",
chat_completion_kwargs=None,
stop_message: str | None = None,
config: AgentConfig = AgentConfig(),
client: GptClient = injected,
tools: Optional[list[FunctionTool]] = None,
memory: Optional[list[Message]] = None,
name: str = "GptAgent",
chat_completion_kwargs: Optional[dict] = None,
):
super().__init__(tools, memory, name=name, stop_message=stop_message)
self.client = GptClient(model=self.model)

super().__init__(config, tools, memory, name)
self.client = client
self.chat_completion_kwargs = chat_completion_kwargs or {}

def run_iteration(self):
logger.debug(f"----[{self.name}] Running Iteration {self.iterations}----")

message, usage = self.client.completion(
message, usage = self.get_completion()
self.process_message(message)
self.update_usage(usage)

return self.memory

def get_completion(self):
return self.client.completion(
messages=self.memory,
tools=([tool.to_dict() for tool in self.tools] if len(self.tools) > 0 else NotGiven()),
model=self.config.model,
tools=([tool.to_dict() for tool in self.tools] if self.tools else NotGiven()),
**self.chat_completion_kwargs,
)

def process_message(self, message: Message):
self.memory.append(message)

logger.debug(f"Message content:\n{message.content}")
logger.debug(f"Message tool calls:\n{message.tool_calls}")

if message.tool_calls:
for tool_call in message.tool_calls:
tool_response = self.call_tool(
ToolCall(
id=tool_call.id,
function=tool_call.function.name,
args=tool_call.function.arguments,
)
)

self.memory.append(tool_response)
converted_tool_calls = self.convert_tool_calls(message.tool_calls)
self.process_tool_calls(converted_tool_calls)

self.iterations += 1

def convert_tool_calls(self, tool_calls: list[ChatCompletionMessageToolCall]) -> list[ToolCall]:
return [
ToolCall(
id=tool_call.id, function=tool_call.function.name, args=tool_call.function.arguments
)
for tool_call in tool_calls
]

def process_tool_calls(self, tool_calls: list[ToolCall]):
for tool_call in tool_calls:
tool_response = self.call_tool(tool_call)
self.memory.append(tool_response)

def update_usage(self, usage: Usage):
self.usage += usage

return self.memory

class MaxIterationsReachedException(Exception):
pass
44 changes: 36 additions & 8 deletions src/seer/automation/agent/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,42 @@
from openai.types.chat import ChatCompletion

from seer.automation.agent.models import Message, Usage
from seer.bootup import module, stub_module

T = TypeVar("T")


class LlmClient(ABC):
@abstractmethod
def completion(
self, messages: list[Message], **chat_completion_kwargs
self, messages: list[Message], model: str, **chat_completion_kwargs
) -> tuple[Message, Usage]:
pass

def completion_with_parser(
self,
messages: list[Message],
parser: Callable[[str | None], T],
model: str,
**chat_completion_kwargs,
) -> tuple[T, Message, Usage]:
message, usage = self.completion(messages, **chat_completion_kwargs)
message, usage = self.completion(messages, model, **chat_completion_kwargs)

return parser(message.content), message, usage


DEFAULT_GPT_MODEL = "gpt-4o-2024-05-13"


class GptClient(LlmClient):
def __init__(self, model: str = "gpt-4o-2024-05-13"):
self.model = model
def __init__(self):
self.openai_client = openai.Client()

def completion(self, messages: list[Message], **chat_completion_kwargs):
def completion(
self, messages: list[Message], model=DEFAULT_GPT_MODEL, **chat_completion_kwargs
):
completion: ChatCompletion = self.openai_client.chat.completions.create(
model=self.model,
model=model,
messages=[message.to_openai_message() for message in messages],
temperature=0.0,
**chat_completion_kwargs,
Expand All @@ -52,17 +58,34 @@ def completion(self, messages: list[Message], **chat_completion_kwargs):

return message, usage

def completion_with_parser(
self,
messages: list[Message],
parser: Callable[[str | None], T],
model=DEFAULT_GPT_MODEL,
**chat_completion_kwargs,
) -> tuple[T, Message, Usage]:
message, usage = self.completion(messages, model, **chat_completion_kwargs)

return parser(message.content), message, usage

def json_completion(
self, messages: list[Message], **chat_completion_kwargs
self, messages: list[Message], model=DEFAULT_GPT_MODEL, **chat_completion_kwargs
) -> tuple[dict[str, Any] | None, Message, Usage]:
return self.completion_with_parser(
messages,
parser=lambda x: json.loads(x) if x else None,
model=model,
response_format={"type": "json_object"},
**chat_completion_kwargs,
)


@module.provider
def provide_gpt_client() -> GptClient:
return GptClient()


GptCompletionHandler = Callable[[list[Message], dict[str, Any]], Optional[tuple[Message, Usage]]]


Expand All @@ -73,10 +96,15 @@ class DummyGptClient(GptClient):
default_factory=list
)

def completion(self, messages: list[Message], **chat_completion_kwargs):
def completion(self, messages: list[Message], model="test-gpt", **chat_completion_kwargs):
for handler in self.handlers:
result = handler(messages, chat_completion_kwargs)
if result:
return result
self.missed_calls.append((messages, chat_completion_kwargs))
return Message(), Usage()


@stub_module.provider
def provide_stub_gpt_client() -> GptClient:
return DummyGptClient()
8 changes: 5 additions & 3 deletions src/seer/automation/autofix/components/executor/component.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from langfuse.decorators import observe
from sentry_sdk.ai.monitoring import ai_track

from seer.automation.agent.agent import GptAgent
from seer.automation.agent.agent import AgentConfig, GptAgent
from seer.automation.agent.models import Message
from seer.automation.autofix.autofix_context import AutofixContext
from seer.automation.autofix.components.executor.models import ExecutorOutput, ExecutorRequest
Expand All @@ -22,15 +22,17 @@ def invoke(self, request: ExecutorRequest) -> None:
code_action_tools = CodeActionTools(self.context)

execution_agent = GptAgent(
name="executor",
config=AgentConfig(
stop_message="<DONE>",
),
name="Executor",
tools=code_action_tools.get_tools(),
memory=[
Message(
role="system",
content=ExecutionPrompts.format_system_msg(),
),
],
stop_message="<DONE>",
)

execution_agent.run(
Expand Down
Loading

0 comments on commit b5d10ae

Please sign in to comment.