Skip to content

Commit

Permalink
add llama 3.1 rope to llama_sft
Browse files Browse the repository at this point in the history
  • Loading branch information
jquesnelle committed Aug 8, 2024
1 parent e387b0f commit 2420384
Showing 1 changed file with 142 additions and 16 deletions.
158 changes: 142 additions & 16 deletions src/nanotron/models/llama_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# limitations under the License.
"""PyTorch LLaMa model."""

from typing import Dict, Optional, Union
from typing import Dict, Optional, Union, Tuple

import math
import torch
from flash_attn import flash_attn_varlen_func
from torch import nn
Expand Down Expand Up @@ -58,25 +59,92 @@

# NOTE(tj.solergibert) Copied from https://github.com/huggingface/transformers/blob/81233c069c166af033794134bd8888783ac49ebe/src/transformers/modeling_rope_utils.py#L29
def _compute_default_rope_parameters(
config: LlamaConfig,
) -> torch.Tensor:
config: Optional[LlamaConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
**rope_kwargs,
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config (LlamaConfig):
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
inv_freq (torch.Tensor)
Contains the inverse frequencies for the RoPE embeddings
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
with torch.autocast(device_type="cuda", enabled=False):
base = config.rope_theta # NOTE(tj.solergibert) 500000.0
dim = int(config.hidden_size // config.num_attention_heads) # NOTE(tj.solergibert) 128
if config is not None and len(rope_kwargs) > 0:
raise ValueError(
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
)
if len(rope_kwargs) > 0:
base = rope_kwargs["base"]
dim = rope_kwargs["dim"]
elif config is not None:
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)

attention_factor = 1.0 # Unused in this type of RoPE

# Compute the inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq, attention_factor

def _compute_llama3_parameters(
config: LlamaConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies for llama 3.1.
# Compute the inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)).cuda()
return inv_freq
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin.
"""
# Gets the default RoPE parameters
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)

factor = config.rope_scaling["factor"] # `8` in the original implementation
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in inv_freq:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device)
return inv_freq, attention_factor

ROPE_INIT_FUNCTIONS = {
"default": _compute_default_rope_parameters,
"llama3": _compute_llama3_parameters,
}

# NOTE(tj.solergibert) Copied from https://github.com/huggingface/transformers/blob/5f841c74b62754f186a8c06a684d491524b7bc03/src/transformers/models/llama/modeling_llama.py#L81
# NOTE(tj.solergibert) FlashAttention RoPEs are faster (triton), but currently they don't support position_ids
Expand All @@ -85,17 +153,71 @@ def _compute_default_rope_parameters(
class LlamaRotaryEmbedding(nn.Module):
def __init__(
self,
config: LlamaConfig,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[LlamaConfig] = None,
):
super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.45"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

self.inv_freq = _compute_default_rope_parameters(self.config)
# self.register_buffer("inv_freq", inv_freq, persistent=False) # NOTE(tj.solergibert) register_buffer casts to bf16!
# self.original_inv_freq = inv_freq
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq

def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)

# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
Expand All @@ -108,6 +230,10 @@ def forward(self, x, position_ids):
cos = emb.cos()
sin = emb.sin()

# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


Expand Down

0 comments on commit 2420384

Please sign in to comment.