Skip to content

Commit

Permalink
Add gpt2 converter, hellaswag eval tool, misc fixes (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez authored Jun 20, 2024
1 parent 93f445f commit c4a8a3d
Show file tree
Hide file tree
Showing 12 changed files with 187 additions and 15 deletions.
55 changes: 51 additions & 4 deletions eole/bin/convert/convert_HF.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,26 +116,55 @@
".input_layernorm.weight": ".input_layernorm.weight",
".post_attention_layernorm.weight": ".post_attention_layernorm.weight",
}

# note: weights are transposed in Linear, not in Conv1D, which is used in HF
key_maps["GPT2LMHeadModel"] = {
"layer_prefix": "h.",
"tgt_emb.embeddings.weight": "wte.weight",
"generator.weight": "wte.weight", # shared with embeddings
"tgt_emb.pe.weight": "wpe.weight",
".self_attn.linear_query.": (".attn.c_attn.", ".t()[:hidden_size, ...]"),
".self_attn.linear_keys.": (
".attn.c_attn.",
".t()[hidden_size:2*hidden_size, ...]",
),
".self_attn.linear_values.": (".attn.c_attn.", ".t()[-hidden_size:, ...]"),
".self_attn.final_linear.": (".attn.c_proj.", ".t()"),
".mlp.gate_up_proj.": (".mlp.c_fc.", ".t()"),
".mlp.down_proj.": (".mlp.c_proj.", ".t()"),
".input_layernorm.weight": ".ln_1.weight",
".input_layernorm.bias": ".ln_1.bias",
".post_attention_layernorm.weight": ".ln_2.weight",
".post_attention_layernorm.bias": ".ln_2.bias",
"decoder.layer_norm.weight": "ln_f.weight",
"decoder.layer_norm.bias": "ln_f.bias",
}

ln_table = {
"LlamaForCausalLM": "rms",
"MistralForCausalLM": "rms",
"MixtralForCausalLM": "rms",
"PhiForCausalLM": "standard",
"Phi3ForCausalLM": "rms",
"GPT2LMHeadModel": "standard",
}

act_table = {
"LlamaForCausalLM": "gated-silu",
"MistralForCausalLM": "gated-silu",
"MixtralForCausalLM": "gated-silu",
"PhiForCausalLM": "gelu",
"Phi3ForCausalLM": "gated-silu",
"GPT2LMHeadModel": "gelu",
}

decoder_start_table = {
"LlamaForCausalLM": "<s>",
"MistralForCausalLM": "<s>",
"MixtralForCausalLM": "<s>",
"PhiForCausalLM": "",
"Phi3ForCausalLM": "<s>",
"GPT2LMHeadModel": "</s>",
}


Expand Down Expand Up @@ -448,13 +477,28 @@ def run(cls, args):
add_ffnbias = False
rotary_interleave = False
shared_layer_norm = False
max_relative_positions = -1
position_encoding = {}
left_pad = True

if arch == "PhiForCausalLM":
parallel_residual = True
shared_layer_norm = True
add_qkvbias = True
add_ffnbias = True
rotary_interleave = False
if arch == "GPT2LMHeadModel":
parallel_residual = False
shared_layer_norm = True
add_qkvbias = True
add_ffnbias = True
max_relative_positions = 0
position_encoding = {
"position_encoding": True,
"position_encoding_type": "Learned",
"n_positions": 1024,
}
left_pad = False

if wmap_path:
with open(wmap_path, encoding="utf-8") as fweights:
Expand Down Expand Up @@ -486,15 +530,15 @@ def get_load_ckpt(dir_path, file_path):
def get_weight(checkpoint, tensor_name):
if isinstance(checkpoint, dict):
if tensor_name in checkpoint.keys():
return checkpoint[tensor_name]
return checkpoint[tensor_name].contiguous()
else:
return None
else:
with safetensors.safe_open(
checkpoint, framework="pt", device="cpu"
) as f:
if tensor_name in f.keys():
return f.get_tensor(tensor_name)
return f.get_tensor(tensor_name).contiguous()
else:
return None

