Skip to content

Commit

Permalink
feat: DPO support for global padding of seq_len to a multiple (NVIDIA…
Browse files Browse the repository at this point in the history
…#386)

Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: abukharin <abukharin@nvidia.com>
  • Loading branch information
terrykong authored and abukharin committed Nov 22, 2024
1 parent d1a913f commit 06a3ca4
Show file tree
Hide file tree
Showing 7 changed files with 430 additions and 66 deletions.
1 change: 1 addition & 0 deletions examples/nlp/gpt/conf/gpt_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ model:
data_impl: jsonl
splits_string: null
seq_length: ${model.encoder_seq_length}
pad_length_to_multiple_of: null # If using sequence_parallel, ensure divisible by tensor_model_parallel_size
skip_warmup: True
num_workers: 0
reset_position_ids: False # Reset position ids after end-of-document token
Expand Down
28 changes: 12 additions & 16 deletions examples/nlp/gpt/train_gpt_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
from nemo_aligner.algorithms.dpo import DPOTrainer, dpo_custom_collate
from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets
from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets, identity_collate
from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
Expand Down Expand Up @@ -85,7 +85,7 @@ def main(cfg) -> None:
# use the entire dataset
train_valid_test_num_samples = [-1 * cfg.model.global_batch_size] * 3

train_ds, validation_ds, test_ds = build_train_valid_test_dpo_datasets(
train_ds, validation_ds, _ = build_train_valid_test_dpo_datasets(
cfg=cfg.model,
data_prefix=cfg.model.data.data_prefix,
data_impl=cfg.model.data.data_impl,
Expand All @@ -104,13 +104,7 @@ def main(cfg) -> None:
gbs=cfg.model.global_batch_size,
load_gbs=True,
pad_samples_to_global_batch_size=False,
collate_fn=partial(
dpo_custom_collate,
eos_id=ptl_model.tokenizer.eos_id,
reset_position_ids=cfg.model.data.get("reset_position_ids", False),
reset_attention_mask=cfg.model.data.get("reset_attention_mask", False),
eod_mask_loss=cfg.model.data.get("eod_mask_loss", False),
),
collate_fn=identity_collate,
)

val_dataloader = build_dataloader(
Expand All @@ -121,13 +115,7 @@ def main(cfg) -> None:
gbs=cfg.model.global_batch_size,
load_gbs=True,
pad_samples_to_global_batch_size=False,
collate_fn=partial(
dpo_custom_collate,
eos_id=ptl_model.tokenizer.eos_id,
reset_position_ids=cfg.model.data.get("reset_position_ids", False),
reset_attention_mask=cfg.model.data.get("reset_attention_mask", False),
eod_mask_loss=cfg.model.data.get("eod_mask_loss", False),
),
collate_fn=identity_collate,
use_random_sampler=False,
)

Expand All @@ -147,6 +135,14 @@ def main(cfg) -> None:
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
test_dataloader=None,
collate_fn=partial(
dpo_custom_collate,
eos_id=ptl_model.tokenizer.eos_id,
reset_position_ids=cfg.model.data.get("reset_position_ids", False),
reset_attention_mask=cfg.model.data.get("reset_attention_mask", False),
eod_mask_loss=cfg.model.data.get("eod_mask_loss", False),
pad_length_to_multiple_of=cfg.model.data.get("pad_length_to_multiple_of", None),
),
logger=logger,
ckpt_callback=ckpt_callback,
run_timer=timer,
Expand Down
57 changes: 53 additions & 4 deletions nemo_aligner/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from collections import defaultdict
from statistics import mean
from typing import Any, Protocol

import torch
import torch.distributed
from omegaconf.dictconfig import DictConfig
from tqdm import tqdm

Expand All @@ -24,13 +27,34 @@
)
from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids
from nemo.utils import logging
from nemo_aligner.utils import parallel_state
from nemo_aligner.utils.distributed import SyncTimer
from nemo_aligner.utils.train_utils import clip_gradients
from nemo_aligner.utils.trainer_utils import check_progress, compute_limit_batches, compute_num_steps_per_epoch
from nemo_aligner.utils.utils import clear_memory


def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False):
class DistributedCollateFunction(Protocol):
def __call__(self, batch: list[dict], **kwargs: Any) -> dict[str, torch.Tensor]:
...


