From 7529c0cb2cb73a6393470e6c0f4a554a2fea5cef Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Wed, 18 Dec 2024 16:28:39 +0000 Subject: [PATCH] fix: improved error handling (#182) --- aidial_adapter_openai/app.py | 13 +- aidial_adapter_openai/exception_handlers.py | 82 +++++---- aidial_adapter_openai/gpt.py | 2 +- .../gpt4_multi_modal/chat_completion.py | 8 +- .../utils/adapter_exception.py | 95 ++++++++++ aidial_adapter_openai/utils/sse_stream.py | 20 ++- aidial_adapter_openai/utils/streaming.py | 28 ++- tests/test_errors.py | 169 +++++++++++++++++- 8 files changed, 349 insertions(+), 68 deletions(-) create mode 100644 aidial_adapter_openai/utils/adapter_exception.py diff --git a/aidial_adapter_openai/app.py b/aidial_adapter_openai/app.py index 5d07629..07efbee 100644 --- a/aidial_adapter_openai/app.py +++ b/aidial_adapter_openai/app.py @@ -1,6 +1,5 @@ from contextlib import asynccontextmanager -import pydantic from aidial_sdk.exceptions import HTTPException as DialException from aidial_sdk.telemetry.init import init_telemetry as sdk_init_telemetry from aidial_sdk.telemetry.types import TelemetryConfig @@ -9,11 +8,7 @@ import aidial_adapter_openai.endpoints as endpoints from aidial_adapter_openai.app_config import ApplicationConfig -from aidial_adapter_openai.exception_handlers import ( - dial_exception_handler, - openai_exception_handler, - pydantic_exception_handler, -) +from aidial_adapter_openai.exception_handlers import adapter_exception_handler from aidial_adapter_openai.utils.http_client import get_http_client from aidial_adapter_openai.utils.log_config import configure_loggers, logger from aidial_adapter_openai.utils.request import set_app_config @@ -45,9 +40,9 @@ def create_app( app.post("/openai/deployments/{deployment_id:path}/chat/completions")( endpoints.chat_completion ) - app.exception_handler(OpenAIError)(openai_exception_handler) - app.exception_handler(pydantic.ValidationError)(pydantic_exception_handler) - app.exception_handler(DialException)(dial_exception_handler) + + for exc_class in [OpenAIError, DialException]: + app.add_exception_handler(exc_class, adapter_exception_handler) return app diff --git a/aidial_adapter_openai/exception_handlers.py b/aidial_adapter_openai/exception_handlers.py index c98c122..55a06e4 100644 --- a/aidial_adapter_openai/exception_handlers.py +++ b/aidial_adapter_openai/exception_handlers.py @@ -1,56 +1,76 @@ -import pydantic -from aidial_sdk._errors import pydantic_validation_exception_handler from aidial_sdk.exceptions import HTTPException as DialException -from fastapi import Request -from fastapi.responses import Response +from aidial_sdk.exceptions import InternalServerError +from fastapi.requests import Request as FastAPIRequest +from fastapi.responses import Response as FastAPIResponse from openai import APIConnectionError, APIError, APIStatusError, APITimeoutError +from aidial_adapter_openai.utils.adapter_exception import ( + AdapterException, + ResponseWrapper, + parse_adapter_exception, +) -def openai_exception_handler(request: Request, e: DialException): - if isinstance(e, APIStatusError): - r = e.response - headers = r.headers - # Avoid encoding the error message when the original response was encoded. - if "Content-Encoding" in headers: - del headers["Content-Encoding"] +def to_adapter_exception(exc: Exception) -> AdapterException: - return Response( - content=r.content, + if isinstance(exc, (DialException, ResponseWrapper)): + return exc + + if isinstance(exc, APIStatusError): + # Non-streaming errors reported by `openai` library via this exception + r = exc.response + httpx_headers = r.headers + + # httpx library (used by openai) automatically sets + # "Accept-Encoding:gzip,deflate" header in requests to the upstream. + # Therefore, we may receive from the upstream gzip-encoded + # response along with "Content-Encoding:gzip" header. + # We either need to encode the response, or + # remove the "Content-Encoding" header. + if "Content-Encoding" in httpx_headers: + del httpx_headers["Content-Encoding"] + + return parse_adapter_exception( status_code=r.status_code, - headers=headers, + headers=dict(httpx_headers.items()), + content=r.text, ) - if isinstance(e, APITimeoutError): - raise DialException( + if isinstance(exc, APITimeoutError): + return DialException( status_code=504, type="timeout", message="Request timed out", display_message="Request timed out. Please try again later.", ) - if isinstance(e, APIConnectionError): - raise DialException( + if isinstance(exc, APIConnectionError): + return DialException( status_code=502, type="connection", message="Error communicating with OpenAI", display_message="OpenAI server is not responsive. Please try again later.", ) - if isinstance(e, APIError): - raise DialException( - status_code=getattr(e, "status_code", None) or 500, - message=e.message, - type=e.type, - code=e.code, - param=e.param, - display_message=None, - ) + if isinstance(exc, APIError): + # Streaming errors reported by `openai` library via this exception + status_code: int = 500 + if exc.code: + try: + status_code = int(exc.code) + except Exception: + pass + return parse_adapter_exception( + status_code=status_code, + headers={}, + content={"error": exc.body or {}}, + ) -def pydantic_exception_handler(request: Request, exc: pydantic.ValidationError): - return pydantic_validation_exception_handler(request, exc) + return InternalServerError(str(exc)) -def dial_exception_handler(request: Request, exc: DialException): - return exc.to_fastapi_response() +def adapter_exception_handler( + request: FastAPIRequest, exc: Exception +) -> FastAPIResponse: + return to_adapter_exception(exc).to_fastapi_response() diff --git a/aidial_adapter_openai/gpt.py b/aidial_adapter_openai/gpt.py index d4c6cde..dd909e2 100644 --- a/aidial_adapter_openai/gpt.py +++ b/aidial_adapter_openai/gpt.py @@ -78,12 +78,12 @@ async def gpt_chat_completion( if isinstance(response, AsyncIterator): return generate_stream( + stream=map_stream(chunk_to_dict, response), get_prompt_tokens=lambda: estimated_prompt_tokens or tokenizer.tokenize_request(request, request["messages"]), tokenize_response=tokenizer.tokenize_response, deployment=deployment_id, discarded_messages=discarded_messages, - stream=map_stream(chunk_to_dict, response), eliminate_empty_choices=eliminate_empty_choices, ) else: diff --git a/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py b/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py index 216137d..6e66e59 100644 --- a/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py +++ b/aidial_adapter_openai/gpt4_multi_modal/chat_completion.py @@ -262,14 +262,14 @@ def debug_print(chunk: T) -> T: return map_stream( debug_print, generate_stream( - get_prompt_tokens=lambda: estimated_prompt_tokens, - tokenize_response=tokenizer.tokenize_response, - deployment=deployment, - discarded_messages=discarded_messages, stream=map_stream( response_transformer, parse_openai_sse_stream(response), ), + get_prompt_tokens=lambda: estimated_prompt_tokens, + tokenize_response=tokenizer.tokenize_response, + deployment=deployment, + discarded_messages=discarded_messages, eliminate_empty_choices=eliminate_empty_choices, ), ) diff --git a/aidial_adapter_openai/utils/adapter_exception.py b/aidial_adapter_openai/utils/adapter_exception.py new file mode 100644 index 0000000..863f6bd --- /dev/null +++ b/aidial_adapter_openai/utils/adapter_exception.py @@ -0,0 +1,95 @@ +import json +from typing import Any, Dict + +from aidial_sdk.exceptions import HTTPException as DialException +from fastapi.responses import Response as FastAPIResponse + + +class ResponseWrapper(Exception): + content: Any + status_code: int + headers: Dict[str, str] | None + + def __init__( + self, + *, + content: Any, + status_code: int, + headers: Dict[str, str] | None, + ) -> None: + super().__init__(str(content)) + self.content = content + self.status_code = status_code + self.headers = headers + + def __repr__(self): + # headers field is omitted deliberately + # since it may contain sensitive information + return "%s(content=%r, status_code=%r)" % ( + self.__class__.__name__, + self.content, + self.status_code, + ) + + def to_fastapi_response(self) -> FastAPIResponse: + return FastAPIResponse( + status_code=self.status_code, + content=self.content, + headers=self.headers, + ) + + def json_error(self) -> dict: + return { + "error": { + "message": str(self.content), + "code": int(self.status_code), + } + } + + +AdapterException = ResponseWrapper | DialException + + +def _parse_dial_exception( + *, status_code: int, headers: Dict[str, str], content: Any +) -> DialException | None: + if isinstance(content, str): + try: + obj = json.loads(content) + except Exception: + return None + else: + obj = content + + if ( + isinstance(obj, dict) + and (error := obj.get("error")) + and isinstance(error, dict) + ): + message = error.get("message") or "Unknown error" + code = error.get("code") + type = error.get("type") + param = error.get("param") + display_message = error.get("display_message") + + return DialException( + status_code=status_code, + message=message, + type=type, + param=param, + code=code, + display_message=display_message, + headers=headers, + ) + + return None + + +def parse_adapter_exception( + *, status_code: int, headers: Dict[str, str], content: Any +) -> AdapterException: + return _parse_dial_exception( + status_code=status_code, headers=headers, content=content + ) or ResponseWrapper( + status_code=status_code, headers=headers, content=content + ) diff --git a/aidial_adapter_openai/utils/sse_stream.py b/aidial_adapter_openai/utils/sse_stream.py index 3094d00..3b02b29 100644 --- a/aidial_adapter_openai/utils/sse_stream.py +++ b/aidial_adapter_openai/utils/sse_stream.py @@ -3,6 +3,9 @@ from aidial_sdk.exceptions import runtime_server_error +from aidial_adapter_openai.exception_handlers import to_adapter_exception +from aidial_adapter_openai.utils.log_config import logger + DATA_PREFIX = "data: " OPENAI_END_MARKER = "[DONE]" @@ -53,6 +56,19 @@ async def parse_openai_sse_stream( async def to_openai_sse_stream( stream: AsyncIterator[dict], ) -> AsyncIterator[str]: - async for chunk in stream: - yield format_chunk(chunk) + try: + async for chunk in stream: + yield format_chunk(chunk) + except Exception as e: + logger.exception( + f"caught exception while streaming: {type(e).__module__}.{type(e).__name__}" + ) + + adapter_exception = to_adapter_exception(e) + logger.error( + f"converted to the adapter exception: {adapter_exception!r}" + ) + + yield format_chunk(adapter_exception.json_error()) + yield END_CHUNK diff --git a/aidial_adapter_openai/utils/streaming.py b/aidial_adapter_openai/utils/streaming.py index debcd76..ebfb779 100644 --- a/aidial_adapter_openai/utils/streaming.py +++ b/aidial_adapter_openai/utils/streaming.py @@ -6,7 +6,6 @@ from aidial_sdk.exceptions import HTTPException as DialException from aidial_sdk.utils.merge_chunks import merge_chat_completion_chunks from fastapi.responses import JSONResponse, Response, StreamingResponse -from openai import APIError, APIStatusError from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from pydantic import BaseModel @@ -54,11 +53,11 @@ def build_chunk( async def generate_stream( *, + stream: AsyncIterator[dict], get_prompt_tokens: Callable[[], int], tokenize_response: Callable[[ChatCompletionResponse], int], deployment: str, discarded_messages: Optional[list[int]], - stream: AsyncIterator[dict], eliminate_empty_choices: bool, ) -> AsyncIterator[dict]: @@ -106,7 +105,7 @@ def set_discarded_messages(chunk: dict | None, indices: list[int]) -> dict: buffer_chunk = None response_snapshot = ChatCompletionStreamingChunk() - error = None + error: Exception | None = None try: async for chunk in stream: @@ -129,15 +128,11 @@ def set_discarded_messages(chunk: dict | None, indices: list[int]) -> dict: yield last_chunk last_chunk = chunk - except APIError as e: - status_code = e.status_code if isinstance(e, APIStatusError) else 500 - error = DialException( - status_code=status_code, - message=e.message, - type=e.type, - param=e.param, - code=e.code, - ).json_error() + except Exception as e: + logger.exception( + f"caught exception while streaming: {type(e).__module__}.{type(e).__name__}" + ) + error = e if last_chunk is not None and buffer_chunk is not None: last_chunk = merge_chat_completion_chunks(last_chunk, buffer_chunk) @@ -168,7 +163,7 @@ def set_discarded_messages(chunk: dict | None, indices: list[int]) -> dict: yield last_chunk if error: - yield error + raise error def create_stage_chunk(name: str, content: str, stream: bool) -> dict: @@ -204,7 +199,7 @@ def create_stage_chunk(name: str, content: str, stream: bool) -> dict: def create_response_from_chunk( chunk: dict, exc: DialException | None, stream: bool -) -> Response: +) -> AsyncIterator[dict] | Response: if not stream: if exc is not None: return exc.to_fastapi_response() @@ -216,10 +211,7 @@ async def generator() -> AsyncIterator[dict]: if exc is not None: yield exc.json_error() - return StreamingResponse( - to_openai_sse_stream(generator()), - media_type="text/event-stream", - ) + return generator() def block_response_to_streaming_chunk(response: dict) -> dict: diff --git a/tests/test_errors.py b/tests/test_errors.py index a74c1dd..e4ebf66 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,5 +1,6 @@ import json -from typing import Any, Callable +from typing import Any, AsyncIterable, AsyncIterator, Callable +from unittest.mock import patch import httpx import pytest @@ -447,7 +448,34 @@ async def test_status_error_from_upstream(test_app: httpx.AsyncClient): ) assert response.status_code == 400 - assert response.content == b"Bad request" + assert response.text == "Bad request" + + +@respx.mock +@pytest.mark.asyncio +async def test_status_error_from_upstream_with_headers( + test_app: httpx.AsyncClient, +): + respx.post( + "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview" + ).respond( + status_code=429, + content="Too many requests", + headers={"Retry-After": "42"}, + ) + + response = await test_app.post( + "/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", + json={"messages": [{"role": "user", "content": "Test content"}]}, + headers={ + "X-UPSTREAM-KEY": "TEST_API_KEY", + "X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions", + }, + ) + + assert response.status_code == 429 + assert response.text == "Too many requests" + assert response.headers["Retry-After"] == "42" @respx.mock @@ -479,7 +507,9 @@ async def test_timeout_error_from_upstream(test_app: httpx.AsyncClient): @respx.mock @pytest.mark.asyncio -async def test_connection_error_from_upstream(test_app: httpx.AsyncClient): +async def test_connection_error_from_upstream_non_streaming( + test_app: httpx.AsyncClient, +): respx.post( "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview" ).mock(side_effect=httpx.ConnectError("Connection error")) @@ -504,6 +534,139 @@ async def test_connection_error_from_upstream(test_app: httpx.AsyncClient): } +@respx.mock +@pytest.mark.asyncio +async def test_connection_error_from_upstream_streaming( + test_app: httpx.AsyncClient, +): + async def mock_stream() -> AsyncIterable[bytes]: + yield b'data: {"message": "first chunk"}\n\n' + yield b'data: {"message": "second chunk"}\n\n' + raise httpx.ConnectError("Connection error") + + respx.post( + "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview" + ).respond( + status_code=200, + content_type="text/event-stream", + content=mock_stream(), + ) + + response = await test_app.post( + "/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", + json={ + "stream": True, + "messages": [{"role": "user", "content": "Test content"}], + }, + headers={ + "X-UPSTREAM-KEY": "TEST_API_KEY", + "X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions", + }, + ) + + assert response.status_code == 200 + assert response.text == "\n\n".join( + [ + 'data: {"message":"first chunk"}', + 'data: {"message":"second chunk"}', + 'data: {"error":{"message":"Connection error","type":"internal_server_error","code":"500"}}', + "data: [DONE]", + "", + ] + ) + + +@respx.mock +@pytest.mark.asyncio +async def test_adapter_internal_error( + test_app: httpx.AsyncClient, +): + async def mock_generate_stream(stream: AsyncIterator[dict], **kwargs): + yield await stream.__anext__() + raise ValueError("failed generating the stream") + + with patch( + "aidial_adapter_openai.gpt.generate_stream", + side_effect=mock_generate_stream, + ): + + async def mock_stream() -> AsyncIterable[bytes]: + yield b'data: {"message": "first chunk"}\n\n' + yield b'data: {"message": "second chunk"}\n\n' + yield b"data: [DONE]" + + respx.post( + "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview" + ).respond( + status_code=200, + content_type="text/event-stream", + content=mock_stream(), + ) + + response = await test_app.post( + "/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", + json={ + "stream": True, + "messages": [{"role": "user", "content": "Test content"}], + }, + headers={ + "X-UPSTREAM-KEY": "TEST_API_KEY", + "X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions", + }, + ) + + assert response.status_code == 200 + assert response.text == "\n\n".join( + [ + 'data: {"message":"first chunk"}', + 'data: {"error":{"message":"failed generating the stream","type":"internal_server_error","code":"500"}}', + "data: [DONE]", + "", + ] + ) + + +@respx.mock +@pytest.mark.asyncio +async def test_invalid_chunk_stream_from_upstream( + test_app: httpx.AsyncClient, +): + async def mock_stream() -> AsyncIterable[bytes]: + yield b"data: chunk1\n\n" + yield b"data: chunk2\n\n" + yield b"data: [DONE]\n\n" + + respx.post( + "http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview" + ).respond( + status_code=200, + content_type="text/event-stream", + content=mock_stream(), + ) + + response = await test_app.post( + "/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview", + json={ + "stream": True, + "messages": [{"role": "user", "content": "Test content"}], + }, + headers={ + "X-UPSTREAM-KEY": "TEST_API_KEY", + "X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions", + }, + ) + + assert response.status_code == 200 + assert response.text == "\n\n".join( + [ + # OpenAI is unable to parse SSE entry with invalid JSON and fails with the following error: + 'data: {"error":{"message":"Expecting value: line 1 column 1 (char 0)","type":"internal_server_error","code":"500"}}', + "data: [DONE]", + "", + ] + ) + + @respx.mock @pytest.mark.asyncio async def test_unexpected_multi_modal_input_streaming(