Skip to content

Commit

Permalink
feat: support functions and function_call params (#22) (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
vladisavvv authored Nov 15, 2023
1 parent 35b1c99 commit 603a95f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
14 changes: 13 additions & 1 deletion aidial_adapter_openai/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from aidial_adapter_openai.utils.streaming import generate_stream
from aidial_adapter_openai.utils.tokens import discard_messages
from aidial_adapter_openai.utils.versions import compare_versions

logging.config.dictConfig(LogConfig().dict())
app = FastAPI()
Expand Down Expand Up @@ -51,6 +52,17 @@ async def chat_completion(deployment_id: str, request: Request):
request.headers["X-UPSTREAM-ENDPOINT"], ApiType.CHAT_COMPLETION
)

api_version = azure_api_version

if "functions" in data or "function_call" in data:
request_api_version = request.query_params.get("api-version")

if request_api_version is not None:
# 2023-07-01-preview is the first azure api version that supports functions
compare_result = compare_versions(request_api_version, "2023-07-01")
if compare_result == 0 or compare_result == 1:
api_version = request_api_version

discarded_messages = None
if "max_prompt_tokens" in data:
max_prompt_tokens = data["max_prompt_tokens"]
Expand Down Expand Up @@ -78,7 +90,7 @@ async def chat_completion(deployment_id: str, request: Request):
api_key=dial_api_key,
api_base=api_base,
api_type="azure",
api_version=azure_api_version,
api_version=api_version,
request_timeout=(10, 600), # connect timeout and total timeout
**data,
)
Expand Down
14 changes: 14 additions & 0 deletions aidial_adapter_openai/utils/versions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Azure API versions follow the next formats: YYYY-MM-DD or YYYY-MM-DD-preview
def compare_versions(v1: str, v2: str):
if len(v1) < 10 or len(v2) < 10:
return None

v1 = v1[0:10]
v2 = v2[0:10]

if v1 < v2:
return -1
elif v1 > v2:
return 1
else:
return 0
20 changes: 20 additions & 0 deletions tests/test_versions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Optional

import pytest

from aidial_adapter_openai.utils.versions import compare_versions

compare_versions_dataset = [
("2023-07-01", "2023-07-01", 0),
("2023-07-01", "2023-07-01-preview", 0),
("2023-07-01-preview", "2023-07-01", 0),
("2023-07-01-preview", "2023-07-01-preview", 0),
("2023-07-0", "2023-07-01", None),
("2022-12-01", "2023-06-01-preview", -1),
("2023-09-01-preview", "2023-05-15", 1),
]


@pytest.mark.parametrize("v1, v2, result", compare_versions_dataset)
def test_compare_versions(v1: str, v2: str, result: Optional[int]):
assert compare_versions(v1, v2) == result

0 comments on commit 603a95f

Please sign in to comment.