diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 62a9e814a5ac..e4b483d557d7 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -15,6 +15,7 @@ Optional, Set, Tuple, Type, TypeVar, Union) import habana_frameworks.torch as htorch +import numpy as np import torch from vllm.attention import AttentionMetadata, get_attn_backend @@ -60,7 +61,7 @@ def read_bucket_settings(phase: str, dim: str, **defaults): param is either 'min', 'step' or 'max' example env variable: VLLM_DECODE_BS_BUCKET_STEP=128 """ - params = ['min', 'step', 'max'] + params = ['min', 'step', 'max', 'limit'] values = [ int( os.environ.get(f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper(), @@ -94,14 +95,169 @@ def warmup_range(config: Tuple[int, int, int]): return list(ramp_up_tw) + list(stable) -def warmup_buckets(bs_bucket_config, seq_bucket_config, - max_num_batched_tokens): - buckets = itertools.product(warmup_range(bs_bucket_config), - warmup_range(seq_bucket_config)) - # Remove buckets exceeding batch token budget - filtered_buckets = filter( - lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens, - buckets) +def warmup_range_with_limit(config: Tuple[int, int, int, int], fill=True): + """ + NOTE(kzawora): we'll use exponential spacing for buckets in which scaled + power will return bmin for first bucket iteration, and bmax for last + iteration, with elements between determined by the exponent, and base being + unchanged. Note that after padding to bstep, duplicates may occur. + Handling of duplicates is configured by fill parameter. + + If fill is False, duplicates are removed and less buckets are returned. + + If fill is True, duplicates are resolved by selecting the closest (larger + or smaller) bucket. If duplicate resolution is not possible, less buckets + are returned. In that case, buckets are guaranteed to be linearly spaced. + + Example (bmin=128, bstep=128, bmax=2048, num_buckets=10): + + There are 16 possible buckets (2048/128), and we'll attempt to select 10 of + them with exponential spacing. + base = (bmax/bmin) ** (1/(num_buckets-1)); (2048/128) ** (1/9) = 1.36079 + exponent = i + power = base ** exponent + scaled_power = b_min * power + + For i == 0 (first bucket), power is 1.36079 ** 0 = 1; + scaled_power is 1 * 128 = 128 (==bmin) + For i == 9 (last bucket), power is 1.36079 ** 9 = 16; + scaled_power is 16 * 128 = 2048 (==bmax) + + So, computing for all buckets: + scaled_powers_unpadded = [bmin*base^0(==bmin), bmin*base^1, bmin*base^2, ..., bmin*base^9(==bmax)] + scaled_powers_unpadded = [128.00, 174.18, 237.02, 322.54, 438.91, 597.26, 812.75, 1105.98, 1505.01, 2048.00] + + if fill is False: + scaled_powers_padded = [ 128, 256, 256, 384, 512, 640, 896, 1152, 1536, 2048] + ^_______^ + duplicates + + buckets = [ 128, 256, 384, 512, 640, 896, 1152, 1536, 2048] + ^ + duplicate bucket removed + + len(buckets) = 9, num_buckets = 10 + + if fill is True: + buckets = [ 128, 256, 384, 512, 640, 768, 896, 1152, 1536, 2048] + ^_______^_______^_______^ + closest unused buckets selected + ^_______^_______^ + these become duplicates once previous duplicates are resolved + + In this case we'll have four duplicated buckets: + + 174.18 -> 256, optimal bucket, + 237.02 -> (256) -> 384, taking closest available bucket, + as optimal bucket 256 was already captured by 174.18, + 322.54 -> (384) -> 512, taking closest available bucket, + as optimal bucket 384 was already captured by 237.02, + 438.91 -> (512) -> 640, taking closest available bucket, + as optimal bucket 512 was already captured by 322.54, + 597.26 -> (640) -> 768, taking closest available bucket, + as optimal bucket 640 was already captured by 438.91, + 812.75 -> 896, optimal bucket + + len(buckets) = 10, num_buckets = 10 + + In this case, the end result has the same buckets as fill=False, + but with additional bucket 768 added. + The difference is more pronounced for larger ranges and larger number + of buckets. + + """ # noqa: E501 + + bmin, bstep, bmax, num_buckets = config + linear_buckets = set(np.arange(bmin, bmax + 1, step=bstep)) + assert num_buckets > 0, "num_buckets must be a positive integer" + if num_buckets == 1: + return [bmax] + buckets: Set[Tuple[int, int]] = set() + for i in range(num_buckets): + power_unpadded = bmin * np.float_power( + bmax / bmin, (1. / float(num_buckets - 1)) * i) + bucket = math.ceil(power_unpadded / bstep) * bstep + if fill and bucket in buckets: + available_buckets = linear_buckets.difference(buckets) + if len(available_buckets) == 0: + break # there are no more unique buckets, let's exit now + new_bucket = min(available_buckets, + key=lambda x: abs(x - power_unpadded)) + buckets.add(new_bucket) + else: + buckets.add(bucket) + return list(sorted(buckets)) + + +def warmup_buckets(bs_bucket_config, seq_bucket_config, max_num_batched_tokens, + block_size, max_model_len): + # NOTE(kzawora): If either max_model_len or max_num_batched_tokens are + # not divisible by block_size, we round them up here to block_size + + bs_buckets = warmup_range(bs_bucket_config[:3]) + seq_len_buckets = warmup_range(seq_bucket_config[:3]) + if bs_bucket_config[3] != 0 and len(bs_buckets) > bs_bucket_config[3]: + bs_buckets = warmup_range_with_limit(bs_bucket_config) + if seq_bucket_config[3] != 0 and len( + seq_len_buckets) > seq_bucket_config[3]: + seq_len_buckets = warmup_range_with_limit(seq_bucket_config) + + buckets = list(itertools.product(bs_buckets, seq_len_buckets)) + + # Remove buckets that either exceed token budget + # or are greater than max model length + filtered_buckets = list( + filter( + lambda bucket: bucket[0] * bucket[1] <= math.ceil( + max_num_batched_tokens / block_size) * block_size and bucket[1] + <= max_model_len, buckets)) + + # NOTE(kzawora): We need to make sure we cover the largest valid scenario + # (max_num_batched_tokens/bs). If we drop out-of-bound buckets, we might + # be left with last bucket that's too small and encounter recompilations + # when exceeding that bucket and still being in the valid range. + + # Example: + # block_size = 128 + # max_model_len = 768 (6 blocks) + # max_num_batched_tokens = 1536 (12 blocks) + # bs_buckets = [1, 2, 4, 5] + # seq_len_buckets = [128, 256, 512, 1024] + # buckets = [ + # (1,128), (1,256), (1,512), (1,1024), # last one is invalid, 1024 > 768 + # (2,128), (2,256), (2,512), (2,1024) # last one is invalid, 2048 > 1536 + # (4,128), (4,256), (4,512), (4,1024) # last two are invalid, 2048 > 1536 + # (5,128), (5,256), (5,512), (5,1024) # last two are invalid, 2560 > 1536 + # ] + # filtered_buckets = [ + # (1,128), (1,256), (1,512), # 512-768 is valid and not bucketed + # (2,128), (2,256), (2,512), # 512-768 is valid and not bucketed + # (4,128), (4,256), # 256-384 is valid and not bucketed + # (5,128), (5,256), # 256-307 is valid and not bucketed + # ] + # boundary_buckets = [ + # (1,768), # covers 512-768 range, constrained by max_model_len + # (2,768) # covers 512-768 range, constrained by budget and max_model_len + # (4,384) # covers 256-384 range, constrained by token budget + # (5,384) # covers 256-307 range, constrained by token budget, + # # rounded up to block_size(=128) + # ] + # While (5,384) bucket exceeds the token budget of 1536 (5*386=1920), + # it is the smallest bucket that can handle all valid use cases on bs=5, + # as all buckets for decode phase need to be aligned to block_size. + + # Generate the largest valid bucket for all valid batch sizes. + # DO NOT use bs_buckets in here as it can include invalid batches + valid_bs_buckets = set(list(zip(*filtered_buckets))[0]) + boundary_buckets = [ + (bs, + min( + math.ceil(max_model_len / block_size) * block_size, + math.ceil(max_num_batched_tokens / (bs * block_size)) * + block_size)) for bs in valid_bs_buckets + ] + filtered_buckets = list(set([*filtered_buckets, *boundary_buckets])) + return list( sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) @@ -118,15 +274,6 @@ def round_up(value: int, k: int): return (value + k - 1) // k * k -def find_bucket(value: int, config: Tuple[int, int, int]): - bmin, bstep, bmax = config - if value < bstep: - result = min(next_pow2(value), bstep) - else: - result = round_up(value, bstep) - return result - - def subtuple(obj: object, typename: str, to_copy: List[str], @@ -509,31 +656,39 @@ def _setup_buckets(self) -> None: step=32, max=min( self.max_num_seqs, - max_bucket_cfg)) + max_bucket_cfg), + limit=0) self.decode_bs_bucket_cfg = read_bucket_settings('decode', 'bs', min=1, step=128, - max=self.max_num_seqs) - self.prompt_seq_bucket_cfg = read_bucket_settings('prompt', - 'seq', - min=self.block_size, - step=self.block_size, - max=1024) - self.decode_seq_bucket_cfg = read_bucket_settings('decode', - 'seq', - min=self.block_size, - step=self.block_size, - max=2048) + max=self.max_num_seqs, + limit=0) + self.prompt_seq_bucket_cfg = read_bucket_settings( + 'prompt', + 'seq', + min=self.block_size, + step=self.block_size, + max=self.max_model_len, + limit=0) + self.decode_seq_bucket_cfg = read_bucket_settings( + 'decode', + 'seq', + min=self.block_size, + step=self.block_size, + max=self.max_model_len, + limit=0) self.graphed_buckets: Set[Any] = set() - msg = ("Prompt bucket config (min, step, max_warmup) " + msg = ("Prompt bucket config (min, step, max_warmup, limit) " f"bs:{self.prompt_bs_bucket_cfg}, " f"seq:{self.prompt_seq_bucket_cfg}") logger.info(msg) self.prompt_buckets = warmup_buckets(self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, - self.max_num_batched_tokens) + self.max_num_batched_tokens, + self.block_size, + self.max_model_len) if self.lora_config: self.prompt_buckets[:] = [ @@ -545,18 +700,32 @@ def _setup_buckets(self) -> None: f"prompt buckets: {list(sorted(self.prompt_buckets))}") logger.info(msg) - msg = ("Decode bucket config (min, step, max_warmup) " + msg = ("Decode bucket config (min, step, max_warmup, limit) " f"bs:{self.decode_bs_bucket_cfg}, " f"seq:{self.decode_seq_bucket_cfg}") logger.info(msg) self.decode_buckets = warmup_buckets(self.decode_bs_bucket_cfg, self.decode_seq_bucket_cfg, - self.max_num_batched_tokens) + self.max_num_batched_tokens, + self.block_size, + self.max_model_len) if self.lora_config: self.decode_buckets[:] = [ bucket for bucket in self.decode_buckets if self._is_valid_bucket(bucket) ] + + find_bucket = lambda bucket, x: next(p for p in sorted(bucket) + if p >= x) + get_bucket_dim = lambda bucket, dim: [b[dim] for b in bucket] + self.find_bs_bucket = lambda x, is_prompt: find_bucket( + get_bucket_dim( + self.prompt_buckets + if is_prompt else self.decode_buckets, 0), x) + self.find_seq_bucket = lambda x, is_prompt: find_bucket( + get_bucket_dim( + self.prompt_buckets + if is_prompt else self.decode_buckets, 1), x) msg = (f"Generated {len(self.decode_buckets)} decode buckets: " f"{list(sorted(self.decode_buckets))}") logger.info(msg) @@ -688,9 +857,8 @@ def _prepare_prompt( multi_modal_input = None max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) - max_prompt_len = max( - find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg), - self.block_size) + max_prompt_len = max(self.find_seq_bucket(max(seq_lens), True), + self.block_size) for seq_group_metadata, context_len in zip(seq_group_metadata_list, context_lens): @@ -896,9 +1064,7 @@ def prepare_input_tensors( self.profiler.start('internal', base_event_name) real_batch_size = len(seq_group_metadata_list) - bucket_cfg = self.prompt_bs_bucket_cfg if is_prompt else \ - self.decode_bs_bucket_cfg - batch_size_padded = find_bucket(real_batch_size, bucket_cfg) + batch_size_padded = self.find_bs_bucket(real_batch_size, is_prompt) batch_size_padding = batch_size_padded - real_batch_size seq_group_metadata_list = seq_group_metadata_list.copy() seq_group_metadata_list.extend(seq_group_metadata_list[0] @@ -1087,16 +1253,30 @@ def create_dummy_seq_group_metadata(self, def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - max_batch_size = self.prompt_bs_bucket_cfg[-1] - max_seq_len = self.prompt_seq_bucket_cfg[-1] - if self.lora_config: - max_seq_len = self.max_num_batched_tokens // max_batch_size - self.warmup_scenario(max_batch_size, - max_seq_len, - True, - kv_caches, - is_profile_run=True) + # take two largest buckets in terms of token counts + # first with largest batch size, second with largest seq length + # if both are the same, only single warmup run is executed + warmup_scenarios = set() + warmup_scenarios.add( + max(self.prompt_buckets, + key=lambda item: (item[0], item[0] * item[1]))) + warmup_scenarios.add( + max(self.prompt_buckets, + key=lambda item: (item[1], item[0] * item[1]))) + for idx, (max_batch_size, max_seq_len) in enumerate(warmup_scenarios): + msg = (f'[{idx+1}/{len(warmup_scenarios)}] ' + f'Executing profile run for bs={max_batch_size}, ' + f'seq_len={max_seq_len}, ' + f'num_batched_tokens={max_batch_size*max_seq_len}') + logger.info(msg) + if self.lora_config: + max_seq_len = self.max_num_batched_tokens // max_batch_size + self.warmup_scenario(max_batch_size, + max_seq_len, + True, + kv_caches, + is_profile_run=True) def warmup_scenario(self, batch_size,