Skip to content

Commit

Permalink
Inc on vLLM - Fix CR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nirda7 committed Aug 13, 2024
1 parent 3b34893 commit 19c96b8
Show file tree
Hide file tree
Showing 17 changed files with 110 additions and 80 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _is_hpu() -> bool:
is_hpu_available = True
try:
subprocess.run(["hl-smi"], capture_output=True, check=True)
except (FileNotFoundError, NotADirectoryError, PermissionError, subprocess.CalledProcessError):
except (FileNotFoundError, PermissionError, subprocess.CalledProcessError):
if not os.path.exists('/dev/accel/accel0') and not os.path.exists(
'/dev/accel/accel_controlD0'):
# last resort...
Expand Down Expand Up @@ -267,7 +267,7 @@ def _is_neuron() -> bool:
torch_neuronx_installed = True
try:
subprocess.run(["neuron-ls"], capture_output=True, check=True)
except (FileNotFoundError, NotADirectoryError, PermissionError, subprocess.CalledProcessError):
except (FileNotFoundError, PermissionError, subprocess.CalledProcessError):
torch_neuronx_installed = False
return torch_neuronx_installed or VLLM_TARGET_DEVICE == "neuron"

Expand Down
28 changes: 16 additions & 12 deletions vllm/attention/backends/habana_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
AttentionMetadata, AttentionType)
from vllm.attention.ops.habana_paged_attn import (HabanaPagedAttention,
HabanaPagedAttentionMetadata)
from vllm.hpu.utils import Matmul, Softmax, VLLMKVCache
from vllm.hpu import cache_ops
from vllm.hpu.utils import Matmul, Softmax, VLLMKVCache
from vllm.logger import init_logger

logger = init_logger(__name__)
Expand Down Expand Up @@ -144,11 +144,11 @@ def __init__(
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.qk_matmul = Matmul()
self.matmul_qk = Matmul()
self.softmax = Softmax()
self.av_matmul = Matmul()
self.key_cache = VLLMKVCache()
self.value_cache = VLLMKVCache()
self.matmul_av = Matmul()
self.k_cache = VLLMKVCache()
self.v_cache = VLLMKVCache()
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
self.position_bias = None
Expand Down Expand Up @@ -212,9 +212,13 @@ def forward(
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
num_kv_cache_passes, num_slots_available, indices, offsets = cache_ops.prepare_to_cache(key_cache, attn_metadata.slot_mapping)
key_cache = self.key_cache(key, key_cache, num_kv_cache_passes, num_slots_available, indices, offsets)
value_cache = self.value_cache(value, value_cache, num_kv_cache_passes, num_slots_available, indices, offsets)
num_kv_cache_passes, num_slots_available, indices, offsets = \
cache_ops.prepare_to_cache(key_cache,
attn_metadata.slot_mapping)
key_cache = self.k_cache(key, key_cache, num_kv_cache_passes,
num_slots_available, indices, offsets)
value_cache = self.v_cache(value, value_cache, num_kv_cache_passes,
num_slots_available, indices, offsets)

if attn_metadata.is_prompt:
# Prompt run.
Expand All @@ -240,9 +244,9 @@ def forward(
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
qk_matmul_op=self.qk_matmul,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
av_matmul_op=self.av_matmul,
matmul_av_op=self.matmul_av,
)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
Expand All @@ -266,8 +270,8 @@ def forward(
query, key_cache, value_cache, attn_metadata.block_tables,
attn_metadata.seq_lens_tensor, self.kv_cache_dtype,
self.num_kv_heads, self.scale, self.position_bias, k_scale,
v_scale, self.qk_matmul, self.softmax, self.av_matmul,
self.key_cache, self.value_cache)
v_scale, self.matmul_qk, self.softmax, self.matmul_av,
self.k_cache, self.v_cache)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)

Expand Down
8 changes: 4 additions & 4 deletions vllm/attention/ops/habana_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def forward_decode(
alibi_slopes: Optional[torch.Tensor],
k_scale: float,
v_scale: float,
qk_matmul_op,
matmul_qk_op,
softmax_op,
av_matmul_op,
matmul_av_op,
k_cache_cls,
v_cache_cls,
) -> torch.Tensor:
Expand All @@ -93,9 +93,9 @@ def forward_decode(
block_size,
alibi_slopes,
kv_cache_dtype,
qk_matmul_op,
matmul_qk_op,
softmax_op,
av_matmul_op,
matmul_av_op,
k_cache_cls,
v_cache_cls,
)
Expand Down
4 changes: 2 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,13 +474,13 @@ def _verify_args(self) -> None:
def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2", "hf8"):
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"):
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"Meanwhile, it may cause accuracy drop without a proper "
"scaling factor. "
"FP8_E4M3 is also supported on hpu (hf8).")
"Intel Gaudi (HPU) supports fp8 (using fp8_inc).")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

