Skip to content

Commit

Permalink
Add forward_hpu to RotaryEmbedding, remove custom module (HabanaAI#404)
Browse files Browse the repository at this point in the history
This PR removes the usage of custom HPU RotaryEmbedding modules, and
adds a forward_hpu method to existing RotaryEmbedding, for reusing
multiple derived implementations without the need of adding them to HPU
extension.
Mark_steps should not be needed within the test, but for whatever
reason, if they are not there, PT bridge crashes. To be investigated
later on. It does not affect actual model execution in any way I could
test/observe.
  • Loading branch information
kzawora-intel authored and xuechendi committed Oct 23, 2024
1 parent f841288 commit 0b565e0
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 31 deletions.
10 changes: 10 additions & 0 deletions tests/kernels/test_pos_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
from vllm.utils import seed_everything

from .allclose_default import get_default_atol, get_default_rtol
Expand All @@ -20,6 +21,9 @@
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
if current_platform.is_hpu():
import habana_frameworks.torch as htorch
CUDA_DEVICES = ['hpu']


@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
Expand Down Expand Up @@ -65,6 +69,8 @@ def test_rotary_embedding(
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope.forward_native(positions, query, key)
if current_platform.is_hpu():
htorch.core.mark_step()
out_query, out_key = rope.forward(positions, query, key)
# Compare the results.
torch.testing.assert_close(out_query,
Expand Down Expand Up @@ -120,6 +126,8 @@ def test_batched_rotary_embedding(
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope.forward_native(positions, query, key)
if current_platform.is_hpu():
htorch.core.mark_step()
out_query, out_key = rope.forward(positions,
query,
key,
Expand Down Expand Up @@ -193,6 +201,8 @@ def test_batched_rotary_embedding_multi_lora(
# because the custom kernel is in-place.
ref_query, ref_key = rope.forward_native(positions, query, key,
query_offsets)
if current_platform.is_hpu():
htorch.core.mark_step()
out_query, out_key = rope.forward(positions, query, key,
query_offsets.flatten())
# Compare the results.
Expand Down
94 changes: 63 additions & 31 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import torch.nn as nn

from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform


def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -195,6 +194,61 @@ def forward_xpu(
self.cos_sin_cache, self.is_neox_style)
return query, key

def forward_hpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
positions = positions.flatten()
if offsets is not None:
positions = positions + offsets
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions).view(
num_tokens, 1, -1)
cos, sin = cos_sin.chunk(2, dim=-1)
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
# to query hidden dimension, so the original tensors need to be
# expanded
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
# and expansion of cos/sin tensors via concatenation
# GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
# and expansion of cos/sin tensors via repeat_interleave
rope_mode: RotaryPosEmbeddingMode
if self.is_neox_style:
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
cos = torch.cat((cos, cos), dim=-1)
sin = torch.cat((sin, sin), dim=-1)
else:
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
sin = torch.repeat_interleave(sin,
2,
dim=-1,
output_size=cos_sin.shape[-1])
cos = torch.repeat_interleave(cos,
2,
dim=-1,
output_size=cos_sin.shape[-1])

query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0,
rope_mode)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"
Expand Down Expand Up @@ -918,17 +972,8 @@ def get_rope(
return _ROPE_DICT[key]

if rope_scaling is None:
if current_platform.is_hpu():
from vllm_hpu_extension.rotary_embed import HpuRotaryEmbedding
rotary_emb = HpuRotaryEmbedding(head_size,
rotary_dim,
max_position,
base,
is_neox_style,
RoPEFallback=RotaryEmbedding)
else:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position,
base, is_neox_style, dtype)
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
else:
scaling_type = rope_scaling[
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
Expand All @@ -941,25 +986,12 @@ def get_rope(
high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
if current_platform.is_hpu():
from vllm_hpu_extension.rotary_embed import (
HpuLlama3RotaryEmbedding)
rotary_emb = HpuLlama3RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
low_freq_factor,
high_freq_factor,
original_max_position,
RoPEFallback=Llama3RotaryEmbedding)
else:
rotary_emb = Llama3RotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
dtype, scaling_factor, low_freq_factor, high_freq_factor,
original_max_position)
rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype,
scaling_factor, low_freq_factor,
high_freq_factor,
original_max_position)
elif scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
Expand Down

0 comments on commit 0b565e0

Please sign in to comment.