From ec38d6714eeaa6ea3d3e1a928f84d434d142dde9 Mon Sep 17 00:00:00 2001 From: huanxing Date: Wed, 13 Nov 2024 09:29:35 +0800 Subject: [PATCH] fix yaff --- vllm/model_executor/sampling_metadata.py | 51 ++++++++++++++++-------- vllm/utils.py | 3 +- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 500cbf6e0b2dd..e87999a77838a 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -9,7 +9,8 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, SequenceGroupMetadata) from vllm.utils import (PyObjectCache, async_tensor_h2d, - is_pin_memory_available, make_tensor_with_pad, + is_pin_memory_available, + make_tensor_with_pad, make_tensor_with_pad_align) _SAMPLING_EPS = 1e-5 @@ -523,22 +524,38 @@ def from_lists( do_penalties = prompt_tokens or output_tokens if do_penalties: - prompt_t = make_tensor_with_pad_align( - prompt_tokens, - vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=pin_memory, - max_len_align=1024, - ) - output_t = make_tensor_with_pad_align( - output_tokens, - vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=pin_memory, - max_len_align=1024, - ) + if current_platform.is_hpu(): + prompt_t = make_tensor_with_pad_align( + prompt_tokens, + vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=pin_memory, + max_len_align=1024, + ) + output_t = make_tensor_with_pad_align( + output_tokens, + vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=pin_memory, + max_len_align=1024, + ) + else: + prompt_t = make_tensor_with_pad( + prompt_tokens, + vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=pin_memory, + ) + output_t = make_tensor_with_pad( + output_tokens, + vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=pin_memory, + ) else: empty_tensor = torch.empty(0, device=device, dtype=torch.long) prompt_t = empty_tensor diff --git a/vllm/utils.py b/vllm/utils.py index 902df324bc090..f916fcd9c07b0 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -877,7 +877,8 @@ def make_tensor_with_pad_align( `max_len`. """ np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] - padded_x = make_ndarray_with_pad_align(x, pad, np_dtype, max_len_align=max_len_align) + padded_x = make_ndarray_with_pad_align(x, pad, np_dtype, + max_len_align=max_len_align) tensor = torch.from_numpy(padded_x).to(device) if pin_memory: