From 15e5d799ab50e4b533528d25b23437cd280155fb Mon Sep 17 00:00:00 2001 From: Huanxing Date: Mon, 28 Oct 2024 08:44:13 +0000 Subject: [PATCH 1/5] to make repetition penalty faster: first, enable pin memory;second use masked_fill instead of boolean index; third add paddings to prompt tokens and output tokens to reduce re-compile. --- vllm/model_executor/layers/sampler.py | 2 +- vllm/model_executor/sampling_metadata.py | 34 ++++++++---- vllm/utils.py | 66 ++++++++++++++++++++++-- vllm/worker/cache_engine.py | 18 +++++-- 4 files changed, 99 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index bb025bfd819d5..4c777b7d3d6ce 100755 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -519,7 +519,7 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, output_tokens_tensor, vocab_size, num_seqs) repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) - repetition_penalties[~(prompt_mask | output_mask)] = 1.0 + repetition_penalties.masked_fill_(~(prompt_mask | output_mask), 1.0) logits = torch.where(logits > 0, logits / repetition_penalties, logits * repetition_penalties) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 9fda807d29236..500cbf6e0b2dd 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -9,7 +9,8 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, SequenceGroupMetadata) from vllm.utils import (PyObjectCache, async_tensor_h2d, - is_pin_memory_available, make_tensor_with_pad) + is_pin_memory_available, make_tensor_with_pad, + make_tensor_with_pad_align) _SAMPLING_EPS = 1e-5 @@ -522,19 +523,21 @@ def from_lists( do_penalties = prompt_tokens or output_tokens if do_penalties: - prompt_t = make_tensor_with_pad( + prompt_t = make_tensor_with_pad_align( prompt_tokens, vocab_size, device="cpu", dtype=torch.int64, pin_memory=pin_memory, + max_len_align=1024, ) - output_t = make_tensor_with_pad( + output_t = make_tensor_with_pad_align( output_tokens, vocab_size, device="cpu", dtype=torch.int64, pin_memory=pin_memory, + max_len_align=1024, ) else: empty_tensor = torch.empty(0, device=device, dtype=torch.long) @@ -545,47 +548,58 @@ def from_lists( temperatures, device="cpu", dtype=dtype, - pin_memory=pin_memory, ) top_ps_t = torch.tensor( top_ps, device="cpu", dtype=dtype, - pin_memory=pin_memory, ) min_ps_t = torch.tensor( min_ps, device="cpu", dtype=dtype, - pin_memory=pin_memory, ) presence_penalties_t = torch.tensor( presence_penalties, device="cpu", dtype=dtype, - pin_memory=pin_memory, ) frequency_penalties_t = torch.tensor( frequency_penalties, device="cpu", dtype=dtype, - pin_memory=pin_memory, ) repetition_penalties_t = torch.tensor( repetition_penalties, device="cpu", dtype=dtype, - pin_memory=pin_memory, ) top_ks_t = torch.tensor( top_ks, device="cpu", dtype=torch.int, - pin_memory=pin_memory, ) # Because the memory is pinned, we can do non-blocking # transfer to device. + if pin_memory: + if not current_platform.is_hpu(): + temperatures_t.pin_memory() + top_ps_t.pin_memory() + min_ps_t.pin_memory() + frequency_penalties_t.pin_memory() + presence_penalties_t.pin_memory() + repetition_penalties_t.pin_memory() + top_ks_t.pin_memory() + else: + temperatures_t.pin_memory(device="hpu") + top_ps_t.pin_memory(device="hpu") + min_ps_t.pin_memory(device="hpu") + frequency_penalties_t.pin_memory(device="hpu") + presence_penalties_t.pin_memory(device="hpu") + repetition_penalties_t.pin_memory(device="hpu") + top_ks_t.pin_memory(device="hpu") + return cls( temperatures=temperatures_t.to(device=device, non_blocking=True), top_ps=top_ps_t.to(device=device, non_blocking=True), diff --git a/vllm/utils.py b/vllm/utils.py index 886946f285ba8..902df324bc090 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -6,6 +6,7 @@ import gc import inspect import ipaddress +import math import os import socket import subprocess @@ -752,9 +753,6 @@ def is_pin_memory_available() -> bool: elif current_platform.is_neuron(): print_warning_once("Pin memory is not supported on Neuron.") return False - elif current_platform.is_hpu(): - print_warning_once("Pin memory is not supported on HPU.") - return False elif current_platform.is_cpu() or current_platform.is_openvino(): return False return True @@ -812,6 +810,29 @@ def make_ndarray_with_pad( return padded_x +def make_ndarray_with_pad_align( + x: List[List[T]], + pad: T, + dtype: npt.DTypeLike, + *, + max_len_align: Optional[int] = None, +) -> npt.NDArray: + """ + Make a padded array from 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + # Unlike for most functions, map is faster than a genexpr over `len` + max_len = max(map(len, x), default=0) + max_len_aligned = math.ceil(max_len / max_len_align) * max_len_align + padded_x = np.full((len(x), max_len_aligned), pad, dtype=dtype) + + for ind, blocktb in enumerate(x): + assert len(blocktb) <= max_len_aligned + padded_x[ind, :len(blocktb)] = blocktb + + return padded_x def make_tensor_with_pad( x: List[List[T]], @@ -833,10 +854,39 @@ def make_tensor_with_pad( tensor = torch.from_numpy(padded_x).to(device) if pin_memory: - tensor = tensor.pin_memory() + if not current_platform.is_hpu(): + tensor = tensor.pin_memory() + else: + tensor = tensor.pin_memory("hpu") return tensor +def make_tensor_with_pad_align( + x: List[List[T]], + pad: T, + dtype: torch.dtype, + *, + max_len_align: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + pin_memory: bool = False, +) -> torch.Tensor: + """ + Make a padded tensor from 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] + padded_x = make_ndarray_with_pad_align(x, pad, np_dtype, max_len_align=max_len_align) + + tensor = torch.from_numpy(padded_x).to(device) + if pin_memory: + if not current_platform.is_hpu(): + tensor = tensor.pin_memory() + else: + tensor = tensor.pin_memory("hpu") + + return tensor def async_tensor_h2d( data: list, @@ -845,7 +895,13 @@ def async_tensor_h2d( pin_memory: bool, ) -> torch.Tensor: """Asynchronously create a tensor and copy it from host to device.""" - t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") + t = torch.tensor(data, dtype=dtype, device="cpu") + if pin_memory: + if not current_platform.is_hpu(): + t.pin_memory() + else: + t.pin_memory(device="hpu") + return t.to(device=target_device, non_blocking=True) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index ac3270d1c9909..2841b390589fe 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, is_pin_memory_available) +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -77,11 +78,18 @@ def _allocate_kv_cache( # null block in CpuGpuBlockAllocator requires at least that # block to be zeroed-out. # We zero-out everything for simplicity. - kv_cache.append( - torch.zeros(kv_cache_shape, - dtype=self.dtype, - pin_memory=pin_memory, - device=device)) + if pin_memory: + if current_platform.is_hpu(): + kv_cache.append( + torch.zeros(kv_cache_shape, + dtype=self.dtype, + pin_memory=pin_memory, + device=device)) + else: + kv_cache.append( + torch.zeros(kv_cache_shape, + dtype=dtype, + device=device).pin_memory(device="hpu")) return kv_cache def swap_in(self, src_to_dst: torch.Tensor) -> None: From 0004cc5db33c786e074c64bda4a6de2ad904bd70 Mon Sep 17 00:00:00 2001 From: Huanxing Date: Wed, 30 Oct 2024 01:19:27 +0000 Subject: [PATCH 2/5] revert change in cache_engine as it is not used in HPU --- vllm/worker/cache_engine.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 2841b390589fe..56a1dcf61afd2 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -78,18 +78,11 @@ def _allocate_kv_cache( # null block in CpuGpuBlockAllocator requires at least that # block to be zeroed-out. # We zero-out everything for simplicity. - if pin_memory: - if current_platform.is_hpu(): - kv_cache.append( - torch.zeros(kv_cache_shape, - dtype=self.dtype, - pin_memory=pin_memory, - device=device)) - else: - kv_cache.append( - torch.zeros(kv_cache_shape, - dtype=dtype, - device=device).pin_memory(device="hpu")) + kv_cache.append( + torch.zeros(kv_cache_shape, + dtype=self.dtype, + pin_memory=pin_memory, + device=device)) return kv_cache def swap_in(self, src_to_dst: torch.Tensor) -> None: From 6bb5cb99aa1cee5e02b23cf32b40908bc61d57d0 Mon Sep 17 00:00:00 2001 From: Huanxing Date: Wed, 30 Oct 2024 01:29:40 +0000 Subject: [PATCH 3/5] remove unnecessary import --- vllm/worker/cache_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 56a1dcf61afd2..ac3270d1c9909 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -8,7 +8,6 @@ from vllm.logger import init_logger from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, is_pin_memory_available) -from vllm.platforms import current_platform logger = init_logger(__name__) From ec38d6714eeaa6ea3d3e1a928f84d434d142dde9 Mon Sep 17 00:00:00 2001 From: huanxing Date: Wed, 13 Nov 2024 09:29:35 +0800 Subject: [PATCH 4/5] fix yaff --- vllm/model_executor/sampling_metadata.py | 51 ++++++++++++++++-------- vllm/utils.py | 3 +- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 500cbf6e0b2dd..e87999a77838a 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -9,7 +9,8 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, SequenceGroupMetadata) from vllm.utils import (PyObjectCache, async_tensor_h2d, - is_pin_memory_available, make_tensor_with_pad, + is_pin_memory_available, + make_tensor_with_pad, make_tensor_with_pad_align) _SAMPLING_EPS = 1e-5 @@ -523,22 +524,38 @@ def from_lists( do_penalties = prompt_tokens or output_tokens if do_penalties: - prompt_t = make_tensor_with_pad_align( - prompt_tokens, - vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=pin_memory, - max_len_align=1024, - ) - output_t = make_tensor_with_pad_align( - output_tokens, - vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=pin_memory, - max_len_align=1024, - ) + if current_platform.is_hpu(): + prompt_t = make_tensor_with_pad_align( + prompt_tokens, + vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=pin_memory, + max_len_align=1024, + ) + output_t = make_tensor_with_pad_align( + output_tokens, + vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=pin_memory, + max_len_align=1024, + ) + else: + prompt_t = make_tensor_with_pad( + prompt_tokens, + vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=pin_memory, + ) + output_t = make_tensor_with_pad( + output_tokens, + vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=pin_memory, + ) else: empty_tensor = torch.empty(0, device=device, dtype=torch.long) prompt_t = empty_tensor diff --git a/vllm/utils.py b/vllm/utils.py index 902df324bc090..f916fcd9c07b0 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -877,7 +877,8 @@ def make_tensor_with_pad_align( `max_len`. """ np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] - padded_x = make_ndarray_with_pad_align(x, pad, np_dtype, max_len_align=max_len_align) + padded_x = make_ndarray_with_pad_align(x, pad, np_dtype, + max_len_align=max_len_align) tensor = torch.from_numpy(padded_x).to(device) if pin_memory: From 5016cc41745bdbcbda068efda760b2fb65a29a14 Mon Sep 17 00:00:00 2001 From: huanxing Date: Thu, 14 Nov 2024 09:26:32 +0800 Subject: [PATCH 5/5] remove device check in the hpu only method --- vllm/utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index f916fcd9c07b0..f9903d59cf149 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -882,10 +882,7 @@ def make_tensor_with_pad_align( tensor = torch.from_numpy(padded_x).to(device) if pin_memory: - if not current_platform.is_hpu(): - tensor = tensor.pin_memory() - else: - tensor = tensor.pin_memory("hpu") + tensor = tensor.pin_memory("hpu") return tensor