def dpo_custom_collate(
batch: list[dict],
eos_id: int,
reset_position_ids: bool = False,
reset_attention_mask: bool = False,
eod_mask_loss: bool = False,
pad_length_to_multiple_of: int | None = None,
) -> dict[str, torch.Tensor]:
"""
Transposes minibatch from list[dict] -> dict[Tensor] and also pads
This collate happens outside of the torch data loader and is not compatible with the multiprocessing
logic due to requiring communication collectives.
"""
if pad_length_to_multiple_of is not None and pad_length_to_multiple_of < 0:
raise ValueError(f"{pad_length_to_multiple_of=} must be >= 0")
chosen_tokens = [item["chosen"] for item in batch]
rejected_tokens = [item["rejected"] for item in batch]
chosen_lengths = torch.LongTensor([item["chosen_length"] for item in batch])
Expand All @@ -44,9 +68,32 @@ def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_
rejected_tokens = torch.nn.utils.rnn.pad_sequence(rejected_tokens, batch_first=True, padding_value=eos_id)
chosen_labels = torch.nn.utils.rnn.pad_sequence(chosen_labels, batch_first=True, padding_value=-100)
rejected_labels = torch.nn.utils.rnn.pad_sequence(rejected_labels, batch_first=True, padding_value=-100)
assert chosen_tokens.shape == rejected_tokens.shape
assert chosen_labels.shape == rejected_labels.shape

if pad_length_to_multiple_of:
# Assumes both chosen and rejected match
max_seq_len = torch.tensor(chosen_tokens.shape[1], device=torch.cuda.current_device())
torch.distributed.all_reduce(
max_seq_len, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_data_parallel_group()
)

padded_max_len = math.ceil(max_seq_len / pad_length_to_multiple_of) * pad_length_to_multiple_of
chosen_tokens = torch.nn.functional.pad(
chosen_tokens, (0, padded_max_len - chosen_tokens.shape[1]), mode="constant", value=eos_id
)
rejected_tokens = torch.nn.functional.pad(
rejected_tokens, (0, padded_max_len - rejected_tokens.shape[1]), mode="constant", value=eos_id
)
chosen_labels = torch.nn.functional.pad(
chosen_labels, (0, padded_max_len - chosen_labels.shape[1]), mode="constant", value=-100
)
rejected_labels = torch.nn.functional.pad(
rejected_labels, (0, padded_max_len - rejected_labels.shape[1]), mode="constant", value=-100
)

attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
chosen_tokens, eos_id, reset_position_ids, reset_attention_mask, eod_mask_loss,
chosen_tokens.cuda(), eos_id, reset_position_ids, reset_attention_mask, eod_mask_loss,
)
assert attention_mask.ndim == 4, "attention_mask is incorrect shape for dpo_custom_collate"
if attention_mask.shape[0] == 1:
Expand All @@ -70,8 +117,7 @@ def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_


class DPOTrainer:
"""Trainer to coordinate DPO training
"""
"""Trainer to coordinate DPO training"""

