diff --git a/src/seer/automation/autofix/autofix_context.py b/src/seer/automation/autofix/autofix_context.py index 7b4b18df7..7124bfad8 100644 --- a/src/seer/automation/autofix/autofix_context.py +++ b/src/seer/automation/autofix/autofix_context.py @@ -2,6 +2,7 @@ import textwrap from typing import Mapping, cast +from pydantic import ValidationError import sentry_sdk from seer.automation.autofix.event_manager import AutofixEventManager @@ -20,8 +21,9 @@ from seer.automation.models import EventDetails, FileChange, FilePatch, RepoDefinition, Stacktrace from seer.automation.pipeline import PipelineContext from seer.automation.state import State +from seer.automation.summarize.issue import IssueSummary from seer.automation.utils import AgentError, get_sentry_client -from seer.db import DbPrIdToAutofixRunIdMapping, Session +from seer.db import DbIssueSummary, DbPrIdToAutofixRunIdMapping, Session from seer.rpc import RpcClient logger = logging.getLogger(__name__) @@ -104,6 +106,17 @@ def signals(self, value: list[str]): with self.state.update() as state: state.signals = value + def get_issue_summary(self) -> IssueSummary | None: + group_id = self.state.get().request.issue.id + with Session() as session: + group_summary = session.get(DbIssueSummary, group_id) + if group_summary: + try: + return IssueSummary.model_validate(group_summary.summary) + except ValidationError: + return None + return None + def repos_by_key(self) -> Mapping[RepoKey, RepoDefinition]: repos_by_key: dict[RepoKey, RepoDefinition] = { repo.external_id: repo for repo in self.repos diff --git a/src/seer/automation/autofix/steps/coding_step.py b/src/seer/automation/autofix/steps/coding_step.py index 19d0c0fe6..7aa2018bf 100644 --- a/src/seer/automation/autofix/steps/coding_step.py +++ b/src/seer/automation/autofix/steps/coding_step.py @@ -70,16 +70,9 @@ def _invoke(self, **kwargs): event_details = EventDetails.from_event(state.request.issue.events[0]) self.context.process_event_paths(event_details) - group_id = state.request.issue.id summary = state.request.issue_summary if not summary: - with Session() as session: - group_summary = session.get(DbIssueSummary, group_id) - if group_summary: - try: - summary = IssueSummary.model_validate(group_summary.summary) - except ValidationError: - pass + summary = self.context.get_issue_summary() coding_output = CodingComponent(self.context).invoke( CodingRequest( diff --git a/src/seer/automation/autofix/steps/root_cause_step.py b/src/seer/automation/autofix/steps/root_cause_step.py index c27311953..fb6fd0ac0 100644 --- a/src/seer/automation/autofix/steps/root_cause_step.py +++ b/src/seer/automation/autofix/steps/root_cause_step.py @@ -57,16 +57,9 @@ def _invoke(self, **kwargs): event_details = EventDetails.from_event(state.request.issue.events[0]) self.context.process_event_paths(event_details) - group_id = state.request.issue.id summary = state.request.issue_summary if not summary: - with Session() as session: - group_summary = session.get(DbIssueSummary, group_id) - if group_summary: - try: - summary = IssueSummary.model_validate(group_summary.summary) - except ValidationError: - pass + summary = self.context.get_issue_summary() root_cause_output = RootCauseAnalysisComponent(self.context).invoke( RootCauseAnalysisRequest(