Skip to content

Commit

Permalink
Add rope_scaling support for LLama3.1 (#356)
Browse files Browse the repository at this point in the history
Add support for rope scaling and FusedRoPE in LLama3.1
  • Loading branch information
kdamaszk authored Oct 3, 2024
1 parent da03d8b commit f848d27
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@940fdb7
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@bb56d3b
26 changes: 19 additions & 7 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from vllm.platforms import current_platform

if current_platform.is_hpu():
from vllm_hpu_extension.rotary_embed import HpuRotaryEmbedding
from vllm_hpu_extension.rotary_embed import (HpuLlama3RotaryEmbedding,
HpuRotaryEmbedding)


def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -943,12 +944,23 @@ def get_rope(
high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype,
scaling_factor, low_freq_factor,
high_freq_factor,
original_max_position)
if current_platform.is_hpu():
rotary_emb = HpuLlama3RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
low_freq_factor,
high_freq_factor,
original_max_position,
RoPEFallback=Llama3RotaryEmbedding)
else:
rotary_emb = Llama3RotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
dtype, scaling_factor, low_freq_factor, high_freq_factor,
original_max_position)
elif scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
Expand Down

0 comments on commit f848d27

Please sign in to comment.