def __init__(
self,
Expand All @@ -82,6 +128,7 @@ def __init__(
train_dataloader,
val_dataloader,
test_dataloader,
collate_fn: DistributedCollateFunction,
logger,
ckpt_callback,
run_timer,
Expand All @@ -90,6 +137,7 @@ def __init__(
self.train_dataloader = train_dataloader
self.val_dataloader = val_dataloader
self.test_dataloader = test_dataloader
self.collate_fn = collate_fn
self.logger = logger
self.cfg = cfg
self.optimizer = optimizer
Expand Down Expand Up @@ -317,6 +365,7 @@ def augment_dataloader(self, dataloader):
while True:
try:
batch = next(iter_dataloader)
batch = self.collate_fn(batch)
logprobs = self.model.get_ref_policy_logprobs(batch).cpu()
chosen_logps, reject_logps = torch.split(logprobs, len(logprobs) // 2, dim=0)
batch["ref_policy_log_probs_chosen"] = chosen_logps
Expand Down
17 changes: 12 additions & 5 deletions nemo_aligner/data/nlp/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import numpy as np
import torch
from datasets import load_dataset
import torch.utils.data
from omegaconf.dictconfig import DictConfig

from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import (
Expand Down Expand Up @@ -212,7 +212,7 @@ def build_train_valid_test_datasets(
cfg=cfg,
data_prefix=data_prefix["validation"],
data_impl=data_impl,
num_samples=int(train_valid_test_num_samples[0]),
num_samples=int(train_valid_test_num_samples[1]),
seq_length=seq_length,
seed=seed,
tokenizer=tokenizer,
Expand All @@ -225,7 +225,7 @@ def build_train_valid_test_datasets(
cfg=cfg,
data_prefix=data_prefix["test"],
data_impl=data_impl,
num_samples=int(train_valid_test_num_samples[0]),
num_samples=int(train_valid_test_num_samples[2]),
seq_length=seq_length,
seed=seed,
tokenizer=tokenizer,
Expand Down Expand Up @@ -431,8 +431,7 @@ def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, i


def collate_with_pad_to_max_batch(max_seqlen, tokenizer_eos_id, cfg, generate_masks_and_position_ids=True):
"""collate function that pads each sequence to the max in the batch
"""
"""collate function that pads each sequence to the max in the batch"""
return partial(
collate_with_batch_max_sequence_length,
response_token_length=max_seqlen,
Expand All @@ -444,6 +443,14 @@ def collate_with_pad_to_max_batch(max_seqlen, tokenizer_eos_id, cfg, generate_ma
)


def identity_collate(batch):
"""
Useful since torch's data loader's default collate will crash with ragged sequences.
Also, this function is needed b/c lambda functions aren't pickle-able.
"""
return batch


def build_dataloader(
cfg,
dataset,
Expand Down
81 changes: 41 additions & 40 deletions nemo_aligner/data/nlp/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __getitem__(self, idx):

class RewardModelDataset(Dataset):
"""This class assumes that we only have 2 responses per prompt that is ranked. Chosen is the better
one(even index) whereas Rejected is the worse response(odd index)
one(even index) whereas Rejected is the worse response(odd index)
"""

def __init__(
Expand Down Expand Up @@ -237,8 +237,7 @@ def encode(self, text):
return text_ids, len(text_ids)

def __getitem__(self, idx, multiple=2):
"""Returns a pair of chosen/rejected pairs, and their respective lengths.
"""
"""Returns a pair of chosen/rejected pairs, and their respective lengths."""
found = False
while not found:
chosen = self.data[multiple * idx]
Expand Down Expand Up @@ -293,16 +292,16 @@ def __getitem__(self, idx, multiple=2):

class DPOModelDataset(Dataset):
"""This class works only with jsonl files. It assumes each line of the json file is a dictionary
with the prompt, along with the chosen response (response only, no prompt), and the rejected response
(response only, no prompt). This Dataset will combine the prompt with each corresponding chosen and
rejected response, and then tokenize it. It also returns the labels for each, which is the response tokens
with -100 for the prompt part.
WARNING: This class will tokenize the text, but it will raise an exception on model max seq len violations!
Meaning it will not truncate tokens to fit to model max seq len, because of special prefix/suffix
strings such as <extra_id_1>, it would not know where it is safe to truncate for each model. Therefore,
the user must do all truncation logic in their preprocessing step when generating the jsonl
used by this class. Put all special truncation logic there specific to your model.
with the prompt, along with the chosen response (response only, no prompt), and the rejected response
(response only, no prompt). This Dataset will combine the prompt with each corresponding chosen and
rejected response, and then tokenize it. It also returns the labels for each, which is the response tokens
with -100 for the prompt part.
WARNING: This class will tokenize the text, but it will raise an exception on model max seq len violations!
Meaning it will not truncate tokens to fit to model max seq len, because of special prefix/suffix
strings such as <extra_id_1>, it would not know where it is safe to truncate for each model. Therefore,
the user must do all truncation logic in their preprocessing step when generating the jsonl
used by this class. Put all special truncation logic there specific to your model.
"""

def __init__(
Expand Down Expand Up @@ -346,8 +345,7 @@ def encode(self, text, append_eod=False):
return text_ids, len(text_ids)

def __getitem__(self, idx):
"""Returns a pair of chosen/rejected pairs, their respective lengths, and labels.
"""
"""Returns a pair of chosen/rejected pairs, their respective lengths, and labels."""
payload = self.data[idx]
prompt, prompt_len = self.encode(payload["prompt"], append_eod=False)
chosen, chosen_len = self.encode(
Expand All @@ -361,15 +359,14 @@ def __getitem__(self, idx):
chosen_labels = ([-100] * prompt_len) + chosen[prompt_len:]
reject_labels = ([-100] * prompt_len) + reject[prompt_len:]

assert chosen[0:prompt_len] == prompt, "the tokenizer for DPO has merged tokens between prompt and response"
assert reject[0:prompt_len] == prompt, "the tokenizer for DPO has merged tokens between prompt and response"
assert (
chosen[0:prompt_len] == prompt
), f"The tokenizer for DPO has merged tokens between prompt and response for {idx=}:\n[[prompt]]={repr(payload['prompt'])}\n[[chosen_response]]={repr(payload['chosen_response'])}"
assert (
reject[0:prompt_len] == prompt
), f"The tokenizer for DPO has merged tokens between prompt and response for {idx=}:\n[[prompt]]={repr(payload['prompt'])}\n[[rejected_response]]={repr(payload['rejected_response'])}"

max_curr_seq_len = max(chosen_len, reject_len)
if max_curr_seq_len > self.seq_length:
logging.warning(
f"WARNING: Tokenized text exceeds max seq length ({max_curr_seq_len} vs {self.seq_length})."
+ f"The example will be ignored."
)

chosen_tokens = torch.nn.functional.pad(
torch.LongTensor(chosen), (0, max_curr_seq_len - chosen_len), mode="constant", value=self.eos_id
Expand All @@ -386,6 +383,10 @@ def __getitem__(self, idx):

# ignore the example whose tokenized text exceeds max seq length.
if max_curr_seq_len > self.seq_length:
logging.warning(
f"WARNING: Tokenized text exceeds max seq length ({max_curr_seq_len} vs {self.seq_length})."
+ f"The example will be ignored."
)
chosen_tokens = chosen_tokens[: self.nograd_length]
rejected_tokens = rejected_tokens[: self.nograd_length]
labels_chosen_tokens = torch.ones_like(chosen_tokens) * (-100)
Expand All @@ -408,16 +409,16 @@ def __getitem__(self, idx):

class KTOModelDataset(Dataset):
"""This class works only with jsonl files. It assumes each line of the json file is a dictionary
with the prompt, along with the response (response only, no prompt), and the status denoting whether the response is
chosen or rejected. This Dataset will combine the prompt with the corresponding response, and then tokenize it. It
will also create a score field that has 1 if the sample is chosen and 0 if rejected. It also returns the labels for
each, which is the response tokens with -100 for the prompt part.
WARNING: This class will tokenize the text, but it will raise an exception on model max seq len violations!
Meaning it will not truncate tokens to fit to model max seq len, because of special prefix/suffix
strings such as <extra_id_1>, it would not know where it is safe to truncate for each model. Therefore,
the user must do all truncation logic in their preprocessing step when generating the jsonl
used by this class. Put all special truncation logic there specific to your model.
with the prompt, along with the response (response only, no prompt), and the status denoting whether the response is
chosen or rejected. This Dataset will combine the prompt with the corresponding response, and then tokenize it. It
will also create a score field that has 1 if the sample is chosen and 0 if rejected. It also returns the labels for
each, which is the response tokens with -100 for the prompt part.
WARNING: This class will tokenize the text, but it will raise an exception on model max seq len violations!
Meaning it will not truncate tokens to fit to model max seq len, because of special prefix/suffix
strings such as <extra_id_1>, it would not know where it is safe to truncate for each model. Therefore,
the user must do all truncation logic in their preprocessing step when generating the jsonl
used by this class. Put all special truncation logic there specific to your model.
"""

def __init__(
Expand Down Expand Up @@ -501,14 +502,14 @@ def __getitem__(self, idx):


class RegressionRewardModelDataset(RewardModelDataset):
"""This class assumes each line of the dataset file is a dictionary with "text" and "label" field,
where "text" is a string representing the input prompt, and "label" is a list of float or int values.
Note that when training the model with multiple datasets which contain different attributes,
we should set missing attributes to model.regression.loss_mask_val(according to training_rm.yaml)
in the dataset files so that their losses are masked. At least one attribute should be present for each sample.
WARNING: It's recommended to preprocess your data in advance to ensure all samples are within self.seq_length.
Otherwise if all samples in a batch are longer than self.seq_length, you may get NaN loss.
"""This class assumes each line of the dataset file is a dictionary with "text" and "label" field,
where "text" is a string representing the input prompt, and "label" is a list of float or int values.
Note that when training the model with multiple datasets which contain different attributes,
we should set missing attributes to model.regression.loss_mask_val(according to training_rm.yaml)
in the dataset files so that their losses are masked. At least one attribute should be present for each sample.
WARNING: It's recommended to preprocess your data in advance to ensure all samples are within self.seq_length.
Otherwise if all samples in a batch are longer than self.seq_length, you may get NaN loss.
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def pytest_collection_modifyitems(config, items):

def pytest_sessionstart(session):
# Remove the file at the start of the session, if it exists
if os.path.exists(SUCCESS_FILE):
if os.path.exists(SUCCESS_FILE) and os.environ["LOCAL_RANK"] == "0":
os.remove(SUCCESS_FILE)


Expand Down
Loading

0 comments on commit 06a3ca4

Please sign in to comment.