Skip to content

Commit

Permalink
add ChatML
Browse files Browse the repository at this point in the history
  • Loading branch information
jquesnelle committed Aug 2, 2024
1 parent d5228bb commit e387b0f
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 8 deletions.
1 change: 1 addition & 0 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def get_dataloader_from_data_stage(
sequence_length=trainer.sequence_length,
train_on_completions_only=data.dataset.train_on_completions_only,
remove_cross_attention=data.dataset.remove_cross_attention,
chat_format=data.dataset.chat_format,
split=data.dataset.hf_dataset_split,
conversation_column_name=data.dataset.conversation_column_name,
dp_rank=trainer.parallel_context.dp_pg.rank(),
Expand Down
5 changes: 5 additions & 0 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nanotron.logging import get_logger
from nanotron.parallel.pipeline_parallel.engine import PipelineEngine
from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode
from nanotron.data.chat_tokenizer import ChatFormat

logger = get_logger(__name__)

Expand Down Expand Up @@ -112,6 +113,7 @@ class ChatDatasetsArgs:
hf_dataset: str
hf_dataset_split: str
conversation_column_name: str
chat_format: Union[str, ChatFormat] = ChatFormat.LLAMA3
# Debug
train_on_completions_only: bool = True
remove_cross_attention: bool = True
Expand All @@ -121,6 +123,8 @@ def __post_init__(self):
self.hf_dataset_split = "train"
if self.conversation_column_name is None:
self.conversation_column_name = "conversations"
if isinstance(self.chat_format, str):
self.chat_format = ChatFormat[self.chat_format.upper()]


@dataclass
Expand Down Expand Up @@ -449,6 +453,7 @@ def get_config_from_dict(
TensorParallelLinearMode: lambda x: TensorParallelLinearMode[x.upper()],
RecomputeGranularity: lambda x: RecomputeGranularity[x.upper()],
SamplerType: lambda x: SamplerType[x.upper()],
ChatFormat: lambda x: ChatFormat[x.upper()],
},
# strict_unions_match=True,
strict=True,
Expand Down
15 changes: 12 additions & 3 deletions src/nanotron/data/chat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from nanotron.data.chat_tokenizer import ChatTokenizer
from nanotron.data.chat_tokenizer import ChatTokenizer, ChatFormat
from nanotron.data.collator import (
build_labels,
build_labels_completions_only,
Expand Down Expand Up @@ -41,6 +41,7 @@ def __init__(
conversation_column_name: str,
train_on_completions_only: bool = True,
remove_cross_attention: bool = True,
chat_format: ChatFormat = ChatFormat.CHATML,
split: str = "train",
dp_rank: int = 0,
dp_ranks_size: int = 1,
Expand All @@ -54,7 +55,7 @@ def __init__(
# TODO(tj.solergibert) Support interleaving datasets

self.dataset_path = dataset_path
self.chat_tokenizer = ChatTokenizer(tokenizer_name_or_path)
self.chat_tokenizer = ChatTokenizer(tokenizer_name_or_path, chat_format)
self.sequence_length = sequence_length
self.conversation_column_name = conversation_column_name
self.skip_num_samples = skip_num_samples
Expand All @@ -79,7 +80,10 @@ def __init__(

# TODO(tj.solergibert) Delete (debug)
self.debug_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) # TODO delete debug
self.debug_tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['from'] + '<|end_header_id|>\n\n'+ message['value'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>' }}{% endif %}"
if chat_format == ChatFormat.LLAMA3:
self.debug_tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['from'] + '<|end_header_id|>\n\n'+ message['value'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>' }}{% endif %}"
elif chat_format == ChatFormat.CHATML:
self.debug_tokenizer.chat_template = "{{- bos_token }}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['from'] + '\n' + message['value'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

def __iter__(self):
max_buffer_token_len = 1 + self.sequence_length
Expand All @@ -93,6 +97,11 @@ def __iter__(self):

# TODO(tj.solergibert) Delete (debug). Check if HF apply_chat_template produces the same result as ChatTokenizer
# The [:-1] of tokens is because apply chat template doesn't adds eos (NOT eot) token
for message in sample["conversations"]:
if message["from"] == "gpt":
message["from"] = "assistant"
if message["from"] == "human":
message["from"] = "user"
assert (
self.debug_tokenizer.apply_chat_template(sample["conversations"]) == tokens[:-1]
), f'{self.debug_tokenizer.apply_chat_template(sample["conversations"])}\n\n{tokens[:-1]}'
Expand Down
41 changes: 36 additions & 5 deletions src/nanotron/data/chat_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from typing import List, Tuple
from enum import Enum, auto

from transformers import AutoTokenizer


class ChatFormat(Enum):
LLAMA3 = auto()
CHATML = auto()


class ChatTokenizer:
"""
The ChatTokenizer encodes a conversation applying the Llama3 Chat Template and returns the role (Either User or Assistant) of each token
Expand All @@ -11,12 +17,22 @@ class ChatTokenizer:
tokenizer_name_or_path (str): A path to a directory containing vocabulary files required by the tokenizer or the model id of a predefined tokenizer hosted inside a model repo on the Hugging Face Hub.
"""

def __init__(self, tokenizer_name_or_path: str):
def __init__(self, tokenizer_name_or_path: str, chat_format: ChatFormat = ChatFormat.LLAMA3):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)

self._chat_format = chat_format
if chat_format == ChatFormat.LLAMA3:
self._header_start = "<|start_header_id|>"
self._header_end = "<|end_header_id|>\n\n"
self._turn_end = "<|eot_id|>"
elif chat_format == ChatFormat.CHATML:
self._header_start = "<|im_start|>"
self._header_end = "\n"
self._turn_end = "<|im_end|>\n"

# Add pad token if necessary
if self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({"pad_token": "<|eot_id|>"})
self.tokenizer.add_special_tokens({"pad_token": self._turn_end})

def __call__(self, conversation: List[dict]) -> Tuple[List[int], List[bool]]:
"""
Expand Down Expand Up @@ -65,19 +81,34 @@ def encode_message(self, message: dict) -> Tuple[List[int], List[int]]:
# single format and document it properly rather than supporting multiple formats, as each DATASET will need a different
# ChatTokenizer and the idea is that all Datasets share the same ChatTokenizer

role, is_input = self._get_role(message)

# Encode header
tokens = self.tokenizer.encode(
f"<|start_header_id|>{message['from']}<|end_header_id|>\n\n", add_special_tokens=False
f"{self._header_start}{role}{self._header_end}", add_special_tokens=False
)
is_completitions = [False] * len(tokens)

# Encode message
tokens.extend(self.tokenizer.encode(message["value"].strip(), add_special_tokens=False))

# Append <|eot_id|> token
tokens.extend(self.tokenizer.encode("<|eot_id|>", add_special_tokens=False))
tokens.extend(self.tokenizer.encode(self._turn_end, add_special_tokens=False))

# True if token belongs to assistant answer, False otherwise
is_completitions.extend([True if message["from"] == "gpt" else False] * (len(tokens) - len(is_completitions)))
is_completitions.extend([not is_input] * (len(tokens) - len(is_completitions)))

return tokens, is_completitions

def _get_role(self, message: dict) -> Tuple[str, bool]:
"""
Return the canonical role for a given message, as well as if its value
should be considered input (and therefore not trained on)
"""
role = message["from"]
if role == "gpt" or role == "assistant":
return "assistant", False
elif role == "human":
return "user", True
else:
return role, True

0 comments on commit e387b0f

Please sign in to comment.