diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 676a51ce67f9..d659d0a3f112 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -517,6 +517,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + if current_platform.is_hpu(): + torch.hpu.synchronize() # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should