From 301bda22c877930b0d3c774118500f5b0c5a2ac9 Mon Sep 17 00:00:00 2001 From: Jodi Jang Date: Fri, 20 Sep 2024 13:13:25 -0700 Subject: [PATCH] ref(similarity): Handle optional message --- src/seer/grouping/grouping.py | 10 ++-- tests/seer/grouping/test_grouping.py | 37 ++++++++++++--- tests/test_seer.py | 69 +++++++++++++++++++++++++++- 3 files changed, 103 insertions(+), 13 deletions(-) diff --git a/src/seer/grouping/grouping.py b/src/seer/grouping/grouping.py index c96294d0b..c136f4f5b 100644 --- a/src/seer/grouping/grouping.py +++ b/src/seer/grouping/grouping.py @@ -26,8 +26,8 @@ class GroupingRequest(BaseModel): project_id: int stacktrace: str - message: str hash: str + message: Optional[str] = None exception_type: Optional[str] = None k: int = 1 threshold: float = NN_GROUPING_DISTANCE @@ -36,7 +36,7 @@ class GroupingRequest(BaseModel): hnsw_distance: float = NN_GROUPING_HNSW_DISTANCE use_reranking: bool = False - @field_validator("stacktrace", "message") + @field_validator("stacktrace") @classmethod def check_field_is_not_empty(cls, v, info: ValidationInfo): if not v: @@ -59,7 +59,7 @@ class CreateGroupingRecordData(BaseModel): group_id: int hash: str project_id: int - message: str + message: Optional[str] = None exception_type: Optional[str] = None @@ -90,7 +90,7 @@ class DeleteGroupingRecordsByHashResponse(BaseModel): class GroupingRecord(BaseModel): project_id: int - message: str + message: Optional[str] = None stacktrace_embedding: np.ndarray hash: str error_type: Optional[str] = None @@ -334,7 +334,7 @@ def get_nearest_neighbors(self, issue: GroupingRequest) -> SimilarityResponse: :param issue: The issue containing the stacktrace, similarity threshold, and number of nearest neighbors to find (k). :return: A SimilarityResponse object containing a list of GroupingResponse objects with the nearest group IDs, - stacktrace similarity scores, message similarity scores, and grouping flags. + stacktrace similarity scores, and grouping flags. """ with Session() as session: embedding = self.encode_text(issue.stacktrace).astype("float32") diff --git a/tests/seer/grouping/test_grouping.py b/tests/seer/grouping/test_grouping.py index d9fd2e3ff..0c6159945 100644 --- a/tests/seer/grouping/test_grouping.py +++ b/tests/seer/grouping/test_grouping.py @@ -161,6 +161,29 @@ def test_insert_new_grouping_record_group_record_cross_project(self): ) assert len(matching_record) == 2 + def test_insert_new_grouping_no_message(self): + """ + Tests that insert_new_grouping_record can create records without messages. + """ + with Session() as session: + embedding = grouping_lookup().encode_text("stacktrace") + project_id, hash = 1, "QYK7aNYNnp5FgSev9Np1soqb1SdtyahD" + grouping_request = GroupingRequest( + project_id=project_id, + stacktrace="stacktrace", + hash=hash, + ) + # Insert the grouping record + grouping_lookup().insert_new_grouping_record(session, grouping_request, embedding) + session.commit() + + created_record = ( + session.query(DbGroupingRecord).filter_by(hash=hash, project_id=project_id).first() + ) + + assert created_record + assert created_record.message is None + def test_bulk_create_and_insert_grouping_records_valid(self): """Test bulk creating and inserting grouping records""" hashes = [str(i) * 32 for i in range(10)] @@ -347,7 +370,6 @@ def test_bulk_create_and_insert_grouping_records_has_neighbor_in_batch(self): group_id=i, hash=hashes[i], project_id=1, - message="message", ) for i in range(10) ], @@ -469,10 +491,11 @@ def test_GroupingLookup_insert_batch_grouping_records_duplicates( ): orig_record.project_id = project_1_id orig_record.hash = hash_1 + orig_record.exception_type = "error" project_2_id = project_1_id + 1 hash_2 = hash_1 + "_2" - updated_duplicate = orig_record.copy(update=dict(message=orig_record.message + " updated?")) + updated_duplicate = orig_record.copy(update=dict(exception_type="transaction")) grouping_request.data = [ orig_record, @@ -492,17 +515,17 @@ def query_created(record: CreateGroupingRecordData) -> DbGroupingRecord | None: ) @change_watcher - def updated_message_for_orig(): + def updated_exception_type_for_orig(): db_record = query_created(orig_record) if db_record: - return db_record.message + return db_record.error_type return None - with updated_message_for_orig as changed: + with updated_exception_type_for_orig as changed: grouping_lookup().insert_batch_grouping_records(grouping_request) assert changed - assert changed.to_value(orig_record.message) + assert changed.to_value(orig_record.exception_type) # ensure that atleast a record was made for each item for item in grouping_request.data: @@ -511,7 +534,7 @@ def updated_message_for_orig(): # Again, ensuring that duplicates are ignored grouping_request.data = [updated_duplicate] grouping_request.stacktrace_list = ["does not matter" for _ in grouping_request.data] - with updated_message_for_orig as changed: + with updated_exception_type_for_orig as changed: grouping_lookup().insert_batch_grouping_records(grouping_request) assert not changed diff --git a/tests/test_seer.py b/tests/test_seer.py index d57fb9d1b..25da97799 100644 --- a/tests/test_seer.py +++ b/tests/test_seer.py @@ -24,7 +24,11 @@ from seer.configuration import AppConfig, provide_test_defaults from seer.db import DbGroupingRecord, DbSmokeTest, ProcessRequest, Session from seer.dependency_injection import Module, resolve -from seer.grouping.grouping import CreateGroupingRecordData, CreateGroupingRecordsRequest +from seer.grouping.grouping import ( + CreateGroupingRecordData, + CreateGroupingRecordsRequest, + GroupingRequest, +) from seer.inference_models import dummy_deferred, reset_loading_state, start_loading from seer.smoke_test import smoke_test @@ -320,6 +324,40 @@ def test_no_data_after_request_start(self): output = json.loads(response.get_data(as_text=True)) assert output == {"data": []} + def test_similarity_endpoint_no_message(self): + """Test the similarity endpoint with no message and the no message response""" + # Create a record + hashes = ["hash1", "hash2"] + request = GroupingRequest(project_id=1, stacktrace="stacktrace", hash=hashes[0]) + response = app.test_client().post( + "/v0/issues/similar-issues", + data=request.json(), + content_type="application/json", + ) + + # Call endpoint with similar issue to the previous call + request = GroupingRequest(project_id=1, stacktrace="stacktrace", hash=hashes[1]) + response = app.test_client().post( + "/v0/issues/similar-issues", + data=request.json(), + content_type="application/json", + ) + output = json.loads(response.get_data(as_text=True)) + assert output == { + "responses": [ + { + "message_distance": 0.0, + "parent_hash": hashes[0], + "should_group": True, + "stacktrace_distance": 0.0, + } + ] + } + with Session() as session: + records = session.query(DbGroupingRecord).filter(DbGroupingRecord.hash.in_(hashes)) + assert len(records) == 1 + assert records.first().hash == hashes[0] + def test_similarity_grouping_record_endpoint_valid(self): """Test the similarity grouping record endpoint""" hashes = [str(i) * 32 for i in range(5)] @@ -348,6 +386,35 @@ def test_similarity_grouping_record_endpoint_valid(self): for i in range(5): assert records[i] is not None + def test_similarity_grouping_record_endpoint_no_message(self): + """Test the similarity grouping record endpoint without providing message""" + hashes = [str(i) * 32 for i in range(5)] + record_requests = CreateGroupingRecordsRequest( + data=[ + CreateGroupingRecordData( + group_id=i, + hash=hashes[i], + project_id=1, + ) + for i in range(5) + ], + stacktrace_list=["stacktrace " + str(i) for i in range(5)], + ) + + response = app.test_client().post( + "/v0/issues/similar-issues/grouping-record", + data=record_requests.json(), + content_type="application/json", + ) + output = json.loads(response.get_data(as_text=True)) + assert output == {"success": True, "groups_with_neighbor": {}} + with Session() as session: + records = session.query(DbGroupingRecord).filter(DbGroupingRecord.hash.in_(hashes)) + for i in range(5): + assert records[i] is not None + assert records[i].message is None + + # test regular similarity endpoint with no message def test_similarity_grouping_record_endpoint_invalid(self): """ Test the similarity grouping record endpoint is unsuccessful when input lists are of