Skip to content

Commit

Permalink
fix: improved error handling (#182)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Dec 18, 2024
1 parent 3977441 commit 7529c0c
Show file tree
Hide file tree
Showing 8 changed files with 349 additions and 68 deletions.
13 changes: 4 additions & 9 deletions aidial_adapter_openai/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from contextlib import asynccontextmanager

import pydantic
from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.telemetry.init import init_telemetry as sdk_init_telemetry
from aidial_sdk.telemetry.types import TelemetryConfig
Expand All @@ -9,11 +8,7 @@

import aidial_adapter_openai.endpoints as endpoints
from aidial_adapter_openai.app_config import ApplicationConfig
from aidial_adapter_openai.exception_handlers import (
dial_exception_handler,
openai_exception_handler,
pydantic_exception_handler,
)
from aidial_adapter_openai.exception_handlers import adapter_exception_handler
from aidial_adapter_openai.utils.http_client import get_http_client
from aidial_adapter_openai.utils.log_config import configure_loggers, logger
from aidial_adapter_openai.utils.request import set_app_config
Expand Down Expand Up @@ -45,9 +40,9 @@ def create_app(
app.post("/openai/deployments/{deployment_id:path}/chat/completions")(
endpoints.chat_completion
)
app.exception_handler(OpenAIError)(openai_exception_handler)
app.exception_handler(pydantic.ValidationError)(pydantic_exception_handler)
app.exception_handler(DialException)(dial_exception_handler)

for exc_class in [OpenAIError, DialException]:
app.add_exception_handler(exc_class, adapter_exception_handler)

return app

Expand Down
82 changes: 51 additions & 31 deletions aidial_adapter_openai/exception_handlers.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,76 @@
import pydantic
from aidial_sdk._errors import pydantic_validation_exception_handler
from aidial_sdk.exceptions import HTTPException as DialException
from fastapi import Request
from fastapi.responses import Response
from aidial_sdk.exceptions import InternalServerError
from fastapi.requests import Request as FastAPIRequest
from fastapi.responses import Response as FastAPIResponse
from openai import APIConnectionError, APIError, APIStatusError, APITimeoutError

from aidial_adapter_openai.utils.adapter_exception import (
AdapterException,
ResponseWrapper,
parse_adapter_exception,
)

def openai_exception_handler(request: Request, e: DialException):
if isinstance(e, APIStatusError):
r = e.response
headers = r.headers

# Avoid encoding the error message when the original response was encoded.
if "Content-Encoding" in headers:
del headers["Content-Encoding"]
def to_adapter_exception(exc: Exception) -> AdapterException:

return Response(
content=r.content,
if isinstance(exc, (DialException, ResponseWrapper)):
return exc

if isinstance(exc, APIStatusError):
# Non-streaming errors reported by `openai` library via this exception
r = exc.response
httpx_headers = r.headers

# httpx library (used by openai) automatically sets
# "Accept-Encoding:gzip,deflate" header in requests to the upstream.
# Therefore, we may receive from the upstream gzip-encoded
# response along with "Content-Encoding:gzip" header.
# We either need to encode the response, or
# remove the "Content-Encoding" header.
if "Content-Encoding" in httpx_headers:
del httpx_headers["Content-Encoding"]

return parse_adapter_exception(
status_code=r.status_code,
headers=headers,
headers=dict(httpx_headers.items()),
content=r.text,
)

if isinstance(e, APITimeoutError):
raise DialException(
if isinstance(exc, APITimeoutError):
return DialException(
status_code=504,
type="timeout",
message="Request timed out",
display_message="Request timed out. Please try again later.",
)

if isinstance(e, APIConnectionError):
raise DialException(
if isinstance(exc, APIConnectionError):
return DialException(
status_code=502,
type="connection",
message="Error communicating with OpenAI",
display_message="OpenAI server is not responsive. Please try again later.",
)

if isinstance(e, APIError):
raise DialException(
status_code=getattr(e, "status_code", None) or 500,
message=e.message,
type=e.type,
code=e.code,
param=e.param,
display_message=None,
)
if isinstance(exc, APIError):
# Streaming errors reported by `openai` library via this exception
status_code: int = 500
if exc.code:
try:
status_code = int(exc.code)
except Exception:
pass

return parse_adapter_exception(
status_code=status_code,
headers={},
content={"error": exc.body or {}},
)

def pydantic_exception_handler(request: Request, exc: pydantic.ValidationError):
return pydantic_validation_exception_handler(request, exc)
return InternalServerError(str(exc))


def dial_exception_handler(request: Request, exc: DialException):
return exc.to_fastapi_response()
def adapter_exception_handler(
request: FastAPIRequest, exc: Exception
) -> FastAPIResponse:
return to_adapter_exception(exc).to_fastapi_response()
2 changes: 1 addition & 1 deletion aidial_adapter_openai/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ async def gpt_chat_completion(

if isinstance(response, AsyncIterator):
return generate_stream(
stream=map_stream(chunk_to_dict, response),
get_prompt_tokens=lambda: estimated_prompt_tokens
or tokenizer.tokenize_request(request, request["messages"]),
tokenize_response=tokenizer.tokenize_response,
deployment=deployment_id,
discarded_messages=discarded_messages,
stream=map_stream(chunk_to_dict, response),
eliminate_empty_choices=eliminate_empty_choices,
)
else:
Expand Down
8 changes: 4 additions & 4 deletions aidial_adapter_openai/gpt4_multi_modal/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,14 @@ def debug_print(chunk: T) -> T:
return map_stream(
debug_print,
generate_stream(
get_prompt_tokens=lambda: estimated_prompt_tokens,
tokenize_response=tokenizer.tokenize_response,
deployment=deployment,
discarded_messages=discarded_messages,
stream=map_stream(
response_transformer,
parse_openai_sse_stream(response),
),
get_prompt_tokens=lambda: estimated_prompt_tokens,
tokenize_response=tokenizer.tokenize_response,
deployment=deployment,
discarded_messages=discarded_messages,
eliminate_empty_choices=eliminate_empty_choices,
),
)
Expand Down
95 changes: 95 additions & 0 deletions aidial_adapter_openai/utils/adapter_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import json
from typing import Any, Dict

from aidial_sdk.exceptions import HTTPException as DialException
from fastapi.responses import Response as FastAPIResponse


class ResponseWrapper(Exception):
content: Any
status_code: int
headers: Dict[str, str] | None

def __init__(
self,
*,
content: Any,
status_code: int,
headers: Dict[str, str] | None,
) -> None:
super().__init__(str(content))
self.content = content
self.status_code = status_code
self.headers = headers

def __repr__(self):
# headers field is omitted deliberately
# since it may contain sensitive information
return "%s(content=%r, status_code=%r)" % (
self.__class__.__name__,
self.content,
self.status_code,
)

def to_fastapi_response(self) -> FastAPIResponse:
return FastAPIResponse(
status_code=self.status_code,
content=self.content,
headers=self.headers,
)

def json_error(self) -> dict:
return {
"error": {
"message": str(self.content),
"code": int(self.status_code),
}
}


AdapterException = ResponseWrapper | DialException


def _parse_dial_exception(
*, status_code: int, headers: Dict[str, str], content: Any
) -> DialException | None:
if isinstance(content, str):
try:
obj = json.loads(content)
except Exception:
return None
else:
obj = content

if (
isinstance(obj, dict)
and (error := obj.get("error"))
and isinstance(error, dict)
):
message = error.get("message") or "Unknown error"
code = error.get("code")
type = error.get("type")
param = error.get("param")
display_message = error.get("display_message")

return DialException(
status_code=status_code,
message=message,
type=type,
param=param,
code=code,
display_message=display_message,
headers=headers,
)

return None


def parse_adapter_exception(
*, status_code: int, headers: Dict[str, str], content: Any
) -> AdapterException:
return _parse_dial_exception(
status_code=status_code, headers=headers, content=content
) or ResponseWrapper(
status_code=status_code, headers=headers, content=content
)
20 changes: 18 additions & 2 deletions aidial_adapter_openai/utils/sse_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

from aidial_sdk.exceptions import runtime_server_error

from aidial_adapter_openai.exception_handlers import to_adapter_exception
from aidial_adapter_openai.utils.log_config import logger

DATA_PREFIX = "data: "
OPENAI_END_MARKER = "[DONE]"

Expand Down Expand Up @@ -53,6 +56,19 @@ async def parse_openai_sse_stream(
async def to_openai_sse_stream(
stream: AsyncIterator[dict],
) -> AsyncIterator[str]:
async for chunk in stream:
yield format_chunk(chunk)
try:
async for chunk in stream:
yield format_chunk(chunk)
except Exception as e:
logger.exception(
f"caught exception while streaming: {type(e).__module__}.{type(e).__name__}"
)

adapter_exception = to_adapter_exception(e)
logger.error(
f"converted to the adapter exception: {adapter_exception!r}"
)

yield format_chunk(adapter_exception.json_error())

yield END_CHUNK
28 changes: 10 additions & 18 deletions aidial_adapter_openai/utils/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.utils.merge_chunks import merge_chat_completion_chunks
from fastapi.responses import JSONResponse, Response, StreamingResponse
from openai import APIError, APIStatusError
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from pydantic import BaseModel

Expand Down Expand Up @@ -54,11 +53,11 @@ def build_chunk(

async def generate_stream(
*,
stream: AsyncIterator[dict],
get_prompt_tokens: Callable[[], int],
tokenize_response: Callable[[ChatCompletionResponse], int],
deployment: str,
discarded_messages: Optional[list[int]],
stream: AsyncIterator[dict],
eliminate_empty_choices: bool,
) -> AsyncIterator[dict]:

Expand Down Expand Up @@ -106,7 +105,7 @@ def set_discarded_messages(chunk: dict | None, indices: list[int]) -> dict:
buffer_chunk = None
response_snapshot = ChatCompletionStreamingChunk()

error = None
error: Exception | None = None

try:
async for chunk in stream:
Expand All @@ -129,15 +128,11 @@ def set_discarded_messages(chunk: dict | None, indices: list[int]) -> dict:
yield last_chunk
last_chunk = chunk

except APIError as e:
status_code = e.status_code if isinstance(e, APIStatusError) else 500
error = DialException(
status_code=status_code,
message=e.message,
type=e.type,
param=e.param,
code=e.code,
).json_error()
except Exception as e:
logger.exception(
f"caught exception while streaming: {type(e).__module__}.{type(e).__name__}"
)
error = e

if last_chunk is not None and buffer_chunk is not None:
last_chunk = merge_chat_completion_chunks(last_chunk, buffer_chunk)
Expand Down Expand Up @@ -168,7 +163,7 @@ def set_discarded_messages(chunk: dict | None, indices: list[int]) -> dict:
yield last_chunk

if error:
yield error
raise error


def create_stage_chunk(name: str, content: str, stream: bool) -> dict:
Expand Down Expand Up @@ -204,7 +199,7 @@ def create_stage_chunk(name: str, content: str, stream: bool) -> dict:

def create_response_from_chunk(
chunk: dict, exc: DialException | None, stream: bool
) -> Response:
) -> AsyncIterator[dict] | Response:
if not stream:
if exc is not None:
return exc.to_fastapi_response()
Expand All @@ -216,10 +211,7 @@ async def generator() -> AsyncIterator[dict]:
if exc is not None:
yield exc.json_error()

return StreamingResponse(
to_openai_sse_stream(generator()),
media_type="text/event-stream",
)
return generator()


def block_response_to_streaming_chunk(response: dict) -> dict:
Expand Down
Loading

0 comments on commit 7529c0c

Please sign in to comment.