Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: reworking api versions logic (#67) #79

Merged
merged 8 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
AZURE_API_VERSION=2023-03-15-preview
MODEL_ALIASES={}
LOG_LEVEL=INFO
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ Copy `.env.example` to `.env` and customize it for your environment:
|---|---|---|
|LOG_LEVEL|INFO|Log level. Use DEBUG for dev purposes and INFO in prod|
|WEB_CONCURRENCY|1|Number of workers for the server|
|AZURE_API_VERSION|2023-03-15-preview|The version API for requests to Azure OpenAI API|
|MODEL_ALIASES|{}|Mapping request's deployment_id to [model name of tiktoken](https://github.com/openai/tiktoken/blob/main/tiktoken/model.py) for correct calculate of tokens. Example: `{"gpt-35-turbo":"gpt-3.5-turbo-0301"}`|
|MODEL_ALIASES|`{}`|Mapping request's deployment_id to [model name of tiktoken](https://github.com/openai/tiktoken/blob/main/tiktoken/model.py) for correct calculate of tokens. Example: `{"gpt-35-turbo":"gpt-3.5-turbo-0301"}`|
|DIAL_USE_FILE_STORAGE|False|Save image model artifacts to DIAL File storage (DALL-E images are uploaded to the files storage and its base64 encodings are replaced with links to the storage)|
|DIAL_URL||URL of the core DIAL server (required when DIAL_USE_FILE_STORAGE=True)|
|DALLE3_DEPLOYMENTS|``|Comma-separated list of deployments that support DALL-E 3 API. Example: `dall-e-3,dalle3,dall-e`|
|GPT4_VISION_DEPLOYMENTS|``|Comma-separated list of deployments that support GPT-4V API. Example: `gpt-4-vision-preview,gpt-4-vision`|
|GPT4_VISION_MAX_TOKENS|1024|Default value of `max_tokens` parameter for GPT-4V when it wasn't provided in the request|
|ACCESS_TOKEN_EXPIRATION_WINDOW|10|Expiration window of access token in seconds|
|AZURE_OPEN_AI_SCOPE|https://cognitiveservices.azure.com/.default|Provided scope of access token to Azure OpenAI services|
|API_VERSIONS_MAPPING|`{}`|The mapping of versions API for requests to Azure OpenAI API. Example: `{"2023-03-15-preview": "2023-05-15", "": "2024-02-15-preview"}`. An empty key sets the default api version for the case when the user didn't pass it in the request|

### Docker

Expand Down
52 changes: 34 additions & 18 deletions aidial_adapter_openai/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,23 @@
parse_deployment_list,
parse_upstream,
)
from aidial_adapter_openai.utils.request_classifier import (
does_request_use_functions_or_tools,
)
from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream
from aidial_adapter_openai.utils.storage import create_file_storage
from aidial_adapter_openai.utils.streaming import generate_stream, map_stream
from aidial_adapter_openai.utils.tokens import Tokenizer, discard_messages
from aidial_adapter_openai.utils.versions import compare_versions

logging.config.dictConfig(LogConfig().dict())
app = FastAPI()
model_aliases: Dict[str, str] = json.loads(os.getenv("MODEL_ALIASES", "{}"))
azure_api_version = os.getenv("AZURE_API_VERSION", "2023-03-15-preview")
dalle3_deployments = parse_deployment_list(
os.getenv("DALLE3_DEPLOYMENTS") or ""
)
gpt4_vision_deployments = parse_deployment_list(
os.getenv("GPT4_VISION_DEPLOYMENTS") or ""
)
api_versions_mapping: Dict[str, str] = json.loads(
os.getenv("API_VERSIONS_MAPPING", "{}")
)


async def handle_exceptions(call):
Expand All @@ -58,6 +56,12 @@ async def handle_exceptions(call):
)


def get_api_version(request: Request):
api_version = request.query_params.get("api-version", "")

return api_versions_mapping.get(api_version, api_version)


@app.post("/openai/deployments/{deployment_id}/chat/completions")
async def chat_completion(deployment_id: str, request: Request):
data = await parse_body(request)
Expand All @@ -67,11 +71,25 @@ async def chat_completion(deployment_id: str, request: Request):
api_type, api_key = await get_credentials(request)

upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"]
api_version = get_api_version(request)

if api_version == "":
Allob marked this conversation as resolved.
Show resolved Hide resolved
raise HTTPException(
"Api version is a required query parameter",
400,
"invalid_request_error",
)

if deployment_id in dalle3_deployments:
storage = create_file_storage("images", request.headers)
return await dalle3_chat_completion(
data, upstream_endpoint, api_key, is_stream, storage, api_type
data,
upstream_endpoint,
api_key,
is_stream,
storage,
api_type,
api_version,
)
elif deployment_id in gpt4_vision_deployments:
storage = create_file_storage("images", request.headers)
Expand All @@ -83,6 +101,7 @@ async def chat_completion(deployment_id: str, request: Request):
is_stream,
storage,
api_type,
api_version,
)