Expand All @@ -506,6 +550,7 @@ def get_weight(checkpoint, tensor_name):
if shard == 0:
targetlist = [
"tgt_emb.embeddings.weight",
"tgt_emb.pe.weight",
"decoder.layer_norm.weight",
"decoder.layer_norm.bias",
"generator.weight",
Expand Down Expand Up @@ -615,7 +660,7 @@ def get_weight(checkpoint, tensor_name):

if w is not None:
if type(source) == tuple:
w = eval("w" + srcmap)
w = eval("w" + srcmap).contiguous()
eole_safetensor[
"decoder.transformer_layers."
+ str(i)
Expand Down Expand Up @@ -840,14 +885,15 @@ def get_weight(checkpoint, tensor_name):
embeddings=EmbeddingsConfig(
src_word_vec_size=src_word_vec_size,
tgt_word_vec_size=tgt_word_vec_size,
**position_encoding,
),
# src_word_vec_size=src_word_vec_size,
# tgt_word_vec_size=tgt_word_vec_size,
model_type="text",
layer_norm=layer_norm,
norm_eps=norm_eps,
mlp_activation_fn=mlp_activation_fn,
max_relative_positions=-1,
max_relative_positions=max_relative_positions,
rotary_interleave=rotary_interleave,
rotary_theta=rope_theta,
rotary_dim=rotary_dim,
Expand All @@ -859,6 +905,7 @@ def get_weight(checkpoint, tensor_name):
add_ffnbias=add_ffnbias,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
left_pad=left_pad,
),
training=TrainingConfig(
model_dtype="fp16",
Expand Down
4 changes: 2 additions & 2 deletions eole/config/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class DecodingConfig(Config):
"Restrict tokens to the most likely until the cumulated probability "
"is over p. In range [0,1]. (https://arxiv.org/abs/1904.09751)",
ge=0.0,
lt=1.0,
lte=1.0,
)
random_sampling_temp: float = Field(
default=1.0,
Expand Down Expand Up @@ -173,5 +173,5 @@ def _validate_running_config(self):
), "-replace_unk option can not be used with -gold_align enabled"
assert self.tgt, "-tgt should be specified with -gold_align"
# originally in validate_translate_opts_dynamic, not sure why
self.__dict__["share_vocab"] = False
# self.__dict__["share_vocab"] = False
return self
20 changes: 19 additions & 1 deletion eole/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,27 @@ class EmbeddingsConfig(Config):
)
position_encoding: bool = Field(
default=False,
description="Use a sin to mark relative words positions. "
description="Absolute position encoding, see position_encoding_type. "
"Necessary for non-RNN style models.",
)
position_encoding_type: PositionEncodingType = Field(
default=PositionEncodingType.SinusoidalInterleaved,
description="Type of positional encoding.",
)
n_positions: int | None = Field(
default=None,
description="Absolute number of positions to learn "
"position embeddings on (position_encoding_type: Learned)",
)

@model_validator(mode="after")
def validate_embeddings(self):
if self.position_encoding_type == PositionEncodingType.Learned:
assert self.n_positions is not None, (
"n_positions must be set if position_encoding_type "
f"is {PositionEncodingType.Learned}"
)
return self


class EncoderConfig(Config):
Expand Down Expand Up @@ -352,6 +366,10 @@ class BaseModelConfig(Config):
)
add_estimator: bool = Field(default=False, description="Add estimator layer")

left_pad: bool = Field(
default=False, description="Enable left-padding, useful for some LLMs."
)

# @computed_field()
# @property
# def brnn(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion eole/config/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class TrainingConfig(
)

