diff --git a/src/seer/grouping/grouping.py b/src/seer/grouping/grouping.py index 65627d977..aeb5fa570 100644 --- a/src/seer/grouping/grouping.py +++ b/src/seer/grouping/grouping.py @@ -363,8 +363,7 @@ def get_nearest_neighbors(self, issue: GroupingRequest) -> SimilarityResponse: "stacktrace_length": len(issue.stacktrace), }, ) - self.insert_new_grouping_record(session, issue, embedding) - session.commit() + self.insert_new_grouping_record(issue, embedding) similarity_response = SimilarityResponse(responses=[]) for record, distance in results: @@ -473,9 +472,7 @@ def insert_batch_grouping_records( return groups_with_neighbor @sentry_sdk.tracing.trace - def insert_new_grouping_record( - self, session, issue: GroupingRequest, embedding: np.ndarray - ) -> None: + def insert_new_grouping_record(self, issue: GroupingRequest, embedding: np.ndarray) -> None: """ Inserts a new GroupingRecord into the database if the group_hash does not already exist. @@ -483,48 +480,49 @@ def insert_new_grouping_record( :param issue: The issue to insert as a new GroupingRecord. :param embedding: The embedding of the stacktrace. """ - existing_record = ( - session.query(DbGroupingRecord) - .filter_by(hash=issue.hash, project_id=issue.project_id) - .first() - ) - - extra = { - "project_id": issue.project_id, - "stacktrace_length": len(issue.stacktrace), - "input_hash": issue.hash, - } + with Session() as session: + existing_record = ( + session.query(DbGroupingRecord) + .filter_by(hash=issue.hash, project_id=issue.project_id) + .first() + ) - if existing_record is None: - new_record = GroupingRecord( - project_id=issue.project_id, - message=issue.message, - stacktrace_embedding=embedding, - hash=issue.hash, - error_type=issue.exception_type, - ).to_db_model() - session.add(new_record) - - try: - session.flush() - except IntegrityError: - session.expunge(new_record) - existing_record = ( - session.query(DbGroupingRecord) - .filter_by(hash=issue.hash, project_id=issue.project_id) - .first() - ) + extra = { + "project_id": issue.project_id, + "stacktrace_length": len(issue.stacktrace), + "input_hash": issue.hash, + } + + if existing_record is None: + new_record = GroupingRecord( + project_id=issue.project_id, + message=issue.message, + stacktrace_embedding=embedding, + hash=issue.hash, + error_type=issue.exception_type, + ).to_db_model() + session.add(new_record) + + try: + session.commit() + except IntegrityError: + session.rollback() + existing_record = ( + session.query(DbGroupingRecord) + .filter_by(hash=issue.hash, project_id=issue.project_id) + .first() + ) + extra["existing_hash"] = existing_record.hash + logger.info( + "group_already_exists_in_seer_db", + extra=extra, + ) + else: extra["existing_hash"] = existing_record.hash logger.info( "group_already_exists_in_seer_db", extra=extra, ) - else: - extra["existing_hash"] = existing_record.hash - logger.info( - "group_already_exists_in_seer_db", - extra=extra, - ) @sentry_sdk.tracing.trace def delete_grouping_records_for_project(self, project_id: int) -> bool: diff --git a/tests/seer/grouping/test_grouping.py b/tests/seer/grouping/test_grouping.py index d9fd2e3ff..22d6f5f89 100644 --- a/tests/seer/grouping/test_grouping.py +++ b/tests/seer/grouping/test_grouping.py @@ -34,7 +34,7 @@ def test_get_nearest_neighbors_has_neighbor(self): message="message", hash="QYK7aNYNnp5FgSev9Np1soqb1SdtyahD", ) - grouping_lookup().insert_new_grouping_record(session, grouping_request, embedding) + grouping_lookup().insert_new_grouping_record(grouping_request, embedding) session.commit() grouping_request = GroupingRequest( @@ -122,11 +122,9 @@ def test_insert_new_grouping_record_group_record_exists(self): hash="QYK7aNYNnp5FgSev9Np1soqb1SdtyahD", ) # Insert the grouping record - grouping_lookup().insert_new_grouping_record(session, grouping_request, embedding) - session.commit() + grouping_lookup().insert_new_grouping_record(grouping_request, embedding) # Re-insert the grouping record - grouping_lookup().insert_new_grouping_record(session, grouping_request, embedding) - session.commit() + grouping_lookup().insert_new_grouping_record(grouping_request, embedding) matching_record = ( session.query(DbGroupingRecord) .filter_by(hash="QYK7aNYNnp5FgSev9Np1soqb1SdtyahD") @@ -150,10 +148,8 @@ def test_insert_new_grouping_record_group_record_cross_project(self): hash="QYK7aNYNnp5FgSev9Np1soqb1SdtyahD", ) # Insert the grouping record - grouping_lookup().insert_new_grouping_record(session, grouping_request1, embedding) - session.commit() - grouping_lookup().insert_new_grouping_record(session, grouping_request2, embedding) - session.commit() + grouping_lookup().insert_new_grouping_record(grouping_request1, embedding) + grouping_lookup().insert_new_grouping_record(grouping_request2, embedding) matching_record = ( session.query(DbGroupingRecord) .filter_by(hash="QYK7aNYNnp5FgSev9Np1soqb1SdtyahD") @@ -293,8 +289,7 @@ def test_bulk_create_and_insert_grouping_records_has_neighbor_in_existing_record message="message", hash="QYK7aNYNnp5FgSev9Np1soqb1SdtyahD", ) - grouping_lookup().insert_new_grouping_record(session, grouping_request, embedding) - session.commit() + grouping_lookup().insert_new_grouping_record(grouping_request, embedding) # Create record data to attempt to be inserted, create 5 with the stacktrace "stacktrace" hashes = [str(i) * 32 for i in range(10)]