From 4538ef6eb2ca8cae0b801a382d706c33e46e8a5b Mon Sep 17 00:00:00 2001 From: Vivek Goel Date: Thu, 29 Aug 2024 11:23:05 +0530 Subject: [PATCH] Update paddings computed to adjust selected_token_indices (#210) Fixes assert seen when "prompt_logprobs is not None" and BS > 1. Assert was due to shape of paddings being added to matching sampling_metadata.selected_token_indices shape for the case where prompt_logprobs is configured. --- vllm/worker/habana_model_runner.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 6627ba1ea5643..a975dba6f5136 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -1012,8 +1012,13 @@ def prepare_input_tensors( paddings = [max_len - s for s in seq_lens] paddings = [0] + paddings[:-1] paddings = list(itertools.accumulate(paddings)) + paddings_prompt_logprobs = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + if seq_group_metadata.sampling_params.prompt_logprobs is not None \ + and seq_group_metadata.is_prompt: + paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) paddings = torch.tensor( - paddings, + paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, dtype=sampling_metadata.selected_token_indices.dtype, device=sampling_metadata.selected_token_indices.device) sampling_metadata.selected_token_indices.add_(paddings)