Skip to content

Commit

Permalink
Support Torch profiler in Habana Worker
Browse files Browse the repository at this point in the history
  • Loading branch information
mswiniarsk committed Oct 3, 2024
1 parent 25f4ed9 commit af7cfc6
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 6 deletions.
7 changes: 5 additions & 2 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from vllm.engine.metrics_types import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.habana_executor import HabanaExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType
from vllm.logger import init_logger
Expand Down Expand Up @@ -1204,15 +1205,17 @@ def remove_logger(self, logger_name: str) -> None:
async def start_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
if type(self.engine.model_executor) == GPUExecutorAsync or \
type(self.engine.model_executor) == HabanaExecutorAsync: # noqa: E721
self.engine.model_executor.start_profile()
else:
self.engine.model_executor._run_workers("start_profile")

async def stop_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
if type(self.engine.model_executor) == GPUExecutorAsync or \
type(self.engine.model_executor) == HabanaExecutorAsync: # noqa: E721
self.engine.model_executor.stop_profile()
else:
self.engine.model_executor._run_workers("stop_profile")
5 changes: 3 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.habana_executor import HabanaExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
InputRegistry, LLMInputs, PromptType)
Expand Down Expand Up @@ -1794,15 +1795,15 @@ def check_health(self) -> None:
def start_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor)
if type(self.model_executor) == GPUExecutor: # noqa: E721
if type(self.model_executor) == GPUExecutor or type(self.model_executor) == HabanaExecutor: # noqa: E721
self.model_executor.start_profile()
else:
self.model_executor._run_workers("start_profile")

def stop_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor)
if type(self.model_executor) == GPUExecutor: # noqa: E721
if type(self.model_executor) == GPUExecutor or type(self.model_executor) == HabanaExecutor: # noqa: E721
self.model_executor.stop_profile()
else:
self.model_executor._run_workers("stop_profile")
Expand Down
7 changes: 5 additions & 2 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.habana_executor import HabanaExecutor
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -364,13 +365,15 @@ def _alive(self):
self._last_alive_time = time.time()

def start_profile(self) -> None:
if type(self.engine.model_executor) is GPUExecutor:
if type(self.engine.model_executor) is GPUExecutor or \
type(self.engine.model_executor) is HabanaExecutor:
self.engine.model_executor.start_profile()
else:
self.engine.model_executor._run_workers("start_profile")

def stop_profile(self) -> None:
if type(self.engine.model_executor) is GPUExecutor:
if type(self.engine.model_executor) is GPUExecutor or \
type(self.engine.model_executor) is HabanaExecutor:
self.engine.model_executor.stop_profile()
else:
self.engine.model_executor._run_workers("stop_profile")
Expand Down
6 changes: 6 additions & 0 deletions vllm/executor/habana_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ def check_health(self) -> None:
# it's running.
return

def start_profile(self) -> None:
self.driver_worker.start_profile()

def stop_profile(self) -> None:
self.driver_worker.stop_profile()

def shutdown(self) -> None:
self.driver_worker.shutdown_inc()

Expand Down
27 changes: 27 additions & 0 deletions vllm/worker/habana_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.distributed
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes

import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
Expand Down Expand Up @@ -95,6 +96,32 @@ def __init__(
self.cache_engine: List[CacheEngine]
# Initialize gpu_cache as embedding models don't initialize kv_caches
self.hpu_cache: Optional[List[List[torch.tensor]]] = None
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.HPU,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None

def start_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.start()

def stop_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()

def _set_env_vars(self):
local_rank = self.local_rank
Expand Down

0 comments on commit af7cfc6

Please sign in to comment.