From db1a0036dedf564e3828dd177774c677cb454b40 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Fri, 23 Aug 2024 16:04:14 -0700 Subject: [PATCH] Feature add openai api for vllm integration (#3287) * Forward additional url segments as url_paths in request header to model * Fix vllm test and clean preproc * First attept to enable OpenAI api for models served via vllm * fix streaming in openai api * Add OpenAIServingCompletion usage example * Add lora modules to vllm engine * Finish openai completion integration; removed req openai client; updated lora example to llama 3.1 * fix lint * Update mistral + llama3 vllm example * Remove openai client from url path test * Add openai chat api to vllm example * Added v1/models endpoint for vllm example * Remove accidential breakpoint() * Add comment to new url_path --- .../utils/test_llm_streaming_response.py | 80 ++++- examples/large_models/vllm/llama3/Readme.md | 11 +- examples/large_models/vllm/llama3/chat.json | 11 + .../vllm/llama3/model-config.yaml | 5 +- examples/large_models/vllm/llama3/prompt.json | 4 +- examples/large_models/vllm/lora/Readme.md | 45 ++- .../large_models/vllm/lora/model-config.yaml | 8 +- examples/large_models/vllm/lora/prompt.json | 8 +- examples/large_models/vllm/mistral/Readme.md | 2 +- .../vllm/mistral/model-config.yaml | 2 + .../large_models/vllm/mistral/prompt.json | 3 +- .../api/rest/InferenceRequestHandler.java | 14 +- test/pytest/test_example_vllm.py | 275 ++++++++++++++---- test/pytest/test_url_path.py | 151 ++++++++++ ts/torch_handler/vllm_handler.py | 161 ++++++---- ts_scripts/spellcheck_conf/wordlist.txt | 2 + 16 files changed, 633 insertions(+), 149 deletions(-) create mode 100644 examples/large_models/vllm/llama3/chat.json create mode 100644 test/pytest/test_url_path.py diff --git a/examples/large_models/utils/test_llm_streaming_response.py b/examples/large_models/utils/test_llm_streaming_response.py index 8aab203fc8..55f9129bc3 100644 --- a/examples/large_models/utils/test_llm_streaming_response.py +++ b/examples/large_models/utils/test_llm_streaming_response.py @@ -27,25 +27,67 @@ def _predict(self): combined_text = "" for chunk in response.iter_content(chunk_size=None): if chunk: - data = json.loads(chunk) + text = self._extract_text(chunk) if self.args.demo_streaming: - print(data["text"], end="", flush=True) + print(text, end="", flush=True) else: - combined_text += data.get("text", "") + combined_text += text if not self.args.demo_streaming: self.queue.put_nowait(f"payload={payload}\n, output={combined_text}\n") + def _extract_completion(self, chunk): + chunk = chunk.decode("utf-8") + if chunk.startswith("data:"): + chunk = chunk[len("data:") :].split("\n")[0].strip() + if chunk.startswith("[DONE]"): + return "" + return json.loads(chunk)["choices"][0]["text"] + + def _extract_chat(self, chunk): + chunk = chunk.decode("utf-8") + if chunk.startswith("data:"): + chunk = chunk[len("data:") :].split("\n")[0].strip() + if chunk.startswith("[DONE]"): + return "" + try: + return json.loads(chunk)["choices"][0].get("message", {})["content"] + except KeyError: + return json.loads(chunk)["choices"][0].get("delta", {}).get("content", "") + + def _extract_text(self, chunk): + if self.args.openai_api: + if "chat" in self.args.api_endpoint: + return self._extract_chat(chunk) + else: + return self._extract_completion(chunk) + else: + return json.loads(chunk).get("text", "") + def _get_url(self): - return f"http://localhost:8080/predictions/{self.args.model}" + if self.args.openai_api: + return f"http://localhost:8080/predictions/{self.args.model}/{self.args.model_version}/{self.args.api_endpoint}" + else: + return f"http://localhost:8080/predictions/{self.args.model}" def _format_payload(self): prompt_input = _load_curl_like_data(self.args.prompt_text) + if "chat" in self.args.api_endpoint: + assert self.args.prompt_json, "Use prompt json file for chat interface" + assert self.args.openai_api, "Chat only work with openai api" + prompt_input = json.loads(prompt_input) + messages = prompt_input.get("messages", None) + assert messages is not None + rt = int(prompt_input.get("max_tokens", self.args.max_tokens)) + prompt_input["max_tokens"] = rt + if self.args.demo_streaming: + prompt_input["stream"] = True + return prompt_input if self.args.prompt_json: prompt_input = json.loads(prompt_input) prompt = prompt_input.get("prompt", None) assert prompt is not None prompt_list = prompt.split(" ") - rt = int(prompt_input.get("max_new_tokens", self.args.max_tokens)) + rt = int(prompt_input.get("max_tokens", self.args.max_tokens)) else: prompt_list = prompt_input.split(" ") rt = self.args.max_tokens @@ -58,13 +100,15 @@ def _format_payload(self): cur_prompt = " ".join(prompt_list) if self.args.prompt_json: prompt_input["prompt"] = cur_prompt - prompt_input["max_new_tokens"] = rt - return prompt_input + prompt_input["max_tokens"] = rt else: - return { + prompt_input = { "prompt": cur_prompt, - "max_new_tokens": rt, + "max_tokens": rt, } + if self.args.demo_streaming and self.args.openai_api: + prompt_input["stream"] = True + return prompt_input def _load_curl_like_data(text): @@ -136,6 +180,24 @@ def parse_args(): default=False, help="Demo streaming response, force num-requests-per-thread=1 and num-threads=1", ) + parser.add_argument( + "--openai-api", + action=argparse.BooleanOptionalAction, + default=False, + help="Use OpenAI compatible API", + ) + parser.add_argument( + "--api-endpoint", + type=str, + default="v1/completions", + help="OpenAI endpoint suffix", + ) + parser.add_argument( + "--model-version", + type=str, + default="1.0", + help="Model vesion. Default: 1.0", + ) return parser.parse_args() diff --git a/examples/large_models/vllm/llama3/Readme.md b/examples/large_models/vllm/llama3/Readme.md index 65b87c687c..fb80f7a3e3 100644 --- a/examples/large_models/vllm/llama3/Readme.md +++ b/examples/large_models/vllm/llama3/Readme.md @@ -1,6 +1,6 @@ # Example showing inference with vLLM on LoRA model -This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `meta-llama/Meta-Llama-3-8B-Instruct` with continuous batching. +This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `meta-llama/Meta-Llama-3.1-8B-Instruct` with continuous batching. This examples supports distributed inference by following [this instruction](../Readme.md#distributed-inference) ### Step 0: Install vLLM @@ -21,7 +21,7 @@ huggingface-cli login --token $HUGGINGFACE_TOKEN ``` ```bash -python ../../utils/Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3-8B-Instruct --use_auth_token True +python ../../utils/Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3.1-8B-Instruct --use_auth_token True ``` ### Step 2: Generate model artifacts @@ -47,7 +47,12 @@ torchserve --start --ncs --ts-config ../config.properties --model-store model_st ``` ### Step 5: Run inference +Run a text completion: +```bash +python ../../utils/test_llm_streaming_response.py -m llama3-8b -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json --openai-api +``` +Or use the chat interface: ```bash -python ../../utils/test_llm_streaming_response.py -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json +python ../../utils/test_llm_streaming_response.py -m llama3-8b -o 50 -t 2 -n 4 --prompt-text "@chat.json" --prompt-json --openai-api --demo-streaming --api-endpoint "v1/chat/completions" ``` diff --git a/examples/large_models/vllm/llama3/chat.json b/examples/large_models/vllm/llama3/chat.json new file mode 100644 index 0000000000..db4289e1a5 --- /dev/null +++ b/examples/large_models/vllm/llama3/chat.json @@ -0,0 +1,11 @@ +{ + "model": "llama3-8b", + "messages":[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, + {"role": "user", "content": "Where was it played?"} + ], + "temperature":0.0, + "max_tokens": 50 +} diff --git a/examples/large_models/vllm/llama3/model-config.yaml b/examples/large_models/vllm/llama3/model-config.yaml index f3e63bf50d..1fea5ff658 100644 --- a/examples/large_models/vllm/llama3/model-config.yaml +++ b/examples/large_models/vllm/llama3/model-config.yaml @@ -7,7 +7,10 @@ deviceType: "gpu" asyncCommunication: true handler: - model_path: "model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/" + model_path: "model/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/8c22764a7e3675c50d4c7c9a4edb474456022b16" vllm_engine_config: max_num_seqs: 16 max_model_len: 250 + served_model_name: + - "meta-llama/Meta-Llama-3.1-8B" + - "llama3-8b" diff --git a/examples/large_models/vllm/llama3/prompt.json b/examples/large_models/vllm/llama3/prompt.json index bc27314191..cba00fe04c 100644 --- a/examples/large_models/vllm/llama3/prompt.json +++ b/examples/large_models/vllm/llama3/prompt.json @@ -1,9 +1,7 @@ { "prompt": "A robot may not injure a human being", - "max_new_tokens": 50, "temperature": 0.8, "logprobs": 1, - "prompt_logprobs": 1, "max_tokens": 128, - "adapter": "adapter_1" + "model": "llama3-8b" } diff --git a/examples/large_models/vllm/lora/Readme.md b/examples/large_models/vllm/lora/Readme.md index c5339e9b33..c592f23a73 100644 --- a/examples/large_models/vllm/lora/Readme.md +++ b/examples/large_models/vllm/lora/Readme.md @@ -1,6 +1,6 @@ # Example showing inference with vLLM on LoRA model -This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `Llama-2-7b-hf` + LoRA model `llama-2-7b-sql-lora-test` with continuous batching. +This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `meta-llama/Meta-Llama-3.1-8B` + LoRA model `llama-duo/llama3.1-8b-summarize-gpt4o-128k` with continuous batching. This examples supports distributed inference by following [this instruction](../Readme.md#distributed-inference) ### Step 0: Install vLLM @@ -21,9 +21,9 @@ huggingface-cli login --token $HUGGINGFACE_TOKEN ``` ```bash -python ../../utils/Download_model.py --model_path model --model_name meta-llama/Llama-2-7b-chat-hf --use_auth_token True +python ../../utils/Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3.1-8B --use_auth_token True mkdir adapters && cd adapters -python ../../../utils/Download_model.py --model_path model --model_name yard1/llama-2-7b-sql-lora-test --use_auth_token True +python ../../../utils/Download_model.py --model_path model --model_name llama-duo/llama3.1-8b-summarize-gpt4o-128k --use_auth_token True cd .. ``` @@ -32,26 +32,53 @@ cd .. Add the downloaded path to "model_path:" and "adapter_1:" in `model-config.yaml` and run the following. ```bash -torch-model-archiver --model-name llama-7b-lora --version 1.0 --handler vllm_handler --config-file model-config.yaml --archive-format no-archive -mv model llama-7b-lora -mv adapters llama-7b-lora +torch-model-archiver --model-name llama-8b-lora --version 1.0 --handler vllm_handler --config-file model-config.yaml --archive-format no-archive +mv model llama-8b-lora +mv adapters llama-8b-lora ``` ### Step 3: Add the model artifacts to model store ```bash mkdir model_store -mv llama-7b-lora model_store +mv llama-8b-lora model_store ``` ### Step 4: Start torchserve ```bash -torchserve --start --ncs --ts-config ../config.properties --model-store model_store --models llama-7b-lora --disable-token-auth --enable-model-api +torchserve --start --ncs --ts-config ../config.properties --model-store model_store --models llama-8b-lora --disable-token-auth --enable-model-api ``` ### Step 5: Run inference +The vllm integration uses an OpenAI compatible interface which lets you perform inference with curl or the openai library client and supports streaming. +Curl: ```bash -python ../../utils/test_llm_streaming_response.py -m lora -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json +curl --header "Content-Type: application/json" --request POST --data @prompt.json http://localhost:8080/predictions/llama-8b-lora/1.0/v1 +``` + +Python + Request: +```bash + python ../../utils/test_llm_streaming_response.py -m llama-8b-lora -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json --openai-api --demo-streaming + ``` + +OpenAI client: +```python +from openai import OpenAI +model_name = "llama-8b-lora" +stream=True +openai_api_key = "EMPTY" +openai_api_base = f"http://localhost:8080/predictions/{model_name}/1.0/v1" + +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + +response = client.completions.create( + model=model_name, prompt="Hello world", temperature=0.0, stream=stream +) +for chunk in reponse: + print(f"{chunk=}") ``` diff --git a/examples/large_models/vllm/lora/model-config.yaml b/examples/large_models/vllm/lora/model-config.yaml index f61a31b535..85db70338e 100644 --- a/examples/large_models/vllm/lora/model-config.yaml +++ b/examples/large_models/vllm/lora/model-config.yaml @@ -7,13 +7,17 @@ deviceType: "gpu" asyncCommunication: true handler: - model_path: "model/models--meta-llama--Llama-2-7b-chat-hf/snapshots/f5db02db724555f92da89c216ac04704f23d4590/" + model_path: "model/models--meta-llama--Meta-Llama-3.1-8B/snapshots/48d6d0fc4e02fb1269b36940650a1b7233035cbb" vllm_engine_config: enable_lora: true max_loras: 4 max_cpu_loras: 4 + max_lora_rank: 32 max_num_seqs: 16 max_model_len: 250 + served_model_name: + - "meta-llama/Meta-Llama-3.1-8B" + - "llama-8b-lora" adapters: - adapter_1: "adapters/model/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/" + adapter_1: "adapters/model/models--llama-duo--llama3.1-8b-summarize-gpt4o-128k/snapshots/4ba83353f24fa38946625c8cc49bf21c80a22825" diff --git a/examples/large_models/vllm/lora/prompt.json b/examples/large_models/vllm/lora/prompt.json index bc27314191..46f1db0519 100644 --- a/examples/large_models/vllm/lora/prompt.json +++ b/examples/large_models/vllm/lora/prompt.json @@ -1,9 +1,7 @@ { + "model": "adapter_1", "prompt": "A robot may not injure a human being", - "max_new_tokens": 50, - "temperature": 0.8, + "temperature": 0.0, "logprobs": 1, - "prompt_logprobs": 1, - "max_tokens": 128, - "adapter": "adapter_1" + "max_tokens": 128 } diff --git a/examples/large_models/vllm/mistral/Readme.md b/examples/large_models/vllm/mistral/Readme.md index d3e9d3f9f4..4816adcae5 100644 --- a/examples/large_models/vllm/mistral/Readme.md +++ b/examples/large_models/vllm/mistral/Readme.md @@ -49,5 +49,5 @@ torchserve --start --ncs --ts-config ../config.properties --model-store model_st ### Step 5: Run inference ```bash -python ../../utils/test_llm_streaming_response.py -m mistral -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json +python ../../utils/test_llm_streaming_response.py -m mistral -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json --openai-api ``` diff --git a/examples/large_models/vllm/mistral/model-config.yaml b/examples/large_models/vllm/mistral/model-config.yaml index 0237ef85df..7aac9b01b1 100644 --- a/examples/large_models/vllm/mistral/model-config.yaml +++ b/examples/large_models/vllm/mistral/model-config.yaml @@ -12,3 +12,5 @@ handler: max_model_len: 250 max_num_seqs: 16 tensor_parallel_size: 4 + served_model_name: + - "mistral" diff --git a/examples/large_models/vllm/mistral/prompt.json b/examples/large_models/vllm/mistral/prompt.json index 5b004673b7..91c8cb2c80 100644 --- a/examples/large_models/vllm/mistral/prompt.json +++ b/examples/large_models/vllm/mistral/prompt.json @@ -1,8 +1,7 @@ { + "model": "mistral", "prompt": "A robot may not injure a human being", - "max_new_tokens": 50, "temperature": 0.8, "logprobs": 1, - "prompt_logprobs": 1, "max_tokens": 128 } diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java index 2bb9102969..ed12b8b992 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java @@ -162,9 +162,21 @@ private void handlePredictions( String modelVersion = null; - if (segments.length == 4) { + if (segments.length >= 4) { modelVersion = segments[3]; } + req.headers().add("url_path", ""); + /** + * If url provides more segments as model_name/version we provide these as url_path in the + * request header This way users can leverage them in the custom handler to e.g. influence + * handler behavior + */ + if (segments.length > 4) { + String joinedSegments = + String.join("/", Arrays.copyOfRange(segments, 4, segments.length)); + req.headers().add("url_path", joinedSegments); + } + req.headers().add("explain", "False"); if (explain) { req.headers().add("explain", "True"); diff --git a/test/pytest/test_example_vllm.py b/test/pytest/test_example_vllm.py index e6a43e6a3c..d48029bed6 100644 --- a/test/pytest/test_example_vllm.py +++ b/test/pytest/test_example_vllm.py @@ -15,9 +15,13 @@ LORA_SRC_PATH = VLLM_PATH / "lora" CONFIG_PROPERTIES_PATH = CURR_FILE_PATH.parents[1] / "test" / "config_ts.properties" -LLAMA_MODEL_PATH = "model/models--meta-llama--Llama-2-7b-chat-hf/snapshots/f5db02db724555f92da89c216ac04704f23d4590/" +LLAMA_MODEL_PATH = "model/models--meta-llama--Meta-Llama-3.1-8B/snapshots/48d6d0fc4e02fb1269b36940650a1b7233035cbb" -ADAPTER_PATH = "adapters/model/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c" +ADAPTER_PATH = "adapters/model/models--llama-duo--llama3.1-8b-summarize-gpt4o-128k/snapshots/4ba83353f24fa38946625c8cc49bf21c80a22825" + +TOKENIZER_CONFIG = LORA_SRC_PATH / LLAMA_MODEL_PATH / "tokenizer_config.json" + +CHAT_TEMPLATE = '{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = "26 Jul 2024" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0][\'role\'] == \'system\' %}\n {%- set system_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- "Environment: ipython\\n" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- "Tools: " + builtin_tools | reject(\'equalto\', \'code_interpreter\') | join(", ") + "\\n\\n"}}\n{%- endif %}\n{{- "Cutting Knowledge Date: December 2023\\n" }}\n{{- "Today Date: " + date_string + "\\n\\n" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- "<|eot_id|>" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there\'s no first user message!") }}\n{%- endif %}\n {{- \'<|start_header_id|>user<|end_header_id|>\\n\\n\' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == \'ipython\' or message.role == \'tool\' or \'tool_calls\' in message) %}\n {{- \'<|start_header_id|>\' + message[\'role\'] + \'<|end_header_id|>\\n\\n\'+ message[\'content\'] | trim + \'<|eot_id|>\' }}\n {%- elif \'tool_calls\' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- "<|python_tag|>" + tool_call.name + ".call(" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + \'="\' + arg_val + \'"\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- ")" }}\n {%- else %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- \'{"name": "\' + tool_call.name + \'", \' }}\n {{- \'"parameters": \' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we\'re in ipython mode #}\n {{- "<|eom_id|>" }}\n {%- else %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' }}\n{%- endif %}\n)' YAML_CONFIG = f""" # TorchServe frontend parameters @@ -36,10 +40,15 @@ vllm_engine_config: enable_lora: true max_loras: 4 + max_lora_rank: 32 max_cpu_loras: 4 max_num_seqs: 16 max_model_len: 250 tensor_parallel_size: {torch.cuda.device_count()} + served_model_name: + - "Meta-Llama-31-8B" + + chat_template: "chat_template.txt" adapters: adapter_1: "{(LORA_SRC_PATH / ADAPTER_PATH).as_posix()}" @@ -48,29 +57,37 @@ PROMPTS = [ { "prompt": "A robot may not injure a human being", - "max_new_tokens": 50, - "temperature": 0.8, + "temperature": 0.0, "logprobs": 1, - "prompt_logprobs": 1, - "max_tokens": 128, - "adapter": "adapter_1", + "max_tokens": 20, + "model": "Meta-Llama-31-8B", + "stream": True, + }, + { + "prompt": "Paris is,", + "logprobs": 1, + "max_tokens": 20, + "temperature": 0.0, + "top_p": 0.1, + "model": "Meta-Llama-31-8B", + "seed": 42, + "stream": True, }, { - "prompt": "Paris is, ", - "max_new_tokens": 50, + "prompt": "Paris is,", "logprobs": 1, - "prompt_logprobs": 1, - "max_tokens": 128, + "max_tokens": 20, "temperature": 0.0, - "top_k": 1, - "top_p": 0, - "adapter": "adapter_1", + "top_p": 0.1, + "model": "adapter_1", "seed": 42, + "stream": True, }, ] EXPECTED = [ - " or, ", # through inaction", # edit to pass see https://github.com/vllm-project/vllm/issues/5404 - "1900.\n\nThe city is", # bathed", + " or, through inaction, allow a human being to come to harm.\nA robot must obey the", + " without a doubt, one of the most beautiful cities in the world. It is a city that is", + " without a doubt, one of the most beautiful cities in the world. Its rich history, stunning architecture", ] try: @@ -80,14 +97,32 @@ except ImportError: VLLM_MISSING = True +try: + from openai import OpenAI # noqa + + OPENAI_MISSING = False +except ImportError: + OPENAI_MISSING = True + def necessary_files_unavailable(): LLAMA = LORA_SRC_PATH / LLAMA_MODEL_PATH ADAPTER = LORA_SRC_PATH / ADAPTER_PATH - return { - "condition": not (LLAMA.exists() and ADAPTER.exists()) or VLLM_MISSING, - "reason": f"Required files are not present or vllm is not installed (see README): {LLAMA.as_posix()} + {ADAPTER.as_posix()}", - } + if not (LLAMA.exists() and ADAPTER.exists()): + return { + "condition": True, + "reason": f"Required files are not present (see README): {LLAMA.as_posix()} + {ADAPTER.as_posix()}", + } + elif VLLM_MISSING: + return { + "condition": True, + "reason": f"VLLM is not installed", + } + else: + return { + "condition": False, + "reason": "None", + } @pytest.fixture @@ -101,7 +136,7 @@ def add_paths(): @pytest.fixture(scope="module") def model_name(): - yield "test_lora" + yield "Meta-Llama-31-8B" @pytest.fixture(scope="module") @@ -116,16 +151,20 @@ def create_mar_file(work_dir, model_archiver, model_name, request): model_config_yaml = Path(work_dir) / "model-config.yaml" model_config_yaml.write_text(YAML_CONFIG) + chat_template_txt = Path(work_dir) / "chat_template.txt" + chat_template_txt.write_text(CHAT_TEMPLATE) + config = ModelArchiverConfig( model_name=model_name, version="1.0", - handler=(VLLM_PATH / "base_vllm_handler.py").as_posix(), + handler="vllm_handler", serialized_file=None, export_path=work_dir, requirements_file=None, runtime="python", force=False, config_file=model_config_yaml.as_posix(), + extra_files=chat_template_txt.as_posix(), archive_format="no-archive", ) @@ -139,60 +178,174 @@ def create_mar_file(work_dir, model_archiver, model_name, request): shutil.rmtree(mar_file_path) -@pytest.mark.skipif(**necessary_files_unavailable()) -def test_vllm_lora_mar(mar_file_path, model_store, torchserve): +@pytest.fixture(scope="module", name="model_name") +def register_model(mar_file_path, model_store, torchserve): """ Register the model in torchserve """ - file_name = Path(mar_file_path).name - model_name = Path(file_name).stem - shutil.copytree(mar_file_path, Path(model_store) / model_name) + shutil.copytree(mar_file_path, model_store + f"/{model_name}") params = ( ("model_name", model_name), - ("url", Path(model_store) / model_name), + ("url", file_name), ("initial_workers", "1"), ("synchronous", "true"), ("batch_size", "1"), ) + test_utils.reg_resp = test_utils.register_model_with_params(params) + + yield model_name + + test_utils.unregister_model(model_name) + + # Clean up files + shutil.rmtree(Path(model_store) / model_name) + + +def extract_text(chunk): + if not isinstance(chunk, str): + chunk = chunk.decode("utf-8") + if chunk.startswith("data:"): + chunk = chunk[len("data:") :].split("\n")[0].strip() + if chunk.startswith("[DONE]"): + return "" + return json.loads(chunk)["choices"][0]["text"] + + +def extract_chat(chunk): + if not isinstance(chunk, str): + chunk = chunk.decode("utf-8") + if chunk.startswith("data:"): + chunk = chunk[len("data:") :].split("\n")[0].strip() + if chunk.startswith("[DONE]"): + return "" try: - test_utils.reg_resp = test_utils.register_model_with_params(params) - responses = [] - - for _ in range(10): - idx = random.randint(0, 1) - response = requests.post( - url=f"http://localhost:8080/predictions/{model_name}", - json=PROMPTS[idx], - stream=True, - ) - - assert response.status_code == 200 - - assert response.headers["Transfer-Encoding"] == "chunked" - responses += [(response, EXPECTED[idx])] - - predictions = [] - expected_result = [] - for response, expected in responses: - prediction = [] - for chunk in response.iter_content(chunk_size=None): - if chunk: - data = json.loads(chunk) - prediction += [data.get("text", "")] - predictions += [prediction] - expected_result += [expected] - assert all(len(p) > 1 for p in predictions) - assert all( - "".join(p).startswith(e) for p, e in zip(predictions, expected_result) + return json.loads(chunk)["choices"][0].get("message", {})["content"] + except KeyError: + return json.loads(chunk)["choices"][0].get("delta", {}).get("content", "") + + +@pytest.mark.skipif(**necessary_files_unavailable()) +def test_vllm_lora(model_name): + """ + Register the model in torchserve + """ + + base_url = f"http://localhost:8080/predictions/{model_name}/1.0/v1/completions" + + responses = [] + + for _ in range(10): + idx = random.randint(0, len(PROMPTS) - 1) + + response = requests.post(base_url, json=PROMPTS[idx], stream=True) + + assert response.status_code == 200 + + assert response.headers["Transfer-Encoding"] == "chunked" + responses += [(response, EXPECTED[idx])] + + predictions = [] + expected_result = [] + for response, expected in responses: + prediction = "" + for chunk in response.iter_content(chunk_size=None): + if chunk: + prediction += extract_text(chunk) + predictions += [prediction] + expected_result += [expected] + + assert all(len(p) > 1 for p in predictions) + assert all("".join(p) == e for p, e in zip(predictions, expected_result)) + + +@pytest.mark.skipif(**necessary_files_unavailable()) +@pytest.mark.parametrize("stream", [True, False]) +def test_openai_api_completions(model_name, stream): + base_url = f"http://localhost:8080/predictions/{model_name}/1.0/v1/completions" + + data = { + "model": model_name, + "prompt": "Hello world", + "temperature": 0.0, + "stream": stream, + } + + response = requests.post(base_url, json=data, stream=stream) + + if stream: + assert response.headers["Transfer-Encoding"] == "chunked" + + EXPECTED = ( + "! I’m a new blogger and I’m excited to share my thoughts and experiences" + ) + i = 0 + + for chunk in response.iter_content(chunk_size=None): + if chunk: + text = extract_text(chunk) + assert text == EXPECTED[i : i + len(text)] + i += len(text) + assert i > 0 + + else: + assert ( + extract_text(response.text) + == "! I’m a new blogger and I’m excited to share my thoughts and experiences" ) - finally: - test_utils.unregister_model(model_name) - # Clean up files - shutil.rmtree(Path(model_store) / model_name) +@pytest.mark.skipif(**necessary_files_unavailable()) +@pytest.mark.parametrize("stream", [True, False]) +def test_openai_api_chat_complations(model_name, stream): + base_url = f"http://localhost:8080/predictions/{model_name}/1.0/v1/chat/completions" + + data = { + "model": model_name, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + { + "role": "assistant", + "content": "The Los Angeles Dodgers won the World Series in 2020.", + }, + {"role": "user", "content": "Where was it played?"}, + ], + "temperature": 0.0, + "max_tokens": 50, + "stream": stream, + } + + response = requests.post(base_url, json=data, stream=stream) + + EXPECTED = " The World Series was played in Arlington, Texas. The Dodgers defeated the Tampa Bay Rays in six games." + if stream: + assert response.headers["Transfer-Encoding"] == "chunked" + + text = "" + for chunk in response.iter_content(chunk_size=None): + if chunk: + text += extract_chat(chunk) + assert text.startswith(EXPECTED) + + else: + assert extract_chat(response.text).startswith(EXPECTED) + + +@pytest.mark.skipif(**necessary_files_unavailable()) +def test_openai_api_models(model_name): + base_url = f"http://localhost:8080/predictions/{model_name}/1.0/v1/models" + + response = requests.post(base_url) + + data = json.loads(response.text) + + models = [m["id"] for m in data["data"]] + + assert model_name in models + + assert "adapter_1" in models diff --git a/test/pytest/test_url_path.py b/test/pytest/test_url_path.py new file mode 100644 index 0000000000..2938472cf1 --- /dev/null +++ b/test/pytest/test_url_path.py @@ -0,0 +1,151 @@ +import shutil +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from unittest.mock import patch + +import pytest +import requests +import test_utils +from model_archiver import ModelArchiverConfig + +CURR_FILE_PATH = Path(__file__).parent +REPO_ROOT_DIR = CURR_FILE_PATH.parents[1] + +HANDLER_PY = """ +import logging +import torch +from ts.torch_handler.base_handler import BaseHandler + +logger = logging.getLogger(__file__) + +class customHandler(BaseHandler): + + def initialize(self, context): + self.context = context + pass + + def preprocess(self, data): + reqs = [] + for i in range(len(data)): + reqs.append(self.context.get_request_header(i, "url_path")) + return reqs + + def inference(self, data): + return data + + def postprocess(self, data): + return data +""" + +MODEL_CONFIG_YAML = """ + #frontend settings + # TorchServe frontend parameters + minWorkers: 1 + batchSize: 2 + maxBatchDelay: 2000 + maxWorkers: 1 + """ + + +@pytest.fixture(scope="module") +def model_name(): + yield "some_model" + + +@pytest.fixture(scope="module") +def work_dir(tmp_path_factory, model_name): + return Path(tmp_path_factory.mktemp(model_name)) + + +@pytest.fixture(scope="module", name="mar_file_path") +def create_mar_file(work_dir, model_archiver, model_name): + mar_file_path = work_dir.joinpath(model_name + ".mar") + + model_config_yaml_file = work_dir / "model_config.yaml" + model_config_yaml_file.write_text(MODEL_CONFIG_YAML) + + handler_py_file = work_dir / "handler.py" + handler_py_file.write_text(HANDLER_PY) + + config = ModelArchiverConfig( + model_name=model_name, + version="1.0", + serialized_file=None, + model_file=None, + handler=handler_py_file.as_posix(), + extra_files=None, + export_path=work_dir, + requirements_file=None, + runtime="python", + force=False, + archive_format="default", + config_file=model_config_yaml_file.as_posix(), + ) + + with patch("archiver.ArgParser.export_model_args_parser", return_value=config): + model_archiver.generate_model_archive() + + assert mar_file_path.exists() + + yield mar_file_path.as_posix() + + # Clean up files + + mar_file_path.unlink(missing_ok=True) + + # Clean up files + + +@pytest.fixture(scope="module", name="model_name") +def register_model(mar_file_path, model_store, torchserve): + """ + Register the model in torchserve + """ + shutil.copy(mar_file_path, model_store) + + file_name = Path(mar_file_path).name + + model_name = Path(file_name).stem + + params = ( + ("model_name", model_name), + ("url", file_name), + ("initial_workers", "1"), + ("synchronous", "true"), + ("batch_size", "2"), + ("max_batch_delay", "2000"), + ) + + test_utils.reg_resp = test_utils.register_model_with_params(params) + + yield model_name + + test_utils.unregister_model(model_name) + + +def test_url_paths(model_name): + response = requests.get(f"http://localhost:8081/models/{model_name}") + assert response.status_code == 200, "Describe Failed" + + response = requests.get( + f"http://localhost:8080/predictions/{model_name}/1.0/v1/chat/completion", + json={"prompt": "Hello world"}, + ) + + url_paths = ["v1/chat/completion", "v1/completion"] + + with ThreadPoolExecutor(max_workers=2) as e: + futures = [] + for p in url_paths: + + def send_file(url): + return requests.post( + f"http://localhost:8080/predictions/{model_name}/1.0/" + url, + json={"prompt": "Hello world"}, + ) + + futures += [e.submit(send_file, p)] + + for i, f in enumerate(futures): + prediction = f.result() + assert prediction.content.decode("utf-8") == url_paths[i], "Wrong prediction" diff --git a/ts/torch_handler/vllm_handler.py b/ts/torch_handler/vllm_handler.py index 35695d69ef..927efe93e2 100644 --- a/ts/torch_handler/vllm_handler.py +++ b/ts/torch_handler/vllm_handler.py @@ -1,12 +1,21 @@ -import json +import asyncio import logging import pathlib import time - -from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams -from vllm.lora.request import LoRARequest +from unittest.mock import MagicMock + +from vllm import AsyncEngineArgs, AsyncLLMEngine +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + CompletionRequest, + ErrorResponse, +) +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_engine import LoRAModulePath from ts.handler_utils.utils import send_intermediate_predict_response +from ts.service import PredictionException from ts.torch_handler.base_handler import BaseHandler logger = logging.getLogger(__name__) @@ -17,10 +26,13 @@ def __init__(self): super().__init__() self.vllm_engine = None - self.model = None + self.model_name = None self.model_dir = None self.lora_ids = {} self.adapters = None + self.chat_completion_service = None + self.completion_service = None + self.raw_request = None self.initialized = False def initialize(self, ctx): @@ -28,9 +40,51 @@ def initialize(self, ctx): vllm_engine_config = self._get_vllm_engine_config( ctx.model_yaml_config.get("handler", {}) ) - self.adapters = ctx.model_yaml_config.get("handler", {}).get("adapters", {}) self.vllm_engine = AsyncLLMEngine.from_engine_args(vllm_engine_config) + + self.adapters = ctx.model_yaml_config.get("handler", {}).get("adapters", {}) + lora_modules = [LoRAModulePath(n, p) for n, p in self.adapters.items()] + + if vllm_engine_config.served_model_name: + served_model_names = vllm_engine_config.served_model_name + else: + served_model_names = [vllm_engine_config.model] + + chat_template = ctx.model_yaml_config.get("handler", {}).get( + "chat_template", None + ) + + loop = asyncio.get_event_loop() + model_config = loop.run_until_complete(self.vllm_engine.get_model_config()) + + self.completion_service = OpenAIServingCompletion( + self.vllm_engine, + model_config, + served_model_names, + lora_modules=lora_modules, + prompt_adapters=None, + request_logger=None, + ) + + self.chat_completion_service = OpenAIServingChat( + self.vllm_engine, + model_config, + served_model_names, + "assistant", + lora_modules=lora_modules, + prompt_adapters=None, + request_logger=None, + chat_template=chat_template, + ) + + async def isd(): + return False + + self.raw_request = MagicMock() + self.raw_request.headers = {} + self.raw_request.is_disconnected = isd + self.initialized = True async def handle(self, data, context): @@ -38,7 +92,7 @@ async def handle(self, data, context): metrics = context.metrics - data_preprocess = await self.preprocess(data) + data_preprocess = await self.preprocess(data, context) output = await self.inference(data_preprocess, context) output = await self.postprocess(output) @@ -48,38 +102,57 @@ async def handle(self, data, context): ) return output - async def preprocess(self, requests): - input_batch = [] + async def preprocess(self, requests, context): assert len(requests) == 1, "Expecting batch_size = 1" - for req_data in requests: - data = req_data.get("data") or req_data.get("body") - if isinstance(data, (bytes, bytearray)): - data = data.decode("utf-8") + req_data = requests[0] + data = req_data.get("data") or req_data.get("body") + if isinstance(data, (bytes, bytearray)): + data = data.decode("utf-8") - prompt = data.get("prompt") - sampling_params = self._get_sampling_params(data) - lora_request = self._get_lora_request(data) - input_batch += [(prompt, sampling_params, lora_request)] - return input_batch + return [data] async def inference(self, input_batch, context): - logger.debug(f"Inputs: {input_batch[0]}") - prompt, params, lora = input_batch[0] - generator = self.vllm_engine.generate( - prompt, params, context.request_ids[0], lora + url_path = context.get_request_header(0, "url_path") + + if url_path == "v1/models": + models = await self.chat_completion_service.show_available_models() + return [models.model_dump()] + + directory = { + "v1/completions": ( + CompletionRequest, + self.completion_service, + "create_completion", + ), + "v1/chat/completions": ( + ChatCompletionRequest, + self.chat_completion_service, + "create_chat_completion", + ), + } + + RequestType, service, func = directory.get(url_path, (None, None, None)) + + if RequestType is None: + raise PredictionException(f"Unknown API endpoint: {url_path}", 404) + + request = RequestType.model_validate(input_batch[0]) + g = await getattr(service, func)( + request, + self.raw_request, ) - text_len = 0 - async for output in generator: - result = { - "text": output.outputs[0].text[text_len:], - "tokens": output.outputs[0].token_ids[-1], - } - text_len = len(output.outputs[0].text) - if not output.finished: - send_intermediate_predict_response( - [json.dumps(result)], context.request_ids, "Result", 200, context - ) - return [json.dumps(result)] + + if isinstance(g, ErrorResponse): + return [g.model_dump()] + if request.stream: + async for response in g: + if response != "data: [DONE]\n\n": + send_intermediate_predict_response( + [response], context.request_ids, "Result", 200, context + ) + return [response] + else: + return [g.model_dump()] async def postprocess(self, inference_outputs): return inference_outputs @@ -98,29 +171,13 @@ def _get_vllm_engine_config(self, handler_config: dict): f"Model path ({model}) does not exist locally. Trying to give without model_dir as prefix." ) model = model_path + else: + model = model.as_posix() logger.debug(f"EngineArgs model: {model}") vllm_engine_config = AsyncEngineArgs(model=model) self._set_attr_value(vllm_engine_config, vllm_engine_params) return vllm_engine_config - def _get_sampling_params(self, req_data: dict): - sampling_params = SamplingParams() - self._set_attr_value(sampling_params, req_data) - - return sampling_params - - def _get_lora_request(self, req_data: dict): - adapter_name = req_data.get("lora_adapter", "") - - if len(adapter_name) > 0: - adapter_path = self.adapters.get(adapter_name, "") - assert len(adapter_path) > 0, f"{adapter_name} misses adapter path" - lora_id = self.lora_ids.setdefault(adapter_name, len(self.lora_ids) + 1) - adapter_path = str(pathlib.Path(self.model_dir).joinpath(adapter_path)) - return LoRARequest(adapter_name, lora_id, adapter_path) - - return None - def _set_attr_value(self, obj, config: dict): items = vars(obj) for k, v in config.items(): diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index 0e3cf7ee17..0242be57a1 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1297,3 +1297,5 @@ photorealistic miniconda torchaudio ln +OpenAI +openai