diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 9572588ce6e53..9d0ecb820548e 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -7,7 +7,7 @@ import torch from transformers import GenerationConfig, GenerationMixin -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import ApplyToppTopkScalar, Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata @@ -700,3 +700,62 @@ def test_sampling_params(sampling_params: List[SamplingParams]): assert tokens1[0] == tokens2[1] assert tokens1[1] == tokens2[0] + + +def test_topk_topk_scalar(): + obj1 = ApplyToppTopkScalar(2) + assert ApplyToppTopkScalar._padded_k == 0 + x = torch.tensor([[9, 9, 8, 8, 8, 8, 7, 7, 7.0], + [10, 10, 9, 9, 9, 8, 5, 5, 5]]) + + retval1 = obj1(x, p=0.9, k=5) + ninf = -float("inf") + expected1 = torch.tensor([[9., 9., 8., 8., 8., 8., ninf, ninf, ninf], + [10., 10., 9., 9., 9., ninf, ninf, ninf, ninf]]) + assert torch.all(retval1 == expected1).item() + assert ApplyToppTopkScalar._padded_k == 9 + + obj2 = ApplyToppTopkScalar(2) + assert obj2._padded_k == 9 + + x = torch.tensor([[2, 2, 9, 9, 2, 2, 1, 1, 1.0], + [10, 9, 9, 5, 9, 9, 5, 9, 10]]) + retval2 = obj2(x, p=0.9, k=5) + expected2 = torch.tensor( + [[ninf, ninf, 9., 9., ninf, ninf, ninf, ninf, ninf], + [10., ninf, 9., ninf, 9., 9., ninf, 9., 10.]]) + assert torch.all(retval2 == expected2).item() + assert obj2._padded_k == 9 + + retval3 = obj2(x, p=1.0, k=5) + expected3 = torch.tensor([[2., 2., 9., 9., 2., 2., ninf, ninf, ninf], + [10., 9., 9., ninf, 9., 9., ninf, 9., 10.]]) + + assert torch.all(retval3 == expected3).item() + + # this should not be done in general, doing it here for testing purposes + ApplyToppTopkScalar._padded_k = 0 + x = torch.tensor([[1, 1, 1, 9, 8, 1, 1, 1, 1.0], + [2, 1, 2, 2, 1, 1, 1, 1, 1]]) + obj3 = ApplyToppTopkScalar(2) + retval4 = obj3(x, p=0.9, k=2) + expected4 = torch.tensor( + [[ninf, ninf, ninf, 9., 8., ninf, ninf, ninf, ninf], + [2., ninf, 2., 2., ninf, ninf, ninf, ninf, ninf]]) + assert torch.all(retval4 == expected4).item() + assert obj3._padded_k == 4 + y = torch.tensor([[8, 8, 8, 9, 8, 1, 1, 1, 1.0], + [2, 1, 2, 2, 1, 1, 1, 1, 1]]) + retval5 = obj3(y, p=0.9, k=2) + assert obj3._padded_k == 8 + expected5 = torch.tensor([[8., 8., 8., 9., 8., ninf, ninf, ninf, ninf], + [2., ninf, 2., 2., ninf, ninf, ninf, ninf, + ninf]]) + assert torch.all(retval5 == expected5).item() + y = torch.tensor([[8, 8, 8, 9, 8, 8, 1, 1, 1.0], + [2, 1, 2, 2, 3, 1, 1, 1, 1]]) + retval6 = obj3(y, p=0.9, k=2) + expected6 = torch.tensor([[8., 8., 8., 9., 8., 8., ninf, ninf, ninf], + [2., ninf, 2., 2., 3., ninf, ninf, ninf, ninf]]) + assert torch.all(retval6 == expected6).item() + assert obj3._padded_k == 8 diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 6632b1c434582..c712f219150a4 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1,5 +1,6 @@ """A layer that samples the next tokens from the model's outputs.""" import itertools +import math from math import inf from typing import Dict, List, Optional, Tuple @@ -77,6 +78,13 @@ def _init_sampling_tensors( self._do_penalties = do_penalties self._do_top_p_top_k = do_top_p_top_k self._do_min_p = do_min_p + self._top_p_scalar = sampling_tensors.top_ps[0].item() + self._top_k_scalar = sampling_tensors.top_ks[0].item() + scalar_p = torch.all(sampling_tensors.top_ps == self._top_p_scalar) + scalar_k = torch.all(sampling_tensors.top_ks == self._top_k_scalar) + self._scalar_p_and_k = (scalar_p and scalar_k).item() + if self._scalar_p_and_k and self._do_top_p_top_k: + self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5) def forward( self, @@ -122,8 +130,13 @@ def forward( logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) if do_top_p_top_k: - logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, - sampling_tensors.top_ks) + if self._scalar_p_and_k: + logits = self._apply_top_k_top_p_opt(logits, + self._top_p_scalar, + self._top_k_scalar) + else: + logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) if do_min_p: logits = _apply_min_p(logits, sampling_tensors.min_ps) @@ -198,6 +211,105 @@ def _get_bin_counts_and_mask( return bin_counts, mask +class ApplyToppTopkScalar(): + """ + The original implementation of _apply_top_k_top_p is more general + as it uses vector topp, topk + However in a lot of cases, topp and topk is same for all batch elements + For such "scalar" topp, topk cases, we can use this class + + The main optimizations in this class is: + Use topk instead of sort, which is much faster especially for small k. + However just using topk might not suffice in cases as shown below + Consider a tensor: 9 9 8 8 8 8 7 7 7 + Topk, with k=5, on this yields 9 9 8 8 8 + The value "8" is on the boundary, hence the last "8" gets snipped off + However the original implementation accepts all the "8", + so it should output: + 9 9 8 8 8 8 (6 values, even though k=5) + To ensure these semantics, we perform topk with __padded_k elements + If we find more boundary elements left over, + then we keep incrementing __padded_k + and in future calls use the expanded value of __padded_k + + The increments to __padded_k should be done + with value > 1 to prevent excessive recompilations + due to dynamic shapes (the output shape of the topk) + + The main logic of this is in __call__ + This is a class instead of a function, just to keep track of + the monotonic non-decreasing state __padded_k + """ + _padded_k = 0 + + def __init__(self, increment: int): + self._increment = increment + + def __call__(self, logits: torch.Tensor, p: float, k: int): + if k > ApplyToppTopkScalar._padded_k: + ApplyToppTopkScalar._padded_k = min(k + self._increment, + logits.shape[1]) + #print("Increment padded_k to ", ApplyToppTopkScalar._padded_k) + while (True): + topk_results = torch.topk(logits, + k=ApplyToppTopkScalar._padded_k, + dim=1, + sorted=True) + # TODO, we may not need this flip, + # which might make this slightly faster + idx = torch.fliplr(topk_results.indices) + vals = torch.fliplr(topk_results.values) + + if ApplyToppTopkScalar._padded_k == logits.shape[1]: + break + + smallest_of_top_k = vals[:, -k] + num_duplicates_of_smallest_of_topk = torch.sum( + logits == smallest_of_top_k.unsqueeze(1), 1) + max_num_duplicates_of_smallest_of_topk = torch.max( + num_duplicates_of_smallest_of_topk).item() + + # there are n repeats for a border + # (border meaning the smallest value of the top k). + # we do not know if only 1 or 2 or (n-1) + # of them lie outside the kth border, + # so we choose to conservatively increase by n-1 + # when num_duplicates > _padded_k - k + if max_num_duplicates_of_smallest_of_topk - 1 > ( + ApplyToppTopkScalar._padded_k - k): + incr = int( + math.ceil((max_num_duplicates_of_smallest_of_topk - 1) / + self._increment) * self._increment) + # this while loop should be traversed at most twice, + # because we dont increment by self._increment and retry + # instead we compute incr in one go + ApplyToppTopkScalar._padded_k = min( + ApplyToppTopkScalar._padded_k + incr, logits.shape[1]) + #print("Increment padded_k to ", ApplyToppTopkScalar._padded_k) + else: + break + + top_k_mask = vals.size(1) - k + + top_k_smallest_val_idx = vals.size(1) - k + top_k_mask = vals[:, top_k_smallest_val_idx].unsqueeze(1) + top_k_mask = vals < top_k_mask + vals.masked_fill_(top_k_mask, -float("inf")) + + probs_sort = vals.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum <= (1 - p) + top_p_mask[:, -1] = False + vals.masked_fill_(top_p_mask, -float("inf")) + + new_logits = torch.full(logits.shape, + -float("inf"), + device=logits.device) + new_logits.scatter_(1, idx, vals) + + return new_logits + + def _apply_min_tokens_penalty( logits: torch.Tensor, sampling_metadata: SamplingMetadata,