From c20a95e781d0d49c457417b361cbc2bdcae56ac8 Mon Sep 17 00:00:00 2001 From: Jenn Mueng <30991498+jennmueng@users.noreply.github.com> Date: Tue, 10 Sep 2024 16:14:05 -0700 Subject: [PATCH] feat(autofix): Reproduction in root causes (#1067) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Introduces reproduction step in root cause output. 2. Uses gpt4o w/ structured output for the root cause for increased reliability 3. (Tacking on) Adds claude agent back into the plan+code agent ![Screenshot 2024-09-10 at 10 45 47 AM](https://github.com/user-attachments/assets/8a1ceb0f-6fec-4c30-8fcf-f4e34e259375) ### Eval results: (Before is above, after is below) ![Screenshot 2024-09-10 at 10 48 23 AM](https://github.com/user-attachments/assets/190c8919-56c1-4ab8-b19f-683b1c8a511c) --- src/seer/automation/agent/client.py | 13 ++ .../autofix/components/coding/component.py | 4 +- .../autofix/components/coding/models.py | 2 + .../components/root_cause/component.py | 121 +++++------ .../autofix/components/root_cause/models.py | 95 +++------ .../autofix/components/root_cause/prompts.py | 22 +- src/seer/automation/autofix/evaluations.py | 6 +- .../components/coding/test_coding_models.py | 1 + .../autofix/components/test_root_cause.py | 189 +++++------------- .../autofix/test_autofix_evaluations.py | 62 +++++- tests/automation/autofix/test_models.py | 7 +- 11 files changed, 243 insertions(+), 279 deletions(-) diff --git a/src/seer/automation/agent/client.py b/src/seer/automation/agent/client.py index 0da80a46f..8b275e36e 100644 --- a/src/seer/automation/agent/client.py +++ b/src/seer/automation/agent/client.py @@ -138,6 +138,19 @@ def json_completion( ) -> tuple[dict[str, Any] | None, Message, Usage]: return super().json_completion(messages, model, system_prompt) + def clean_tool_call_assistant_messages(self, messages: list[Message]) -> list[Message]: + new_messages = [] + for message in messages: + if message.role == "assistant" and message.tool_calls: + new_messages.append( + Message(role="assistant", content=message.content, tool_calls=[]) + ) + elif message.role == "tool": + new_messages.append(Message(role="user", content=message.content, tool_calls=[])) + else: + new_messages.append(message) + return new_messages + class ClaudeClient(LlmClient): @inject diff --git a/src/seer/automation/autofix/components/coding/component.py b/src/seer/automation/autofix/components/coding/component.py index 3ab94a249..3c8fbb44c 100644 --- a/src/seer/automation/autofix/components/coding/component.py +++ b/src/seer/automation/autofix/components/coding/component.py @@ -3,7 +3,7 @@ from langfuse.decorators import observe from sentry_sdk.ai.monitoring import ai_track -from seer.automation.agent.agent import AgentConfig, GptAgent +from seer.automation.agent.agent import AgentConfig, ClaudeAgent from seer.automation.autofix.autofix_context import AutofixContext from seer.automation.autofix.components.coding.models import ( CodingOutput, @@ -38,7 +38,7 @@ def _append_file_change(self, repo_external_id: str, file_change: FileChange): def invoke(self, request: CodingRequest) -> CodingOutput | None: tools = BaseTools(self.context) - agent = GptAgent( + agent = ClaudeAgent( tools=tools.get_tools(), config=AgentConfig(system_prompt=CodingPrompts.format_system_msg()), ) diff --git a/src/seer/automation/autofix/components/coding/models.py b/src/seer/automation/autofix/components/coding/models.py index a6c5ed013..8fb265412 100644 --- a/src/seer/automation/autofix/components/coding/models.py +++ b/src/seer/automation/autofix/components/coding/models.py @@ -46,6 +46,7 @@ def from_root_cause_context(cls, context: RootCauseRelevantContext): class RootCausePlanTaskPromptXml(PromptXmlModel, tag="root_cause", skip_empty=True): title: str = element() description: str = element() + reproduction: str = element() contexts: list[CodeContextXml] @classmethod @@ -53,6 +54,7 @@ def from_root_cause(cls, root_cause: RootCauseAnalysisItem): return cls( title=root_cause.title, description=root_cause.description, + reproduction=root_cause.reproduction, contexts=( [ CodeContextXml.from_root_cause_context(context) diff --git a/src/seer/automation/autofix/components/root_cause/component.py b/src/seer/automation/autofix/components/root_cause/component.py index ee6466e61..95e73686e 100644 --- a/src/seer/automation/autofix/components/root_cause/component.py +++ b/src/seer/automation/autofix/components/root_cause/component.py @@ -4,16 +4,19 @@ from sentry_sdk.ai.monitoring import ai_track from seer.automation.agent.agent import AgentConfig, GptAgent +from seer.automation.agent.client import GptClient +from seer.automation.agent.models import Message from seer.automation.autofix.autofix_context import AutofixContext from seer.automation.autofix.components.root_cause.models import ( + MultipleRootCauseAnalysisOutputPrompt, RootCauseAnalysisOutput, - RootCauseAnalysisOutputPromptXml, RootCauseAnalysisRequest, ) from seer.automation.autofix.components.root_cause.prompts import RootCauseAnalysisPrompts from seer.automation.autofix.tools import BaseTools from seer.automation.component import BaseComponent -from seer.automation.utils import escape_multi_xml, extract_text_inside_tags +from seer.automation.utils import extract_parsed_model +from seer.dependency_injection import inject, injected logger = logging.getLogger(__name__) @@ -23,7 +26,10 @@ class RootCauseAnalysisComponent(BaseComponent[RootCauseAnalysisRequest, RootCau @observe(name="Root Cause Analysis") @ai_track(description="Root Cause Analysis") - def invoke(self, request: RootCauseAnalysisRequest) -> RootCauseAnalysisOutput | None: + @inject + def invoke( + self, request: RootCauseAnalysisRequest, gpt_client: GptClient = injected + ) -> RootCauseAnalysisOutput | None: tools = BaseTools(self.context) agent = GptAgent( @@ -35,56 +41,59 @@ def invoke(self, request: RootCauseAnalysisRequest) -> RootCauseAnalysisOutput | state = self.context.state.get() - response = agent.run( - RootCauseAnalysisPrompts.format_default_msg( - event=request.event_details.format_event(), - summary=request.summary, - instruction=request.instruction, - repo_names=[repo.full_name for repo in state.request.repos], - ), - context=self.context, - ) - - original_usage = agent.usage - with self.context.state.update() as cur: - cur.usage += agent.usage - - if not response: - logger.warning("Root Cause Analysis agent did not return a valid response") - return None - - if "" in response: - return None - - formatter_response = agent.run(RootCauseAnalysisPrompts.root_cause_formatter_msg()) - - with self.context.state.update() as cur: - cur.usage += agent.usage - original_usage - - if not formatter_response: - logger.warning("Root Cause Analysis formatter did not return a valid response") - return None - - extracted_text = extract_text_inside_tags(formatter_response, "potential_root_causes") - - xml_response = RootCauseAnalysisOutputPromptXml.from_xml( - f"{escape_multi_xml(extracted_text, ['thoughts', 'title', 'description', 'code'])}" - ) - - if not xml_response.potential_root_causes.causes: - logger.warning("Root Cause Analysis formatter did not return causes") - return None - - # Assign the ids to be the numerical indices of the causes and relevant code context - causes = [] - for i, cause in enumerate(xml_response.potential_root_causes.causes): - cause_model = cause.to_model() - cause_model.id = i - - if cause_model.code_context: - for j, snippet in enumerate(cause_model.code_context): - snippet.id = j - - causes.append(cause_model) - - return RootCauseAnalysisOutput(causes=causes) + try: + response = agent.run( + RootCauseAnalysisPrompts.format_default_msg( + event=request.event_details.format_event(), + summary=request.summary, + instruction=request.instruction, + repo_names=[repo.full_name for repo in state.request.repos], + ), + context=self.context, + ) + + if not response: + logger.warning("Root Cause Analysis agent did not return a valid response") + return None + + if "" in response: + return None + + # Ask for reproduction + agent.run( + RootCauseAnalysisPrompts.reproduction_prompt_msg(), + ) + + response = gpt_client.openai_client.beta.chat.completions.parse( + messages=[ + message.to_message() + for message in gpt_client.clean_tool_call_assistant_messages(agent.memory) + ] + + [ + Message( + role="user", + content=RootCauseAnalysisPrompts.root_cause_formatter_msg(), + ).to_message(), + ], + model="gpt-4o-2024-08-06", + response_format=MultipleRootCauseAnalysisOutputPrompt, + ) + + parsed = extract_parsed_model(response) + + # Assign the ids to be the numerical indices of the causes and relevant code context + causes = [] + for i, cause in enumerate(parsed.causes): + cause_model = cause.to_model() + cause_model.id = i + + if cause_model.code_context: + for j, snippet in enumerate(cause_model.code_context): + snippet.id = j + + causes.append(cause_model) + + return RootCauseAnalysisOutput(causes=causes) + finally: + with self.context.state.update() as cur: + cur.usage += agent.usage diff --git a/src/seer/automation/autofix/components/root_cause/models.py b/src/seer/automation/autofix/components/root_cause/models.py index 84d21f57e..20c98300e 100644 --- a/src/seer/automation/autofix/components/root_cause/models.py +++ b/src/seer/automation/autofix/components/root_cause/models.py @@ -3,7 +3,7 @@ from johen import gen from johen.examples import Examples from pydantic import BaseModel, Field, StringConstraints -from pydantic_xml import attr, element +from pydantic_xml import attr from seer.automation.component import BaseComponentOutput, BaseComponentRequest from seer.automation.models import EventDetails, PromptXmlModel @@ -26,64 +26,38 @@ def get_example(cls): class RootCauseRelevantCodeSnippet(BaseModel): file_path: str - repo_name: Optional[str] = None + repo_name: Optional[str] snippet: str class RootCauseRelevantContext(BaseModel): - id: int = -1 + id: int title: str description: str - snippet: Optional[RootCauseRelevantCodeSnippet] = None - - -class RootCauseRelevantContextPromptXml(PromptXmlModel, tag="code_snippet", skip_empty=True): - title: Annotated[str, StringConstraints(strip_whitespace=True)] = element() - description: Annotated[str, StringConstraints(strip_whitespace=True)] = element() - snippet: Optional[SnippetPromptXml] = None - - @classmethod - def get_example(cls): - return cls( - title="`foo()` returns the wrong value", - description="The issue happens because `foo()` always returns `bar`, as seen in this snippet, when it should return `baz`.", - snippet=SnippetPromptXml.get_example(), - ) + snippet: Optional[RootCauseRelevantCodeSnippet] class RootCauseAnalysisItem(BaseModel): id: int = -1 title: str description: str + reproduction: str likelihood: Annotated[float, Examples(r.uniform(0, 1) for r in gen)] = Field(..., ge=0, le=1) actionability: Annotated[float, Examples(r.uniform(0, 1) for r in gen)] = Field(..., ge=0, le=1) code_context: Optional[list[RootCauseRelevantContext]] = None -class RootCauseAnalysisRelevantContextPromptXml(PromptXmlModel, tag="code_context"): - snippets: list[RootCauseRelevantContextPromptXml] - - @classmethod - def get_example(cls): - return cls(snippets=[RootCauseRelevantContextPromptXml.get_example()]) - +class RootCauseAnalysisRelevantContext(BaseModel): + snippets: list[RootCauseRelevantContext] -class RootCauseAnalysisItemPromptXml(PromptXmlModel, tag="potential_cause", skip_empty=True): - title: Annotated[str, StringConstraints(strip_whitespace=True)] = element() - description: Annotated[str, StringConstraints(strip_whitespace=True)] = element() - likelihood: float = attr() - actionability: float = attr() - relevant_code: Optional[RootCauseAnalysisRelevantContextPromptXml] = None - @classmethod - def get_example(cls): - return cls( - title="Summarize the root cause here in a few words.", - likelihood=0.8, - actionability=1.0, - description="Explain the root cause in full detail here with the full chain of reasoning behind it.", - relevant_code=RootCauseAnalysisRelevantContextPromptXml.get_example(), - ) +class RootCauseAnalysisItemPrompt(BaseModel): + title: str + description: str + likelihood: float + actionability: float + reproduction: str + relevant_code: Optional[RootCauseAnalysisRelevantContext] @classmethod def from_model(cls, model: RootCauseAnalysisItem): @@ -92,21 +66,15 @@ def from_model(cls, model: RootCauseAnalysisItem): likelihood=model.likelihood, actionability=model.actionability, description=model.description, + reproduction=model.reproduction, relevant_code=( - RootCauseAnalysisRelevantContextPromptXml( + RootCauseAnalysisRelevantContext( snippets=[ - RootCauseRelevantContextPromptXml( + RootCauseRelevantContext( + id=snippet.id, title=snippet.title, description=snippet.description, - snippet=( - SnippetPromptXml( - file_path=snippet.snippet.file_path, - snippet=snippet.snippet.snippet, - repo_name=snippet.snippet.repo_name, - ) - if snippet.snippet - else None - ), + snippet=snippet.snippet, ) for snippet in model.code_context ] @@ -127,28 +95,13 @@ def to_model(self): ) -class MultipleRootCauseAnalysisOutputPromptXml(PromptXmlModel, tag="potential_root_causes"): - causes: list[RootCauseAnalysisItemPromptXml] = [] - - @classmethod - def get_example(cls): - return cls( - causes=[ - RootCauseAnalysisItemPromptXml.get_example(), - RootCauseAnalysisItemPromptXml( - title="Summarize the root cause here in a few words.", - likelihood=0.2, - actionability=1.0, - description="Explain the root cause in full detail here with the full chain of reasoning behind it.", - relevant_code=RootCauseAnalysisRelevantContextPromptXml.get_example(), - ), - ] - ) +class MultipleRootCauseAnalysisOutputPrompt(BaseModel): + causes: list[RootCauseAnalysisItemPrompt] -class RootCauseAnalysisOutputPromptXml(PromptXmlModel, tag="root"): - thoughts: Optional[str] = element(default=None) - potential_root_causes: MultipleRootCauseAnalysisOutputPromptXml +class RootCauseAnalysisOutputPrompt(BaseModel): + thoughts: Optional[str] + potential_root_causes: MultipleRootCauseAnalysisOutputPrompt class RootCauseAnalysisRequest(BaseComponentRequest): diff --git a/src/seer/automation/autofix/components/root_cause/prompts.py b/src/seer/automation/autofix/components/root_cause/prompts.py index 26f3f054a..de94511b2 100644 --- a/src/seer/automation/autofix/components/root_cause/prompts.py +++ b/src/seer/automation/autofix/components/root_cause/prompts.py @@ -1,9 +1,6 @@ import textwrap from typing import Optional -from seer.automation.autofix.components.root_cause.models import ( - MultipleRootCauseAnalysisOutputPromptXml, -) from seer.automation.autofix.prompts import format_instruction, format_repo_names, format_summary from seer.automation.summarize.issue import IssueSummary @@ -26,8 +23,6 @@ def format_system_msg(): - You also MUST think step-by-step before giving the final answer. It is important that we find all the potential root causes of the issue, so provide as many possibilities as you can for the root cause, ordered from most likely to least likely.""" - ).format( - root_cause_output_example_str=MultipleRootCauseAnalysisOutputPromptXml.get_example().to_prompt_str(), ) @staticmethod @@ -65,13 +60,20 @@ def format_default_msg( def root_cause_formatter_msg(): return textwrap.dedent( """\ - Please format the output properly to match the following example: - - {root_cause_output_example_str} + Please format the output properly. Note: If the provided root cause analysis is not formatted properly, such as code snippets missing descriptions, you can derive them from the provided root cause analysis. Return only the formatted root cause analysis:""" - ).format( - root_cause_output_example_str=MultipleRootCauseAnalysisOutputPromptXml.get_example().to_prompt_str(), + ) + + @staticmethod + def reproduction_prompt_msg(): + return textwrap.dedent( + """\ + Given all the above potential root causes you just gave, please provide a 1-2 sentence concise instruction on how to reproduce the issue for each root cause. + - Assume the user is an experienced developer well-versed in the codebase, simply give the reproduction steps. + - You must use the local variables provided to you in the stacktrace to give your reproduction steps. + - Try to be open ended to allow for the most flexibility in reproducing the issue. Avoid being too confident. + - This step is optional, if you're not sure about the reproduction steps for a root cause, just skip it.""" ) diff --git a/src/seer/automation/autofix/evaluations.py b/src/seer/automation/autofix/evaluations.py index 6b827afab..4a23dc7a7 100644 --- a/src/seer/automation/autofix/evaluations.py +++ b/src/seer/automation/autofix/evaluations.py @@ -8,9 +8,9 @@ from seer.automation.agent.client import GptClient from seer.automation.agent.models import Message +from seer.automation.autofix.components.coding.models import RootCausePlanTaskPromptXml from seer.automation.autofix.components.root_cause.models import ( RootCauseAnalysisItem, - RootCauseAnalysisItemPromptXml, RootCauseRelevantCodeSnippet, RootCauseRelevantContext, ) @@ -103,8 +103,10 @@ def sync_run_execution(item: DatasetItemClient): description="", likelihood=1.0, actionability=1.0, + reproduction="", code_context=[ RootCauseRelevantContext( + id=-1, title=expected_output.solution_summary, description="", snippet=RootCauseRelevantCodeSnippet( @@ -278,7 +280,7 @@ def score_root_cause_single_it( raise ValueError("Expected output is missing from dataset item") expected_output = RootCauseExpectedOutput.model_validate(dataset_item.expected_output) - causes_xml = [RootCauseAnalysisItemPromptXml.from_model(cause) for cause in causes] + causes_xml = [RootCausePlanTaskPromptXml.from_root_cause(cause) for cause in causes] solution_strs: list[str] = [] for i, cause in enumerate(causes_xml): diff --git a/tests/automation/autofix/components/coding/test_coding_models.py b/tests/automation/autofix/components/coding/test_coding_models.py index 8cb1465c5..0a99ff2ef 100644 --- a/tests/automation/autofix/components/coding/test_coding_models.py +++ b/tests/automation/autofix/components/coding/test_coding_models.py @@ -13,6 +13,7 @@ def test_root_cause_conversion(self): id=1, title="title", description="description", + reproduction="reproduction", likelihood=0.5, actionability=0.75, code_context=[ diff --git a/tests/automation/autofix/components/test_root_cause.py b/tests/automation/autofix/components/test_root_cause.py index 277558e5c..9e6bad6ee 100644 --- a/tests/automation/autofix/components/test_root_cause.py +++ b/tests/automation/autofix/components/test_root_cause.py @@ -1,16 +1,16 @@ from unittest.mock import MagicMock, patch -from xml.etree.ElementTree import ParseError import pytest +from openai.types.chat import ParsedChatCompletion, ParsedChatCompletionMessage, ParsedChoice +from seer.automation.agent.client import GptClient from seer.automation.autofix.autofix_context import AutofixContext from seer.automation.autofix.components.root_cause.component import RootCauseAnalysisComponent from seer.automation.autofix.components.root_cause.models import ( - RootCauseAnalysisItem, - RootCauseAnalysisOutput, - RootCauseRelevantCodeSnippet, - RootCauseRelevantContext, + MultipleRootCauseAnalysisOutputPrompt, + RootCauseAnalysisItemPrompt, ) +from seer.dependency_injection import Module class TestRootCauseComponent: @@ -29,140 +29,59 @@ def mock_gpt_agent(self): def test_root_cause_simple_response_parsing(self, component, mock_gpt_agent): mock_gpt_agent.return_value.run.side_effect = [ "Anything really", - "Missing Null CheckThe root cause of the issue is ...", + "Reproduction steps", ] - output = component.invoke(MagicMock()) - - assert output is not None - assert len(output.causes) == 1 - assert output.causes[0].title == "Missing Null Check" - assert output.causes[0].description == "The root cause of the issue is ..." - assert output.causes[0].likelihood == 0.9 - assert output.causes[0].actionability == 1.0 - assert output.causes[0].code_context is None - - def test_root_cause_code_context_response_parsing(self, component, mock_gpt_agent): - mock_gpt_agent.return_value.run.side_effect = [ - "Anything really", - "Missing Null CheckThe root cause of the issue is ...Add Null CheckThis fix involves adding ...def foo():", - ] - - output = component.invoke(MagicMock()) - - assert output is not None - assert len(output.causes) == 1 - assert output.causes[0].title == "Missing Null Check" - assert output.causes[0].description == "The root cause of the issue is ..." - assert output.causes[0].likelihood == 0.9 - assert output.causes[0].actionability == 1.0 - assert output.causes[0].code_context is not None - assert len(output.causes[0].code_context) == 1 - assert output.causes[0].code_context[0].title == "Add Null Check" - assert output.causes[0].code_context[0].description == "This fix involves adding ..." - assert output.causes[0].code_context[0].snippet.file_path == "some/app/path.py" - assert output.causes[0].code_context[0].snippet.repo_name == "owner/repo" - assert output.causes[0].code_context[0].snippet.snippet == "def foo():" - - def test_root_cause_multiple_causes_response_parsing(self, component, mock_gpt_agent): - mock_gpt_agent.return_value.run.side_effect = [ - "Anything really", - "Missing Null CheckThe root cause of the issue is ...Incorrect API UsageAnother potential cause is ...", - ] - output = component.invoke(MagicMock()) - - assert output is not None - assert len(output.causes) == 2 - assert output.causes[0].title == "Missing Null Check" - assert output.causes[1].title == "Incorrect API Usage" - - def test_root_cause_empty_response_parsing(self, component, mock_gpt_agent): - mock_gpt_agent.return_value.run.side_effect = [ - "Anything really", - "", - ] - - output = component.invoke(MagicMock()) - - assert output is None - - def test_root_cause_invalid_xml_response(self, component, mock_gpt_agent): - mock_gpt_agent.return_value.run.side_effect = [ - "Anything really", - "", - ] - - with pytest.raises(ParseError): - component.invoke(MagicMock()) - - def test_root_cause_missing_required_fields(self, component, mock_gpt_agent): - mock_gpt_agent.return_value.run.side_effect = [ - "Anything really", - "Missing Null Check", - ] - - with pytest.raises(ValueError): - component.invoke(MagicMock()) - - def test_root_cause_invalid_likelihood_actionability(self, component, mock_gpt_agent): - mock_gpt_agent.return_value.run.side_effect = [ - "Anything really", - "Invalid ValuesTest", - ] - - with pytest.raises(ValueError): - component.invoke(MagicMock()) - - def test_root_cause_no_formatter_response(self, component, mock_gpt_agent): - mock_gpt_agent.return_value.run.side_effect = [ - "TestTest", - None, - ] - - output = component.invoke(MagicMock()) - - assert output is None - - def test_root_cause_analysis_output_model(self): - output = RootCauseAnalysisOutput( - causes=[ - RootCauseAnalysisItem( - id=0, - title="Test Cause", - description="Test Description", - likelihood=0.8, - actionability=0.9, - code_context=[ - RootCauseRelevantContext( - id=0, - title="Test Fix", - description="Test Fix Description", - snippet=RootCauseRelevantCodeSnippet( - file_path="test.py", - snippet="def test():\n pass", - repo_name="owner/repo", + mock_gpt_client = MagicMock() + mock_gpt_client.openai_client.beta.chat.completions.parse.return_value = ( + ParsedChatCompletion( + id="1", + choices=[ + ParsedChoice( + index=0, + message=ParsedChatCompletionMessage( + role="assistant", + content=None, + function_call=None, + tool_calls=None, + parsed=MultipleRootCauseAnalysisOutputPrompt( + causes=[ + RootCauseAnalysisItemPrompt( + title="Missing Null Check", + description="The root cause of the issue is ...", + likelihood=0.9, + actionability=1.0, + reproduction="Steps to reproduce", + relevant_code=None, + ) + ] ), - ) - ], - ) - ] + refusal=None, + ), + finish_reason="stop", + ) + ], + created=1234567890, + model="gpt-4o-2024-08-06", + object="chat.completion", + system_fingerprint="test", + usage=None, + ) ) - assert len(output.causes) == 1 - assert output.causes[0].id == 0 - assert output.causes[0].title == "Test Cause" - assert output.causes[0].description == "Test Description" - assert output.causes[0].likelihood == 0.8 - assert output.causes[0].actionability == 0.9 - assert len(output.causes[0].code_context or []) == 1 - if output.causes[0].code_context: - assert output.causes[0].code_context[0].id == 0 - assert output.causes[0].code_context[0].title == "Test Fix" - assert output.causes[0].code_context[0].description == "Test Fix Description" - if output.causes[0].code_context[0].snippet: - assert output.causes[0].code_context[0].snippet.file_path == "test.py" - assert output.causes[0].code_context[0].snippet.snippet == "def test():\n pass" - assert output.causes[0].code_context[0].snippet.repo_name == "owner/repo" + module = Module() + module.constant(GptClient, mock_gpt_client) + with module: + output = component.invoke(MagicMock()) + + assert output is not None + assert len(output.causes) == 1 + assert output.causes[0].title == "Missing Null Check" + assert output.causes[0].description == "The root cause of the issue is ..." + assert output.causes[0].likelihood == 0.9 + assert output.causes[0].actionability == 1.0 + assert output.causes[0].reproduction == "Steps to reproduce" + assert output.causes[0].code_context is None def test_no_root_causes_response(self, component, mock_gpt_agent): mock_gpt_agent.return_value.run.return_value = "" @@ -170,5 +89,5 @@ def test_no_root_causes_response(self, component, mock_gpt_agent): output = component.invoke(MagicMock()) assert output is None - # Ensure that the second run (formatter) is not called when is returned + # Ensure that the second run (reproduction) and the formatter are not called when is returned assert mock_gpt_agent.return_value.run.call_count == 1 diff --git a/tests/automation/autofix/test_autofix_evaluations.py b/tests/automation/autofix/test_autofix_evaluations.py index d8c5ddc42..8f9acfacd 100644 --- a/tests/automation/autofix/test_autofix_evaluations.py +++ b/tests/automation/autofix/test_autofix_evaluations.py @@ -6,6 +6,7 @@ from seer.automation.autofix.components.root_cause.models import ( RootCauseAnalysisItem, + RootCauseRelevantCodeSnippet, RootCauseRelevantContext, ) from seer.automation.autofix.evaluations import ( @@ -86,7 +87,28 @@ def teardown_method(self): def test_sync_run_evaluation_on_item_happy_path(self): # Setup state changes for root cause step root_cause_model = next(generate(RootCauseStepModel)) - root_cause_model.causes = [Mock(id=1, code_context=[Mock(id=2)])] + root_cause_model.causes = [ + RootCauseAnalysisItem( + id=1, + title="Test Cause", + description="Test Description", + likelihood=0.8, + actionability=0.7, + reproduction="Steps to reproduce", + code_context=[ + RootCauseRelevantContext( + id=2, + title="Test Fix", + description="Test fix description", + snippet=RootCauseRelevantCodeSnippet( + file_path="test.py", + snippet="def test():\n pass", + repo_name="owner/repo", + ), + ) + ], + ) + ] def root_cause_apply_side_effect(): with self.test_state.update() as cur: @@ -180,6 +202,7 @@ def test_sync_run_evaluation_on_item_no_code_context(self): description="Test cause description", likelihood=0.8, actionability=0.7, + reproduction="Steps to reproduce", code_context=[], ) ], @@ -211,11 +234,17 @@ def test_sync_run_evaluation_on_item_no_changes(self): description="Test cause description", likelihood=0.8, actionability=0.7, + reproduction="Steps to reproduce", code_context=[ RootCauseRelevantContext( id=2, title="Test Fix", description="Test fix description", + snippet=RootCauseRelevantCodeSnippet( + file_path="test.py", + snippet="def test():\n pass", + repo_name="owner/repo", + ), ) ], ) @@ -288,7 +317,28 @@ def teardown_method(self): def test_sync_run_root_cause_happy_path(self): root_cause_model = next(generate(RootCauseStepModel)) - root_cause_model.causes = [Mock(id=1, code_context=[Mock(id=2)])] + root_cause_model.causes = [ + RootCauseAnalysisItem( + id=1, + title="Test Cause", + description="Test Description", + likelihood=0.8, + actionability=0.7, + reproduction="Steps to reproduce", + code_context=[ + RootCauseRelevantContext( + id=2, + title="Test Fix", + description="Test fix description", + snippet=RootCauseRelevantCodeSnippet( + file_path="test.py", + snippet="def test():\n pass", + repo_name="owner/repo", + ), + ) + ], + ) + ] def root_cause_apply_side_effect(): with self.test_state.update() as cur: @@ -365,6 +415,7 @@ def test_score_root_cause_single_it(self, mock_gpt_client, mock_dataset_item): description="Description 1", likelihood=0.8, actionability=0.7, + reproduction="Steps to reproduce 1", ), RootCauseAnalysisItem( id=2, @@ -372,6 +423,7 @@ def test_score_root_cause_single_it(self, mock_gpt_client, mock_dataset_item): description="Description 2", likelihood=0.6, actionability=0.5, + reproduction="Steps to reproduce 2", ), ] @@ -393,6 +445,7 @@ def test_score_root_cause_single_it_no_score(self, mock_gpt_client, mock_dataset description="Description 1", likelihood=0.8, actionability=0.7, + reproduction="Steps to reproduce", ) ] @@ -410,6 +463,7 @@ def test_score_root_cause_single_it_missing_expected_output(self, mock_dataset_i description="Description 1", likelihood=0.8, actionability=0.7, + reproduction="Steps to reproduce", ) ] @@ -433,6 +487,7 @@ def test_score_root_causes(self, mock_score_root_cause_single_it, mock_dataset_i description="Description 1", likelihood=0.8, actionability=0.7, + reproduction="Steps to reproduce 1", ), RootCauseAnalysisItem( id=2, @@ -440,6 +495,7 @@ def test_score_root_causes(self, mock_score_root_cause_single_it, mock_dataset_i description="Description 2", likelihood=0.6, actionability=0.5, + reproduction="Steps to reproduce 2", ), ] @@ -461,6 +517,7 @@ def test_score_root_causes_custom_n_panel( description="Description 1", likelihood=0.8, actionability=0.7, + reproduction="Steps to reproduce 1", ), RootCauseAnalysisItem( id=2, @@ -468,6 +525,7 @@ def test_score_root_causes_custom_n_panel( description="Description 2", likelihood=0.6, actionability=0.5, + reproduction="Steps to reproduce 2", ), ] diff --git a/tests/automation/autofix/test_models.py b/tests/automation/autofix/test_models.py index 18dbcdcad..4bce8fad2 100644 --- a/tests/automation/autofix/test_models.py +++ b/tests/automation/autofix/test_models.py @@ -543,7 +543,12 @@ def test_set_last_step_completed_message(self): def test_get_selected_root_cause_and_fix(self): root_cause_step = RootCauseStep(key="root_cause_analysis", title="test") cause = RootCauseAnalysisItem( - id=1, title="test", description="test", likelihood=0.5, actionability=0.5 + id=1, + title="test", + description="test", + reproduction="test", + likelihood=0.5, + actionability=0.5, ) root_cause_step.causes = [cause] root_cause_step.selection = CodeContextRootCauseSelection(cause_id=1)