From 9abadba502916eeb0432c6a8c300e09d0c3a5a48 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 28 Aug 2024 11:39:33 +0200 Subject: [PATCH] Make max_num_batched_tokens behavior more verbose, add legacy mode (#208) Addressing issues from https://github.com/HabanaAI/vllm-fork/pull/207 Now, filtering behavior is more verbose, handling common errors and displaying numbers of omitted buckets due to token budget (in debug log level, buckets are printed): ``` INFO 08-27 20:57:27 profiler.py:62] Profiler enabled for: vllm-instance-1ab4f6c4d726480d8825044cf74e9af1 WARNING 08-27 20:57:27 utils.py:566] Pin memory is not supported on HPU. INFO 08-27 20:57:27 selector.py:85] Using HabanaAttention backend. INFO 08-27 20:57:27 habana_model_runner.py:563] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 64], seq:[128, 128, 1024] INFO 08-27 20:57:27 habana_model_runner.py:576] Generated 23 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (8, 128), (8, 256), (16, 128)] INFO 08-27 20:57:27 habana_model_runner.py:581] Omitted 33 prompt buckets due to exceeded token budget (max_num_batched_tokens=2048) INFO 08-27 20:57:27 habana_model_runner.py:589] Decode bucket config (min, step, max_warmup) bs:[1, 128, 256], seq:[128, 128, 2048] INFO 08-27 20:57:27 habana_model_runner.py:600] Generated 31 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (8, 128), (8, 256), (16, 128)] INFO 08-27 20:57:27 habana_model_runner.py:605] Omitted 113 decode buckets due to exceeded token budget (max_num_batched_tokens=2048) ``` Legacy mode was also added, which throws a nasty error message whenever token budget is set too low, but then it omits filtering and works as it did previously (ran with ``VLLM_DECODE_BS_BUCKET_MIN=128 VLLM_DECODE_SEQ_BUCKET_MIN=1024 python vllm_test.py --max-num-batched-tokens=2048``): ``` INFO 08-27 21:01:02 profiler.py:62] Profiler enabled for: vllm-instance-51f60d3978d347e992436f1dc0aa4702 WARNING 08-27 21:01:02 utils.py:566] Pin memory is not supported on HPU. INFO 08-27 21:01:02 selector.py:85] Using HabanaAttention backend. INFO 08-27 21:01:02 habana_model_runner.py:563] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 64], seq:[128, 128, 1024] INFO 08-27 21:01:02 habana_model_runner.py:576] Generated 23 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (8, 128), (8, 256), (16, 128)] INFO 08-27 21:01:02 habana_model_runner.py:581] Omitted 33 prompt buckets due to exceeded token budget (max_num_batched_tokens=2048) INFO 08-27 21:01:02 habana_model_runner.py:589] Decode bucket config (min, step, max_warmup) bs:[128, 128, 256], seq:[1024, 128, 2048] ERROR 08-27 21:01:02 habana_model_runner.py:128] The current bucketing configuration (min, step, max_warmup): bs:[128, 128, 256], seq:[1024, 128, 2048] cannot be used with specified max_num_batched_tokens (2048), as the smallest bucket (16384) would exceed token budget. Please increase max_num_batched_tokens or decrease bucket minimum Ignoring max_num_batched_tokens at risk of out-of-memory errors. INFO 08-27 21:01:02 habana_model_runner.py:600] Generated 32 decode buckets: [(128, 128), (128, 256), (128, 384), (128, 512), (128, 640), (128, 768), (128, 896), (128, 1024), (128, 1152), (128, 1280), (128, 1408), (128, 1536), (128, 1664), (128, 1792), (128, 1920), (128, 2048), (256, 128), (256, 256), (256, 384), (256, 512), (256, 640), (256, 768), (256, 896), (256, 1024), (256, 1152), (256, 1280), (256, 1408), (256, 1536), (256, 1664), (256, 1792), (256, 1920), (256, 2048)] INFO 08-27 21:01:02 habana_model_runner.py:605] Omitted 0 decode buckets due to exceeded token budget (max_num_batched_tokens=2048) ``` --- vllm/worker/habana_model_runner.py | 70 +++++++++++++++++++++++++----- 1 file changed, 58 insertions(+), 12 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 62a9e814a5ac4..6627ba1ea5643 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -96,14 +96,44 @@ def warmup_range(config: Tuple[int, int, int]): def warmup_buckets(bs_bucket_config, seq_bucket_config, max_num_batched_tokens): - buckets = itertools.product(warmup_range(bs_bucket_config), - warmup_range(seq_bucket_config)) + buckets = list( + itertools.product(warmup_range(bs_bucket_config), + warmup_range(seq_bucket_config))) + if len(buckets) == 0: + msg = ("No buckets could be captured with following config " + f"(min, step, max_warmup): " + f"bs:{bs_bucket_config}, " + f"seq:{seq_bucket_config}") + raise ValueError(msg) + # Remove buckets exceeding batch token budget - filtered_buckets = filter( - lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens, - buckets) - return list( + filtered_buckets = list( + filter(lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens, + buckets)) + + if len(filtered_buckets) == 0: + # legacy case - we can handle this if we ignore max_num_batched_tokens + min_bucket_bs, min_bucket_seq = min(buckets, + key=lambda b: (b[0] * b[1])) + min_reqd_budget = min_bucket_bs * min_bucket_seq + msg = ( + "The current bucketing configuration " + f"(min, step, max_warmup): " + f"bs:{bs_bucket_config}, " + f"seq:{seq_bucket_config} cannot be used with specified " + f"max_num_batched_tokens ({max_num_batched_tokens}), as the " + f"smallest bucket ({min_reqd_budget}) would exceed token budget. " + "Please increase max_num_batched_tokens or decrease bucket minimum " + "Ignoring max_num_batched_tokens at risk of out-of-memory errors.") + logger.error(msg) + return list(sorted(buckets, key=lambda b: + (b[0] * b[1], b[1], b[0]))), [] + + captured_buckets = list( sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) + omitted_buckets = list( + sorted([x for x in buckets if x not in filtered_buckets])) + return captured_buckets, omitted_buckets def next_pow2(value: int): @@ -531,9 +561,9 @@ def _setup_buckets(self) -> None: f"bs:{self.prompt_bs_bucket_cfg}, " f"seq:{self.prompt_seq_bucket_cfg}") logger.info(msg) - self.prompt_buckets = warmup_buckets(self.prompt_bs_bucket_cfg, - self.prompt_seq_bucket_cfg, - self.max_num_batched_tokens) + self.prompt_buckets, prompt_omitted_buckets = warmup_buckets( + self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, + self.max_num_batched_tokens) if self.lora_config: self.prompt_buckets[:] = [ @@ -545,13 +575,21 @@ def _setup_buckets(self) -> None: f"prompt buckets: {list(sorted(self.prompt_buckets))}") logger.info(msg) + msg = (f"Omitted {len(prompt_omitted_buckets)} " + "prompt buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") + logger.info(msg) + + msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" + logger.debug(msg) + msg = ("Decode bucket config (min, step, max_warmup) " f"bs:{self.decode_bs_bucket_cfg}, " f"seq:{self.decode_seq_bucket_cfg}") logger.info(msg) - self.decode_buckets = warmup_buckets(self.decode_bs_bucket_cfg, - self.decode_seq_bucket_cfg, - self.max_num_batched_tokens) + self.decode_buckets, decode_omitted_buckets = warmup_buckets( + self.decode_bs_bucket_cfg, self.decode_seq_bucket_cfg, + self.max_num_batched_tokens) if self.lora_config: self.decode_buckets[:] = [ bucket for bucket in self.decode_buckets @@ -561,6 +599,14 @@ def _setup_buckets(self) -> None: f"{list(sorted(self.decode_buckets))}") logger.info(msg) + msg = (f"Omitted {len(decode_omitted_buckets)} " + "decode buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") + logger.info(msg) + + msg = f"Omitted decode buckets: {list(sorted(decode_omitted_buckets))}" + logger.debug(msg) + def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata],