Skip to content

Commit

Permalink
Merge branch 'habana_main' into private/kzawora/insert_or_update_cach…
Browse files Browse the repository at this point in the history
…e_opt
  • Loading branch information
kzawora-intel authored Sep 26, 2024
2 parents e7e87d5 + 4c8a6c6 commit d5cdb42
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 51 deletions.
12 changes: 6 additions & 6 deletions README_GAUDI.md
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ for graph capture (later referred to as \"usable graph memory\"), and
the remaining 90% will be utilized for KV cache. Environment variable
`VLLM_GRAPH_PROMPT_RATIO` determines the ratio of usable graph memory
reserved for prefill and decode graphs. By default
(`VLLM_GRAPH_PROMPT_RATIO=0.5`), both stages have equal memory
(`VLLM_GRAPH_PROMPT_RATIO=0.3`), both stages have equal memory
constraints. Lower value corresponds to less usable graph memory
reserved for prefill stage, e.g. `VLLM_GRAPH_PROMPT_RATIO=0.2` will
reserve 20% of usable graph memory for prefill graphs, and 80% of usable
Expand Down Expand Up @@ -388,7 +388,7 @@ INFO 08-02 17:37:54 habana_worker.py:190] Initializing cache engine took 23.73 G
INFO 08-02 17:37:54 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB
...
INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB
INFO 08-02 17:38:22 habana_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.5)
INFO 08-02 17:38:22 habana_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 4.755 GiB for prompt and 11.095 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.3)
INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB
...
INFO 08-02 17:38:26 habana_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB
Expand Down Expand Up @@ -448,7 +448,7 @@ Environment variables
- `VLLM_GRAPH_RESERVED_MEM`: percentage of memory dedicated for
HPUGraph capture, `0.1` by default
- `VLLM_GRAPH_PROMPT_RATIO`: percentage of reserved graph memory
dedicated for prompt graphs, `0.5` by default
dedicated for prompt graphs, `0.3` by default
- `VLLM_GRAPH_PROMPT_STRATEGY`: strategy determining order of prompt
graph capture, `min_tokens` or `max_bs`, `min_tokens` by default
- `VLLM_GRAPH_DECODE_STRATEGY`: strategy determining order of decode
Expand All @@ -472,15 +472,15 @@ Environment variables
`max_model_len`

- Decode:
- batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `min(max_num_seqs, 32)`
- batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `1`
- batch size step (`VLLM_DECODE_BS_BUCKET_STEP`):
`min(max_num_seqs, 32)`
- batch size max (`VLLM_DECODE_BS_BUCKET_MAX`):
`max_num_seqs`
- block size min (`VLLM_DECODE_BLOCK_BUCKET_MIN`):
`128`
`block_size`
- block size step
(`VLLM_DECODE_BLOCK_BUCKET_STEP`): `128`
(`VLLM_DECODE_BLOCK_BUCKET_STEP`): `block_size`
- block size max (`VLLM_DECODE_BLOCK_BUCKET_MAX`):
`max(128, (max_num_seqs*max_model_len)/block_size)`

Expand Down
12 changes: 6 additions & 6 deletions docs/source/getting_started/gaudi-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ Only after that, ``gpu_memory_utilization`` flag is utilized - at its default va
Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured.
Environment variable ``VLLM_GRAPH_RESERVED_MEM`` defines the ratio of memory reserved for HPU Graphs capture.
With its default value (``VLLM_GRAPH_RESERVED_MEM=0.1``), 10% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 90% will be utilized for KV cache.
Environment variable ``VLLM_GRAPH_PROMPT_RATIO`` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (``VLLM_GRAPH_PROMPT_RATIO=0.5``), both stages have equal memory constraints.
Environment variable ``VLLM_GRAPH_PROMPT_RATIO`` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (``VLLM_GRAPH_PROMPT_RATIO=0.3``), both stages have equal memory constraints.
Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. ``VLLM_GRAPH_PROMPT_RATIO=0.2`` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs.

.. note::
Expand Down Expand Up @@ -280,7 +280,7 @@ Each described step is logged by vLLM server, as follows (negative values corres
INFO 08-02 17:37:54 habana_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB
...
INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB
INFO 08-02 17:38:22 habana_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.5)
INFO 08-02 17:38:22 habana_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.3)
INFO 08-02 17:38:22 habana_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB
...
INFO 08-02 17:38:26 habana_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB
Expand Down Expand Up @@ -324,7 +324,7 @@ Environment variables

