Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref(agents): Better interface for LLM agents and clients + test coverage #918

Merged
merged 6 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 109 additions & 90 deletions src/seer/automation/agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,150 +1,169 @@
import logging
from abc import ABC, abstractmethod
from typing import Any
from typing import Optional

from openai._types import NotGiven
from openai.types.chat import ChatCompletionMessageToolCall
from pydantic import BaseModel, Field

from seer.automation.agent.client import GptClient, LlmClient
from seer.automation.agent.client import DEFAULT_GPT_MODEL, GptClient
from seer.automation.agent.models import Message, ToolCall, Usage
from seer.automation.agent.tools import FunctionTool
from seer.automation.agent.utils import parse_json_with_keys
from seer.dependency_injection import inject, injected

logger = logging.getLogger("autofix")


class LlmAgent(ABC):
name: str
tools: list[FunctionTool]
memory: list[Message]
class AgentConfig(BaseModel):
max_iterations: int = Field(
default=16, description="Maximum number of iterations the agent can perform"
)
model: str = Field(default=DEFAULT_GPT_MODEL, description="The model to be used by the agent")
stop_message: Optional[str] = Field(
default=None, description="Message that signals the agent to stop"
)

class Config:
validate_assignment = True

client: LlmClient
iterations: int = 0
max_iterations: int = 16

class LlmAgent(ABC):
def __init__(
self,
tools: list[FunctionTool] | None = None,
memory: list[Message] | None = None,
name="Agent",
stop_message: str | None = None,
config: AgentConfig,
tools: Optional[list[FunctionTool]] = None,
memory: Optional[list[Message]] = None,
name: str = "Agent",
):
self.config = config
self.tools = tools or []
self.memory = memory or []
self.usage = Usage()
self.name = name
self.stop_message = stop_message
self.iterations = 0

@abstractmethod
def run_iteration(self):
pass

def run(self, prompt: str):
self.memory.append(
Message(
role="user",
content=prompt,
)
)
def should_continue(self) -> bool:
# If this is the first iteration or there are no messages, continue
if self.iterations == 0 or not self.memory:
return True

# Stop if we've reached the maximum number of iterations
if self.iterations >= self.config.max_iterations:
return False

last_message = self.memory[-1]
if last_message and last_message.role in ["assistant", "model"]:
if last_message.content:
# Stop if the stop message is found in the content
if self.config.stop_message and self.config.stop_message in last_message.content:
return False
# Stop if there are no tool calls
if not last_message.tool_calls:
return False

# Continue in all other cases
return True

def run(self, prompt: str):
self.add_user_message(prompt)
logger.debug(f"----[{self.name}] Running Agent----")
logger.debug("Previous messages: ")
for message in self.memory:
logger.debug(f"{message.role}: {message.content}")

while (
self.iterations == 0
or (
not (
self.memory[-1].role
in ["assistant", "model"] # Will never end on a message not from the assistant
and (
self.memory[-1].content
and (
self.stop_message in self.memory[-1].content
) # If stop message is defined; will end if the assistant response contains the stop message
if self.stop_message
else self.memory[-1].content
is not None # If a stop message is not defined; will end on any non-empty assistant response (OpenAI tool call does not output a message!)
)
)
)
and self.iterations < self.max_iterations # Went above max iterations
):
# runs until the assistant sends a message with no more tool calls.

while self.should_continue():
self.run_iteration()

if self.iterations == self.max_iterations:
raise Exception(f"Agent {self.name} reached maximum iterations without finishing.")
if self.iterations == self.config.max_iterations:
raise MaxIterationsReachedException(
f"Agent {self.name} reached maximum iterations without finishing."
)

return self.memory[-1].content
return self.get_last_message_content()

def call_tool(self, tool_call: ToolCall):
logger.debug(
f"[{tool_call.id}] Calling tool {tool_call.function} with arguments {tool_call.args}"
)
def add_user_message(self, content: str):
self.memory.append(Message(role="user", content=content))

