Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/private/kzawora/hpu_bf16_default…
Browse files Browse the repository at this point in the history
…' into private/kzawora/pruned_habana_main
  • Loading branch information
kzawora-intel committed Oct 4, 2024
2 parents eed1b05 + 76cbbb5 commit 5c3e29c
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1631,6 +1631,13 @@ def _get_and_verify_dtype(
torch_dtype = torch.float16
else:
torch_dtype = config_dtype

if current_platform.is_hpu() and config_dtype == torch.float16:
logger.info(
"For HPU, we cast models to bfloat16 instead of"
"using float16 by default. Please specify `dtype` if you "
"want to use float16.")
torch_dtype = torch.bfloat16
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
Expand Down

0 comments on commit 5c3e29c

Please sign in to comment.