- ``VLLM_SKIP_WARMUP``: if ``true``, warmup will be skipped, ``false`` by default
- ``VLLM_GRAPH_RESERVED_MEM``: percentage of memory dedicated for HPUGraph capture, ``0.1`` by default
- ``VLLM_GRAPH_PROMPT_RATIO``: percentage of reserved graph memory dedicated for prompt graphs, ``0.5`` by default
- ``VLLM_GRAPH_PROMPT_RATIO``: percentage of reserved graph memory dedicated for prompt graphs, ``0.3`` by default
- ``VLLM_GRAPH_PROMPT_STRATEGY``: strategy determining order of prompt graph capture, ``min_tokens`` or ``max_bs``, ``min_tokens`` by default
- ``VLLM_GRAPH_DECODE_STRATEGY``: strategy determining order of decode graph capture, ``min_tokens`` or ``max_bs``, ``max_bs`` by default
- ``VLLM_{phase}_{dim}_BUCKET_{param}`` - collection of 12 environment variables configuring ranges of bucketing mechanism
Expand All @@ -343,11 +343,11 @@ Environment variables
- sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``max_model_len``

- Decode:
- batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``min(max_num_seqs, 32)``
- batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``1``
- batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)``
- batch size max (``VLLM_DECODE_BS_BUCKET_MAX``): ``max_num_seqs``
- sequence length min (``VLLM_DECODE_BLOCK_BUCKET_MIN``): ``128``
- sequence length step (``VLLM_DECODE_BLOCK_BUCKET_STEP``): ``128``
- sequence length min (``VLLM_DECODE_BLOCK_BUCKET_MIN``): ``block_size``
- sequence length step (``VLLM_DECODE_BLOCK_BUCKET_STEP``): ``block_size``
- sequence length max (``VLLM_DECODE_BLOCK_BUCKET_MAX``): ``max(128, (max_num_seqs*max_model_len)/block_size)``


Expand Down
3 changes: 0 additions & 3 deletions vllm/executor/habana_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,6 @@ def check_health(self) -> None:
def shutdown(self) -> None:
self.driver_worker.shutdown_inc()

def __del__(self):
self.shutdown()


class HabanaExecutorAsync(HabanaExecutor, ExecutorAsyncBase):

Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform

if current_platform.is_hpu():
from vllm_hpu_extension.ops import scaled_fp8_quant
ops.scaled_fp8_quant = scaled_fp8_quant

logger = init_logger(__name__)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,10 @@ def _get_scheme_from_parts(
# TODO @dsikka: clean-up conditions
if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
is_fp8_w8a8_supported = current_platform.is_hpu() or \
self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(),
error=False)
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy,
Expand Down Expand Up @@ -314,7 +316,8 @@ def get_scheme(

# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
if not current_platform.is_hpu():
self._check_scheme_supported(scheme.get_min_capability())

return scheme

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
from vllm.utils import is_hip

__all__ = ["CompressedTensorsW8A8Fp8"]
Expand All @@ -23,7 +24,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.cutlass_fp8_supported = not current_platform.is_hpu() and \
cutlass_fp8_supported()

@classmethod
def get_min_capability(cls) -> int:
Expand Down
24 changes: 16 additions & 8 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from vllm.platforms import current_platform
from vllm.utils import is_hip, print_warning_once

if current_platform.is_hpu():
from vllm_hpu_extension.ops import scaled_fp8_quant
ops.scaled_fp8_quant = scaled_fp8_quant

ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)
Expand Down Expand Up @@ -116,14 +120,18 @@ class Fp8LinearMethod(LinearMethodBase):

def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()

# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if is_hip():
if current_platform.is_cuda_alike():
self.cutlass_fp8_supported = cutlass_fp8_supported()

# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if is_hip():
self.use_marlin = False
else:
self.cutlass_fp8_supported = False
self.use_marlin = False

