diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 7f7f15bea86fa..4fb68741386d8 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -958,8 +958,13 @@ def prepare_input_tensors( paddings = [max_len - s for s in seq_lens] paddings = [0] + paddings[:-1] paddings = list(itertools.accumulate(paddings)) + paddings_prompt_logprobs = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + if seq_group_metadata.sampling_params.prompt_logprobs is not None \ + and seq_group_metadata.is_prompt: + paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) paddings = torch.tensor( - paddings, + paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, dtype=sampling_metadata.selected_token_indices.dtype, device=sampling_metadata.selected_token_indices.device) sampling_metadata.selected_token_indices.add_(paddings)