diff --git a/src/sentry/api/endpoints/group_ai_summary.py b/src/sentry/api/endpoints/group_ai_summary.py index 8e171b7892719d..3a96982b2fe7ab 100644 --- a/src/sentry/api/endpoints/group_ai_summary.py +++ b/src/sentry/api/endpoints/group_ai_summary.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from datetime import timedelta from typing import Any import orjson @@ -20,6 +21,7 @@ from sentry.models.group import Group from sentry.seer.signed_seer_api import sign_with_seer_secret from sentry.types.ratelimit import RateLimit, RateLimitCategory +from sentry.utils.cache import cache logger = logging.getLogger(__name__) @@ -30,6 +32,7 @@ class SummarizeIssueResponse(BaseModel): group_id: str summary: str impact: str + headline: str @region_silo_endpoint @@ -89,6 +92,7 @@ def _call_seer( }, option=orjson.OPT_NON_STR_KEYS, ) + response = requests.post( f"{settings.SEER_AUTOFIX_URL}{path}", data=body, @@ -109,10 +113,10 @@ def post(self, request: Request, group: Group) -> Response: if not features.has("organizations:ai-summary", group.organization, actor=request.user): return Response({"detail": "Feature flag not enabled"}, status=400) - if group.data.get("issue_summary"): - return Response( - convert_dict_key_case(group.data["issue_summary"], snake_to_camel_case), status=200 - ) + cache_key = "ai-group-summary:" + str(group.id) + + if cached_summary := cache.get(cache_key): + return Response(convert_dict_key_case(cached_summary, snake_to_camel_case), status=200) serialized_event = self._get_event(group, request.user) @@ -121,8 +125,7 @@ def post(self, request: Request, group: Group) -> Response: issue_summary = self._call_seer(group, serialized_event) - group.data.update({"issue_summary": issue_summary.dict()}) - group.save() + cache.set(cache_key, issue_summary.dict(), timeout=int(timedelta(days=7).total_seconds())) return Response( convert_dict_key_case(issue_summary.dict(), snake_to_camel_case), status=200 diff --git a/tests/sentry/api/endpoints/test_group_ai_summary.py b/tests/sentry/api/endpoints/test_group_ai_summary.py index 4c0ef5992c44b6..9a7b24768d35c6 100644 --- a/tests/sentry/api/endpoints/test_group_ai_summary.py +++ b/tests/sentry/api/endpoints/test_group_ai_summary.py @@ -5,28 +5,40 @@ from sentry.testutils.cases import APITestCase, SnubaTestCase from sentry.testutils.helpers.features import apply_feature_flag_on_cls from sentry.testutils.skips import requires_snuba +from sentry.utils.cache import cache pytestmark = [requires_snuba] @apply_feature_flag_on_cls("organizations:ai-summary") class GroupAiSummaryEndpointTest(APITestCase, SnubaTestCase): + def setUp(self): + super().setUp() + self.group = self.create_group() + self.url = self._get_url(self.group.id) + self.login_as(user=self.user) + + def tearDown(self): + super().tearDown() + # Clear the cache after each test + cache.delete(f"ai-group-summary:{self.group.id}") + def _get_url(self, group_id: int): return f"/api/0/issues/{group_id}/summarize/" @patch("sentry.api.endpoints.group_ai_summary.GroupAiSummaryEndpoint._call_seer") def test_ai_summary_get_endpoint_with_existing_summary(self, mock_call_seer): - group = self.create_group() existing_summary = { - "group_id": str(group.id), + "group_id": str(self.group.id), "summary": "Existing summary", "impact": "Existing impact", + "headline": "Existing headline", } - group.data["issue_summary"] = existing_summary - group.save() - self.login_as(user=self.user) - response = self.client.post(self._get_url(group.id), format="json") + # Set the cache with the existing summary + cache.set(f"ai-group-summary:{self.group.id}", existing_summary, timeout=60 * 60 * 24 * 7) + + response = self.client.post(self.url, format="json") assert response.status_code == 200 assert response.data == convert_dict_key_case(existing_summary, snake_to_camel_case) @@ -35,54 +47,99 @@ def test_ai_summary_get_endpoint_with_existing_summary(self, mock_call_seer): @patch("sentry.api.endpoints.group_ai_summary.GroupAiSummaryEndpoint._get_event") def test_ai_summary_get_endpoint_without_event(self, mock_get_event): mock_get_event.return_value = None - group = self.create_group() - self.login_as(user=self.user) - response = self.client.post(self._get_url(group.id), format="json") + response = self.client.post(self.url, format="json") assert response.status_code == 400 assert response.data == {"detail": "Could not find an event for the issue"} + assert cache.get(f"ai-group-summary:{self.group.id}") is None @patch("sentry.api.endpoints.group_ai_summary.GroupAiSummaryEndpoint._call_seer") @patch("sentry.api.endpoints.group_ai_summary.GroupAiSummaryEndpoint._get_event") def test_ai_summary_get_endpoint_without_existing_summary(self, mock_get_event, mock_call_seer): - group = self.create_group() mock_event = {"id": "test_event_id", "data": "test_event_data"} mock_get_event.return_value = mock_event mock_summary = SummarizeIssueResponse( - group_id=str(group.id), + group_id=str(self.group.id), summary="Test summary", impact="Test impact", + headline="Test headline", ) mock_call_seer.return_value = mock_summary - self.login_as(user=self.user) - response = self.client.post(self._get_url(group.id), format="json") + response = self.client.post(self.url, format="json") assert response.status_code == 200 assert response.data == convert_dict_key_case(mock_summary.dict(), snake_to_camel_case) - mock_get_event.assert_called_once_with(group, ANY) - mock_call_seer.assert_called_once_with(group, mock_event) + mock_get_event.assert_called_once_with(self.group, ANY) + mock_call_seer.assert_called_once_with(self.group, mock_event) + + # Check if the cache was set correctly + cached_summary = cache.get(f"ai-group-summary:{self.group.id}") + assert cached_summary == mock_summary.dict() @patch("sentry.api.endpoints.group_ai_summary.requests.post") @patch("sentry.api.endpoints.group_ai_summary.GroupAiSummaryEndpoint._get_event") def test_ai_summary_call_seer(self, mock_get_event, mock_post): - group = self.create_group() serialized_event = {"id": "test_event_id", "data": "test_event_data"} mock_get_event.return_value = serialized_event mock_response = Mock() mock_response.json.return_value = { - "group_id": str(group.id), + "group_id": str(self.group.id), "summary": "Test summary", "impact": "Test impact", + "headline": "Test headline", } mock_post.return_value = mock_response - self.login_as(user=self.user) - response = self.client.post(self._get_url(group.id), format="json") + response = self.client.post(self.url, format="json") assert response.status_code == 200 assert response.data == convert_dict_key_case( mock_response.json.return_value, snake_to_camel_case ) mock_post.assert_called_once() + + assert cache.get(f"ai-group-summary:{self.group.id}") == mock_response.json.return_value + + def test_ai_summary_cache_write_read(self): + # First request to populate the cache + with ( + patch( + "sentry.api.endpoints.group_ai_summary.GroupAiSummaryEndpoint._get_event" + ) as mock_get_event, + patch( + "sentry.api.endpoints.group_ai_summary.GroupAiSummaryEndpoint._call_seer" + ) as mock_call_seer, + ): + mock_event = {"id": "test_event_id", "data": "test_event_data"} + mock_get_event.return_value = mock_event + + mock_summary = SummarizeIssueResponse( + group_id=str(self.group.id), + summary="Test summary", + impact="Test impact", + headline="Test headline", + ) + mock_call_seer.return_value = mock_summary + + response = self.client.post(self.url, format="json") + assert response.status_code == 200 + assert response.data == convert_dict_key_case(mock_summary.dict(), snake_to_camel_case) + + # Second request should use cached data + with ( + patch( + "sentry.api.endpoints.group_ai_summary.GroupAiSummaryEndpoint._get_event" + ) as mock_get_event, + patch( + "sentry.api.endpoints.group_ai_summary.GroupAiSummaryEndpoint._call_seer" + ) as mock_call_seer, + ): + response = self.client.post(self.url, format="json") + assert response.status_code == 200 + assert response.data == convert_dict_key_case(mock_summary.dict(), snake_to_camel_case) + + # Verify that _get_event and _call_seer were not called for the second request + mock_get_event.assert_not_called() + mock_call_seer.assert_not_called()