Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
vladisavvv committed Mar 14, 2024
1 parent 465bb00 commit a4f6dcb
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
26 changes: 20 additions & 6 deletions aidial_adapter_openai/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
APITimeoutError,
AsyncAzureOpenAI,
)
from openai.types.chat import ChatCompletion

from aidial_adapter_openai.dalle3 import (
chat_completion as dalle3_chat_completion,
Expand Down Expand Up @@ -147,16 +148,23 @@ async def chat_completion(deployment_id: str, request: Request):
if "model" in data:
del data["model"]

authorization = (
{"api_key": api_key}
if api_type == "azure"
else {"azure_ad_token": api_key}
)

response = await handle_exceptions(
AsyncAzureOpenAI(
api_version=api_version,
azure_endpoint=api_base,
api_key=api_key,
timeout=httpx.Timeout(timeout=300, connect=10),
timeout=httpx.Timeout(timeout=600, connect=10),
max_retries=0,
**authorization,
).chat.completions.create(
model=upstream_deployment,
messages=[],
stream=data.get("stream", False),
extra_body=data,
)
)
Expand All @@ -181,8 +189,8 @@ async def chat_completion(deployment_id: str, request: Request):
)
else:
if discarded_messages is not None:
assert type(response) == OpenAIObject
response = response.to_dict() | {
assert type(response) == ChatCompletion
response = response.dict() | {
"statistics": {"discarded_messages": discarded_messages}
}

Expand All @@ -199,13 +207,19 @@ async def embedding(deployment_id: str, request: Request):
)
api_version = get_api_version(request)

authorization = (
{"api_key": api_key}
if api_type == "azure"
else {"azure_ad_token": api_key}
)

return await handle_exceptions(
AsyncAzureOpenAI(
api_version=api_version,
azure_endpoint=api_base,
api_key=api_key,
timeout=httpx.Timeout(timeout=300, connect=10),
timeout=httpx.Timeout(timeout=600, connect=10),
max_retries=0,
**authorization,
).embeddings.create(
model=upstream_deployment,
input=[],
Expand Down
5 changes: 0 additions & 5 deletions aidial_adapter_openai/utils/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from aidial_sdk.utils.merge_chunks import merge
from fastapi.responses import JSONResponse, Response, StreamingResponse

# from aidial_adapter_openai.openai_override import OpenAIException
from aidial_adapter_openai.utils.env import get_env_bool
from aidial_adapter_openai.utils.log_config import logger
from aidial_adapter_openai.utils.sse_stream import END_CHUNK, format_chunk
Expand Down Expand Up @@ -55,7 +54,6 @@ async def generate_stream(
last_chunk, temp_chunk = None, None
stream_finished = False

# try:
total_content = ""
async for chunk in stream:
if len(chunk["choices"]) > 0:
Expand Down Expand Up @@ -89,9 +87,6 @@ async def generate_stream(
yield chunk

last_chunk = chunk
# except OpenAIException as e:
# yield e.body
# return

if not stream_finished:
if last_chunk is not None:
Expand Down

0 comments on commit a4f6dcb

Please sign in to comment.