Skip to content

Commit

Permalink
feat(autofix): Reproduction in root causes (#1067)
Browse files Browse the repository at this point in the history
1. Introduces reproduction step in root cause output.
2. Uses gpt4o w/ structured output for the root cause for increased
reliability
3. (Tacking on) Adds claude agent back into the plan+code agent

![Screenshot 2024-09-10 at 10 45
47 AM](https://github.com/user-attachments/assets/8a1ceb0f-6fec-4c30-8fcf-f4e34e259375)

### Eval results:
(Before is above, after is below)
![Screenshot 2024-09-10 at 10 48
23 AM](https://github.com/user-attachments/assets/190c8919-56c1-4ab8-b19f-683b1c8a511c)
  • Loading branch information
jennmueng committed Sep 10, 2024
1 parent 950dfa3 commit c20a95e
Show file tree
Hide file tree
Showing 11 changed files with 243 additions and 279 deletions.
13 changes: 13 additions & 0 deletions src/seer/automation/agent/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,19 @@ def json_completion(
) -> tuple[dict[str, Any] | None, Message, Usage]:
return super().json_completion(messages, model, system_prompt)

def clean_tool_call_assistant_messages(self, messages: list[Message]) -> list[Message]:
new_messages = []
for message in messages:
if message.role == "assistant" and message.tool_calls:
new_messages.append(
Message(role="assistant", content=message.content, tool_calls=[])
)
elif message.role == "tool":
new_messages.append(Message(role="user", content=message.content, tool_calls=[]))
else:
new_messages.append(message)
return new_messages


class ClaudeClient(LlmClient):
@inject
Expand Down
4 changes: 2 additions & 2 deletions src/seer/automation/autofix/components/coding/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from langfuse.decorators import observe
from sentry_sdk.ai.monitoring import ai_track

from seer.automation.agent.agent import AgentConfig, GptAgent
from seer.automation.agent.agent import AgentConfig, ClaudeAgent
from seer.automation.autofix.autofix_context import AutofixContext
from seer.automation.autofix.components.coding.models import (
CodingOutput,
Expand Down Expand Up @@ -38,7 +38,7 @@ def _append_file_change(self, repo_external_id: str, file_change: FileChange):
def invoke(self, request: CodingRequest) -> CodingOutput | None:
tools = BaseTools(self.context)

agent = GptAgent(
agent = ClaudeAgent(
tools=tools.get_tools(),
config=AgentConfig(system_prompt=CodingPrompts.format_system_msg()),
)
Expand Down
2 changes: 2 additions & 0 deletions src/seer/automation/autofix/components/coding/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ def from_root_cause_context(cls, context: RootCauseRelevantContext):
class RootCausePlanTaskPromptXml(PromptXmlModel, tag="root_cause", skip_empty=True):
title: str = element()
description: str = element()
reproduction: str = element()
contexts: list[CodeContextXml]

@classmethod
def from_root_cause(cls, root_cause: RootCauseAnalysisItem):
return cls(
title=root_cause.title,
description=root_cause.description,
reproduction=root_cause.reproduction,
contexts=(
[
CodeContextXml.from_root_cause_context(context)
Expand Down
121 changes: 65 additions & 56 deletions src/seer/automation/autofix/components/root_cause/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
from sentry_sdk.ai.monitoring import ai_track

from seer.automation.agent.agent import AgentConfig, GptAgent
from seer.automation.agent.client import GptClient
from seer.automation.agent.models import Message
from seer.automation.autofix.autofix_context import AutofixContext
from seer.automation.autofix.components.root_cause.models import (
MultipleRootCauseAnalysisOutputPrompt,
RootCauseAnalysisOutput,
RootCauseAnalysisOutputPromptXml,
RootCauseAnalysisRequest,
)
from seer.automation.autofix.components.root_cause.prompts import RootCauseAnalysisPrompts
from seer.automation.autofix.tools import BaseTools
from seer.automation.component import BaseComponent
from seer.automation.utils import escape_multi_xml, extract_text_inside_tags
from seer.automation.utils import extract_parsed_model
from seer.dependency_injection import inject, injected

logger = logging.getLogger(__name__)

Expand All @@ -23,7 +26,10 @@ class RootCauseAnalysisComponent(BaseComponent[RootCauseAnalysisRequest, RootCau

@observe(name="Root Cause Analysis")
@ai_track(description="Root Cause Analysis")
def invoke(self, request: RootCauseAnalysisRequest) -> RootCauseAnalysisOutput | None:
@inject
def invoke(
self, request: RootCauseAnalysisRequest, gpt_client: GptClient = injected
) -> RootCauseAnalysisOutput | None:
tools = BaseTools(self.context)

agent = GptAgent(
Expand All @@ -35,56 +41,59 @@ def invoke(self, request: RootCauseAnalysisRequest) -> RootCauseAnalysisOutput |

state = self.context.state.get()

response = agent.run(
RootCauseAnalysisPrompts.format_default_msg(
event=request.event_details.format_event(),
summary=request.summary,
instruction=request.instruction,
repo_names=[repo.full_name for repo in state.request.repos],
),
context=self.context,
)

original_usage = agent.usage
with self.context.state.update() as cur:
cur.usage += agent.usage

if not response:
logger.warning("Root Cause Analysis agent did not return a valid response")
return None

if "<NO_ROOT_CAUSES>" in response:
return None

formatter_response = agent.run(RootCauseAnalysisPrompts.root_cause_formatter_msg())

with self.context.state.update() as cur:
cur.usage += agent.usage - original_usage

if not formatter_response:
logger.warning("Root Cause Analysis formatter did not return a valid response")
return None

extracted_text = extract_text_inside_tags(formatter_response, "potential_root_causes")

xml_response = RootCauseAnalysisOutputPromptXml.from_xml(
f"<root><potential_root_causes>{escape_multi_xml(extracted_text, ['thoughts', 'title', 'description', 'code'])}</potential_root_causes></root>"
)

if not xml_response.potential_root_causes.causes:
logger.warning("Root Cause Analysis formatter did not return causes")
return None

# Assign the ids to be the numerical indices of the causes and relevant code context
causes = []
for i, cause in enumerate(xml_response.potential_root_causes.causes):
cause_model = cause.to_model()
cause_model.id = i

if cause_model.code_context:
for j, snippet in enumerate(cause_model.code_context):
snippet.id = j

causes.append(cause_model)

return RootCauseAnalysisOutput(causes=causes)
try:
response = agent.run(
RootCauseAnalysisPrompts.format_default_msg(
event=request.event_details.format_event(),
summary=request.summary,
instruction=request.instruction,
repo_names=[repo.full_name for repo in state.request.repos],
),
context=self.context,
)

if not response:
logger.warning("Root Cause Analysis agent did not return a valid response")
return None

if "<NO_ROOT_CAUSES>" in response:
return None

# Ask for reproduction
agent.run(
RootCauseAnalysisPrompts.reproduction_prompt_msg(),
)

response = gpt_client.openai_client.beta.chat.completions.parse(
messages=[
message.to_message()
for message in gpt_client.clean_tool_call_assistant_messages(agent.memory)
]
+ [
Message(
role="user",
content=RootCauseAnalysisPrompts.root_cause_formatter_msg(),
).to_message(),
],
model="gpt-4o-2024-08-06",
response_format=MultipleRootCauseAnalysisOutputPrompt,
)

parsed = extract_parsed_model(response)

# Assign the ids to be the numerical indices of the causes and relevant code context
causes = []
for i, cause in enumerate(parsed.causes):
cause_model = cause.to_model()
cause_model.id = i

if cause_model.code_context:
for j, snippet in enumerate(cause_model.code_context):
snippet.id = j

causes.append(cause_model)

return RootCauseAnalysisOutput(causes=causes)
finally:
with self.context.state.update() as cur:
cur.usage += agent.usage
95 changes: 24 additions & 71 deletions src/seer/automation/autofix/components/root_cause/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from johen import gen
from johen.examples import Examples
from pydantic import BaseModel, Field, StringConstraints
from pydantic_xml import attr, element
from pydantic_xml import attr

from seer.automation.component import BaseComponentOutput, BaseComponentRequest
from seer.automation.models import EventDetails, PromptXmlModel
Expand All @@ -26,64 +26,38 @@ def get_example(cls):

class RootCauseRelevantCodeSnippet(BaseModel):
file_path: str
repo_name: Optional[str] = None
repo_name: Optional[str]
snippet: str


class RootCauseRelevantContext(BaseModel):
id: int = -1
id: int
title: str
description: str
snippet: Optional[RootCauseRelevantCodeSnippet] = None


class RootCauseRelevantContextPromptXml(PromptXmlModel, tag="code_snippet", skip_empty=True):
title: Annotated[str, StringConstraints(strip_whitespace=True)] = element()
description: Annotated[str, StringConstraints(strip_whitespace=True)] = element()
snippet: Optional[SnippetPromptXml] = None

@classmethod
def get_example(cls):
return cls(
title="`foo()` returns the wrong value",
description="The issue happens because `foo()` always returns `bar`, as seen in this snippet, when it should return `baz`.",
snippet=SnippetPromptXml.get_example(),
)
snippet: Optional[RootCauseRelevantCodeSnippet]


class RootCauseAnalysisItem(BaseModel):
id: int = -1
title: str
description: str
reproduction: str
likelihood: Annotated[float, Examples(r.uniform(0, 1) for r in gen)] = Field(..., ge=0, le=1)
actionability: Annotated[float, Examples(r.uniform(0, 1) for r in gen)] = Field(..., ge=0, le=1)
code_context: Optional[list[RootCauseRelevantContext]] = None


class RootCauseAnalysisRelevantContextPromptXml(PromptXmlModel, tag="code_context"):
snippets: list[RootCauseRelevantContextPromptXml]

@classmethod
def get_example(cls):
return cls(snippets=[RootCauseRelevantContextPromptXml.get_example()])

class RootCauseAnalysisRelevantContext(BaseModel):
snippets: list[RootCauseRelevantContext]

class RootCauseAnalysisItemPromptXml(PromptXmlModel, tag="potential_cause", skip_empty=True):
title: Annotated[str, StringConstraints(strip_whitespace=True)] = element()
description: Annotated[str, StringConstraints(strip_whitespace=True)] = element()
likelihood: float = attr()
actionability: float = attr()
relevant_code: Optional[RootCauseAnalysisRelevantContextPromptXml] = None

@classmethod
def get_example(cls):
return cls(
title="Summarize the root cause here in a few words.",
likelihood=0.8,
actionability=1.0,
description="Explain the root cause in full detail here with the full chain of reasoning behind it.",
relevant_code=RootCauseAnalysisRelevantContextPromptXml.get_example(),
)
class RootCauseAnalysisItemPrompt(BaseModel):
title: str
description: str
likelihood: float
actionability: float
reproduction: str
relevant_code: Optional[RootCauseAnalysisRelevantContext]

@classmethod
def from_model(cls, model: RootCauseAnalysisItem):
Expand All @@ -92,21 +66,15 @@ def from_model(cls, model: RootCauseAnalysisItem):
likelihood=model.likelihood,
actionability=model.actionability,
description=model.description,
reproduction=model.reproduction,
relevant_code=(
RootCauseAnalysisRelevantContextPromptXml(
RootCauseAnalysisRelevantContext(
snippets=[
RootCauseRelevantContextPromptXml(
RootCauseRelevantContext(
id=snippet.id,
title=snippet.title,
description=snippet.description,
snippet=(
SnippetPromptXml(
file_path=snippet.snippet.file_path,
snippet=snippet.snippet.snippet,
repo_name=snippet.snippet.repo_name,
)
if snippet.snippet
else None
),
snippet=snippet.snippet,
)
for snippet in model.code_context
]
Expand All @@ -127,28 +95,13 @@ def to_model(self):
)


class MultipleRootCauseAnalysisOutputPromptXml(PromptXmlModel, tag="potential_root_causes"):
causes: list[RootCauseAnalysisItemPromptXml] = []

@classmethod
def get_example(cls):
return cls(
causes=[
RootCauseAnalysisItemPromptXml.get_example(),
RootCauseAnalysisItemPromptXml(
title="Summarize the root cause here in a few words.",
likelihood=0.2,
actionability=1.0,
description="Explain the root cause in full detail here with the full chain of reasoning behind it.",
relevant_code=RootCauseAnalysisRelevantContextPromptXml.get_example(),
),
]
)
class MultipleRootCauseAnalysisOutputPrompt(BaseModel):
causes: list[RootCauseAnalysisItemPrompt]


class RootCauseAnalysisOutputPromptXml(PromptXmlModel, tag="root"):
thoughts: Optional[str] = element(default=None)
potential_root_causes: MultipleRootCauseAnalysisOutputPromptXml
class RootCauseAnalysisOutputPrompt(BaseModel):
thoughts: Optional[str]
potential_root_causes: MultipleRootCauseAnalysisOutputPrompt


class RootCauseAnalysisRequest(BaseComponentRequest):
Expand Down
22 changes: 12 additions & 10 deletions src/seer/automation/autofix/components/root_cause/prompts.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import textwrap
from typing import Optional

from seer.automation.autofix.components.root_cause.models import (
MultipleRootCauseAnalysisOutputPromptXml,
)
from seer.automation.autofix.prompts import format_instruction, format_repo_names, format_summary
from seer.automation.summarize.issue import IssueSummary

Expand All @@ -26,8 +23,6 @@ def format_system_msg():
- You also MUST think step-by-step before giving the final answer.
It is important that we find all the potential root causes of the issue, so provide as many possibilities as you can for the root cause, ordered from most likely to least likely."""
).format(
root_cause_output_example_str=MultipleRootCauseAnalysisOutputPromptXml.get_example().to_prompt_str(),
)

@staticmethod
Expand Down Expand Up @@ -65,13 +60,20 @@ def format_default_msg(
def root_cause_formatter_msg():
return textwrap.dedent(
"""\
Please format the output properly to match the following example:
{root_cause_output_example_str}
Please format the output properly.
Note: If the provided root cause analysis is not formatted properly, such as code snippets missing descriptions, you can derive them from the provided root cause analysis.
Return only the formatted root cause analysis:"""
).format(
root_cause_output_example_str=MultipleRootCauseAnalysisOutputPromptXml.get_example().to_prompt_str(),
)

@staticmethod
def reproduction_prompt_msg():
return textwrap.dedent(
"""\
Given all the above potential root causes you just gave, please provide a 1-2 sentence concise instruction on how to reproduce the issue for each root cause.
- Assume the user is an experienced developer well-versed in the codebase, simply give the reproduction steps.
- You must use the local variables provided to you in the stacktrace to give your reproduction steps.
- Try to be open ended to allow for the most flexibility in reproducing the issue. Avoid being too confident.
- This step is optional, if you're not sure about the reproduction steps for a root cause, just skip it."""
)
Loading

0 comments on commit c20a95e

Please sign in to comment.