From 7488c584ddb36a81a900614d434445e1d66dbcf0 Mon Sep 17 00:00:00 2001 From: Marceli Fylcek Date: Fri, 6 Sep 2024 15:16:33 +0200 Subject: [PATCH] Use PT_COMPILE_ONLY_MODE during warmup (#227) With PT_COMPILE_ONLY_MODE flag, graphs can be compiled without performing synLaunch. The flag has been added to the warmup phase to decrease its execution time. --- vllm/worker/habana_model_runner.py | 125 +++++++++++++++-------------- 1 file changed, 66 insertions(+), 59 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 166ad760d27ca..9dc02fba0213a 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -15,6 +15,7 @@ Optional, Set, Tuple, Type, TypeVar, Union) import habana_frameworks.torch as htorch +import habana_frameworks.torch.internal.bridge_config as bc import torch from vllm.attention import AttentionMetadata, get_attn_backend @@ -1402,67 +1403,73 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: self.profiler.start('internal', 'warmup') start_mem = HabanaMemoryProfiler.current_device_memory_usage() start_time = time.perf_counter() - self.warmup_all_buckets(self.prompt_buckets, True, kv_caches) - self.warmup_all_buckets(self.decode_buckets, False, kv_caches) - - if not self.enforce_eager and htorch.utils.internal.is_lazy(): - assert self.mem_margin is not None, \ - ("HabanaWorker.determine_num_available_blocks needs " - "to be called before warming up the model.") - free_mem = HabanaMemoryProfiler.current_free_device_memory() - graph_free_mem = free_mem - self.mem_margin - graph_free_mem = align_workers(graph_free_mem, - torch.distributed.ReduceOp.MIN) - prompt_graph_mem_ratio = float( - os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.5')) - prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem - decode_available_memory = graph_free_mem - prompt_available_memory - msg = (f"Using {format_bytes(graph_free_mem)}" - f"/{format_bytes(free_mem)} " - "of free device memory for HPUGraphs, " - f"{format_bytes(prompt_available_memory)} for prompt and " - f"{format_bytes(decode_available_memory)} for decode " - f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})") - logger.info(msg) - prompt_strategy = os.environ.get('VLLM_GRAPH_PROMPT_STRATEGY', - 'min_tokens') - decode_strategy = os.environ.get('VLLM_GRAPH_DECODE_STRATEGY', - 'max_bs') - mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ - self.warmup_graphs( - prompt_strategy, self.prompt_buckets, True, kv_caches, - prompt_available_memory) - mem_post_decode, decode_batch_seq, decode_captured_all = \ - self.warmup_graphs( - decode_strategy, self.decode_buckets, False, kv_caches, - decode_available_memory) - - # Not all prompt buckets were captured, but all decode buckets were - # captured and we have some free graph-allocated space left. - # Let's try to use it for capturing more prompt buckets. - if mem_post_decode + mem_post_prompt < graph_free_mem \ - and not prompt_captured_all \ - and decode_captured_all: - mem_post_prompt, _, prompt_captured_all = self.warmup_graphs( + + with bc.env_setting("PT_COMPILE_ONLY_MODE", True): + self.warmup_all_buckets(self.prompt_buckets, True, kv_caches) + self.warmup_all_buckets(self.decode_buckets, False, kv_caches) + + if not self.enforce_eager and htorch.utils.internal.is_lazy(): + assert self.mem_margin is not None, \ + ("HabanaWorker.determine_num_available_blocks needs " + "to be called before warming up the model.") + free_mem = HabanaMemoryProfiler.current_free_device_memory() + graph_free_mem = free_mem - self.mem_margin + graph_free_mem = align_workers(graph_free_mem, + torch.distributed.ReduceOp.MIN) + prompt_graph_mem_ratio = float( + os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.5')) + prompt_available_memory = (prompt_graph_mem_ratio * + graph_free_mem) + decode_available_memory = (graph_free_mem - + prompt_available_memory) + msg = ( + f"Using {format_bytes(graph_free_mem)}" + f"/{format_bytes(free_mem)} " + "of free device memory for HPUGraphs, " + f"{format_bytes(prompt_available_memory)} for prompt and " + f"{format_bytes(decode_available_memory)} for decode " + f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})") + logger.info(msg) + prompt_strategy = os.environ.get('VLLM_GRAPH_PROMPT_STRATEGY', + 'min_tokens') + decode_strategy = os.environ.get('VLLM_GRAPH_DECODE_STRATEGY', + 'max_bs') + mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ + self.warmup_graphs( prompt_strategy, self.prompt_buckets, True, kv_caches, - graph_free_mem - mem_post_prompt - mem_post_decode, - mem_post_prompt, prompt_batch_seq) - - # Not all decode buckets were captured, but all prompt buckets were - # captured and we have some free graph-allocated space left. - # Let's try to use it for capturing more decode buckets. - if mem_post_decode + mem_post_prompt < graph_free_mem \ - and not decode_captured_all \ - and prompt_captured_all: - mem_post_decode, _, _ = self.warmup_graphs( + prompt_available_memory) + mem_post_decode, decode_batch_seq, decode_captured_all = \ + self.warmup_graphs( decode_strategy, self.decode_buckets, False, kv_caches, - graph_free_mem - mem_post_prompt - mem_post_decode, - mem_post_decode, decode_batch_seq) - - self.log_graph_warmup_summary(self.prompt_buckets, True, - mem_post_prompt) - self.log_graph_warmup_summary(self.decode_buckets, False, - mem_post_decode) + decode_available_memory) + + # Not all prompt buckets were captured, but all decode buckets + # were captured and we have some free graph-allocated space + # left. Let's try to use it for capturing more prompt buckets. + if (mem_post_decode + mem_post_prompt < graph_free_mem + and not prompt_captured_all and decode_captured_all): + mem_post_prompt, _, prompt_captured_all = ( + self.warmup_graphs( + prompt_strategy, self.prompt_buckets, True, + kv_caches, + graph_free_mem - mem_post_prompt - mem_post_decode, + mem_post_prompt, prompt_batch_seq)) + + # Not all decode buckets were captured, but all prompt buckets + # were captured and we have some free graph-allocated space + # left. Let's try to use it for capturing more decode buckets. + if mem_post_decode + mem_post_prompt < graph_free_mem \ + and not decode_captured_all \ + and prompt_captured_all: + mem_post_decode, _, _ = self.warmup_graphs( + decode_strategy, self.decode_buckets, False, kv_caches, + graph_free_mem - mem_post_prompt - mem_post_decode, + mem_post_decode, decode_batch_seq) + + self.log_graph_warmup_summary(self.prompt_buckets, True, + mem_post_prompt) + self.log_graph_warmup_summary(self.decode_buckets, False, + mem_post_decode) end_time = time.perf_counter() end_mem = HabanaMemoryProfiler.current_device_memory_usage()