Skip to content

Commit

Permalink
Merge branch 'habana_main' into adobrzyniewicz/tflops_main
Browse files Browse the repository at this point in the history
  • Loading branch information
adobrzyniewicz-habana authored Aug 19, 2024
2 parents cb5c62a + f7dd91d commit b44f672
Show file tree
Hide file tree
Showing 26 changed files with 1,064 additions and 150 deletions.
21 changes: 0 additions & 21 deletions .github/workflows/reminder_comment.yml

This file was deleted.

500 changes: 437 additions & 63 deletions README_GAUDI.md

Large diffs are not rendered by default.

254 changes: 236 additions & 18 deletions docs/source/getting_started/gaudi-installation.rst

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/getting_started/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ This guide shows how to use vLLM to:
* build an API server for a large language model;
* start an OpenAI-compatible API server.

Be sure to complete the :ref:`installation instructions <installation>` before continuing with this guide.
Be sure to complete the `Gaudi installation instructions <https://github.com/HabanaAI/vllm-fork/blob/habana_main/docs/source/getting_started/gaudi-installation.rst#run-docker-image>`_ before continuing with this guide.

.. note::

Expand Down
26 changes: 21 additions & 5 deletions vllm/attention/backends/habana_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
AttentionMetadata, AttentionType)
from vllm.attention.ops.habana_paged_attn import (HabanaPagedAttention,
HabanaPagedAttentionMetadata)
from vllm.hpu import cache_ops
from vllm.hpu.utils import Matmul, Softmax, VLLMKVCache
from vllm.logger import init_logger

logger = init_logger(__name__)
Expand Down Expand Up @@ -108,7 +110,7 @@ def __post_init__(self):
self.attn_bias: Optional[torch.Tensor] = None


class HabanaAttentionImpl(AttentionImpl):
class HabanaAttentionImpl(AttentionImpl, torch.nn.Module):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
Expand Down Expand Up @@ -137,10 +139,16 @@ def __init__(
blocksparse_params: Optional[Dict[str, Any]] = None,
max_seq_len: int = 4096,
) -> None:
super(AttentionImpl, self).__init__()
self.kv_cache_dtype = kv_cache_dtype
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.matmul_qk = Matmul()
self.softmax = Softmax()
self.matmul_av = Matmul()
self.k_cache = VLLMKVCache()
self.v_cache = VLLMKVCache()
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
self.position_bias = None
Expand Down Expand Up @@ -204,9 +212,13 @@ def forward(
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
HabanaPagedAttention.write_to_paged_cache(
key, value, key_cache, value_cache, attn_metadata.slot_mapping,
self.kv_cache_dtype, attn_metadata.is_prompt)
num_kv_cache_passes, num_slots_available, indices, offsets = \
cache_ops.prepare_to_cache(key_cache,
attn_metadata.slot_mapping)
key_cache = self.k_cache(key, key_cache, num_kv_cache_passes,
num_slots_available, indices, offsets)
value_cache = self.v_cache(value, value_cache, num_kv_cache_passes,
num_slots_available, indices, offsets)

if attn_metadata.is_prompt:
# Prompt run.
Expand All @@ -232,6 +244,9 @@ def forward(
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
Expand All @@ -255,7 +270,8 @@ def forward(
query, key_cache, value_cache, attn_metadata.block_tables,
attn_metadata.seq_lens_tensor, self.kv_cache_dtype,
self.num_kv_heads, self.scale, self.position_bias, k_scale,
v_scale)
v_scale, self.matmul_qk, self.softmax, self.matmul_av,
self.k_cache, self.v_cache)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)

Expand Down
10 changes: 10 additions & 0 deletions vllm/attention/ops/habana_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def forward_decode(
alibi_slopes: Optional[torch.Tensor],
k_scale: float,
v_scale: float,
matmul_qk_op,
softmax_op,
matmul_av_op,
k_cache_cls,
v_cache_cls,
) -> torch.Tensor:
block_size = value_cache.shape[1]
return ops.paged_attention_v1(
Expand All @@ -88,6 +93,11 @@ def forward_decode(
block_size,
alibi_slopes,
kv_cache_dtype,
matmul_qk_op,
softmax_op,
matmul_av_op,
k_cache_cls,
v_cache_cls,
)

@staticmethod
Expand Down
8 changes: 5 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,13 @@ def _verify_args(self) -> None:
def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"):
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"Meanwhile, it may cause accuracy drop without a proper "
"scaling factor")
"scaling factor. "
"Intel Gaudi (HPU) supports fp8 (using fp8_inc).")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