tool = next(tool for tool in self.tools if tool.name == tool_call.function)
def get_last_message_content(self) -> str | None:
return self.memory[-1].content if self.memory else None

kwargs = parse_json_with_keys(
tool_call.args,
[param["name"] for param in tool.parameters if isinstance(param["name"], str)],
)
def call_tool(self, tool_call: ToolCall) -> Message:
logger.debug(f"[{tool_call.id}] Calling tool {tool_call.function}")

tool = self.get_tool_by_name(tool_call.function)
kwargs = self.parse_tool_arguments(tool, tool_call.args)
tool_result = tool.call(**kwargs)

logger.debug(f"Tool {tool_call.function} returned \n{tool_result}")
return Message(role="tool", content=tool_result, tool_call_id=tool_call.id)

return Message(
role="tool",
content=tool_result,
tool_call_id=tool_call.id,
def get_tool_by_name(self, name: str) -> FunctionTool:
return next(tool for tool in self.tools if tool.name == name)

def parse_tool_arguments(self, tool: FunctionTool, args: str) -> dict:
return parse_json_with_keys(
args, [param["name"] for param in tool.parameters if isinstance(param["name"], str)]
)


class GptAgent(LlmAgent):
model: str = "gpt-4o-2024-05-13"

chat_completion_kwargs: dict[str, Any] = {}

@inject
def __init__(
self,
tools: list[FunctionTool] | None = None,
memory: list[Message] | None = None,
name="GptAgent",
chat_completion_kwargs=None,
stop_message: str | None = None,
config: AgentConfig = AgentConfig(),
client: GptClient = injected,
tools: Optional[list[FunctionTool]] = None,
memory: Optional[list[Message]] = None,
name: str = "GptAgent",
chat_completion_kwargs: Optional[dict] = None,
):
super().__init__(tools, memory, name=name, stop_message=stop_message)
self.client = GptClient(model=self.model)

super().__init__(config, tools, memory, name)
self.client = client
self.chat_completion_kwargs = chat_completion_kwargs or {}

def run_iteration(self):
logger.debug(f"----[{self.name}] Running Iteration {self.iterations}----")

message, usage = self.client.completion(
message, usage = self.get_completion()
self.process_message(message)
self.update_usage(usage)

return self.memory

def get_completion(self):
return self.client.completion(
messages=self.memory,
tools=([tool.to_dict() for tool in self.tools] if len(self.tools) > 0 else NotGiven()),
model=self.config.model,
tools=([tool.to_dict() for tool in self.tools] if self.tools else NotGiven()),
**self.chat_completion_kwargs,
)

def process_message(self, message: Message):
self.memory.append(message)

logger.debug(f"Message content:\n{message.content}")
logger.debug(f"Message tool calls:\n{message.tool_calls}")

if message.tool_calls:
for tool_call in message.tool_calls:
tool_response = self.call_tool(
ToolCall(
id=tool_call.id,
function=tool_call.function.name,
args=tool_call.function.arguments,
)
)

self.memory.append(tool_response)
converted_tool_calls = self.convert_tool_calls(message.tool_calls)
self.process_tool_calls(converted_tool_calls)

self.iterations += 1

def convert_tool_calls(self, tool_calls: list[ChatCompletionMessageToolCall]) -> list[ToolCall]:
return [
ToolCall(
id=tool_call.id, function=tool_call.function.name, args=tool_call.function.arguments
)
for tool_call in tool_calls
]

def process_tool_calls(self, tool_calls: list[ToolCall]):
for tool_call in tool_calls:
tool_response = self.call_tool(tool_call)
self.memory.append(tool_response)

def update_usage(self, usage: Usage):
self.usage += usage

return self.memory

class MaxIterationsReachedException(Exception):
pass
44 changes: 36 additions & 8 deletions src/seer/automation/agent/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,42 @@
from openai.types.chat import ChatCompletion

from seer.automation.agent.models import Message, Usage
from seer.bootup import module, stub_module

T = TypeVar("T")


class LlmClient(ABC):
@abstractmethod
def completion(
self, messages: list[Message], **chat_completion_kwargs
self, messages: list[Message], model: str, **chat_completion_kwargs
) -> tuple[Message, Usage]:
pass

def completion_with_parser(
self,
messages: list[Message],
parser: Callable[[str | None], T],
model: str,
**chat_completion_kwargs,
) -> tuple[T, Message, Usage]:
message, usage = self.completion(messages, **chat_completion_kwargs)
message, usage = self.completion(messages, model, **chat_completion_kwargs)

return parser(message.content), message, usage


DEFAULT_GPT_MODEL = "gpt-4o-2024-05-13"


class GptClient(LlmClient):
def __init__(self, model: str = "gpt-4o-2024-05-13"):
self.model = model
def __init__(self):
self.openai_client = openai.Client()

def completion(self, messages: list[Message], **chat_completion_kwargs):
def completion(
self, messages: list[Message], model=DEFAULT_GPT_MODEL, **chat_completion_kwargs
):
completion: ChatCompletion = self.openai_client.chat.completions.create(
model=self.model,
model=model,
messages=[message.to_openai_message() for message in messages],
temperature=0.0,
**chat_completion_kwargs,
Expand All @@ -52,17 +58,34 @@ def completion(self, messages: list[Message], **chat_completion_kwargs):

return message, usage

def completion_with_parser(
self,
messages: list[Message],
parser: Callable[[str | None], T],
model=DEFAULT_GPT_MODEL,
**chat_completion_kwargs,
) -> tuple[T, Message, Usage]:
message, usage = self.completion(messages, model, **chat_completion_kwargs)

return parser(message.content), message, usage

def json_completion(
self, messages: list[Message], **chat_completion_kwargs
self, messages: list[Message], model=DEFAULT_GPT_MODEL, **chat_completion_kwargs
) -> tuple[dict[str, Any] | None, Message, Usage]:
return self.completion_with_parser(
messages,
parser=lambda x: json.loads(x) if x else None,
model=model,
response_format={"type": "json_object"},
**chat_completion_kwargs,
)


@module.provider
def provide_gpt_client() -> GptClient:
return GptClient()


GptCompletionHandler = Callable[[list[Message], dict[str, Any]], Optional[tuple[Message, Usage]]]


Expand All @@ -73,10 +96,15 @@ class DummyGptClient(GptClient):
default_factory=list
)

def completion(self, messages: list[Message], **chat_completion_kwargs):
def completion(self, messages: list[Message], model="test-gpt", **chat_completion_kwargs):
for handler in self.handlers:
result = handler(messages, chat_completion_kwargs)
if result:
return result
self.missed_calls.append((messages, chat_completion_kwargs))
return Message(), Usage()


@stub_module.provider
def provide_stub_gpt_client() -> GptClient:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

return DummyGptClient()
8 changes: 5 additions & 3 deletions src/seer/automation/autofix/components/executor/component.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from langfuse.decorators import observe
from sentry_sdk.ai.monitoring import ai_track

from seer.automation.agent.agent import GptAgent
from seer.automation.agent.agent import AgentConfig, GptAgent
from seer.automation.agent.models import Message
from seer.automation.autofix.autofix_context import AutofixContext
from seer.automation.autofix.components.executor.models import ExecutorOutput, ExecutorRequest
Expand All @@ -22,15 +22,17 @@ def invoke(self, request: ExecutorRequest) -> None:
code_action_tools = CodeActionTools(self.context)

execution_agent = GptAgent(
name="executor",
config=AgentConfig(
stop_message="<DONE>",
),
name="Executor",
tools=code_action_tools.get_tools(),
memory=[
Message(
role="system",
content=ExecutionPrompts.format_system_msg(),
),
],
stop_message="<DONE>",
)

execution_agent.run(
Expand Down
Loading
Loading