Skip to content

Commit

Permalink
Merge remote-tracking branch 'nous/ring_attention_with_chat' into rin…
Browse files Browse the repository at this point in the history
…g_attention_with_chat_fa3
  • Loading branch information
jquesnelle committed Aug 17, 2024
2 parents c19a3cd + c0964e8 commit ecb5b37
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 232 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,4 @@ cython_debug/

checkpoints/
wandb/
slurm-*
282 changes: 51 additions & 231 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from nanotron.logging import log_rank
from nanotron.models import NanotronModel
from nanotron.nn.activations import ACT2FN
from nanotron.nn.layer_norm import TritonRMSNorm
from nanotron.nn.layer_norm import RMSNorm
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
Expand Down Expand Up @@ -130,191 +130,76 @@ def forward(
return x_out.type(dtype)


## Copy from transformers. Non interleaved version of RoPE. Will be refactored later
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def _compute_default_rope_parameters(
config: Optional[LlamaConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
**rope_kwargs,
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
if config is not None and len(rope_kwargs) > 0:
raise ValueError(
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
)
if len(rope_kwargs) > 0:
base = rope_kwargs["base"]
dim = rope_kwargs["dim"]
elif config is not None:
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)

attention_factor = 1.0 # Unused in this type of RoPE

# Compute the inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq, attention_factor

def _compute_llama3_parameters(
config: LlamaConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies for llama 3.1.
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin.
"""
# Gets the default RoPE parameters
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)

factor = config.rope_scaling["factor"] # `8` in the original implementation
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation

def apply_scaling(freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor

wavelen = 2 * math.pi / inv_freq
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)

return inv_freq_llama, attention_factor

ROPE_INIT_FUNCTIONS = {
"default": _compute_default_rope_parameters,
"llama3": _compute_llama3_parameters,
}
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)

class LlamaRotaryEmbedding(nn.Module):
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[LlamaConfig] = None,
):
def __init__(self, dim: int, end: int, theta: float = 500000.0):
super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.45"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq

def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
self.dim = dim
self.end = end
self.theta = theta
self.init_rotary_embeddings()

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
def init_rotary_embeddings(self):
inv_freq = 1.0 / (
self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu") / self.dim)
) # important to compute on CPU
inv_freq = apply_scaling(inv_freq)
self.register_buffer(
"inv_freq", torch.empty(self.dim // 2, dtype=torch.float, device="cuda"), persistent=False
)
self.inv_freq = self.inv_freq.to(
torch.float
) # make it float32 before copy to avoid precision loss during copy_
self.inv_freq.copy_(inv_freq)

@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)

# Core RoPE block
def forward(
self,
x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk]
position_ids: Optional[torch.LongTensor], # [batch_size, seq_length]
):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
Expand All @@ -336,71 +221,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed

## Copy from transformers. Non interleaved version of RoPE. Will be refactored later
# def rotate_half(x):
# """Rotates half the hidden dims of the input."""
# x1 = x[..., : x.shape[-1] // 2]
# x2 = x[..., x.shape[-1] // 2 :]
# return torch.cat((-x2, x1), dim=-1)


# class LlamaRotaryEmbedding(nn.Module):
# def __init__(self, dim: int, end: int, theta: float = 500000.0):
# super().__init__()
# self.dim = dim
# self.end = end
# self.theta = theta
# self.init_rotary_embeddings()

# def init_rotary_embeddings(self):
# inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda") / self.dim))
# self.register_buffer("inv_freq", inv_freq, persistent=False)

# @torch.no_grad()
# def forward(
# self,
# x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk]
# position_ids: Optional[torch.LongTensor], # [batch_size, seq_length]
# ):
# # x: [bs, num_attention_heads, seq_len, head_size]
# # print("rotary")
# inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
# position_ids_expanded = position_ids[:, None, :].float()
# # Force float32 since bfloat16 loses precision on long contexts
# # See https://github.com/huggingface/transformers/pull/29285
# device_type = x.device.type
# device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
# with torch.autocast(device_type=device_type, enabled=False):
# freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
# emb = torch.cat((freqs, freqs), dim=-1)
# cos = emb.cos()
# sin = emb.sin()
# return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


# def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
# """Applies Rotary Position Embedding to the query and key tensors.
# Args:
# q (`torch.Tensor`): The query tensor.
# k (`torch.Tensor`): The key tensor.
# cos (`torch.Tensor`): The cosine part of the rotary embedding.
# sin (`torch.Tensor`): The sine part of the rotary embedding.
# unsqueeze_dim (`int`, *optional*, defaults to 1):
# The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
# sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
# that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
# k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
# cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
# the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
# Returns:
# `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
# """
# cos = cos.unsqueeze(unsqueeze_dim)
# sin = sin.unsqueeze(unsqueeze_dim)
# q_embed = (q * cos) + (rotate_half(q) * sin)
# k_embed = (k * cos) + (rotate_half(k) * sin)
# return q_embed, k_embed

class GLUActivation(nn.Module):
def __init__(self, act_fn_name: str):
super().__init__()
Expand Down Expand Up @@ -639,11 +459,11 @@ def __init__(
else:
self.rotary_embedding = LlamaRotaryEmbedding(
dim=self.d_qk,
max_position_embeddings=config.max_position_embeddings,
#end=config.max_position_embeddings,
#theta=config.rope_theta,
base=config.rope_theta,
config=config,
end=config.max_position_embeddings,
theta=config.rope_theta,
#max_position_embeddings=config.max_position_embeddings,
#base=config.rope_theta,
#config=config,
)
self.rope_interleaved = config.rope_interleaved

Expand Down Expand Up @@ -955,7 +775,7 @@ def __init__(
layer_idx: int,
):
super().__init__()
self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attn = CausalSelfAttention(
config=config,
parallel_config=parallel_config,
Expand All @@ -965,7 +785,7 @@ def __init__(
)
self.layer_idx = layer_idx

self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)

self.recompute_layer = parallel_config.recompute_layer
Expand Down Expand Up @@ -1103,7 +923,7 @@ def __init__(

self.final_layer_norm = PipelineBlock(
p2p=self.p2p,
module_builder=TritonRMSNorm,
module_builder=RMSNorm,
module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps},
module_input_keys={"input"},
module_output_keys={"hidden_states"},
Expand Down
17 changes: 17 additions & 0 deletions src/nanotron/nn/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,20 @@ def forward(
is_rms_norm=True,
return_dropout_mask=return_dropout_mask,
)

# equivalent to TritonRMSNorm
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, input):
input_dtype = input.dtype
input = input.to(torch.float32)
variance = input.pow(2).mean(-1, keepdim=True)
input = input * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * input.to(input_dtype)
Loading

0 comments on commit ecb5b37

Please sign in to comment.