diff --git a/serving_utils/client.py b/serving_utils/client.py index 0527ece..e5a0bc0 100644 --- a/serving_utils/client.py +++ b/serving_utils/client.py @@ -259,6 +259,8 @@ async def async_predict( try: stub = self.get_round_robin_stub(is_async_stub=True) response = await stub.Predict(request) + except asyncio.CancelledError: + raise except EmptyPool: self.logger.warning("serving_utils.Client -- empty pool") self._setup_connections() diff --git a/serving_utils/tests/test_client.py b/serving_utils/tests/test_client.py index 2e10d95..b60de85 100644 --- a/serving_utils/tests/test_client.py +++ b/serving_utils/tests/test_client.py @@ -228,6 +228,20 @@ def server_fails_to_Predict_because_model_doesnt_exist(request): assert exc_info.value.message == "Model XXX not found" +@pytest.mark.asyncio +async def test_asyncio_cancel_during_async_predict(): + t = test_asyncio_cancel_during_async_predict + t.mock_gethostbyname_ex.return_value = ('localhost', [], ['1.2.3.4']) + + mock_logger = mock.Mock() + c = Client(host='localhost', port=9999, n_trys=1, logger=mock_logger) + for stub in t.created_async_stubs: + stub.Predict.side_effect = aio.CancelledError() + + with pytest.raises(aio.CancelledError): + await client_async_predict(c) + + @pytest.mark.asyncio async def test_model_not_found_error_passes_through_sync_predict():