Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(model): Fix openai adapt previous proxy_server_url configuration and support azure openai model #668

Merged
merged 1 commit into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions pilot/model/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,30 @@ class ProxyModelParameters(BaseModelParameters):
"help": "Proxy server url, such as: https://api.openai.com/v1/chat/completions"
},
)

proxy_api_key: str = field(
metadata={"tags": "privacy", "help": "The api key of current proxy LLM"},
)

proxy_api_base: str = field(
default=None,
metadata={
"help": "The base api address, such as: https://api.openai.com/v1. If None, we will use proxy_api_base first"
},
)

proxy_api_type: Optional[str] = field(
default=None,
metadata={
"help": "The api type of current proxy the current proxy model, if you use Azure, it can be: azure"
},
)

proxy_api_version: Optional[str] = field(
default=None,
metadata={"help": "The api version of current proxy the current model"},
)

http_proxy: Optional[str] = field(
default=os.environ.get("http_proxy") or os.environ.get("https_proxy"),
metadata={"help": "The http or https proxy to use openai"},
Expand Down
95 changes: 80 additions & 15 deletions pilot/model/proxy/llms/chatgpt.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,63 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import json
import os
from typing import List
import logging

import openai

from pilot.model.proxy.llms.proxy_model import ProxyModel
from pilot.model.parameter import ProxyModelParameters
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType

logger = logging.getLogger(__name__)

def chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):

def _initialize_openai(params: ProxyModelParameters):
api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")

api_base = params.proxy_api_base or os.getenv(
"OPENAI_API_TYPE",
os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None,
)
api_key = params.proxy_api_key or os.getenv(
"OPENAI_API_KEY",
os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None,
)
api_version = params.proxy_api_version or os.getenv("OPENAI_API_VERSION")

if not api_base and params.proxy_server_url:
# Adapt previous proxy_server_url configuration
api_base = params.proxy_server_url.split("/chat/completions")[0]
if api_type:
openai.api_type = api_type
if api_base:
openai.api_base = api_base
if api_key:
openai.api_key = api_key
if api_version:
openai.api_version = api_version
if params.http_proxy:
openai.proxy = params.http_proxy

openai_params = {
"api_type": api_type,
"api_base": api_base,
"api_version": api_version,
"proxy": params.http_proxy,
}

return openai_params


def _build_request(model: ProxyModel, params):
history = []

model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
logger.info(f"Model: {model}, model_params: {model_params}")

proxy_api_key = model_params.proxy_api_key
if model_params.http_proxy:
openai.proxy = model_params.http_proxy
openai.api_key = os.getenv("OPENAI_API_KEY") or proxy_api_key
proxyllm_backend = model_params.proxyllm_backend
if not proxyllm_backend:
proxyllm_backend = "gpt-3.5-turbo"
openai_params = _initialize_openai(model_params)

messages: List[ModelMessage] = params["messages"]
# Add history conversation
Expand All @@ -51,18 +83,51 @@ def chatgpt_generate_stream(
history.append(last_user_input)

payloads = {
"model": proxyllm_backend, # just for test, remove this later
"temperature": params.get("temperature"),
"max_tokens": params.get("max_new_tokens"),
"stream": True,
}
res = openai.ChatCompletion.create(messages=history, **payloads)
proxyllm_backend = model_params.proxyllm_backend

if openai_params["api_type"] == "azure":
# engine = "deployment_name".
proxyllm_backend = proxyllm_backend or "gpt-35-turbo"
payloads["engine"] = proxyllm_backend
else:
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
payloads["model"] = proxyllm_backend

print(f"Send request to real model {proxyllm_backend}")
logger.info(
f"Send request to real model {proxyllm_backend}, openai_params: {openai_params}"
)
return history, payloads


def chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
history, payloads = _build_request(model, params)

res = openai.ChatCompletion.create(messages=history, **payloads)

text = ""
for r in res:
if r["choices"][0]["delta"].get("content") is not None:
content = r["choices"][0]["delta"]["content"]
text += content
yield text


async def async_chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
history, payloads = _build_request(model, params)

res = await openai.ChatCompletion.acreate(messages=history, **payloads)

text = ""
async for r in res:
if r["choices"][0]["delta"].get("content") is not None:
content = r["choices"][0]["delta"]["content"]
text += content
yield text