Expand Down Expand Up @@ -600,11 +601,12 @@ class LoadConfig:
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
device: Device on which weights are loaded.
"""

load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
download_dir: Optional[str] = None
device: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(
default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
Expand Down
14 changes: 12 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class EngineArgs:
trust_remote_code: bool = False
download_dir: Optional[str] = None
load_format: str = 'auto'
weights_load_device: Optional[str] = None
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None
Expand Down Expand Up @@ -205,6 +206,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
parser.add_argument("--weights-load-device",
type=str,
default=EngineArgs.weights_load_device,
choices=["cuda", "neuron", "hpu", "cpu"],
help='Device on which weights are loaded.')
parser.add_argument(
'--dtype',
type=str,
Expand All @@ -223,11 +229,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3', 'fp8_inc'],
default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3). '
'Intel Gaudi (HPU) supports fp8 (using fp8_inc).')
parser.add_argument(
'--quantization-param-path',
type=nullable_str,
Expand Down Expand Up @@ -835,9 +842,12 @@ def create_engine_config(self, ) -> EngineConfig:
self.model_loader_extra_config[
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

device = device_config.device if self.weights_load_device is None else \
self.weights_load_device
load_config = LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
device=device,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
)
Expand Down
6 changes: 5 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, "
"disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, "
"weights_load_device=%s, enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
Expand All @@ -206,6 +206,7 @@ def __init__(
parallel_config.pipeline_parallel_size,
parallel_config.disable_custom_all_reduce,
model_config.quantization,
load_config.device,
model_config.enforce_eager,
cache_config.cache_dtype,
model_config.quantization_param_path,
Expand Down Expand Up @@ -853,6 +854,9 @@ def _process_model_outputs(
request_outputs.append(request_output)
return request_outputs

def finish_measurements(self):
self.model_executor.finish_measurements()

def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
Expand Down
3 changes: 3 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ def set_tokenizer(
self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer(
tokenizer)

def finish_measurements(self):
self.llm_engine.finish_measurements()

@overload # LEGACY: single (prompt + optional token ids)
def generate(
self,
Expand Down
9 changes: 9 additions & 0 deletions vllm/executor/habana_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
msg = f"init_cache_engine took {cache_init_m.get_summary_string()}"
logger.info(msg)

def finish_measurements(self):
self.driver_worker.finish_measurements()

def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
Expand Down Expand Up @@ -180,6 +183,12 @@ def check_health(self) -> None:
# it's running.
return

def shutdown(self) -> None:
self.driver_worker.shutdown_inc()

def __del__(self):
self.shutdown()


class HabanaExecutorAsync(HabanaExecutor, ExecutorAsyncBase):

Expand Down
3 changes: 3 additions & 0 deletions vllm/executor/ray_habana_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ def _driver_execute_model(
return self.driver_worker.execute_method("execute_model",
execute_model_req)

def finish_measurements(self):
self._run_workers("finish_measurements")

def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
Expand Down
31 changes: 31 additions & 0 deletions vllm/hpu/cache_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,37 @@ def reshape_and_cache(key,
value[start_idx:end_idx])


def prepare_to_cache(cache, slot_mapping):
num_blocks = cache.size(0)
block_size = cache.size(1)
slot_mapping = slot_mapping.flatten()
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
offsets = torch.fmod(slot_mapping, block_size)
num_slots_requested = slot_mapping.size(0)
num_slots_available = num_blocks * block_size
# NOTE(kzawora): HPU PT bridge crashes with
# RuntimeError: Invalid inputs for scatter_nd_onnx
# on index_put when num_slots_requested > num_slots_available.
# This case might occur when we have little kv cache blocks and
# lots of padding, or are doing warmup.
# This loop is a workaround for this issue. Please remove it
# once key_cache.index_put_(indices, offsets), key) works.
num_kv_cache_passes = torch.div(num_slots_requested,
num_slots_available).ceil().int().item()

return num_kv_cache_passes, num_slots_available, indices, offsets


def insert_or_update_cache(input, cache, num_kv_cache_passes,
num_slots_available, block_indices, block_offsets):
for i in range(num_kv_cache_passes):
start_idx = i * num_slots_available
end_idx = (i + 1) * num_slots_available
cache.index_put_((block_indices[start_idx:end_idx],
block_offsets[start_idx:end_idx]),
input[start_idx:end_idx])


def swap_blocks(src, dst, block_mapping):
index_src = torch.zeros((1, ), dtype=torch.int32, device=src.device)
index_dst = torch.zeros((1, ), dtype=torch.int32, device=dst.device)
Expand Down
32 changes: 22 additions & 10 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from vllm.logger import init_logger
from vllm.worker.profiler import Profiler

logger = init_logger()
logger = init_logger(__name__)
HPUFusedRMSNorm = None
try:
from habana_frameworks.torch.hpex.normalization import FusedRMSNorm
Expand Down Expand Up @@ -44,7 +44,12 @@ def paged_attention_v1(query,
context_lens,
block_size,
alibi_slopes=None,
kv_cache_dtype=None) -> None:
kv_cache_dtype=None,
matmul_qk_op=torch.matmul,
softmax_op=torch.softmax,
matmul_av_op=torch.matmul,
k_cache_cls=None,
v_cache_cls=None) -> None:
habana_profiler = Profiler()
torch.hpu.synchronize()
start_time = time.time()
Expand All @@ -62,27 +67,31 @@ def paged_attention_v1(query,
batch_size, 1, 1, -1))
query.mul_(scale)
query = query.unsqueeze(-2)
keys = fetch_from_cache(key_cache, block_tables, (0, 2, 3, 1))
fetch_keys = fetch_from_cache if k_cache_cls is None else \
k_cache_cls.fetch_from_cache
keys = fetch_keys(key_cache, block_tables, (0, 2, 3, 1))
if query_heads != kv_heads:
query = query.unflatten(1, (kv_heads, -1))
keys = [k.unflatten(1, (kv_heads, 1)) for k in keys]
mask = mask.unsqueeze(2)

attn_weights = torch.cat([torch.matmul(query, k) for k in keys], dim=-1)
attn_weights = torch.cat([matmul_qk_op(query, k) for k in keys], dim=-1)
if alibi_slopes is not None:
attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):,
-attn_weights.size(3):])
attn_weights = (attn_weights.masked_fill(mask, min_inf).softmax(dim=-1))
attn_weights = softmax_op(attn_weights.masked_fill(mask, min_inf), dim=-1)

values = fetch_from_cache(value_cache, block_tables, (0, 2, 1, 3))
fetch_values = fetch_from_cache if v_cache_cls is None else \
v_cache_cls.fetch_from_cache
values = fetch_values(value_cache, block_tables, (0, 2, 1, 3))
if PA_SPLIT_VALUE:
attn_weights = attn_weights.split(block_size, dim=-1)
else:
values = [torch.cat(values, dim=-2)]
attn_weights = [attn_weights]
if query_heads != kv_heads:
values = [v.unflatten(1, (kv_heads, 1)) for v in values]
attn_weights = [torch.matmul(a, v) for a, v in zip(attn_weights, values)]
attn_weights = [matmul_av_op(a, v) for a, v in zip(attn_weights, values)]
if query_heads != kv_heads:
attn_weights = [a.flatten(1, 2) for a in attn_weights]
attn_weights = sum(attn_weights)
Expand Down Expand Up @@ -148,6 +157,9 @@ def prompt_attention(
attn_bias: Optional[torch.Tensor] = None,
p: float = 0.0,
scale: Optional[float] = None,
matmul_qk_op=torch.matmul,
softmax_op=torch.softmax,
matmul_av_op=torch.matmul,
) -> torch.Tensor:
habana_profiler = Profiler()
start_time = time.time()
Expand All @@ -164,11 +176,11 @@ def prompt_attention(
value = value.unflatten(1, (kv_heads, 1))
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(2)
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
attn_weights.add_(attn_bias)
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_weights = torch.matmul(attn_weights, value)
attn_weights = softmax_op(attn_weights, dim=-1)
attn_weights = matmul_av_op(attn_weights, value)
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
attn_weights = attn_weights.transpose(1, 2)
Expand Down
Loading

0 comments on commit b44f672

Please sign in to comment.