Skip to content

Commit

Permalink
Add Llama31 Config (#10260)
Browse files Browse the repository at this point in the history
* add llama31 config

* Apply isort and black reformatting

Signed-off-by: suiyoubi <suiyoubi@users.noreply.github.com>

* fix init method

* typo

* revert llama3-70b init method std

---------

Signed-off-by: suiyoubi <suiyoubi@users.noreply.github.com>
Co-authored-by: suiyoubi <suiyoubi@users.noreply.github.com>
Co-authored-by: Chen Cui <chcui@nvidia.com>
  • Loading branch information
3 people authored Aug 27, 2024
1 parent 57aa305 commit 19668e5
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 2 deletions.
6 changes: 6 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
Llama2Config70B,
Llama3Config8B,
Llama3Config70B,
Llama31Config8B,
Llama31Config70B,
Llama31Config405B,
LlamaConfig,
LlamaModel,
MaskedTokenLossReduction,
Expand Down Expand Up @@ -93,6 +96,9 @@
"Llama2Config70B",
"Llama3Config8B",
"Llama3Config70B",
"Llama31Config8B",
"Llama31Config70B",
"Llama31Config405B",
"CodeLlamaConfig7B",
"CodeLlamaConfig13B",
"CodeLlamaConfig34B",
Expand Down
6 changes: 6 additions & 0 deletions nemo/collections/llm/gpt/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
Llama2Config70B,
Llama3Config8B,
Llama3Config70B,
Llama31Config8B,
Llama31Config70B,
Llama31Config405B,
LlamaConfig,
LlamaModel,
)
Expand Down Expand Up @@ -62,6 +65,9 @@
"Llama2Config70B",
"Llama3Config8B",
"Llama3Config70B",
"Llama31Config8B",
"Llama31Config70B",
"Llama31Config405B",
"NemotronConfig",
"Nemotron3Config4B",
"Nemotron3Config8B",
Expand Down
89 changes: 87 additions & 2 deletions nemo/collections/llm/gpt/model/llama.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Callable, Optional
Expand All @@ -9,6 +10,7 @@
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io, teardown
from nemo.utils import logging

if TYPE_CHECKING:
from transformers import LlamaConfig as HFLlamaConfig
Expand Down Expand Up @@ -66,7 +68,7 @@ class Llama3Config(GPTConfig):
num_query_groups: int = 8
hidden_dropout: float = 0.0
attention_dropout: float = 0.0
normalization = "RMSNorm"
normalization: str = "RMSNorm"
init_method_std: float = 0.01
layernorm_epsilon: float = 1.0e-05
add_bias_linear: bool = False
Expand All @@ -80,10 +82,31 @@ class Llama3Config(GPTConfig):
bias_dropout_fusion: bool = True
apply_rope_fusion: bool = True
share_embeddings_and_output_weights: bool = False
position_embedding_type = "rope"
position_embedding_type: str = "rope"
rotary_percent: float = 1.0


@dataclass
class Llama31Config(Llama3Config):
scale_factor: int = 8
low_freq_factor: int = 1
high_freq_factor: int = 4
old_context_len: int = 8192
init_method_std: float = 0.02

def configure_model(self, tokenizer) -> "MCoreGPTModel":
model = super().configure_model(tokenizer)
# Apply rope scaling for Llama3.1 model
model.rotary_pos_emb.inv_freq = apply_rope_scaling(
model.rotary_pos_emb.inv_freq,
factor=self.scale_factor,
low_freq_factor=self.low_freq_factor,
high_freq_factor=self.high_freq_factor,
old_context_len=self.old_context_len,
)
return model


@dataclass
class Llama3Config8B(Llama3Config):
rotary_base: int = 500_000
Expand All @@ -106,6 +129,38 @@ class Llama3Config70B(Llama3Config):
make_vocab_size_divisible_by: int = 128


@dataclass
class Llama31Config8B(Llama31Config):
rotary_base: int = 500_000
seq_length: int = 131072
num_layers: int = 32
hidden_size: int = 4096
ffn_hidden_size: int = 14336
num_attention_heads: int = 32


@dataclass
class Llama31Config70B(Llama31Config):
rotary_base: int = 500_000
seq_length: int = 131072
num_layers: int = 80
hidden_size: int = 8192
ffn_hidden_size: int = 28672
num_attention_heads: int = 64
make_vocab_size_divisible_by: int = 128


@dataclass
class Llama31Config405B(Llama31Config):
rotary_base: int = 500_000
seq_length: int = 131072
num_layers: int = 126
hidden_size: int = 16384
ffn_hidden_size: int = 53248
num_attention_heads: int = 128
make_vocab_size_divisible_by: int = 128


@dataclass
class CodeLlamaConfig7B(Llama2Config7B):
rotary_base: int = 1_000_000
Expand Down Expand Up @@ -365,13 +420,43 @@ def _export_linear_fc1(linear_fc1):
return gate_proj, up_proj


def apply_rope_scaling(
inv_freq,
factor: int = 8,
low_freq_factor: int = 1,
high_freq_factor: int = 4,
old_context_len: int = 8192,
):
logging.info(
f"Apply rope scaling with factor={factor}, low_freq_factor={low_freq_factor}, high_freq_factor={high_freq_factor}, old_context_len={old_context_len}."
)

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor

wavelen = 2 * math.pi / inv_freq
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)

return inv_freq_llama


__all__ = [
"LlamaConfig",
"Llama2Config7B",
"Llama2Config13B",
"Llama2Config70B",
"Llama3Config8B",
"Llama3Config70B",
"Llama31Config8B",
"Llama31Config70B",
"Llama31Config405B",
"CodeLlamaConfig7B",
"CodeLlamaConfig13B",
"CodeLlamaConfig34B",
Expand Down

0 comments on commit 19668e5

Please sign in to comment.