openai_model_name = model_aliases.get(deployment_id, deployment_id)
Expand All @@ -92,17 +111,6 @@ async def chat_completion(deployment_id: str, request: Request):
upstream_endpoint, ApiType.CHAT_COMPLETION
)

api_version = azure_api_version

if does_request_use_functions_or_tools(data):
request_api_version = request.query_params.get("api-version")

if request_api_version is not None:
# 2023-07-01-preview is the first azure api version that supports functions
compare_result = compare_versions(request_api_version, "2023-07-01")
if compare_result == 0 or compare_result == 1:
api_version = request_api_version

discarded_messages = None
if "max_prompt_tokens" in data:
max_prompt_tokens = data["max_prompt_tokens"]
Expand Down Expand Up @@ -172,14 +180,22 @@ async def embedding(deployment_id: str, request: Request):
api_base, upstream_deployment = parse_upstream(
request.headers["X-UPSTREAM-ENDPOINT"], ApiType.EMBEDDING
)
api_version = get_api_version(request)

if api_version == "":
raise HTTPException(
"Api version is a required query parameter",
400,
"invalid_request_error",
)

return await handle_exceptions(
Embedding().acreate(
deployment_id=upstream_deployment,
api_key=api_key,
api_base=api_base,
api_type=api_type,
api_version=azure_api_version,
api_version=api_version,
request_timeout=(10, 600), # connect timeout and total timeout
**data,
)
Expand Down
3 changes: 2 additions & 1 deletion aidial_adapter_openai/dalle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ async def chat_completion(
is_stream: bool,
file_storage: Optional[FileStorage],
api_type: str,
api_version: str,
) -> Response:
if data.get("n", 1) > 1:
raise HTTPException(
Expand All @@ -126,7 +127,7 @@ async def chat_completion(
type="invalid_request_error",
)

api_url = upstream_endpoint + "?api-version=2023-12-01-preview"
api_url = f"{upstream_endpoint}?api-version={api_version}"
user_prompt = get_user_prompt(data)
model_response = await generate_image(
api_url, api_key, user_prompt, api_type
Expand Down
3 changes: 2 additions & 1 deletion aidial_adapter_openai/gpt4_vision/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ async def chat_completion(
is_stream: bool,
file_storage: Optional[FileStorage],
api_type: str,
api_version,
) -> Response:
if request.get("n", 1) > 1:
raise HTTPException(
Expand All @@ -283,7 +284,7 @@ async def chat_completion(
type="invalid_request_error",
)

api_url = upstream_endpoint + "?api-version=2023-12-01-preview"
api_url = f"{upstream_endpoint}?api-version={api_version}"

result = await transform_messages(file_storage, messages)

Expand Down
24 changes: 0 additions & 24 deletions aidial_adapter_openai/utils/request_classifier.py

This file was deleted.

14 changes: 0 additions & 14 deletions aidial_adapter_openai/utils/versions.py

This file was deleted.

8 changes: 4 additions & 4 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def test_error_during_streaming(aioresponses: aioresponses):
test_app = AsyncClient(app=app, base_url="http://test.com")

response = await test_app.post(
"/openai/deployments/gpt-4/chat/completions",
"/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview",
json={
"messages": [{"role": "user", "content": "Test content"}],
"stream": True,
Expand Down Expand Up @@ -94,7 +94,7 @@ async def test_incorrect_upstream_url(aioresponses: aioresponses):
test_app = AsyncClient(app=app, base_url="http://test.com")

response = await test_app.post(
"/openai/deployments/gpt-4/chat/completions",
"/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview",
json={"messages": [{"role": "user", "content": "Test content"}]},
headers={
"X-UPSTREAM-KEY": "TEST_API_KEY",
Expand Down Expand Up @@ -123,7 +123,7 @@ async def test_incorrect_format(aioresponses: aioresponses):
test_app = AsyncClient(app=app, base_url="http://test.com")

response = await test_app.post(
"/openai/deployments/gpt-4/chat/completions",
"/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview",
json={"messages": [{"role": "user", "content": "Test content"}]},
headers={
"X-UPSTREAM-KEY": "TEST_API_KEY",
Expand Down Expand Up @@ -156,7 +156,7 @@ async def test_incorrect_streaming_request(aioresponses: aioresponses):
test_app = AsyncClient(app=app, base_url="http://test.com")

response = await test_app.post(
"/openai/deployments/gpt-4/chat/completions",
"/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview",
json={
"messages": [{"role": "user", "content": "Test content"}],
"stream": True,
Expand Down
24 changes: 0 additions & 24 deletions tests/test_request_classifier.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
@pytest.mark.asyncio
async def test_streaming(aioresponses: aioresponses):
aioresponses.post(
"http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview",
"http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-06-15",
status=200,
body="data: "
+ json.dumps(
Expand Down
20 changes: 0 additions & 20 deletions tests/test_versions.py

This file was deleted.

Loading