model_dtype: Literal["fp32", "fp16"] = Field(
default="fp32", description="Data type of the model."
default="fp16", description="Data type of the model."
)
loss_scale: float = Field(
default=0.0,
Expand Down
1 change: 1 addition & 0 deletions eole/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class ModelTask(str, Enum):
class PositionEncodingType(str, Enum):
SinusoidalInterleaved = "SinusoidalInterleaved"
SinusoidalConcat = "SinusoidalConcat"
Learned = "Learned"


class ActivationFunction(str, Enum):
Expand Down
3 changes: 2 additions & 1 deletion eole/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __init__(self, config):
self.transforms_cls = get_transforms_cls(config._all_transform)
self.vocabs = self.predictor.vocabs
self.transforms = make_transforms(config, self.transforms_cls, self.vocabs)
self.transform_pipe = TransformPipe.build_from(self.transforms.values())

def _predict(self, infer_iter):
scores, estims, preds = self.predictor._predict(
Expand All @@ -166,7 +167,7 @@ def _predict(self, infer_iter):

def _score(self, infer_iter):
self.predictor.with_scores = True
self.return_gold_log_probs = True
self.predictor.return_gold_log_probs = True
return self.predictor._score(infer_iter)

def score_list_parallel(self, src):
Expand Down
8 changes: 6 additions & 2 deletions eole/inputters/dynamic_iterator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module that contain iterator used for dynamic data."""
import torch
from itertools import cycle
from eole.constants import CorpusTask, ModelTask
from eole.constants import CorpusTask # , ModelTask
from eole.inputters.text_corpus import get_corpora, build_corpora_iters
from eole.inputters.text_utils import (
text_sort_key,
Expand Down Expand Up @@ -136,6 +136,7 @@ def __init__(
stride=1,
offset=0,
score_threshold=0,
left_pad=False,
):
super(DynamicDatasetIter).__init__()
self.corpora = corpora
Expand All @@ -162,7 +163,9 @@ def __init__(
self.skip_empty_level = skip_empty_level
self.random_shuffler = RandomShuffler()
self.bucket_idx = 0
if task != CorpusTask.TRAIN and vocabs["data_task"] == ModelTask.LANGUAGE_MODEL:
# TODO: we might want to enable some hybrid mode (default left_pad True for LM, else False)
# if task != CorpusTask.TRAIN and vocabs["data_task"] == ModelTask.LANGUAGE_MODEL:
if task != CorpusTask.TRAIN and left_pad:
self.left_pad = True
else:
self.left_pad = False
Expand Down Expand Up @@ -229,6 +232,7 @@ def from_config(
score_threshold=0
if isinstance(config, PredictConfig)
else running_config.score_threshold,
left_pad=getattr(config.model, "left_pad", False),
)

def _init_datasets(self, worker_id):
Expand Down
9 changes: 8 additions & 1 deletion eole/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def build_src_emb(model_config, vocabs, running_config=None):
word_vocab_size=len(vocabs["src"]),
sparse=getattr(running_config, "optim", None) == "sparseadam",
freeze_word_vecs=model_config.embeddings.freeze_word_vecs_enc,
n_positions=model_config.embeddings.n_positions,
)
else:
src_emb = None
Expand All @@ -89,6 +90,7 @@ def build_tgt_emb(
word_vocab_size=len(vocabs["tgt"]),
sparse=getattr(running_config, "optim", None) == "sparseadam",
freeze_word_vecs=model_config.embeddings.freeze_word_vecs_dec,
n_positions=model_config.embeddings.n_positions,
)

if share_embeddings:
Expand Down Expand Up @@ -309,6 +311,8 @@ def inference_logic(self, checkpoint, running_config, vocabs, device_id=None):
# override gpu_ranks/world_size to prevent warnings
training_config.gpu_ranks = running_config.gpu_ranks
training_config.world_size = running_config.world_size
# retrieve share_vocab flag from checkpoint config
running_config.share_vocab = checkpoint["config"].share_vocab
# in fine we might have some nested Lora/QuantizeConfig that are updated from checkpoint values # noqa: E501
# should quant type be in model config or running config ?
if hasattr(training_config, "quant_type") and training_config.quant_type in [
Expand Down Expand Up @@ -574,7 +578,10 @@ def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset)
col_slice_start:col_slice_end,
row_slice_start:row_slice_end,
].size()
), "An error in model's partition and checkpoint's slice was detected"
), (
"An error in model's partition and checkpoint's slice was detected, "
f"[{name}, {module}, {param_name}, {param.data.size()}, {ckpt_t.size()}]"
)
if name + "." + param_name in buf_list:
if module.__class__.__name__ == "WQLinear_GEMM":
module.register_buffer(
Expand Down
28 changes: 26 additions & 2 deletions eole/modules/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.nn.utils import skip_init

from eole.utils.logging import logger
from eole.constants import PositionEncodingType


class SequenceTooLongError(Exception):
Expand Down Expand Up @@ -99,6 +100,7 @@ def __init__(
dropout=0,
sparse=False,
freeze_word_vecs=False,
n_positions=1024,
):
super(Embeddings, self).__init__()
self._validate_args()
Expand All @@ -119,7 +121,12 @@ def __init__(
self.dropout_p = dropout

self.position_encoding = position_encoding
if self.position_encoding:
self.position_encoding_type = position_encoding_type

if self.position_encoding_type == PositionEncodingType.Learned:
self.pe = nn.Embedding(n_positions, word_vec_size)
self.past_length = 0
elif self.position_encoding:
self.pe = PositionalEncoding(word_vec_size, position_encoding_type)

if freeze_word_vecs:
Expand Down Expand Up @@ -157,7 +164,24 @@ def forward(self, source, step=None):
FloatTensor: Word embeddings ``(batch, len, embedding_size)``
"""
emb = self.embeddings(source)
if self.position_encoding:
if self.position_encoding_type == PositionEncodingType.Learned:
if step == 0 or step is None:
# reset
self.past_length = 0
position_ids = torch.arange(
self.past_length,
source.size(-1) + self.past_length,
dtype=torch.long,
device=source.device,
)
position_ids = position_ids.unsqueeze(0)
position_emb = self.pe(position_ids)
emb += position_emb
if self.past_length == 0:
self.past_length += source.size(-1)
else:
self.past_length += 1
elif self.position_encoding:
emb = self.pe(emb, step)

if self.dropout_p > 0:
Expand Down
3 changes: 2 additions & 1 deletion eole/predict/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def _score_target(self, batch, enc_out, src_len):
)

log_probs[:, :, self._tgt_pad_idx] = 0
gold_log_probs = log_probs.gather(2, tgt)
tgt = tgt.unsqueeze(2)
gold_log_probs = log_probs.gather(2, tgt).squeeze(-1)
gold_scores = gold_log_probs.sum(dim=1).view(-1)

if self.return_gold_log_probs:
Expand Down
Loading

0 comments on commit c4a8a3d

Please sign in to comment.