From 4c8a6c6092532d8df3f45831d2bfa2715a06507f Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Thu, 26 Sep 2024 20:26:16 +0800 Subject: [PATCH] Fix torch.compile issue of dispatch key set mismatch (#299) ### Issue: torch.compile recompiles after warmup because `tensor 'L['input_ids']' dispatch key set mismatch. expected DispatchKeySet(HPU, BackendSelect), actual DispatchKeySet(HPU, BackendSelect, ADInplaceOrView). ` ### Detail: Run script with `TORCH_LOGS="guards"` and get different dispatch key set info: - warmup: ``` TENSOR_MATCH: check_tensor(L['input_ids'], Tensor, DispatchKeySet(HPU, BackendSelect), torch.int64, device=0, requires_grad=False, size=[2, 1], stride=[1, 1]) # masked_input = input_ # ome/zyuwen/workspace/vllm/habana_main_g3_v2/vllm/model_executor/layers/vocab_parallel_embedding.py:358 in forward ``` - after warmup: ``` TENSOR_MATCH: check_tensor(L['input_ids'], Tensor, DispatchKeySet(HPU, BackendSelect, ADInplaceOrView), torch.int64, device=0, requires_grad=False, size=[2, 1], stride=[1, 1]) # masked_input = input_ # ome/zyuwen/workspace/vllm/habana_main_g3_v2/vllm/model_executor/layers/vocab_parallel_embedding.py:358 in forward ``` ### Solution: The difference in dispatch key set is caused by the 'torch.inference_mode()' decoration, and here is a simple example: ```python import torch import habana_frameworks.torch as htorch @torch.inference_mode() def func(): x = torch.rand(3, 3).to("hpu") print(torch._C._dispatch_key_set(x)) func() # output: DispatchKeySet(HPU, AutocastHPU) ``` ```python import torch import habana_frameworks.torch as htorch def func(): x = torch.rand(3, 3).to("hpu") print(torch._C._dispatch_key_set(x)) func() # output: DispatchKeySet(HPU, ADInplaceOrView, AutogradHPU, AutocastHPU) ``` In vllm-fork, the warmup phase is decorated with `torch.inference_mode()` in [habana_model_runner.py#L1487-L1488](https://github.com/HabanaAI/vllm-fork/blob/b62fba85ac03326e9f466d8d37e91ae1b14a6511/vllm/worker/habana_model_runner.py#L1487-L1488), but the after-warmup phase is not. So in this PR I add the decorator to `prepare_input_tensors` function to keep the dispatch key set the same. ---
PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Signed-off-by: yuwenzho --- vllm/worker/habana_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index c43acdf04923b..f3bda39ec4822 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -1790,6 +1790,7 @@ def make_model_input_from_broadcasted_tensor_dict( attn_backend=self.attn_backend, )) + @torch.inference_mode() def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata],