diff --git a/aidial_adapter_openai/app.py b/aidial_adapter_openai/app.py index d452742..8ae0810 100644 --- a/aidial_adapter_openai/app.py +++ b/aidial_adapter_openai/app.py @@ -12,6 +12,7 @@ APITimeoutError, AsyncAzureOpenAI, ) +from openai.types.chat import ChatCompletion from aidial_adapter_openai.dalle3 import ( chat_completion as dalle3_chat_completion, @@ -147,16 +148,23 @@ async def chat_completion(deployment_id: str, request: Request): if "model" in data: del data["model"] + authorization = ( + {"api_key": api_key} + if api_type == "azure" + else {"azure_ad_token": api_key} + ) + response = await handle_exceptions( AsyncAzureOpenAI( api_version=api_version, azure_endpoint=api_base, - api_key=api_key, - timeout=httpx.Timeout(timeout=300, connect=10), + timeout=httpx.Timeout(timeout=600, connect=10), max_retries=0, + **authorization, ).chat.completions.create( model=upstream_deployment, messages=[], + stream=data.get("stream", False), extra_body=data, ) ) @@ -181,8 +189,8 @@ async def chat_completion(deployment_id: str, request: Request): ) else: if discarded_messages is not None: - assert type(response) == OpenAIObject - response = response.to_dict() | { + assert type(response) == ChatCompletion + response = response.dict() | { "statistics": {"discarded_messages": discarded_messages} } @@ -199,13 +207,19 @@ async def embedding(deployment_id: str, request: Request): ) api_version = get_api_version(request) + authorization = ( + {"api_key": api_key} + if api_type == "azure" + else {"azure_ad_token": api_key} + ) + return await handle_exceptions( AsyncAzureOpenAI( api_version=api_version, azure_endpoint=api_base, - api_key=api_key, - timeout=httpx.Timeout(timeout=300, connect=10), + timeout=httpx.Timeout(timeout=600, connect=10), max_retries=0, + **authorization, ).embeddings.create( model=upstream_deployment, input=[], diff --git a/aidial_adapter_openai/utils/streaming.py b/aidial_adapter_openai/utils/streaming.py index f8434be..7e613f4 100644 --- a/aidial_adapter_openai/utils/streaming.py +++ b/aidial_adapter_openai/utils/streaming.py @@ -5,7 +5,6 @@ from aidial_sdk.utils.merge_chunks import merge from fastapi.responses import JSONResponse, Response, StreamingResponse -# from aidial_adapter_openai.openai_override import OpenAIException from aidial_adapter_openai.utils.env import get_env_bool from aidial_adapter_openai.utils.log_config import logger from aidial_adapter_openai.utils.sse_stream import END_CHUNK, format_chunk @@ -55,7 +54,6 @@ async def generate_stream( last_chunk, temp_chunk = None, None stream_finished = False - # try: total_content = "" async for chunk in stream: if len(chunk["choices"]) > 0: @@ -89,9 +87,6 @@ async def generate_stream( yield chunk last_chunk = chunk - # except OpenAIException as e: - # yield e.body - # return if not stream_finished: if last_chunk is not None: