Skip to content

Commit

Permalink
Add vllm version check for compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
jibxie committed Nov 6, 2024
1 parent c85d972 commit 566e0cc
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 23 deletions.
46 changes: 25 additions & 21 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm.version import __version__ as _VLLM_VERSION

from utils.metrics import VllmStatLogger

Expand All @@ -54,12 +55,6 @@ class TritonPythonModel:
def auto_complete_config(auto_complete_model_config):
inputs = [
{"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]},
{
"name": "image",
"data_type": "TYPE_STRING",
"dims": [-1], # can be multiple images as separate elements
"optional": True,
},
{
"name": "stream",
"data_type": "TYPE_BOOL",
Expand All @@ -79,6 +74,14 @@ def auto_complete_config(auto_complete_model_config):
"optional": True,
},
]
if _VLLM_VERSION >= "0.6.3.post1":
inputs.append({
"name": "image",
"data_type": "TYPE_STRING",
"dims": [-1], # can be multiple images as separate elements
"optional": True,
})

outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}]

# Store the model configuration as a dictionary.
Expand Down Expand Up @@ -394,22 +397,23 @@ async def generate(self, request):
if isinstance(prompt, bytes):
prompt = prompt.decode("utf-8")

image_input_tensor = pb_utils.get_input_tensor_by_name(
request, "image"
)
if image_input_tensor:
image_list = []
for image_raw in image_input_tensor.as_numpy():
image_data = base64.b64decode(image_raw.decode("utf-8"))
image = Image.open(BytesIO(image_data)).convert("RGB")
image_list.append(image)
if len(image_list) > 0:
prompt = {
"prompt": prompt,
"multi_modal_data": {
"image": image_list
if _VLLM_VERSION >= "0.6.3.post1":
image_input_tensor = pb_utils.get_input_tensor_by_name(
request, "image"
)
if image_input_tensor:
image_list = []
for image_raw in image_input_tensor.as_numpy():
image_data = base64.b64decode(image_raw.decode("utf-8"))
image = Image.open(BytesIO(image_data)).convert("RGB")
image_list.append(image)
if len(image_list) > 0:
prompt = {
"prompt": prompt,
"multi_modal_data": {
"image": image_list
}
}
}

stream = pb_utils.get_input_tensor_by_name(request, "stream")
if stream:
Expand Down
18 changes: 16 additions & 2 deletions src/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase
from vllm.engine.metrics import Stats as VllmStats
from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets

from vllm.version import __version__ as _VLLM_VERSION

class TritonMetrics:
def __init__(self, labels: List[str], max_model_len: int):
Expand Down Expand Up @@ -76,6 +76,14 @@ def __init__(self, labels: List[str], max_model_len: int):
description="Number of generation tokens processed.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)
# 'best_of' metric has been hidden since vllm 0.6.3
# https://github.com/vllm-project/vllm/commit/cbc2ef55292b2af6ff742095c030e8425124c005
if _VLLM_VERSION < "0.6.3":
self.histogram_best_of_request_family = pb_utils.MetricFamily(
name="vllm:request_params_best_of",
description="Histogram of the best_of request parameter.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)
self.histogram_n_request_family = pb_utils.MetricFamily(
name="vllm:request_params_n",
description="Histogram of the n request parameter.",
Expand Down Expand Up @@ -154,6 +162,11 @@ def __init__(self, labels: List[str], max_model_len: int):
buckets=build_1_2_5_buckets(max_model_len),
)
)
if _VLLM_VERSION < "0.6.3":
self.histogram_best_of_request = self.histogram_best_of_request_family.Metric(
labels=labels,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_n_request = self.histogram_n_request_family.Metric(
labels=labels,
buckets=[1, 2, 5, 10, 20],
Expand Down Expand Up @@ -240,7 +253,8 @@ def log(self, stats: VllmStats) -> None:
),
(self.metrics.histogram_n_request, stats.n_requests),
]

if _VLLM_VERSION < "0.6.3":
histogram_metrics.append((self.metrics.histogram_best_of_request, stats.best_of_requests))
for metric, data in counter_metrics:
self._log_counter(metric, data)
for metric, data in histogram_metrics:
Expand Down

0 comments on commit 566e0cc

Please sign in to comment.