diff --git a/examples/nlp/gpt/conf/gpt_dpo.yaml b/examples/nlp/gpt/conf/gpt_dpo.yaml index 2a165bf9d..192265244 100644 --- a/examples/nlp/gpt/conf/gpt_dpo.yaml +++ b/examples/nlp/gpt/conf/gpt_dpo.yaml @@ -114,6 +114,7 @@ model: data_impl: jsonl splits_string: null seq_length: ${model.encoder_seq_length} + pad_length_to_multiple_of: null # If using sequence_parallel, ensure divisible by tensor_model_parallel_size skip_warmup: True num_workers: 0 reset_position_ids: False # Reset position ids after end-of-document token diff --git a/examples/nlp/gpt/train_gpt_dpo.py b/examples/nlp/gpt/train_gpt_dpo.py index 49784d485..f16a9dacf 100644 --- a/examples/nlp/gpt/train_gpt_dpo.py +++ b/examples/nlp/gpt/train_gpt_dpo.py @@ -20,7 +20,7 @@ from nemo.utils import logging from nemo.utils.exp_manager import exp_manager from nemo_aligner.algorithms.dpo import DPOTrainer, dpo_custom_collate -from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets +from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets, identity_collate from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel from nemo_aligner.utils.distributed import Timer from nemo_aligner.utils.train_script_utils import ( @@ -85,7 +85,7 @@ def main(cfg) -> None: # use the entire dataset train_valid_test_num_samples = [-1 * cfg.model.global_batch_size] * 3 - train_ds, validation_ds, test_ds = build_train_valid_test_dpo_datasets( + train_ds, validation_ds, _ = build_train_valid_test_dpo_datasets( cfg=cfg.model, data_prefix=cfg.model.data.data_prefix, data_impl=cfg.model.data.data_impl, @@ -104,13 +104,7 @@ def main(cfg) -> None: gbs=cfg.model.global_batch_size, load_gbs=True, pad_samples_to_global_batch_size=False, - collate_fn=partial( - dpo_custom_collate, - eos_id=ptl_model.tokenizer.eos_id, - reset_position_ids=cfg.model.data.get("reset_position_ids", False), - reset_attention_mask=cfg.model.data.get("reset_attention_mask", False), - eod_mask_loss=cfg.model.data.get("eod_mask_loss", False), - ), + collate_fn=identity_collate, ) val_dataloader = build_dataloader( @@ -121,13 +115,7 @@ def main(cfg) -> None: gbs=cfg.model.global_batch_size, load_gbs=True, pad_samples_to_global_batch_size=False, - collate_fn=partial( - dpo_custom_collate, - eos_id=ptl_model.tokenizer.eos_id, - reset_position_ids=cfg.model.data.get("reset_position_ids", False), - reset_attention_mask=cfg.model.data.get("reset_attention_mask", False), - eod_mask_loss=cfg.model.data.get("eod_mask_loss", False), - ), + collate_fn=identity_collate, use_random_sampler=False, ) @@ -147,6 +135,14 @@ def main(cfg) -> None: train_dataloader=train_dataloader, val_dataloader=val_dataloader, test_dataloader=None, + collate_fn=partial( + dpo_custom_collate, + eos_id=ptl_model.tokenizer.eos_id, + reset_position_ids=cfg.model.data.get("reset_position_ids", False), + reset_attention_mask=cfg.model.data.get("reset_attention_mask", False), + eod_mask_loss=cfg.model.data.get("eod_mask_loss", False), + pad_length_to_multiple_of=cfg.model.data.get("pad_length_to_multiple_of", None), + ), logger=logger, ckpt_callback=ckpt_callback, run_timer=timer, diff --git a/nemo_aligner/algorithms/dpo.py b/nemo_aligner/algorithms/dpo.py index 6b2103328..626b7b58e 100644 --- a/nemo_aligner/algorithms/dpo.py +++ b/nemo_aligner/algorithms/dpo.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from collections import defaultdict from statistics import mean +from typing import Any, Protocol import torch +import torch.distributed from omegaconf.dictconfig import DictConfig from tqdm import tqdm @@ -24,13 +27,34 @@ ) from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids from nemo.utils import logging +from nemo_aligner.utils import parallel_state from nemo_aligner.utils.distributed import SyncTimer from nemo_aligner.utils.train_utils import clip_gradients from nemo_aligner.utils.trainer_utils import check_progress, compute_limit_batches, compute_num_steps_per_epoch from nemo_aligner.utils.utils import clear_memory -def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False): +class DistributedCollateFunction(Protocol): + def __call__(self, batch: list[dict], **kwargs: Any) -> dict[str, torch.Tensor]: + ... + + +def dpo_custom_collate( + batch: list[dict], + eos_id: int, + reset_position_ids: bool = False, + reset_attention_mask: bool = False, + eod_mask_loss: bool = False, + pad_length_to_multiple_of: int | None = None, +) -> dict[str, torch.Tensor]: + """ + Transposes minibatch from list[dict] -> dict[Tensor] and also pads + + This collate happens outside of the torch data loader and is not compatible with the multiprocessing + logic due to requiring communication collectives. + """ + if pad_length_to_multiple_of is not None and pad_length_to_multiple_of < 0: + raise ValueError(f"{pad_length_to_multiple_of=} must be >= 0") chosen_tokens = [item["chosen"] for item in batch] rejected_tokens = [item["rejected"] for item in batch] chosen_lengths = torch.LongTensor([item["chosen_length"] for item in batch]) @@ -44,9 +68,32 @@ def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_ rejected_tokens = torch.nn.utils.rnn.pad_sequence(rejected_tokens, batch_first=True, padding_value=eos_id) chosen_labels = torch.nn.utils.rnn.pad_sequence(chosen_labels, batch_first=True, padding_value=-100) rejected_labels = torch.nn.utils.rnn.pad_sequence(rejected_labels, batch_first=True, padding_value=-100) + assert chosen_tokens.shape == rejected_tokens.shape + assert chosen_labels.shape == rejected_labels.shape + + if pad_length_to_multiple_of: + # Assumes both chosen and rejected match + max_seq_len = torch.tensor(chosen_tokens.shape[1], device=torch.cuda.current_device()) + torch.distributed.all_reduce( + max_seq_len, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_data_parallel_group() + ) + + padded_max_len = math.ceil(max_seq_len / pad_length_to_multiple_of) * pad_length_to_multiple_of + chosen_tokens = torch.nn.functional.pad( + chosen_tokens, (0, padded_max_len - chosen_tokens.shape[1]), mode="constant", value=eos_id + ) + rejected_tokens = torch.nn.functional.pad( + rejected_tokens, (0, padded_max_len - rejected_tokens.shape[1]), mode="constant", value=eos_id + ) + chosen_labels = torch.nn.functional.pad( + chosen_labels, (0, padded_max_len - chosen_labels.shape[1]), mode="constant", value=-100 + ) + rejected_labels = torch.nn.functional.pad( + rejected_labels, (0, padded_max_len - rejected_labels.shape[1]), mode="constant", value=-100 + ) attention_mask, _, position_ids = get_ltor_masks_and_position_ids( - chosen_tokens, eos_id, reset_position_ids, reset_attention_mask, eod_mask_loss, + chosen_tokens.cuda(), eos_id, reset_position_ids, reset_attention_mask, eod_mask_loss, ) assert attention_mask.ndim == 4, "attention_mask is incorrect shape for dpo_custom_collate" if attention_mask.shape[0] == 1: @@ -70,8 +117,7 @@ def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_ class DPOTrainer: - """Trainer to coordinate DPO training - """ + """Trainer to coordinate DPO training""" def __init__( self, @@ -82,6 +128,7 @@ def __init__( train_dataloader, val_dataloader, test_dataloader, + collate_fn: DistributedCollateFunction, logger, ckpt_callback, run_timer, @@ -90,6 +137,7 @@ def __init__( self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader self.test_dataloader = test_dataloader + self.collate_fn = collate_fn self.logger = logger self.cfg = cfg self.optimizer = optimizer @@ -317,6 +365,7 @@ def augment_dataloader(self, dataloader): while True: try: batch = next(iter_dataloader) + batch = self.collate_fn(batch) logprobs = self.model.get_ref_policy_logprobs(batch).cpu() chosen_logps, reject_logps = torch.split(logprobs, len(logprobs) // 2, dim=0) batch["ref_policy_log_probs_chosen"] = chosen_logps diff --git a/nemo_aligner/data/nlp/builders.py b/nemo_aligner/data/nlp/builders.py index a61fb46f9..800a4808d 100644 --- a/nemo_aligner/data/nlp/builders.py +++ b/nemo_aligner/data/nlp/builders.py @@ -22,6 +22,7 @@ import numpy as np import torch +import torch.utils.data from omegaconf.dictconfig import DictConfig from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( @@ -123,7 +124,7 @@ def build_train_valid_test_datasets( cfg=cfg, data_prefix=data_prefix["validation"], data_impl=data_impl, - num_samples=int(train_valid_test_num_samples[0]), + num_samples=int(train_valid_test_num_samples[1]), seq_length=seq_length, seed=seed, tokenizer=tokenizer, @@ -134,7 +135,7 @@ def build_train_valid_test_datasets( cfg=cfg, data_prefix=data_prefix["test"], data_impl=data_impl, - num_samples=int(train_valid_test_num_samples[0]), + num_samples=int(train_valid_test_num_samples[2]), seq_length=seq_length, seed=seed, tokenizer=tokenizer, @@ -318,8 +319,7 @@ def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, i def collate_with_pad_to_max_batch(max_seqlen, tokenizer_eos_id, cfg, generate_masks_and_position_ids=True): - """collate function that pads each sequence to the max in the batch - """ + """collate function that pads each sequence to the max in the batch""" return partial( collate_with_batch_max_sequence_length, response_token_length=max_seqlen, @@ -331,6 +331,14 @@ def collate_with_pad_to_max_batch(max_seqlen, tokenizer_eos_id, cfg, generate_ma ) +def identity_collate(batch): + """ + Useful since torch's data loader's default collate will crash with ragged sequences. + Also, this function is needed b/c lambda functions aren't pickle-able. + """ + return batch + + def build_dataloader( cfg, dataset, diff --git a/nemo_aligner/data/nlp/datasets.py b/nemo_aligner/data/nlp/datasets.py index 5b8acaeb6..d03101449 100644 --- a/nemo_aligner/data/nlp/datasets.py +++ b/nemo_aligner/data/nlp/datasets.py @@ -134,7 +134,7 @@ def __getitem__(self, idx): class RewardModelDataset(Dataset): """This class assumes that we only have 2 responses per prompt that is ranked. Chosen is the better - one(even index) whereas Rejected is the worse response(odd index) + one(even index) whereas Rejected is the worse response(odd index) """ def __init__( @@ -184,8 +184,7 @@ def encode(self, text): return text_ids, len(text_ids) def __getitem__(self, idx, multiple=2): - """Returns a pair of chosen/rejected pairs, and their respective lengths. - """ + """Returns a pair of chosen/rejected pairs, and their respective lengths.""" found = False while not found: chosen = self.data[multiple * idx] @@ -240,16 +239,16 @@ def __getitem__(self, idx, multiple=2): class DPOModelDataset(Dataset): """This class works only with jsonl files. It assumes each line of the json file is a dictionary - with the prompt, along with the chosen response (response only, no prompt), and the rejected response - (response only, no prompt). This Dataset will combine the prompt with each corresponding chosen and - rejected response, and then tokenize it. It also returns the labels for each, which is the response tokens - with -100 for the prompt part. - - WARNING: This class will tokenize the text, but it will raise an exception on model max seq len violations! - Meaning it will not truncate tokens to fit to model max seq len, because of special prefix/suffix - strings such as , it would not know where it is safe to truncate for each model. Therefore, - the user must do all truncation logic in their preprocessing step when generating the jsonl - used by this class. Put all special truncation logic there specific to your model. + with the prompt, along with the chosen response (response only, no prompt), and the rejected response + (response only, no prompt). This Dataset will combine the prompt with each corresponding chosen and + rejected response, and then tokenize it. It also returns the labels for each, which is the response tokens + with -100 for the prompt part. + + WARNING: This class will tokenize the text, but it will raise an exception on model max seq len violations! + Meaning it will not truncate tokens to fit to model max seq len, because of special prefix/suffix + strings such as , it would not know where it is safe to truncate for each model. Therefore, + the user must do all truncation logic in their preprocessing step when generating the jsonl + used by this class. Put all special truncation logic there specific to your model. """ def __init__( @@ -293,8 +292,7 @@ def encode(self, text, append_eod=False): return text_ids, len(text_ids) def __getitem__(self, idx): - """Returns a pair of chosen/rejected pairs, their respective lengths, and labels. - """ + """Returns a pair of chosen/rejected pairs, their respective lengths, and labels.""" payload = self.data[idx] prompt, prompt_len = self.encode(payload["prompt"], append_eod=False) chosen, chosen_len = self.encode( @@ -308,15 +306,14 @@ def __getitem__(self, idx): chosen_labels = ([-100] * prompt_len) + chosen[prompt_len:] reject_labels = ([-100] * prompt_len) + reject[prompt_len:] - assert chosen[0:prompt_len] == prompt, "the tokenizer for DPO has merged tokens between prompt and response" - assert reject[0:prompt_len] == prompt, "the tokenizer for DPO has merged tokens between prompt and response" + assert ( + chosen[0:prompt_len] == prompt + ), f"The tokenizer for DPO has merged tokens between prompt and response for {idx=}:\n[[prompt]]={repr(payload['prompt'])}\n[[chosen_response]]={repr(payload['chosen_response'])}" + assert ( + reject[0:prompt_len] == prompt + ), f"The tokenizer for DPO has merged tokens between prompt and response for {idx=}:\n[[prompt]]={repr(payload['prompt'])}\n[[rejected_response]]={repr(payload['rejected_response'])}" max_curr_seq_len = max(chosen_len, reject_len) - if max_curr_seq_len > self.seq_length: - logging.warning( - f"WARNING: Tokenized text exceeds max seq length ({max_curr_seq_len} vs {self.seq_length})." - + f"The example will be ignored." - ) chosen_tokens = torch.nn.functional.pad( torch.LongTensor(chosen), (0, max_curr_seq_len - chosen_len), mode="constant", value=self.eos_id @@ -333,6 +330,10 @@ def __getitem__(self, idx): # ignore the example whose tokenized text exceeds max seq length. if max_curr_seq_len > self.seq_length: + logging.warning( + f"WARNING: Tokenized text exceeds max seq length ({max_curr_seq_len} vs {self.seq_length})." + + f"The example will be ignored." + ) chosen_tokens = chosen_tokens[: self.nograd_length] rejected_tokens = rejected_tokens[: self.nograd_length] labels_chosen_tokens = torch.ones_like(chosen_tokens) * (-100) @@ -355,16 +356,16 @@ def __getitem__(self, idx): class KTOModelDataset(Dataset): """This class works only with jsonl files. It assumes each line of the json file is a dictionary - with the prompt, along with the response (response only, no prompt), and the status denoting whether the response is - chosen or rejected. This Dataset will combine the prompt with the corresponding response, and then tokenize it. It - will also create a score field that has 1 if the sample is chosen and 0 if rejected. It also returns the labels for - each, which is the response tokens with -100 for the prompt part. - - WARNING: This class will tokenize the text, but it will raise an exception on model max seq len violations! - Meaning it will not truncate tokens to fit to model max seq len, because of special prefix/suffix - strings such as , it would not know where it is safe to truncate for each model. Therefore, - the user must do all truncation logic in their preprocessing step when generating the jsonl - used by this class. Put all special truncation logic there specific to your model. + with the prompt, along with the response (response only, no prompt), and the status denoting whether the response is + chosen or rejected. This Dataset will combine the prompt with the corresponding response, and then tokenize it. It + will also create a score field that has 1 if the sample is chosen and 0 if rejected. It also returns the labels for + each, which is the response tokens with -100 for the prompt part. + + WARNING: This class will tokenize the text, but it will raise an exception on model max seq len violations! + Meaning it will not truncate tokens to fit to model max seq len, because of special prefix/suffix + strings such as , it would not know where it is safe to truncate for each model. Therefore, + the user must do all truncation logic in their preprocessing step when generating the jsonl + used by this class. Put all special truncation logic there specific to your model. """ def __init__( @@ -448,14 +449,14 @@ def __getitem__(self, idx): class RegressionRewardModelDataset(RewardModelDataset): - """This class assumes each line of the dataset file is a dictionary with "text" and "label" field, - where "text" is a string representing the input prompt, and "label" is a list of float or int values. - Note that when training the model with multiple datasets which contain different attributes, - we should set missing attributes to model.regression.loss_mask_val(according to training_rm.yaml) - in the dataset files so that their losses are masked. At least one attribute should be present for each sample. - - WARNING: It's recommended to preprocess your data in advance to ensure all samples are within self.seq_length. - Otherwise if all samples in a batch are longer than self.seq_length, you may get NaN loss. + """This class assumes each line of the dataset file is a dictionary with "text" and "label" field, + where "text" is a string representing the input prompt, and "label" is a list of float or int values. + Note that when training the model with multiple datasets which contain different attributes, + we should set missing attributes to model.regression.loss_mask_val(according to training_rm.yaml) + in the dataset files so that their losses are masked. At least one attribute should be present for each sample. + + WARNING: It's recommended to preprocess your data in advance to ensure all samples are within self.seq_length. + Otherwise if all samples in a batch are longer than self.seq_length, you may get NaN loss. """ def __init__( diff --git a/tests/conftest.py b/tests/conftest.py index a9f49b28f..8ac1c2af7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -401,7 +401,7 @@ def pytest_collection_modifyitems(config, items): def pytest_sessionstart(session): # Remove the file at the start of the session, if it exists - if os.path.exists(SUCCESS_FILE): + if os.path.exists(SUCCESS_FILE) and os.environ["LOCAL_RANK"] == "0": os.remove(SUCCESS_FILE) diff --git a/tests/test_datasets.py b/tests/test_datasets.py new file mode 100644 index 000000000..01425357b --- /dev/null +++ b/tests/test_datasets.py @@ -0,0 +1,310 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from functools import partial +from tempfile import TemporaryDirectory + +import pytest +import torch.distributed +from omegaconf import OmegaConf + +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo_aligner.algorithms.dpo import dpo_custom_collate +from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets +from nemo_aligner.utils import parallel_state + + +@pytest.fixture +def llama3_tokenizer(): + return AutoTokenizer("meta-llama/Meta-Llama-3-8b") + + +@pytest.fixture +def str_to_list_tokenizer(): + class StringToListTokenizer: + eos_id: int = -1 + + def text_to_ids(self, text: str) -> list[int]: + return [int(x) for x in text.split()] + + return StringToListTokenizer() + + +@pytest.fixture +def make_tmp_jsonl(): + with TemporaryDirectory() as tmp_dir: + + def write_jsonl(jsonl: list[dict], prefix="tmp"): + jsonl_path = f"{tmp_dir}/{prefix}.jsonl" + with open(jsonl_path, "w") as f: + for obj in jsonl: + f.write(json.dumps(obj) + "\n") + return jsonl_path + + yield write_jsonl + + +@pytest.mark.run_only_on("GPU") +def test_dpo_loader(init_model_parallel, make_tmp_jsonl, llama3_tokenizer): + init_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) + + tmp_jsonl = make_tmp_jsonl( + [ + { + "prompt": f"System\n\nUser\n{'+'.join('1'*i)}={i}?\nAssistant\n", + "chosen_response": f"yes\n", + "rejected_response": f"no\n", + } + for i in range(1, 100, 10) + ] + ) + cfg = OmegaConf.create( + { + "model": { + "data": { + "data_prefix": {"train": [tmp_jsonl], "validation": [tmp_jsonl], "test": [tmp_jsonl]}, + "splits_string": None, + "num_workers": 2, + }, + "seed": 42, + } + } + ) + mbs = 1 + minibs = 2 + gbs = minibs * torch.distributed.get_world_size() + + train_ds, _, _ = build_train_valid_test_dpo_datasets( + cfg=cfg.model, + data_prefix=cfg.model.data.data_prefix, + data_impl="jsonl", + splits_string=None, + train_valid_test_num_samples=[-1 * gbs] * 3, + seq_length=1024, + seed=cfg.model.seed, + tokenizer=llama3_tokenizer, + ) + + train_dataloader = build_dataloader( + cfg=cfg, + dataset=train_ds, + consumed_samples=0, + mbs=mbs, + gbs=gbs, + load_gbs=True, + pad_samples_to_global_batch_size=False, + collate_fn=lambda x: x, + ) + + distributed_collate_fn = partial( + dpo_custom_collate, + eos_id=llama3_tokenizer.eos_id, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + ) + + num_mini_batches = 0 + for mbatch in train_dataloader: + mbatch = distributed_collate_fn(mbatch) + padded_seq_len = mbatch["chosen"].shape[1] + for in_name, in_tensor in mbatch.items(): + assert in_tensor.shape[0] == minibs, f"Expected {in_name}.shape={in_tensor.shape} first dim to be {minibs}" + + assert mbatch["chosen"].shape == (minibs, padded_seq_len) + assert mbatch["rejected"].shape == (minibs, padded_seq_len) + assert mbatch["chosen_length"].shape == (minibs,) + assert mbatch["rejected_length"].shape == (minibs,) + assert mbatch["chosen_labels"].shape == (minibs, padded_seq_len) + assert mbatch["rejected_labels"].shape == (minibs, padded_seq_len) + assert mbatch["attention_mask"].shape == (minibs, 1, padded_seq_len, padded_seq_len) + assert mbatch["position_ids"].shape == (minibs, padded_seq_len) + assert mbatch["chosen_rewards"].shape == (minibs,) + assert mbatch["rejected_rewards"].shape == (minibs,) + num_mini_batches += 1 + assert num_mini_batches == 2 + + +@pytest.mark.run_only_on("GPU") +def test_dpo_loader_original(init_model_parallel, make_tmp_jsonl, llama3_tokenizer): + init_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) + + tmp_jsonl = make_tmp_jsonl( + [ + { + "prompt": f"System\n\nUser\n{'+'.join('1'*i)}={i}?\nAssistant\n", + "chosen_response": f"yes\n", + "rejected_response": f"no\n", + } + for i in range(1, 100, 10) + ] + ) + cfg = OmegaConf.create( + { + "model": { + "data": { + "data_prefix": {"train": [tmp_jsonl], "validation": [tmp_jsonl], "test": [tmp_jsonl]}, + "splits_string": None, + "num_workers": 2, + }, + "seed": 42, + } + } + ) + mbs = 1 + minibs = 2 + gbs = minibs * torch.distributed.get_world_size() + + train_ds, _, _ = build_train_valid_test_dpo_datasets( + cfg=cfg.model, + data_prefix=cfg.model.data.data_prefix, + data_impl="jsonl", + splits_string=None, + train_valid_test_num_samples=[-1 * gbs] * 3, + seq_length=1024, + seed=cfg.model.seed, + tokenizer=llama3_tokenizer, + ) + + train_dataloader = build_dataloader( + cfg=cfg, + dataset=train_ds, + consumed_samples=0, + mbs=mbs, + gbs=gbs, + load_gbs=True, + pad_samples_to_global_batch_size=False, + collate_fn=lambda x: x, + ) + + distributed_collate_fn = partial( + dpo_custom_collate, + eos_id=llama3_tokenizer.eos_id, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + ) + + num_mini_batches = 0 + for mbatch in train_dataloader: + mbatch = distributed_collate_fn(mbatch) + padded_seq_len = mbatch["chosen"].shape[1] + for in_name, in_tensor in mbatch.items(): + assert in_tensor.shape[0] == minibs, f"Expected {in_name}.shape={in_tensor.shape} first dim to be {minibs}" + + assert mbatch["chosen"].shape == (minibs, padded_seq_len) + assert mbatch["rejected"].shape == (minibs, padded_seq_len) + assert mbatch["chosen_length"].shape == (minibs,) + assert mbatch["rejected_length"].shape == (minibs,) + assert mbatch["chosen_labels"].shape == (minibs, padded_seq_len) + assert mbatch["rejected_labels"].shape == (minibs, padded_seq_len) + assert mbatch["attention_mask"].shape == (minibs, 1, padded_seq_len, padded_seq_len) + assert mbatch["position_ids"].shape == (minibs, padded_seq_len) + assert mbatch["chosen_rewards"].shape == (minibs,) + assert mbatch["rejected_rewards"].shape == (minibs,) + num_mini_batches += 1 + assert num_mini_batches == 2 + + +@pytest.mark.run_only_on("GPU") +def test_dpo_loader_pad_to_multiple(init_model_parallel, make_tmp_jsonl, str_to_list_tokenizer): + init_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) + + tmp_jsonl = make_tmp_jsonl( + [ + { + "prompt": f"{' '.join(str(x) for x in range(i))} ", + "chosen_response": f"{i * 10}", + "rejected_response": f"{i * 100}", + } + for i in range(1, 100, 10) + ] + ) + cfg = OmegaConf.create( + { + "model": { + "data": { + "data_prefix": {"train": [tmp_jsonl], "validation": [tmp_jsonl], "test": [tmp_jsonl]}, + "splits_string": None, + "num_workers": 2, + }, + "seed": 42, + } + } + ) + mbs = 1 + minibs = 2 + gbs = minibs * torch.distributed.get_world_size() + expected_seq_len_multiple = 29 # pick a prime to make sure + + train_ds, _, _ = build_train_valid_test_dpo_datasets( + cfg=cfg.model, + data_prefix=cfg.model.data.data_prefix, + data_impl="jsonl", + splits_string=None, + train_valid_test_num_samples=[-1 * gbs] * 3, + seq_length=1024, + seed=cfg.model.seed, + tokenizer=str_to_list_tokenizer, + ) + + train_dataloader = build_dataloader( + cfg=cfg, + dataset=train_ds, + consumed_samples=0, + mbs=mbs, + gbs=gbs, + load_gbs=True, + pad_samples_to_global_batch_size=False, + collate_fn=lambda x: x, + ) + + distributed_collate_fn = partial( + dpo_custom_collate, + eos_id=str_to_list_tokenizer.eos_id, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + pad_length_to_multiple_of=expected_seq_len_multiple, + ) + + num_mini_batches = 0 + for mbatch in train_dataloader: + chosen_lengths = [len(x["chosen"]) for x in mbatch] + rejected_lengths = [len(x["rejected"]) for x in mbatch] + assert chosen_lengths == rejected_lengths + + assert len(set(chosen_lengths)) == len( + chosen_lengths + ), f"Lengths should be unique in this test: {chosen_lengths=}" + + mbatch = distributed_collate_fn(mbatch) + assert mbatch["chosen"].shape[1] % expected_seq_len_multiple == 0 + assert mbatch["rejected"].shape[1] % expected_seq_len_multiple == 0 + assert mbatch["chosen_labels"].shape[1] % expected_seq_len_multiple == 0 + assert mbatch["rejected_labels"].shape[1] % expected_seq_len_multiple == 0 + assert mbatch["attention_mask"].shape[2] % expected_seq_len_multiple == 0 + assert mbatch["attention_mask"].shape[3] % expected_seq_len_multiple == 0 + assert mbatch["position_ids"].shape[1] % expected_seq_len_multiple == 0 + + # Check that all ranks have the same length + max_chosen_seq_length = torch.tensor(mbatch["chosen"].shape[1], device="cuda") + torch.distributed.all_reduce( + max_chosen_seq_length, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_data_parallel_group() + ) + assert mbatch["chosen"].shape[1] == max_chosen_seq_length.item() + + num_mini_batches += 1 + + assert num_mini_batches == 2