From 76cbbb5deeff3e3d760aff2487f284234c4fd5bb Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Fri, 4 Oct 2024 19:45:47 +0300 Subject: [PATCH] Use BF16 on HPU by default --- vllm/config.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index 786ed1586a3ea..b3329f1c449ff 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1635,6 +1635,13 @@ def _get_and_verify_dtype( torch_dtype = torch.float16 else: torch_dtype = config_dtype + + if current_platform.is_hpu() and config_dtype == torch.float16: + logger.info( + "For HPU, we cast models to bfloat16 instead of" + "using float16 by default. Please specify `dtype` if you " + "want to use float16.") + torch_dtype = torch.bfloat16 else: if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: raise ValueError(f"Unknown dtype: {dtype}")