From 1f1e98199cc570baa3f406e3cfff0e4b95ec14d8 Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Wed, 21 Aug 2024 09:33:09 +0300 Subject: [PATCH 1/4] Handle compile-mode unwrap bug for indices length fix in LoRA --- vllm/worker/habana_model_runner.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index d129bb5cbc0ca..7f7f15bea86fa 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -1441,6 +1441,15 @@ def get_counter_dict(self, cache_config, duration, seq_len, return counters +def unwrap_model(model): + if isinstance(model, torch._dynamo.eval_frame.OptimizedModule): + return unwrap_model(model._orig_mod) + else: + model = list(vars(model)['_modules'].values())[0] + modules = list(vars(model)['_modules'].values()) + return modules + + class HabanaModelRunner( HabanaModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): """ @@ -1558,13 +1567,10 @@ def execute_model( if self.lora_config: from vllm.lora.layers import VocabParallelEmbeddingWithLoRA - property = vars(self.model.model) - model = list(property['_modules'].values())[0] - property = vars(model) - modules = list(property['_modules'].values()) + modules = unwrap_model(self.model.model) for module in modules: if isinstance(module, VocabParallelEmbeddingWithLoRA): - for i in range(0, 4): + for i in range(0, len(module.indices_len)): module.indices_len[ i] = sampling_metadata.selected_token_indices.numel( ) From aefd336798248d519ddc4cc5662c9aa03a9dbfad Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Tue, 27 Aug 2024 14:42:57 +0200 Subject: [PATCH 2/4] Ensure buckets do not exceed the batch token limit (#206) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR ensures we don't capture buckets that are above the specified token budget (as set by `max_num_batched_tokens` argument) Example for token budget of 2048 (`--max-num-batched-tokens 2048`): ``` $ python vllm_test.py --max-num-batched-tokens 2048 WARNING 08-27 14:48:55 _custom_ops.py:14] Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'") /usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:366: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead warnings.warn( No chat template is set for this tokenizer, falling back to a default class-level template. This is very error-prone, because models are often trained with templates different from the class default! Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which point any code depending on them will stop working. We recommend setting a valid chat template before then to ensure that this model continues working without issues. INFO 08-27 14:48:56 llm_engine.py:176] Initializing an LLM engine (v0.5.3.post1) with config: model='facebook/opt-125m', speculative_config=None, tokenizer='facebook/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, weights_load_device=hpu, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=hpu, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=facebook/opt-125m, use_v2_block_manager=False, enable_prefix_caching=False) generation_config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████| 137/137 [00:00<00:00, 1.91MB/s] INFO 08-27 14:48:57 profiler.py:62] Profiler enabled for: vllm-instance-d356a015eeb349f7a4650e00bf6ce976 WARNING 08-27 14:48:57 utils.py:566] Pin memory is not supported on HPU. INFO 08-27 14:48:57 selector.py:85] Using HabanaAttention backend. INFO 08-27 14:48:57 habana_model_runner.py:532] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 64], seq:[128, 128, 1024] INFO 08-27 14:48:57 habana_model_runner.py:545] Generated 23 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (8, 128), (8, 256), (16, 128)] INFO 08-27 14:48:57 habana_model_runner.py:550] Decode bucket config (min, step, max_warmup) bs:[1, 128, 256], seq:[128, 128, 2048] INFO 08-27 14:48:57 habana_model_runner.py:561] Generated 31 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (8, 128), (8, 256), (16, 128)] ============================= HABANA PT BRIDGE CONFIGURATION =========================== PT_HPU_LAZY_MODE = 1 PT_RECIPE_CACHE_PATH = PT_CACHE_FOLDER_DELETE = 0 PT_HPU_RECIPE_CACHE_CONFIG = PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807 PT_HPU_LAZY_ACC_PAR_MODE = 1 PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0 ---------------------------: System Configuration :--------------------------- Num CPU Cores : 160 CPU RAM : 1056398260 KB ------------------------------------------------------------------------------ INFO 08-27 14:49:00 selector.py:85] Using HabanaAttention backend. INFO 08-27 14:49:00 loader.py:284] Loading weights on hpu ... INFO 08-27 14:49:00 weight_utils.py:224] Using model weights format ['*.bin'] pytorch_model.bin: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 251M/251M [00:06<00:00, 35.9MB/s] Loading pt checkpoint shards: 0% Completed | 0/1 [00:00 None: f"seq:{self.prompt_seq_bucket_cfg}") logger.info(msg) self.prompt_buckets = warmup_buckets(self.prompt_bs_bucket_cfg, - self.prompt_seq_bucket_cfg) + self.prompt_seq_bucket_cfg, + self.max_num_batched_tokens) if self.lora_config: self.prompt_buckets[:] = [ @@ -543,7 +550,8 @@ def _setup_buckets(self) -> None: f"seq:{self.decode_seq_bucket_cfg}") logger.info(msg) self.decode_buckets = warmup_buckets(self.decode_bs_bucket_cfg, - self.decode_seq_bucket_cfg) + self.decode_seq_bucket_cfg, + self.max_num_batched_tokens) if self.lora_config: self.decode_buckets[:] = [ bucket for bucket in self.decode_buckets From 9abadba502916eeb0432c6a8c300e09d0c3a5a48 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 28 Aug 2024 11:39:33 +0200 Subject: [PATCH 3/4] Make max_num_batched_tokens behavior more verbose, add legacy mode (#208) Addressing issues from https://github.com/HabanaAI/vllm-fork/pull/207 Now, filtering behavior is more verbose, handling common errors and displaying numbers of omitted buckets due to token budget (in debug log level, buckets are printed): ``` INFO 08-27 20:57:27 profiler.py:62] Profiler enabled for: vllm-instance-1ab4f6c4d726480d8825044cf74e9af1 WARNING 08-27 20:57:27 utils.py:566] Pin memory is not supported on HPU. INFO 08-27 20:57:27 selector.py:85] Using HabanaAttention backend. INFO 08-27 20:57:27 habana_model_runner.py:563] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 64], seq:[128, 128, 1024] INFO 08-27 20:57:27 habana_model_runner.py:576] Generated 23 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (8, 128), (8, 256), (16, 128)] INFO 08-27 20:57:27 habana_model_runner.py:581] Omitted 33 prompt buckets due to exceeded token budget (max_num_batched_tokens=2048) INFO 08-27 20:57:27 habana_model_runner.py:589] Decode bucket config (min, step, max_warmup) bs:[1, 128, 256], seq:[128, 128, 2048] INFO 08-27 20:57:27 habana_model_runner.py:600] Generated 31 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (8, 128), (8, 256), (16, 128)] INFO 08-27 20:57:27 habana_model_runner.py:605] Omitted 113 decode buckets due to exceeded token budget (max_num_batched_tokens=2048) ``` Legacy mode was also added, which throws a nasty error message whenever token budget is set too low, but then it omits filtering and works as it did previously (ran with ``VLLM_DECODE_BS_BUCKET_MIN=128 VLLM_DECODE_SEQ_BUCKET_MIN=1024 python vllm_test.py --max-num-batched-tokens=2048``): ``` INFO 08-27 21:01:02 profiler.py:62] Profiler enabled for: vllm-instance-51f60d3978d347e992436f1dc0aa4702 WARNING 08-27 21:01:02 utils.py:566] Pin memory is not supported on HPU. INFO 08-27 21:01:02 selector.py:85] Using HabanaAttention backend. INFO 08-27 21:01:02 habana_model_runner.py:563] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 64], seq:[128, 128, 1024] INFO 08-27 21:01:02 habana_model_runner.py:576] Generated 23 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (8, 128), (8, 256), (16, 128)] INFO 08-27 21:01:02 habana_model_runner.py:581] Omitted 33 prompt buckets due to exceeded token budget (max_num_batched_tokens=2048) INFO 08-27 21:01:02 habana_model_runner.py:589] Decode bucket config (min, step, max_warmup) bs:[128, 128, 256], seq:[1024, 128, 2048] ERROR 08-27 21:01:02 habana_model_runner.py:128] The current bucketing configuration (min, step, max_warmup): bs:[128, 128, 256], seq:[1024, 128, 2048] cannot be used with specified max_num_batched_tokens (2048), as the smallest bucket (16384) would exceed token budget. Please increase max_num_batched_tokens or decrease bucket minimum Ignoring max_num_batched_tokens at risk of out-of-memory errors. INFO 08-27 21:01:02 habana_model_runner.py:600] Generated 32 decode buckets: [(128, 128), (128, 256), (128, 384), (128, 512), (128, 640), (128, 768), (128, 896), (128, 1024), (128, 1152), (128, 1280), (128, 1408), (128, 1536), (128, 1664), (128, 1792), (128, 1920), (128, 2048), (256, 128), (256, 256), (256, 384), (256, 512), (256, 640), (256, 768), (256, 896), (256, 1024), (256, 1152), (256, 1280), (256, 1408), (256, 1536), (256, 1664), (256, 1792), (256, 1920), (256, 2048)] INFO 08-27 21:01:02 habana_model_runner.py:605] Omitted 0 decode buckets due to exceeded token budget (max_num_batched_tokens=2048) ``` --- vllm/worker/habana_model_runner.py | 70 +++++++++++++++++++++++++----- 1 file changed, 58 insertions(+), 12 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 62a9e814a5ac4..6627ba1ea5643 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -96,14 +96,44 @@ def warmup_range(config: Tuple[int, int, int]): def warmup_buckets(bs_bucket_config, seq_bucket_config, max_num_batched_tokens): - buckets = itertools.product(warmup_range(bs_bucket_config), - warmup_range(seq_bucket_config)) + buckets = list( + itertools.product(warmup_range(bs_bucket_config), + warmup_range(seq_bucket_config))) + if len(buckets) == 0: + msg = ("No buckets could be captured with following config " + f"(min, step, max_warmup): " + f"bs:{bs_bucket_config}, " + f"seq:{seq_bucket_config}") + raise ValueError(msg) + # Remove buckets exceeding batch token budget - filtered_buckets = filter( - lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens, - buckets) - return list( + filtered_buckets = list( + filter(lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens, + buckets)) + + if len(filtered_buckets) == 0: + # legacy case - we can handle this if we ignore max_num_batched_tokens + min_bucket_bs, min_bucket_seq = min(buckets, + key=lambda b: (b[0] * b[1])) + min_reqd_budget = min_bucket_bs * min_bucket_seq + msg = ( + "The current bucketing configuration " + f"(min, step, max_warmup): " + f"bs:{bs_bucket_config}, " + f"seq:{seq_bucket_config} cannot be used with specified " + f"max_num_batched_tokens ({max_num_batched_tokens}), as the " + f"smallest bucket ({min_reqd_budget}) would exceed token budget. " + "Please increase max_num_batched_tokens or decrease bucket minimum " + "Ignoring max_num_batched_tokens at risk of out-of-memory errors.") + logger.error(msg) + return list(sorted(buckets, key=lambda b: + (b[0] * b[1], b[1], b[0]))), [] + + captured_buckets = list( sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) + omitted_buckets = list( + sorted([x for x in buckets if x not in filtered_buckets])) + return captured_buckets, omitted_buckets def next_pow2(value: int): @@ -531,9 +561,9 @@ def _setup_buckets(self) -> None: f"bs:{self.prompt_bs_bucket_cfg}, " f"seq:{self.prompt_seq_bucket_cfg}") logger.info(msg) - self.prompt_buckets = warmup_buckets(self.prompt_bs_bucket_cfg, - self.prompt_seq_bucket_cfg, - self.max_num_batched_tokens) + self.prompt_buckets, prompt_omitted_buckets = warmup_buckets( + self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, + self.max_num_batched_tokens) if self.lora_config: self.prompt_buckets[:] = [ @@ -545,13 +575,21 @@ def _setup_buckets(self) -> None: f"prompt buckets: {list(sorted(self.prompt_buckets))}") logger.info(msg) + msg = (f"Omitted {len(prompt_omitted_buckets)} " + "prompt buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") + logger.info(msg) + + msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" + logger.debug(msg) + msg = ("Decode bucket config (min, step, max_warmup) " f"bs:{self.decode_bs_bucket_cfg}, " f"seq:{self.decode_seq_bucket_cfg}") logger.info(msg) - self.decode_buckets = warmup_buckets(self.decode_bs_bucket_cfg, - self.decode_seq_bucket_cfg, - self.max_num_batched_tokens) + self.decode_buckets, decode_omitted_buckets = warmup_buckets( + self.decode_bs_bucket_cfg, self.decode_seq_bucket_cfg, + self.max_num_batched_tokens) if self.lora_config: self.decode_buckets[:] = [ bucket for bucket in self.decode_buckets @@ -561,6 +599,14 @@ def _setup_buckets(self) -> None: f"{list(sorted(self.decode_buckets))}") logger.info(msg) + msg = (f"Omitted {len(decode_omitted_buckets)} " + "decode buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") + logger.info(msg) + + msg = f"Omitted decode buckets: {list(sorted(decode_omitted_buckets))}" + logger.debug(msg) + def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], From 17cd6251924ef66246eeca224bb2cb09da23217b Mon Sep 17 00:00:00 2001 From: Vivek Goel Date: Thu, 29 Aug 2024 11:23:05 +0530 Subject: [PATCH 4/4] Update paddings computed to adjust selected_token_indices (#210) Fixes assert seen when "prompt_logprobs is not None" and BS > 1. Assert was due to shape of paddings being added to matching sampling_metadata.selected_token_indices shape for the case where prompt_logprobs is configured. --- vllm/worker/habana_model_runner.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 6627ba1ea5643..a975dba6f5136 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -1012,8 +1012,13 @@ def prepare_input_tensors( paddings = [max_len - s for s in seq_lens] paddings = [0] + paddings[:-1] paddings = list(itertools.accumulate(paddings)) + paddings_prompt_logprobs = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + if seq_group_metadata.sampling_params.prompt_logprobs is not None \ + and seq_group_metadata.is_prompt: + paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) paddings = torch.tensor( - paddings, + paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, dtype=sampling_metadata.selected_token_indices.dtype, device=sampling_metadata.selected_token_indices.device) sampling_metadata.selected_token_indices.add_(paddings)