From cfe231d905fe9e3ecf779eaf62e5d177900a0e6e Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Fri, 4 Oct 2024 18:35:34 +0200 Subject: [PATCH] [Refactor] Rename components *Habana* -> *HPU* (#359) Refactoring Gaudi-specific components to use `hpu` name instead of `habana` (e.g. `habana_model_runner.py` -> `hpu_model_runner.py`, `habana_executor.py` -> `hpu_executor.py`, etc.), as suggested in the upstream PR. --- README_GAUDI.md | 78 +++++++++---------- .../getting_started/gaudi-installation.rst | 78 +++++++++---------- vllm/engine/async_llm_engine.py | 15 ++-- vllm/engine/llm_engine.py | 14 ++-- vllm/engine/multiprocessing/engine.py | 6 +- .../{habana_executor.py => hpu_executor.py} | 8 +- ...habana_executor.py => ray_hpu_executor.py} | 8 +- ...na_model_runner.py => hpu_model_runner.py} | 7 +- .../{habana_worker.py => hpu_worker.py} | 6 +- 9 files changed, 109 insertions(+), 111 deletions(-) rename vllm/executor/{habana_executor.py => hpu_executor.py} (97%) rename vllm/executor/{ray_habana_executor.py => ray_hpu_executor.py} (99%) rename vllm/worker/{habana_model_runner.py => hpu_model_runner.py} (99%) rename vllm/worker/{habana_worker.py => hpu_worker.py} (99%) diff --git a/README_GAUDI.md b/README_GAUDI.md index 04e2ff22f96e5..6ba3bb50d4a04 100644 --- a/README_GAUDI.md +++ b/README_GAUDI.md @@ -195,10 +195,10 @@ batch size and sequence length dimension. These parameters can be observed in logs during vLLM startup: ``` {.} -INFO 08-01 21:37:59 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] -INFO 08-01 21:37:59 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] -INFO 08-01 21:37:59 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] -INFO 08-01 21:37:59 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] +INFO 08-01 21:37:59 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] +INFO 08-01 21:37:59 hpu_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] +INFO 08-01 21:37:59 hpu_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] +INFO 08-01 21:37:59 hpu_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] ``` `min` determines the lowest value of the bucket. `step` determines the @@ -267,17 +267,17 @@ graph compilation overheads within bucket boundaries during server runtime. Each warmup step is logged during vLLM startup: ``` {.} -INFO 08-01 22:26:47 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:79.16 GiB -INFO 08-01 22:26:47 habana_model_runner.py:1066] [Warmup][Prompt][2/24] batch_size:4 seq_len:896 free_mem:55.43 GiB -INFO 08-01 22:26:48 habana_model_runner.py:1066] [Warmup][Prompt][3/24] batch_size:4 seq_len:768 free_mem:55.43 GiB +INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:79.16 GiB +INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][2/24] batch_size:4 seq_len:896 free_mem:55.43 GiB +INFO 08-01 22:26:48 hpu_model_runner.py:1066] [Warmup][Prompt][3/24] batch_size:4 seq_len:768 free_mem:55.43 GiB ... -INFO 08-01 22:26:59 habana_model_runner.py:1066] [Warmup][Prompt][24/24] batch_size:1 seq_len:128 free_mem:55.43 GiB -INFO 08-01 22:27:00 habana_model_runner.py:1066] [Warmup][Decode][1/48] batch_size:4 seq_len:2048 free_mem:55.43 GiB -INFO 08-01 22:27:00 habana_model_runner.py:1066] [Warmup][Decode][2/48] batch_size:4 seq_len:1920 free_mem:55.43 GiB -INFO 08-01 22:27:01 habana_model_runner.py:1066] [Warmup][Decode][3/48] batch_size:4 seq_len:1792 free_mem:55.43 GiB +INFO 08-01 22:26:59 hpu_model_runner.py:1066] [Warmup][Prompt][24/24] batch_size:1 seq_len:128 free_mem:55.43 GiB +INFO 08-01 22:27:00 hpu_model_runner.py:1066] [Warmup][Decode][1/48] batch_size:4 seq_len:2048 free_mem:55.43 GiB +INFO 08-01 22:27:00 hpu_model_runner.py:1066] [Warmup][Decode][2/48] batch_size:4 seq_len:1920 free_mem:55.43 GiB +INFO 08-01 22:27:01 hpu_model_runner.py:1066] [Warmup][Decode][3/48] batch_size:4 seq_len:1792 free_mem:55.43 GiB ... -INFO 08-01 22:27:16 habana_model_runner.py:1066] [Warmup][Decode][47/48] batch_size:2 seq_len:128 free_mem:55.43 GiB -INFO 08-01 22:27:16 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB +INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][47/48] batch_size:2 seq_len:128 free_mem:55.43 GiB +INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB ``` This example uses the same buckets as in *Bucketing mechanism* section. @@ -374,35 +374,35 @@ Each described step is logged by vLLM server, as follows (negative values correspond to memory being released): ``` {.} -INFO 08-02 17:37:44 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] -INFO 08-02 17:37:44 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] -INFO 08-02 17:37:44 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] -INFO 08-02 17:37:44 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] -INFO 08-02 17:37:52 habana_model_runner.py:430] Pre-loading model weights on hpu:0 took 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) -INFO 08-02 17:37:52 habana_model_runner.py:438] Wrapping in HPU Graph took 0 B of device memory (14.97 GiB/94.62 GiB used) and -252 KiB of host memory (475.2 GiB/1007 GiB used) -INFO 08-02 17:37:52 habana_model_runner.py:442] Loading model weights took in total 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) -INFO 08-02 17:37:54 habana_worker.py:134] Model profiling run took 504 MiB of device memory (15.46 GiB/94.62 GiB used) and 180.9 MiB of host memory (475.4 GiB/1007 GiB used) -INFO 08-02 17:37:54 habana_worker.py:158] Free device memory: 79.16 GiB, 39.58 GiB usable (gpu_memory_utilization=0.5), 15.83 GiB reserved for HPUGraphs (VLLM_GRAPH_RESERVED_MEM=0.4), 23.75 GiB reserved for KV cache -INFO 08-02 17:37:54 habana_executor.py:85] # HPU blocks: 1519, # CPU blocks: 0 -INFO 08-02 17:37:54 habana_worker.py:190] Initializing cache engine took 23.73 GiB of device memory (39.2 GiB/94.62 GiB used) and -1.238 MiB of host memory (475.4 GiB/1007 GiB used) -INFO 08-02 17:37:54 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB +INFO 08-02 17:37:44 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] +INFO 08-02 17:37:44 hpu_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] +INFO 08-02 17:37:44 hpu_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] +INFO 08-02 17:37:44 hpu_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] +INFO 08-02 17:37:52 hpu_model_runner.py:430] Pre-loading model weights on hpu:0 took 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) +INFO 08-02 17:37:52 hpu_model_runner.py:438] Wrapping in HPU Graph took 0 B of device memory (14.97 GiB/94.62 GiB used) and -252 KiB of host memory (475.2 GiB/1007 GiB used) +INFO 08-02 17:37:52 hpu_model_runner.py:442] Loading model weights took in total 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) +INFO 08-02 17:37:54 hpu_worker.py:134] Model profiling run took 504 MiB of device memory (15.46 GiB/94.62 GiB used) and 180.9 MiB of host memory (475.4 GiB/1007 GiB used) +INFO 08-02 17:37:54 hpu_worker.py:158] Free device memory: 79.16 GiB, 39.58 GiB usable (gpu_memory_utilization=0.5), 15.83 GiB reserved for HPUGraphs (VLLM_GRAPH_RESERVED_MEM=0.4), 23.75 GiB reserved for KV cache +INFO 08-02 17:37:54 hpu_executor.py:85] # HPU blocks: 1519, # CPU blocks: 0 +INFO 08-02 17:37:54 hpu_worker.py:190] Initializing cache engine took 23.73 GiB of device memory (39.2 GiB/94.62 GiB used) and -1.238 MiB of host memory (475.4 GiB/1007 GiB used) +INFO 08-02 17:37:54 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB ... -INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB -INFO 08-02 17:38:22 habana_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 4.755 GiB for prompt and 11.095 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.3) -INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB +INFO 08-02 17:38:22 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB +INFO 08-02 17:38:22 hpu_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 4.755 GiB for prompt and 11.095 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.3) +INFO 08-02 17:38:22 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB ... -INFO 08-02 17:38:26 habana_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB -INFO 08-02 17:38:27 habana_model_runner.py:1066] [Warmup][Graph/Decode][1/48] batch_size:4 seq_len:128 free_mem:47.51 GiB +INFO 08-02 17:38:26 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB +INFO 08-02 17:38:27 hpu_model_runner.py:1066] [Warmup][Graph/Decode][1/48] batch_size:4 seq_len:128 free_mem:47.51 GiB ... -INFO 08-02 17:38:41 habana_model_runner.py:1066] [Warmup][Graph/Decode][48/48] batch_size:1 seq_len:2048 free_mem:47.35 GiB -INFO 08-02 17:38:41 habana_model_runner.py:1066] [Warmup][Graph/Prompt][12/24] batch_size:4 seq_len:256 free_mem:47.35 GiB -INFO 08-02 17:38:42 habana_model_runner.py:1066] [Warmup][Graph/Prompt][13/24] batch_size:2 seq_len:512 free_mem:45.91 GiB -INFO 08-02 17:38:42 habana_model_runner.py:1066] [Warmup][Graph/Prompt][14/24] batch_size:1 seq_len:1024 free_mem:44.48 GiB -INFO 08-02 17:38:43 habana_model_runner.py:1066] [Warmup][Graph/Prompt][15/24] batch_size:2 seq_len:640 free_mem:43.03 GiB -INFO 08-02 17:38:43 habana_model_runner.py:1128] Graph/Prompt captured:15 (62.5%) used_mem:14.03 GiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (4, 128), (4, 256)] -INFO 08-02 17:38:43 habana_model_runner.py:1128] Graph/Decode captured:48 (100.0%) used_mem:161.9 MiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] -INFO 08-02 17:38:43 habana_model_runner.py:1206] Warmup finished in 49 secs, allocated 14.19 GiB of device memory -INFO 08-02 17:38:43 habana_executor.py:91] init_cache_engine took 37.92 GiB of device memory (53.39 GiB/94.62 GiB used) and 57.86 MiB of host memory (475.4 GiB/1007 GiB used) +INFO 08-02 17:38:41 hpu_model_runner.py:1066] [Warmup][Graph/Decode][48/48] batch_size:1 seq_len:2048 free_mem:47.35 GiB +INFO 08-02 17:38:41 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][12/24] batch_size:4 seq_len:256 free_mem:47.35 GiB +INFO 08-02 17:38:42 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][13/24] batch_size:2 seq_len:512 free_mem:45.91 GiB +INFO 08-02 17:38:42 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][14/24] batch_size:1 seq_len:1024 free_mem:44.48 GiB +INFO 08-02 17:38:43 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][15/24] batch_size:2 seq_len:640 free_mem:43.03 GiB +INFO 08-02 17:38:43 hpu_model_runner.py:1128] Graph/Prompt captured:15 (62.5%) used_mem:14.03 GiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (4, 128), (4, 256)] +INFO 08-02 17:38:43 hpu_model_runner.py:1128] Graph/Decode captured:48 (100.0%) used_mem:161.9 MiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] +INFO 08-02 17:38:43 hpu_model_runner.py:1206] Warmup finished in 49 secs, allocated 14.19 GiB of device memory +INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of device memory (53.39 GiB/94.62 GiB used) and 57.86 MiB of host memory (475.4 GiB/1007 GiB used) ``` Recommended vLLM Parameters diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index db1d8666e4800..5915de92802d9 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -173,10 +173,10 @@ Bucketing ranges are determined with 3 parameters - ``min``, ``step`` and ``max` .. code-block:: - INFO 08-01 21:37:59 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] - INFO 08-01 21:37:59 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] - INFO 08-01 21:37:59 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] - INFO 08-01 21:37:59 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] + INFO 08-01 21:37:59 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] + INFO 08-01 21:37:59 hpu_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] + INFO 08-01 21:37:59 hpu_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] + INFO 08-01 21:37:59 hpu_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] ``min`` determines the lowest value of the bucket. ``step`` determines the interval between buckets, and ``max`` determines the upper bound of the bucket. Furthermore, interval between ``min`` and ``step`` has special handling - ``min`` gets multiplied by consecutive powers of two, until ``step`` gets reached. We call this the ramp-up phase and it is used for handling lower batch sizes with minimum wastage, while allowing larger padding on larger batch sizes. @@ -216,17 +216,17 @@ Warmup is an optional, but highly recommended step occurring before vLLM server .. code-block:: - INFO 08-01 22:26:47 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:79.16 GiB - INFO 08-01 22:26:47 habana_model_runner.py:1066] [Warmup][Prompt][2/24] batch_size:4 seq_len:896 free_mem:55.43 GiB - INFO 08-01 22:26:48 habana_model_runner.py:1066] [Warmup][Prompt][3/24] batch_size:4 seq_len:768 free_mem:55.43 GiB + INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:79.16 GiB + INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][2/24] batch_size:4 seq_len:896 free_mem:55.43 GiB + INFO 08-01 22:26:48 hpu_model_runner.py:1066] [Warmup][Prompt][3/24] batch_size:4 seq_len:768 free_mem:55.43 GiB ... - INFO 08-01 22:26:59 habana_model_runner.py:1066] [Warmup][Prompt][24/24] batch_size:1 seq_len:128 free_mem:55.43 GiB - INFO 08-01 22:27:00 habana_model_runner.py:1066] [Warmup][Decode][1/48] batch_size:4 seq_len:2048 free_mem:55.43 GiB - INFO 08-01 22:27:00 habana_model_runner.py:1066] [Warmup][Decode][2/48] batch_size:4 seq_len:1920 free_mem:55.43 GiB - INFO 08-01 22:27:01 habana_model_runner.py:1066] [Warmup][Decode][3/48] batch_size:4 seq_len:1792 free_mem:55.43 GiB + INFO 08-01 22:26:59 hpu_model_runner.py:1066] [Warmup][Prompt][24/24] batch_size:1 seq_len:128 free_mem:55.43 GiB + INFO 08-01 22:27:00 hpu_model_runner.py:1066] [Warmup][Decode][1/48] batch_size:4 seq_len:2048 free_mem:55.43 GiB + INFO 08-01 22:27:00 hpu_model_runner.py:1066] [Warmup][Decode][2/48] batch_size:4 seq_len:1920 free_mem:55.43 GiB + INFO 08-01 22:27:01 hpu_model_runner.py:1066] [Warmup][Decode][3/48] batch_size:4 seq_len:1792 free_mem:55.43 GiB ... - INFO 08-01 22:27:16 habana_model_runner.py:1066] [Warmup][Decode][47/48] batch_size:2 seq_len:128 free_mem:55.43 GiB - INFO 08-01 22:27:16 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB + INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][47/48] batch_size:2 seq_len:128 free_mem:55.43 GiB + INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB This example uses the same buckets as in *Bucketing mechanism* section. Each output line corresponds to execution of a single bucket. When bucket is executed for the first time, its graph is compiled and can be reused later on, skipping further graph compilations. @@ -266,35 +266,35 @@ Each described step is logged by vLLM server, as follows (negative values corres .. code-block:: - INFO 08-02 17:37:44 habana_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] - INFO 08-02 17:37:44 habana_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] - INFO 08-02 17:37:44 habana_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] - INFO 08-02 17:37:44 habana_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] - INFO 08-02 17:37:52 habana_model_runner.py:430] Pre-loading model weights on hpu:0 took 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) - INFO 08-02 17:37:52 habana_model_runner.py:438] Wrapping in HPU Graph took 0 B of device memory (14.97 GiB/94.62 GiB used) and -252 KiB of host memory (475.2 GiB/1007 GiB used) - INFO 08-02 17:37:52 habana_model_runner.py:442] Loading model weights took in total 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) - INFO 08-02 17:37:54 habana_worker.py:134] Model profiling run took 504 MiB of device memory (15.46 GiB/94.62 GiB used) and 180.9 MiB of host memory (475.4 GiB/1007 GiB used) - INFO 08-02 17:37:54 habana_worker.py:158] Free device memory: 79.16 GiB, 39.58 GiB usable (gpu_memory_utilization=0.5), 15.83 GiB reserved for HPUGraphs (VLLM_GRAPH_RESERVED_MEM=0.4), 23.75 GiB reserved for KV cache - INFO 08-02 17:37:54 habana_executor.py:85] # HPU blocks: 1519, # CPU blocks: 0 - INFO 08-02 17:37:54 habana_worker.py:190] Initializing cache engine took 23.73 GiB of device memory (39.2 GiB/94.62 GiB used) and -1.238 MiB of host memory (475.4 GiB/1007 GiB used) - INFO 08-02 17:37:54 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB + INFO 08-02 17:37:44 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] + INFO 08-02 17:37:44 hpu_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)] + INFO 08-02 17:37:44 hpu_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048] + INFO 08-02 17:37:44 hpu_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] + INFO 08-02 17:37:52 hpu_model_runner.py:430] Pre-loading model weights on hpu:0 took 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) + INFO 08-02 17:37:52 hpu_model_runner.py:438] Wrapping in HPU Graph took 0 B of device memory (14.97 GiB/94.62 GiB used) and -252 KiB of host memory (475.2 GiB/1007 GiB used) + INFO 08-02 17:37:52 hpu_model_runner.py:442] Loading model weights took in total 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used) + INFO 08-02 17:37:54 hpu_worker.py:134] Model profiling run took 504 MiB of device memory (15.46 GiB/94.62 GiB used) and 180.9 MiB of host memory (475.4 GiB/1007 GiB used) + INFO 08-02 17:37:54 hpu_worker.py:158] Free device memory: 79.16 GiB, 39.58 GiB usable (gpu_memory_utilization=0.5), 15.83 GiB reserved for HPUGraphs (VLLM_GRAPH_RESERVED_MEM=0.4), 23.75 GiB reserved for KV cache + INFO 08-02 17:37:54 hpu_executor.py:85] # HPU blocks: 1519, # CPU blocks: 0 + INFO 08-02 17:37:54 hpu_worker.py:190] Initializing cache engine took 23.73 GiB of device memory (39.2 GiB/94.62 GiB used) and -1.238 MiB of host memory (475.4 GiB/1007 GiB used) + INFO 08-02 17:37:54 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB ... - INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB - INFO 08-02 17:38:22 habana_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.3) - INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB + INFO 08-02 17:38:22 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB + INFO 08-02 17:38:22 hpu_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.3) + INFO 08-02 17:38:22 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB ... - INFO 08-02 17:38:26 habana_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB - INFO 08-02 17:38:27 habana_model_runner.py:1066] [Warmup][Graph/Decode][1/48] batch_size:4 seq_len:128 free_mem:47.51 GiB + INFO 08-02 17:38:26 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB + INFO 08-02 17:38:27 hpu_model_runner.py:1066] [Warmup][Graph/Decode][1/48] batch_size:4 seq_len:128 free_mem:47.51 GiB ... - INFO 08-02 17:38:41 habana_model_runner.py:1066] [Warmup][Graph/Decode][48/48] batch_size:1 seq_len:2048 free_mem:47.35 GiB - INFO 08-02 17:38:41 habana_model_runner.py:1066] [Warmup][Graph/Prompt][12/24] batch_size:4 seq_len:256 free_mem:47.35 GiB - INFO 08-02 17:38:42 habana_model_runner.py:1066] [Warmup][Graph/Prompt][13/24] batch_size:2 seq_len:512 free_mem:45.91 GiB - INFO 08-02 17:38:42 habana_model_runner.py:1066] [Warmup][Graph/Prompt][14/24] batch_size:1 seq_len:1024 free_mem:44.48 GiB - INFO 08-02 17:38:43 habana_model_runner.py:1066] [Warmup][Graph/Prompt][15/24] batch_size:2 seq_len:640 free_mem:43.03 GiB - INFO 08-02 17:38:43 habana_model_runner.py:1128] Graph/Prompt captured:15 (62.5%) used_mem:14.03 GiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (4, 128), (4, 256)] - INFO 08-02 17:38:43 habana_model_runner.py:1128] Graph/Decode captured:48 (100.0%) used_mem:161.9 MiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] - INFO 08-02 17:38:43 habana_model_runner.py:1206] Warmup finished in 49 secs, allocated 14.19 GiB of device memory - INFO 08-02 17:38:43 habana_executor.py:91] init_cache_engine took 37.92 GiB of device memory (53.39 GiB/94.62 GiB used) and 57.86 MiB of host memory (475.4 GiB/1007 GiB used) + INFO 08-02 17:38:41 hpu_model_runner.py:1066] [Warmup][Graph/Decode][48/48] batch_size:1 seq_len:2048 free_mem:47.35 GiB + INFO 08-02 17:38:41 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][12/24] batch_size:4 seq_len:256 free_mem:47.35 GiB + INFO 08-02 17:38:42 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][13/24] batch_size:2 seq_len:512 free_mem:45.91 GiB + INFO 08-02 17:38:42 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][14/24] batch_size:1 seq_len:1024 free_mem:44.48 GiB + INFO 08-02 17:38:43 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][15/24] batch_size:2 seq_len:640 free_mem:43.03 GiB + INFO 08-02 17:38:43 hpu_model_runner.py:1128] Graph/Prompt captured:15 (62.5%) used_mem:14.03 GiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (4, 128), (4, 256)] + INFO 08-02 17:38:43 hpu_model_runner.py:1128] Graph/Decode captured:48 (100.0%) used_mem:161.9 MiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)] + INFO 08-02 17:38:43 hpu_model_runner.py:1206] Warmup finished in 49 secs, allocated 14.19 GiB of device memory + INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of device memory (53.39 GiB/94.62 GiB used) and 57.86 MiB of host memory (475.4 GiB/1007 GiB used) Recommended vLLM Parameters diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index cb489084f48de..6f3b73dbeee20 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -16,7 +16,7 @@ from vllm.engine.metrics_types import StatLoggerBase from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutorAsync -from vllm.executor.habana_executor import HabanaExecutorAsync +from vllm.executor.hpu_executor import HPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptType from vllm.logger import init_logger @@ -620,12 +620,11 @@ def _get_executor_cls( elif engine_config.device_config.device_type == "hpu": if distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_habana_executor import ( - RayHabanaExecutorAsync) - executor_class = RayHabanaExecutorAsync + from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync + executor_class = RayHPUExecutorAsync else: - from vllm.executor.habana_executor import HabanaExecutorAsync - executor_class = HabanaExecutorAsync + from vllm.executor.hpu_executor import HPUExecutorAsync + executor_class = HPUExecutorAsync elif engine_config.device_config.device_type == "openvino": assert distributed_executor_backend is None, ( "Distributed execution is not supported with " @@ -1206,7 +1205,7 @@ async def start_profile(self) -> None: # using type instead of isinstance to check to avoid capturing # inherited classes if type(self.engine.model_executor) == GPUExecutorAsync or \ - type(self.engine.model_executor) == HabanaExecutorAsync: # noqa: E721 + type(self.engine.model_executor) == HPUExecutorAsync: # noqa: E721 self.engine.model_executor.start_profile() else: self.engine.model_executor._run_workers("start_profile") @@ -1215,7 +1214,7 @@ async def stop_profile(self) -> None: # using type instead of isinstance to check to avoid capturing # inherited classes if type(self.engine.model_executor) == GPUExecutorAsync or \ - type(self.engine.model_executor) == HabanaExecutorAsync: # noqa: E721 + type(self.engine.model_executor) == HPUExecutorAsync: # noqa: E721 self.engine.model_executor.stop_profile() else: self.engine.model_executor._run_workers("stop_profile") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f41d074ad536c..3635443421e88 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -28,7 +28,7 @@ from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.executor.executor_base import ExecutorBase from vllm.executor.gpu_executor import GPUExecutor -from vllm.executor.habana_executor import HabanaExecutor +from vllm.executor.hpu_executor import HPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, InputRegistry, LLMInputs, PromptType) @@ -533,11 +533,11 @@ def _get_executor_cls(cls, elif engine_config.device_config.device_type == "hpu": if distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_habana_executor import RayHabanaExecutor - executor_class = RayHabanaExecutor + from vllm.executor.ray_hpu_executor import RayHPUExecutor + executor_class = RayHPUExecutor else: - from vllm.executor.habana_executor import HabanaExecutor - executor_class = HabanaExecutor + from vllm.executor.hpu_executor import HPUExecutor + executor_class = HPUExecutor elif engine_config.device_config.device_type == "openvino": from vllm.executor.openvino_executor import OpenVINOExecutor executor_class = OpenVINOExecutor @@ -1796,7 +1796,7 @@ def start_profile(self) -> None: # using type instead of isinstance to check to avoid capturing # inherited classes (MultiprocessingGPUExecutor) if type(self.model_executor) == GPUExecutor or \ - type(self.model_executor) == HabanaExecutor: # noqa: E721 + type(self.model_executor) == HPUExecutor: # noqa: E721 self.model_executor.start_profile() else: self.model_executor._run_workers("start_profile") @@ -1805,7 +1805,7 @@ def stop_profile(self) -> None: # using type instead of isinstance to check to avoid capturing # inherited classes (MultiprocessingGPUExecutor) if type(self.model_executor) == GPUExecutor or \ - type(self.model_executor) == HabanaExecutor: # noqa: E721 + type(self.model_executor) == HPUExecutor: # noqa: E721 self.model_executor.stop_profile() else: self.model_executor._run_workers("stop_profile") diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 49500099fbcaf..3501f12c065cf 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -23,7 +23,7 @@ # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT from vllm.executor.gpu_executor import GPUExecutor -from vllm.executor.habana_executor import HabanaExecutor +from vllm.executor.hpu_executor import HPUExecutor from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext @@ -366,14 +366,14 @@ def _alive(self): def start_profile(self) -> None: if type(self.engine.model_executor) is GPUExecutor or \ - type(self.engine.model_executor) is HabanaExecutor: + type(self.engine.model_executor) is HPUExecutor: self.engine.model_executor.start_profile() else: self.engine.model_executor._run_workers("start_profile") def stop_profile(self) -> None: if type(self.engine.model_executor) is GPUExecutor or \ - type(self.engine.model_executor) is HabanaExecutor: + type(self.engine.model_executor) is HPUExecutor: self.engine.model_executor.stop_profile() else: self.engine.model_executor._run_workers("stop_profile") diff --git a/vllm/executor/habana_executor.py b/vllm/executor/hpu_executor.py similarity index 97% rename from vllm/executor/habana_executor.py rename to vllm/executor/hpu_executor.py index e6d0fbc0d431d..cc5609ebe5c8e 100644 --- a/vllm/executor/habana_executor.py +++ b/vllm/executor/hpu_executor.py @@ -21,7 +21,7 @@ logger = init_logger(__name__) -class HabanaExecutor(ExecutorBase): +class HPUExecutor(ExecutorBase): uses_ray: bool = False @@ -57,8 +57,8 @@ def _create_worker(self, rank: int = 0, distributed_init_method: Optional[str] = None): wrapper = WorkerWrapperBase( - worker_module_name="vllm.worker.habana_worker", - worker_class_name="HabanaWorker", + worker_module_name="vllm.worker.hpu_worker", + worker_class_name="HPUWorker", ) wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, distributed_init_method)) @@ -202,7 +202,7 @@ def shutdown(self) -> None: self.driver_worker.shutdown_inc() -class HabanaExecutorAsync(HabanaExecutor, ExecutorAsyncBase): +class HPUExecutorAsync(HPUExecutor, ExecutorAsyncBase): async def execute_model_async( self, diff --git a/vllm/executor/ray_habana_executor.py b/vllm/executor/ray_hpu_executor.py similarity index 99% rename from vllm/executor/ray_habana_executor.py rename to vllm/executor/ray_hpu_executor.py index 645bceb1af446..343fa43b0eda1 100644 --- a/vllm/executor/ray_habana_executor.py +++ b/vllm/executor/ray_hpu_executor.py @@ -29,7 +29,7 @@ logger = init_logger(__name__) -class RayHabanaExecutor(DistributedGPUExecutor): +class RayHPUExecutor(DistributedGPUExecutor): uses_ray: bool = True @@ -90,8 +90,8 @@ def _get_worker_module_and_class( raise NotImplementedError( "Speculative decoding is not implemented for HPU") else: - worker_module_name = "vllm.worker.habana_worker" - worker_class_name = "HabanaWorker" + worker_module_name = "vllm.worker.hpu_worker" + worker_class_name = "HPUWorker" return (worker_module_name, worker_class_name, worker_class_fn) def _get_worker_wrapper_args(self) -> Dict[str, Any]: @@ -479,7 +479,7 @@ def __del__(self): self.shutdown() -class RayHabanaExecutorAsync(RayHabanaExecutor, DistributedGPUExecutorAsync): +class RayHPUExecutorAsync(RayHPUExecutor, DistributedGPUExecutorAsync): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/hpu_model_runner.py similarity index 99% rename from vllm/worker/habana_model_runner.py rename to vllm/worker/hpu_model_runner.py index 2d72be5690664..b1b62e6bde7f6 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -489,7 +489,7 @@ def from_broadcasted_tensor_dict( return cls(**tensor_dict) -class HabanaModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): +class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): """ Helper class for shared methods between GPU model runners. """ @@ -1730,8 +1730,7 @@ def unwrap_model(model): return modules -class HabanaModelRunner( - HabanaModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): +class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): """ GPU model runner with sampling step. """ @@ -1872,7 +1871,7 @@ def execute_model( ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError( - "num_steps > 1 is not supported in HabanaModelRunner") + "num_steps > 1 is not supported in HPUModelRunner") if self.lora_config: assert model_input.lora_requests is not None diff --git a/vllm/worker/habana_worker.py b/vllm/worker/hpu_worker.py similarity index 99% rename from vllm/worker/habana_worker.py rename to vllm/worker/hpu_worker.py index 7fc1e48b8c960..59a5adf65ebc1 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/hpu_worker.py @@ -25,14 +25,14 @@ from vllm.sequence import ExecuteModelRequest from vllm.utils import hpu_backend_string, hpu_device_string, is_fake_hpu from vllm.worker.cache_engine import CacheEngine -from vllm.worker.habana_model_runner import HabanaModelRunner +from vllm.worker.hpu_model_runner import HPUModelRunner from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput logger = init_logger(__name__) -class HabanaWorker(LocalOrDistributedWorkerBase): +class HPUWorker(LocalOrDistributedWorkerBase): """A worker class that executes (a partition of) the model on a HPU. Each worker is associated with a single HPU. The worker is responsible for @@ -79,7 +79,7 @@ def __init__( from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.model_runner: HabanaModelRunner = HabanaModelRunner( + self.model_runner: HPUModelRunner = HPUModelRunner( model_config, parallel_config, scheduler_config,