From d00654791697288500dca223d0863324383c6cb0 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Tue, 28 Nov 2023 14:27:38 +0800 Subject: [PATCH] Make tests for validation rules stricter --- graphene_django/tests/test_views.py | 58 +++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/graphene_django/tests/test_views.py b/graphene_django/tests/test_views.py index 648261cb..c2b42bce 100644 --- a/graphene_django/tests/test_views.py +++ b/graphene_django/tests/test_views.py @@ -829,13 +829,13 @@ def test_query_errors_non_atomic(set_rollback_mock, client): set_rollback_mock.assert_not_called() -validation_urls = [ +VALIDATION_URLS = [ "/graphql/validation/", "/graphql/validation/alternative/", "/graphql/validation/inherited/", ] -query_with_two_introspections = """ +QUERY_WITH_TWO_INTROSPECTIONS = """ query Instrospection { queryType: __schema { queryType {name} @@ -846,8 +846,10 @@ def test_query_errors_non_atomic(set_rollback_mock, client): } """ -introspection_disallow_error_message = "introspection is disabled" -max_validation_errors_exceeded_message = "too many validation errors" +N_INTROSPECTIONS = 2 + +INTROSPECTION_DISALLOWED_ERROR_MESSAGE = "introspection is disabled" +MAX_VALIDATION_ERRORS_EXCEEDED_MESSAGE = "too many validation errors" @pytest.mark.urls("graphene_django.tests.urls_validation") @@ -862,34 +864,60 @@ def test_allow_introspection(client): } -@pytest.mark.parametrize("url", validation_urls) +@pytest.mark.parametrize("url", VALIDATION_URLS) @pytest.mark.urls("graphene_django.tests.urls_validation") def test_validation_disallow_introspection(client, url): response = client.post(url_string(url, query="{__schema {queryType {name}}}")) assert response.status_code == 400 - assert introspection_disallow_error_message in response.content.decode() + + json_response = response_json(response) + assert "data" not in json_response + assert "errors" in json_response + assert len(json_response["errors"]) == 1 + + error_message = json_response["errors"][0]["message"] + assert INTROSPECTION_DISALLOWED_ERROR_MESSAGE in error_message -@pytest.mark.parametrize("url", validation_urls) +@pytest.mark.parametrize("url", VALIDATION_URLS) @pytest.mark.urls("graphene_django.tests.urls_validation") -@patch("graphene_django.settings.graphene_settings.MAX_VALIDATION_ERRORS", 2) +@patch( + "graphene_django.settings.graphene_settings.MAX_VALIDATION_ERRORS", N_INTROSPECTIONS +) def test_within_max_validation_errors(client, url): - response = client.post(url_string(url, query=query_with_two_introspections)) + response = client.post(url_string(url, query=QUERY_WITH_TWO_INTROSPECTIONS)) assert response.status_code == 400 - text_response = response.content.decode().lower() + json_response = response_json(response) + assert "data" not in json_response + assert "errors" in json_response + assert len(json_response["errors"]) == N_INTROSPECTIONS + + error_messages = [error["message"].lower() for error in json_response["errors"]] - assert text_response.count(introspection_disallow_error_message) == 2 - assert max_validation_errors_exceeded_message not in text_response + n_introspection_error_messages = sum( + INTROSPECTION_DISALLOWED_ERROR_MESSAGE in msg for msg in error_messages + ) + assert n_introspection_error_messages == N_INTROSPECTIONS + + assert all( + MAX_VALIDATION_ERRORS_EXCEEDED_MESSAGE not in msg for msg in error_messages + ) -@pytest.mark.parametrize("url", validation_urls) +@pytest.mark.parametrize("url", VALIDATION_URLS) @pytest.mark.urls("graphene_django.tests.urls_validation") @patch("graphene_django.settings.graphene_settings.MAX_VALIDATION_ERRORS", 1) def test_exceeds_max_validation_errors(client, url): - response = client.post(url_string(url, query=query_with_two_introspections)) + response = client.post(url_string(url, query=QUERY_WITH_TWO_INTROSPECTIONS)) assert response.status_code == 400 - assert max_validation_errors_exceeded_message in response.content.decode().lower() + + json_response = response_json(response) + assert "data" not in json_response + assert "errors" in json_response + + error_messages = (error["message"].lower() for error in json_response["errors"]) + assert any(MAX_VALIDATION_ERRORS_EXCEEDED_MESSAGE in msg for msg in error_messages)