Expand Down
7 changes: 4 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3', 'hf8'],
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3', 'fp8_inc'],
default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3). '
'FP8_E4M3 is also supported on hpu (hf8).')
'Intel Gaudi (HPU) supports fp8 (using fp8_inc).')
parser.add_argument(
'--quantization-param-path',
type=nullable_str,
Expand Down Expand Up @@ -842,7 +842,8 @@ def create_engine_config(self, ) -> EngineConfig:
self.model_loader_extra_config[
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

device = device_config.device if self.weights_load_device is None else self.weights_load_device
device = device_config.device if self.weights_load_device is None else \
self.weights_load_device
load_config = LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
Expand Down
9 changes: 5 additions & 4 deletions vllm/hpu/cache_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,14 @@ def prepare_to_cache(cache, slot_mapping):
return num_kv_cache_passes, num_slots_available, indices, offsets


def insert_or_update_cache(input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offsets):
def insert_or_update_cache(input, cache, num_kv_cache_passes,
num_slots_available, block_indices, block_offsets):
for i in range(num_kv_cache_passes):
start_idx = i * num_slots_available
end_idx = (i + 1) * num_slots_available
cache.index_put_(
(block_indices[start_idx:end_idx], block_offsets[start_idx:end_idx]),
input[start_idx:end_idx])
cache.index_put_((block_indices[start_idx:end_idx],
block_offsets[start_idx:end_idx]),
input[start_idx:end_idx])


def swap_blocks(src, dst, block_mapping):
Expand Down
25 changes: 13 additions & 12 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch
import torch.nn.functional as F

import vllm.hpu.utils as hpu_utils
from vllm.logger import init_logger

logger = init_logger()
Expand Down Expand Up @@ -43,9 +42,9 @@ def paged_attention_v1(query,
block_size,
alibi_slopes=None,
kv_cache_dtype=None,
qk_matmul_op=torch.matmul,
matmul_qk_op=torch.matmul,
softmax_op=torch.softmax,
av_matmul_op=torch.matmul,
matmul_av_op=torch.matmul,
k_cache_cls=None,
v_cache_cls=None) -> None:
seq_len = block_tables.size(1)
Expand All @@ -60,20 +59,22 @@ def paged_attention_v1(query,
batch_size, 1, 1, -1))
query.mul_(scale)
query = query.unsqueeze(-2)
fetch_keys = fetch_from_cache if k_cache_cls is None else k_cache_cls.fetch_from_cache
fetch_keys = fetch_from_cache if k_cache_cls is None else \
k_cache_cls.fetch_from_cache
keys = fetch_keys(key_cache, block_tables, (0, 2, 3, 1))
if query_heads != kv_heads:
query = query.unflatten(1, (kv_heads, -1))
keys = [k.unflatten(1, (kv_heads, 1)) for k in keys]
mask = mask.unsqueeze(2)

attn_weights = torch.cat([qk_matmul_op(query, k) for k in keys], dim=-1)
attn_weights = torch.cat([matmul_qk_op(query, k) for k in keys], dim=-1)
if alibi_slopes is not None:
attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):,
-attn_weights.size(3):])
attn_weights = softmax_op(attn_weights.masked_fill(mask, min_inf), dim=-1)

