diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 6627ba1ea564..a975dba6f513 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -1012,8 +1012,13 @@ def prepare_input_tensors( paddings = [max_len - s for s in seq_lens] paddings = [0] + paddings[:-1] paddings = list(itertools.accumulate(paddings)) + paddings_prompt_logprobs = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + if seq_group_metadata.sampling_params.prompt_logprobs is not None \ + and seq_group_metadata.is_prompt: + paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) paddings = torch.tensor( - paddings, + paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, dtype=sampling_metadata.selected_token_indices.dtype, device=sampling_metadata.selected_token_indices.device) sampling_metadata.selected_token_indices.add_(paddings)