From 670d0cf3ff4fd41cd9e8ca1a3df93f52238f38e3 Mon Sep 17 00:00:00 2001 From: David Kleiven Date: Tue, 3 Sep 2024 12:03:59 +0200 Subject: [PATCH] feat!: allow users to specify arbitrary retry stopping criteria" -m "BREAKING CHANGE: num_retries is no longer available as ServiceConfig --- cimsparql/graphdb.py | 6 ++++-- tests/test_retry.py | 4 +++- tests/test_retry_cb.py | 6 +++++- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/cimsparql/graphdb.py b/cimsparql/graphdb.py index fd6fa132..778dd545 100644 --- a/cimsparql/graphdb.py +++ b/cimsparql/graphdb.py @@ -29,6 +29,8 @@ if TYPE_CHECKING: from collections.abc import Callable + from tenacity.stop import stop_base + from cimsparql.sparql_result_json import SparqlResultValue @@ -80,6 +82,7 @@ class ServiceConfig: rest_api: RestApi = field(default=RestApi(os.getenv("SPARQL_REST_API", "RDF4J"))) ca_bundle: str | None = field(default=None) retry_callback_factory: RetryCallbackFactory = field(default=RetryCallback) + retry_stop_criteria: stop_base = field(default=tenacity.stop_after_attempt(1)) # Parameters for rest api # https://rdf4j.org/documentation/reference/rest-api/ @@ -88,7 +91,6 @@ class ServiceConfig: limit: int | None = None offset: int | None = None timeout: int | None = None - num_retries: int = 0 max_delay_seconds: int = 60 validate: bool = False @@ -241,7 +243,7 @@ def exec_query(self, query: str) -> SparqlResultJson: sparql_result = None for attempt in tenacity.Retrying( - stop=tenacity.stop_after_attempt(self.service_cfg.num_retries + 1), + stop=self.service_cfg.retry_stop_criteria, wait=tenacity.wait_exponential(max=self.service_cfg.max_delay_seconds), before=retry_cb.before, after=retry_cb.after, diff --git a/tests/test_retry.py b/tests/test_retry.py index ff6ab784..25fe845a 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -42,7 +42,9 @@ def test_retry_sync( url = fail_first_ok_second_server(httpserver) config = ServiceConfig( - server=url, rest_api=RestApi.DIRECT_SPARQL_ENDPOINT, num_retries=num_retries + server=url, + rest_api=RestApi.DIRECT_SPARQL_ENDPOINT, + retry_stop_criteria=tenacity.stop_after_attempt(num_retries + 1), ) client = GraphDBClient(config) diff --git a/tests/test_retry_cb.py b/tests/test_retry_cb.py index d5dcc25d..da766bfc 100644 --- a/tests/test_retry_cb.py +++ b/tests/test_retry_cb.py @@ -1,6 +1,7 @@ from typing import Any import pytest +import tenacity from SPARQLWrapper import SPARQLWrapper from cimsparql.graphdb import GraphDBClient, ServiceConfig @@ -26,7 +27,10 @@ def queryAndConvert(self) -> dict[str, Any]: # noqa: N802 def test_after_callback(caplog: pytest.LogCaptureFixture): wrapper = FailFirstSparqlWrapper("http://some-sparql-endpint") - client = GraphDBClient(service_cfg=ServiceConfig(num_retries=1), sparql_wrapper=wrapper) + client = GraphDBClient( + service_cfg=ServiceConfig(retry_stop_criteria=tenacity.stop_after_attempt(2)), + sparql_wrapper=wrapper, + ) client.exec_query("# Name: Select everything\nselect * where {?s ?p ?o}") # Expect one message to contain be logged with the query name