fetch_values = fetch_from_cache if v_cache_cls is None else k_cache_cls.fetch_from_cache
fetch_values = fetch_from_cache if v_cache_cls is None else \
v_cache_cls.fetch_from_cache
values = fetch_values(value_cache, block_tables, (0, 2, 1, 3))
if PA_SPLIT_VALUE:
attn_weights = attn_weights.split(block_size, dim=-1)
Expand All @@ -82,7 +83,7 @@ def paged_attention_v1(query,
attn_weights = [attn_weights]
if query_heads != kv_heads:
values = [v.unflatten(1, (kv_heads, 1)) for v in values]
attn_weights = [av_matmul_op(a, v) for a, v in zip(attn_weights, values)]
attn_weights = [matmul_av_op(a, v) for a, v in zip(attn_weights, values)]
if query_heads != kv_heads:
attn_weights = [a.flatten(1, 2) for a in attn_weights]
attn_weights = sum(attn_weights)
Expand Down Expand Up @@ -132,9 +133,9 @@ def prompt_attention(
attn_bias: Optional[torch.Tensor] = None,
p: float = 0.0,
scale: Optional[float] = None,
qk_matmul_op = torch.matmul,
softmax_op = torch.softmax,
av_matmul_op = torch.matmul,
matmul_qk_op=torch.matmul,
softmax_op=torch.softmax,
matmul_av_op=torch.matmul,
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
Expand All @@ -147,11 +148,11 @@ def prompt_attention(
value = value.unflatten(1, (kv_heads, 1))
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(2)
attn_weights = qk_matmul_op(query * scale, key.transpose(-1, -2))
attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
attn_weights.add_(attn_bias)
attn_weights = softmax_op(attn_weights, dim=-1)
attn_weights = av_matmul_op(attn_weights, value)
attn_weights = matmul_av_op(attn_weights, value)
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
attn_weights = attn_weights.transpose(1, 2)
Expand Down
25 changes: 19 additions & 6 deletions vllm/hpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
# LICENSE file in the root directory of this source tree.
###############################################################################

import torch
from functools import wraps

import habana_frameworks.torch as htorch
import torch

from vllm.hpu.cache_ops import insert_or_update_cache


def with_mark_steps(fn):

@wraps(fn)
Expand All @@ -24,7 +26,9 @@ def wrapped(*args, **kwargs):

return wrapped


class Matmul(torch.nn.Module):

def __init__(self):
super(Matmul, self).__init__()

Expand All @@ -33,19 +37,28 @@ def forward(self, x, y):


class Softmax(torch.nn.Module):
def __init__(self):

def __init__(self):
super().__init__()

def forward(self, x, dim = None, inv_head = None):
def forward(self, x, dim=None, inv_head=None):
return torch.softmax(x, dim)


class VLLMKVCache(torch.nn.Module):

def __init__(self):
super(VLLMKVCache, self).__init__()

def forward(self, input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset):
insert_or_update_cache(input, cache, num_kv_cache_passes, num_slots_available, block_indices, block_offset)
def forward(self, input, cache, num_kv_cache_passes, num_slots_available,
block_indices, block_offset):
insert_or_update_cache(input, cache, num_kv_cache_passes,
num_slots_available, block_indices,
block_offset)
return cache

def fetch_from_cache(self, cache, blocks, permutations):
return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))]
return [
cache.index_select(0, blocks[:, i]).permute(permutations)
for i in range(blocks.size(1))
]
3 changes: 1 addition & 2 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ def forward_hpu(
self.variance_epsilon)
return x.view(orig_shape), residual

x = HPUFusedRMSNorm.apply(x, self.weight,
self.variance_epsilon)
x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
return x

def forward_xpu(
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config)
from vllm.model_executor.layers.quantization.inc import INCConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.inc import INCConfig

QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
Expand Down
16 changes: 9 additions & 7 deletions vllm/model_executor/layers/quantization/inc.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
Expand Down Expand Up @@ -59,14 +57,16 @@ def get_quant_method(self, layer: torch.nn.Module,
def get_scaled_act_names(self) -> List[str]:
return []

def get_min_capability(self) -> int:
@classmethod
def get_min_capability(cls) -> int:
# The AWQ kernel only supports Turing or newer GPUs.
return 75

@staticmethod
def get_config_filenames() -> List[str]:
return []


class INCLinearMethod(LinearMethodBase):
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
Expand All @@ -83,7 +83,9 @@ class INCLinearMethod(LinearMethodBase):
quant_config: The quantization config.
"""

def __init__(self, quant_config: INCConfig, separate_bias_add: bool = False):
def __init__(self,
quant_config: INCConfig,
separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add
self.quant_config = quant_config

Expand All @@ -110,4 +112,4 @@ def apply(self,
if bias is not None:
return F.linear(x, weight) + bias
return F.linear(x, weight)
return F.linear(x, weight, bias)
return F.linear(x, weight, bias)
Loading

0 comments on commit 19c96b8

Please sign in to comment.