diff --git a/vllm/config.py b/vllm/config.py index 011563038e6b..6be993bdc388 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -349,9 +349,10 @@ def verify_async_output_proc(self, parallel_config, speculative_config, self.use_async_output_proc = False return - if device_config.device_type not in ("cuda", "tpu"): + if device_config.device_type not in ("cuda", "tpu", "hpu"): logger.warning( - "Async output processing is only supported for CUDA or TPU. " + "Async output processing is only supported for CUDA, TPU " + "and HPU. " "Disabling it for other platforms.") self.use_async_output_proc = False return diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index c43acdf04923..4b3ee10417ff 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -417,6 +417,7 @@ class ModelInputForHPU(ModelRunnerInputBase): virtual_engine: int = 0 lora_mask: Optional[torch.Tensor] = None lora_logits_mask: Optional[torch.Tensor] = None + async_callback: Optional[Callable] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -1921,6 +1922,9 @@ def execute_model( if not self.is_driver_worker: return [] + if model_input.async_callback is not None: + model_input.async_callback() + # Sample the next token. with self.profiler.record_event( 'internal', ('sample_'