def create_weights(
Expand Down
40 changes: 32 additions & 8 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None

if current_platform.is_hpu():
import habana_frameworks.torch.utils.experimental as htexp
from vllm_hpu_extension.ops import scaled_fp8_quant
ops.scaled_fp8_quant = scaled_fp8_quant


def cutlass_fp8_supported() -> bool:
# cutlass is not supported on Rocm
Expand All @@ -25,7 +30,15 @@ def cutlass_fp8_supported() -> bool:
def per_tensor_dequantize(
tensor: torch.Tensor, inv_scale: Union[float,
torch.Tensor]) -> torch.Tensor:
fake_qweight = tensor.to(torch.float16)
dtype = torch.float16
device = tensor.device
if current_platform.is_hpu():
dtype = torch.bfloat16
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
#dequant on cpu to avoid nan on gaudi2
tensor = tensor.to('cpu')

fake_qweight = tensor.to(dtype).to(device)
dq_weight = fake_qweight * inv_scale
return dq_weight

Expand Down Expand Up @@ -58,7 +71,10 @@ def requantize_with_max_scale(
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()

if current_platform.is_hpu() and htexp._get_device_type(
) == htexp.synDeviceType.synDeviceGaudi2:
max_w_scale = max_w_scale * (torch.finfo(torch.float8_e4m3fn).max /
torch.finfo(torch.float8_e4m3fnuz).max)
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
Expand Down Expand Up @@ -129,12 +145,20 @@ def apply_fp8_linear(

if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ
output = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
if current_platform.is_hpu():
#hpu does not support torch._scaled_mm (SW-197036)
output = torch.ops.hpu.fp8_gemm_v2(qinput, False, weight,
False, None, input.dtype,
x_scale, weight_scale, None,
False)
else:
output = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)

# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
Expand Down
42 changes: 29 additions & 13 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ def __init__(
# Lazy initialization
self.lora_manager: LRUCacheWorkerLoRAManager = None
self.model: torch.nn.Module = None
self.inc_initialized_successfully = False

# Profiler stats
self.profiler_counter_helper = HabanaProfilerCounterHelper()
Expand Down Expand Up @@ -597,8 +598,7 @@ def _set_gc_threshold(self) -> None:

def load_model(self) -> None:
import habana_frameworks.torch.core as htcore
if self.model_config.quantization == 'inc':
htcore.hpu_set_env()
htcore.hpu_set_env()
with HabanaMemoryProfiler() as m:
with HabanaMemoryProfiler() as m_getmodel:
self.model = get_model(model_config=self.model_config,
Expand Down Expand Up @@ -643,6 +643,7 @@ def load_model(self) -> None:
self.model = convert(self.model, config)
htcore.hpu_initialize(self.model,
mark_only_scales_as_const=True)
self.inc_initialized_successfully = True
logger.info("Preparing model with INC took %s",
m_inc.get_summary_string())
elif not is_fake_hpu():
Expand Down Expand Up @@ -679,7 +680,6 @@ def _setup_buckets(self) -> None:
if self.lora_config and \
max_bucket_cfg > self.max_num_batched_tokens // self.block_size:
max_bucket_cfg = self.max_num_batched_tokens // self.block_size
blocks_step = 128
#FIXME: The default values should be max_model_len
max_prompt_seq = 1024
max_decode_seq = 2048
Expand All @@ -691,7 +691,7 @@ def _setup_buckets(self) -> None:
max=align_bs(max_bucket_cfg))
self.decode_bs_bucket_cfg = read_bucket_settings('decode',
'bs',
min=align_bs(32),
min=1,
step=align_bs(32),
max=self.max_num_seqs)
self.prompt_seq_bucket_cfg = read_bucket_settings('prompt',
Expand All @@ -702,9 +702,9 @@ def _setup_buckets(self) -> None:
self.decode_block_bucket_cfg = read_bucket_settings(
'decode',
'block',
min=blocks_step,
step=blocks_step,
max=max(blocks_step,
min=self.block_size,
step=self.block_size,
max=max(self.block_size,
self.max_num_seqs * max_decode_seq // self.block_size))
self.graphed_buckets: Set[Any] = set()

Expand Down Expand Up @@ -1571,6 +1571,17 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
len(self.decode_buckets),
list(sorted(self.decode_buckets)))

if not htorch.utils.internal.is_lazy() and not self.enforce_eager:
cache_size_limit = len(self.prompt_buckets) + len(
self.decode_buckets) + 1
torch._dynamo.config.cache_size_limit = max(
cache_size_limit, torch._dynamo.config.cache_size_limit)
# Multiply by 8 to follow the original default ratio between
# the cache_size_limit and accumulated_cache_size_limit
torch._dynamo.config.accumulated_cache_size_limit = max(
cache_size_limit * 8,
torch._dynamo.config.accumulated_cache_size_limit)

start_mem = HabanaMemoryProfiler.current_device_memory_usage()
start_time = time.perf_counter()

Expand Down Expand Up @@ -1601,7 +1612,7 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
graph_free_mem = align_workers(graph_free_mem,
torch.distributed.ReduceOp.MIN)
prompt_graph_mem_ratio = float(
os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.5'))
os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.3'))
prompt_available_memory = (prompt_graph_mem_ratio *
graph_free_mem)
decode_available_memory = (graph_free_mem -
Expand Down Expand Up @@ -1799,6 +1810,7 @@ def make_model_input_from_broadcasted_tensor_dict(
attn_backend=self.attn_backend,
))

@torch.inference_mode()
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
Expand Down Expand Up @@ -1958,14 +1970,18 @@ def execute_model(
return [output]

def shutdown_inc(self):
print('inc shutdown')
if (model_config := getattr(self, "model_config", None)) and \
getattr(model_config, "quantization", None) == 'inc':
print('inc shutdown start')
can_finalize_inc = False
from contextlib import suppress
with suppress(AttributeError):
can_finalize_inc = (self.model_config.quantization == 'inc') and \
(self.model.model is not None) and \
self.inc_initialized_successfully and \
not getattr(self, "_is_inc_finalized", False)
if can_finalize_inc:
from neural_compressor.torch.quantization import (
finalize_calibration)
finalize_calibration(self.model.model)
print('inc shutdown')
self._is_inc_finalized = True

def __del__(self):
self.shutdown_inc()
3 changes: 0 additions & 3 deletions vllm/worker/habana_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,6 @@ def list_prompt_adapters(self) -> Set[int]:
def shutdown_inc(self):
self.model_runner.shutdown_inc()

def __del__(self):
self.shutdown_inc()

@property
def max_model_len(self) -> int:
return self.model_config.max_model_len
Expand Down

0 comments on commit d5cdb42

Please sign in to comment.