Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Async output process for HPU #342

Merged
merged 2 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,9 @@ def verify_async_output_proc(self, parallel_config, speculative_config,
self.use_async_output_proc = False
return

if device_config.device_type not in ("cuda", "tpu"):
if device_config.device_type not in ("cuda", "tpu", "hpu"):
logger.warning(
"Async output processing is only supported for CUDA or TPU. "
"Async output processing is only supported for CUDA, TPU and HPU. "
"Disabling it for other platforms.")
self.use_async_output_proc = False
return
Expand Down
4 changes: 4 additions & 0 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ class ModelInputForHPU(ModelRunnerInputBase):
virtual_engine: int = 0
lora_mask: Optional[torch.Tensor] = None
lora_logits_mask: Optional[torch.Tensor] = None
async_callback: Optional[Callable] = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this set?
Vllm code names it output_proc_callback_fn shouldnt we keep the name?

Copy link
Author

@zhouyu5 zhouyu5 Sep 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michalkuligowski
It is first set in llm_engine.py, see:

if model_config.use_async_output_proc:
    process_model_outputs = weak_bind(self._process_model_outputs)
    self.async_callbacks = [
        partial(process_model_outputs,
                ctx=self.scheduler_contexts[v_id])
        for v_id in range(self.parallel_config.pipeline_parallel_size)
    ]
...
if allow_async_output_proc:
    execute_model_req.async_callback = self.async_callbacks[
        virtual_engine]

then pass to worker_base.py, which is inherited by HabanaWorker,

if execute_model_req.async_callback:
    model_input = dataclasses.replace(  # type: ignore
        model_input,
        async_callback=execute_model_req.async_callback)

For its name, it is initially called output_proc_callback_fn, but in vllm's latest code, it's changed to async_callback, since it could involve other operations, not only output processing, see this comment in PR#7911,
image

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Formatting issues

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Formatted now. @michalkuligowski


def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
Expand Down Expand Up @@ -1920,6 +1921,9 @@ def execute_model(
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return []

if model_input.async_callback is not None:
model_input.async_callback()

# Sample the next token.
with self.profiler.record_event(
Expand Down
Loading