From 174b706f33131568103dd8321c3b904d8434368b Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Tue, 27 Aug 2024 19:03:21 +0300 Subject: [PATCH] fix case where max_num_batched_tokens is not divisible by bucket batch size --- vllm/worker/habana_model_runner.py | 68 ++++++++++++++++++------------ 1 file changed, 42 insertions(+), 26 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index bf266ca768850..e4b483d557d7e 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -191,6 +191,9 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], fill=True): def warmup_buckets(bs_bucket_config, seq_bucket_config, max_num_batched_tokens, block_size, max_model_len): + # NOTE(kzawora): If either max_model_len or max_num_batched_tokens are + # not divisible by block_size, we round them up here to block_size + bs_buckets = warmup_range(bs_bucket_config[:3]) seq_len_buckets = warmup_range(seq_bucket_config[:3]) if bs_bucket_config[3] != 0 and len(bs_buckets) > bs_bucket_config[3]: @@ -199,49 +202,62 @@ def warmup_buckets(bs_bucket_config, seq_bucket_config, max_num_batched_tokens, seq_len_buckets) > seq_bucket_config[3]: seq_len_buckets = warmup_range_with_limit(seq_bucket_config) - # NOTE(kzawora): When using sequence limits, we need to make sure we still - # cover the largest valid scenario (max_num_batched_tokens/bs). If we drop - # out-of-bound buckets, we might be left with last bucket that's too small - # and encounter recompilations when exceeding that bucket and still being - # in valid range - + buckets = list(itertools.product(bs_buckets, seq_len_buckets)) + + # Remove buckets that either exceed token budget + # or are greater than max model length + filtered_buckets = list( + filter( + lambda bucket: bucket[0] * bucket[1] <= math.ceil( + max_num_batched_tokens / block_size) * block_size and bucket[1] + <= max_model_len, buckets)) + + # NOTE(kzawora): We need to make sure we cover the largest valid scenario + # (max_num_batched_tokens/bs). If we drop out-of-bound buckets, we might + # be left with last bucket that's too small and encounter recompilations + # when exceeding that bucket and still being in the valid range. + # Example: - # max_model_len = 768 - # max_num_batched_tokens = 1536 - # bs_buckets = [1, 2, 4] + # block_size = 128 + # max_model_len = 768 (6 blocks) + # max_num_batched_tokens = 1536 (12 blocks) + # bs_buckets = [1, 2, 4, 5] # seq_len_buckets = [128, 256, 512, 1024] # buckets = [ # (1,128), (1,256), (1,512), (1,1024), # last one is invalid, 1024 > 768 # (2,128), (2,256), (2,512), (2,1024) # last one is invalid, 2048 > 1536 # (4,128), (4,256), (4,512), (4,1024) # last two are invalid, 2048 > 1536 + # (5,128), (5,256), (5,512), (5,1024) # last two are invalid, 2560 > 1536 # ] # filtered_buckets = [ # (1,128), (1,256), (1,512), # 512-768 is valid and not bucketed # (2,128), (2,256), (2,512), # 512-768 is valid and not bucketed # (4,128), (4,256), # 256-384 is valid and not bucketed + # (5,128), (5,256), # 256-307 is valid and not bucketed # ] # boundary_buckets = [ # (1,768), # covers 512-768 range, constrained by max_model_len # (2,768) # covers 512-768 range, constrained by budget and max_model_len # (4,384) # covers 256-384 range, constrained by token budget - # ] - - # Generate the largest valid bucket for all batch sizes. - boundary_buckets = [(bs, - min( - max_model_len, - math.ceil(max_num_batched_tokens / - (bs * block_size)) * block_size)) - for bs in bs_buckets] - buckets = list(itertools.product(bs_buckets, seq_len_buckets)) - buckets = list(set([*buckets, *boundary_buckets])) + # (5,384) # covers 256-307 range, constrained by token budget, + # # rounded up to block_size(=128) + # ] + # While (5,384) bucket exceeds the token budget of 1536 (5*386=1920), + # it is the smallest bucket that can handle all valid use cases on bs=5, + # as all buckets for decode phase need to be aligned to block_size. + + # Generate the largest valid bucket for all valid batch sizes. + # DO NOT use bs_buckets in here as it can include invalid batches + valid_bs_buckets = set(list(zip(*filtered_buckets))[0]) + boundary_buckets = [ + (bs, + min( + math.ceil(max_model_len / block_size) * block_size, + math.ceil(max_num_batched_tokens / (bs * block_size)) * + block_size)) for bs in valid_bs_buckets + ] + filtered_buckets = list(set([*filtered_buckets, *boundary_buckets])) - # Remove buckets that either exceed token budget - # or are greater than max model length - filtered_buckets = filter( - lambda bucket: bucket[0] * bucket[1] <= math.ceil( - max_num_batched_tokens / block_size) * block_size and bucket[1] <= - max_model_len, buckets) return list( sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))