diff --git a/src/model.py b/src/model.py index 7211b10..0fdbe0c 100644 --- a/src/model.py +++ b/src/model.py @@ -42,6 +42,7 @@ from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid +from vllm.version import __version__ as _VLLM_VERSION from utils.metrics import VllmStatLogger @@ -54,12 +55,6 @@ class TritonPythonModel: def auto_complete_config(auto_complete_model_config): inputs = [ {"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]}, - { - "name": "image", - "data_type": "TYPE_STRING", - "dims": [-1], # can be multiple images as separate elements - "optional": True, - }, { "name": "stream", "data_type": "TYPE_BOOL", @@ -79,6 +74,14 @@ def auto_complete_config(auto_complete_model_config): "optional": True, }, ] + if _VLLM_VERSION >= "0.6.3.post1": + inputs.append({ + "name": "image", + "data_type": "TYPE_STRING", + "dims": [-1], # can be multiple images as separate elements + "optional": True, + }) + outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}] # Store the model configuration as a dictionary. @@ -394,22 +397,23 @@ async def generate(self, request): if isinstance(prompt, bytes): prompt = prompt.decode("utf-8") - image_input_tensor = pb_utils.get_input_tensor_by_name( - request, "image" - ) - if image_input_tensor: - image_list = [] - for image_raw in image_input_tensor.as_numpy(): - image_data = base64.b64decode(image_raw.decode("utf-8")) - image = Image.open(BytesIO(image_data)).convert("RGB") - image_list.append(image) - if len(image_list) > 0: - prompt = { - "prompt": prompt, - "multi_modal_data": { - "image": image_list + if _VLLM_VERSION >= "0.6.3.post1": + image_input_tensor = pb_utils.get_input_tensor_by_name( + request, "image" + ) + if image_input_tensor: + image_list = [] + for image_raw in image_input_tensor.as_numpy(): + image_data = base64.b64decode(image_raw.decode("utf-8")) + image = Image.open(BytesIO(image_data)).convert("RGB") + image_list.append(image) + if len(image_list) > 0: + prompt = { + "prompt": prompt, + "multi_modal_data": { + "image": image_list + } } - } stream = pb_utils.get_input_tensor_by_name(request, "stream") if stream: diff --git a/src/utils/metrics.py b/src/utils/metrics.py index 3588a0d..0504eef 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -32,7 +32,7 @@ from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase from vllm.engine.metrics import Stats as VllmStats from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets - +from vllm.version import __version__ as _VLLM_VERSION class TritonMetrics: def __init__(self, labels: List[str], max_model_len: int): @@ -76,6 +76,14 @@ def __init__(self, labels: List[str], max_model_len: int): description="Number of generation tokens processed.", kind=pb_utils.MetricFamily.HISTOGRAM, ) + # 'best_of' metric has been hidden since vllm 0.6.3 + # https://github.com/vllm-project/vllm/commit/cbc2ef55292b2af6ff742095c030e8425124c005 + if _VLLM_VERSION < "0.6.3": + self.histogram_best_of_request_family = pb_utils.MetricFamily( + name="vllm:request_params_best_of", + description="Histogram of the best_of request parameter.", + kind=pb_utils.MetricFamily.HISTOGRAM, + ) self.histogram_n_request_family = pb_utils.MetricFamily( name="vllm:request_params_n", description="Histogram of the n request parameter.", @@ -154,6 +162,11 @@ def __init__(self, labels: List[str], max_model_len: int): buckets=build_1_2_5_buckets(max_model_len), ) ) + if _VLLM_VERSION < "0.6.3": + self.histogram_best_of_request = self.histogram_best_of_request_family.Metric( + labels=labels, + buckets=[1, 2, 5, 10, 20], + ) self.histogram_n_request = self.histogram_n_request_family.Metric( labels=labels, buckets=[1, 2, 5, 10, 20], @@ -240,7 +253,8 @@ def log(self, stats: VllmStats) -> None: ), (self.metrics.histogram_n_request, stats.n_requests), ] - + if _VLLM_VERSION < "0.6.3": + histogram_metrics.append((self.metrics.histogram_best_of_request, stats.best_of_requests)) for metric, data in counter_metrics: self._log_counter(metric, data) for metric, data in histogram_metrics: