diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index b8c4457ead65..c71f55faf9fe 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -5216,6 +5216,21 @@ jobs: rm -rf tests/collections/llm/t5_pretrain_results/${{ github.run_id }} rm -rf tests/collections/llm/t5_index_mappings/${{ github.run_id }} + L2_NeMo_2_T5_Finetuning: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_T5_Finetuning') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python tests/collections/llm/megatron_t5_finetuning.py \ + --devices=2 \ + --max-steps=250 \ + --experiment-dir=tests/collections/llm/t5_finetune_results/${{ github.run_id }} \ + --checkpoint-path=/home/TestData/nlp/megatron_t5/220m/nemo2.0_t5_220m_150steps + AFTER_SCRIPT: | + rm -rf tests/collections/llm/t5_finetune_results/${{ github.run_id }} + L2_NeMo_2_Mixtral_Pretraining: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -5365,6 +5380,7 @@ jobs: - L2_NeMo_2_SSM_Pretraining - L2_NeMo_2_SSM_Finetuning - L2_NeMo_2_T5_Pretraining + - L2_NeMo_2_T5_Finetuning - L2_NeMo_2_Mixtral_Pretraining - L2_PTQ_Llama2_INT8_SQ - L2_PTQ_Llama2_FP8 @@ -5511,4 +5527,4 @@ jobs: - name: "Pipeline not successful, set exit code to 1" if: ${{ always() && steps.pipeline-conclusion.outputs.SUCCESS == 'false' }} - run: exit 1 + run: exit 1 \ No newline at end of file diff --git a/nemo/collections/llm/t5/data/__init__.py b/nemo/collections/llm/t5/data/__init__.py index 537c12fd9115..d65f6923033f 100644 --- a/nemo/collections/llm/t5/data/__init__.py +++ b/nemo/collections/llm/t5/data/__init__.py @@ -1,3 +1,5 @@ +from nemo.collections.llm.t5.data.fine_tuning import FineTuningDataModule from nemo.collections.llm.t5.data.pre_training import PreTrainingDataModule +from nemo.collections.llm.t5.data.squad import SquadDataModule -__all__ = ["PreTrainingDataModule"] +__all__ = ["FineTuningDataModule", "PreTrainingDataModule", "SquadDataModule"] diff --git a/nemo/collections/llm/t5/data/core.py b/nemo/collections/llm/t5/data/core.py new file mode 100644 index 000000000000..11543274c3b9 --- /dev/null +++ b/nemo/collections/llm/t5/data/core.py @@ -0,0 +1,46 @@ +from pathlib import Path +from typing import TYPE_CHECKING, Optional + +from nemo.lightning.base import NEMO_DATASETS_CACHE + +if TYPE_CHECKING: + from nemo.collections.common.tokenizers import TokenizerSpec + from nemo.collections.nlp.data.language_modeling.megatron.t5_sft_dataset import T5SFTDataset + + +def get_dataset_root(name: str) -> Path: + output = Path(NEMO_DATASETS_CACHE) / name + output.mkdir(parents=True, exist_ok=True) + + return output + + +def create_sft_dataset( + path: Path, + tokenizer: "TokenizerSpec", + seq_length: int = 512, + seq_length_dec: int = 128, + add_bos: bool = True, + add_eos: bool = True, + replace_bos_with_pad: bool = False, + seed: int = 1234, + index_mapping_dir: Optional[str] = None, + memmap_workers: int = 2, + hf_dataset: bool = False, + **kwargs, +) -> "T5SFTDataset": + from nemo.collections.nlp.data.language_modeling.megatron.t5_sft_dataset import T5SFTDataset + + return T5SFTDataset( + file_path=str(path), + src_tokenizer=tokenizer, + tgt_tokenizer=tokenizer, + max_src_seq_length=seq_length, + max_tgt_seq_length=seq_length_dec, + memmap_workers=memmap_workers, + hf_dataset=hf_dataset, + add_bos_to_input=add_bos, + add_eos_to_input=add_eos, + replace_bos_with_pad=replace_bos_with_pad, + index_mapping_dir=index_mapping_dir, + ) diff --git a/nemo/collections/llm/t5/data/fine_tuning.py b/nemo/collections/llm/t5/data/fine_tuning.py new file mode 100644 index 000000000000..b1315f7a708a --- /dev/null +++ b/nemo/collections/llm/t5/data/fine_tuning.py @@ -0,0 +1,147 @@ +import math +from functools import lru_cache +from pathlib import Path +from typing import TYPE_CHECKING, List, Optional, Union + +import pytorch_lightning as pl +from torch.utils.data import DataLoader + +from nemo.collections.llm.t5.data.core import create_sft_dataset +from nemo.lightning.pytorch.plugins import MegatronDataSampler + +if TYPE_CHECKING: + from nemo.collections.common.tokenizers import TokenizerSpec + + +class FineTuningDataModule(pl.LightningDataModule): + """Base class for fine-tuning an LLM. + + This class provides a foundation for building custom data modules for fine-tuning Nemo NLP models. It inherits from + `pl.LightningDataModule` from the PyTorch Lightning library and handles data loading, preprocessing, and batch creation + for training, validation, and testing. + + Args: + dataset_root (Union[str, Path]): The root directory containing the training, validation, and test data. + seq_length (int, optional): The maximum sequence length for the input and output text. Defaults to 2048. + tokenizer (Optional[TokenizerSpec], optional): The tokenizer to use for preprocessing the text. Defaults to None. + If not provided, a BertWordPieceCase tokenizer will be used. + micro_batch_size (int, optional): The micro batch size for training. Defaults to 4. + global_batch_size (int, optional): The global batch size for training. Defaults to 8. + rampup_batch_size (Optional[List[int]], optional): A list of batch sizes for ramping up during training. Defaults to None. + seed (int, optional): The random seed for data shuffling. Defaults to 1234. + memmap_workers (int, optional): The number of worker processes for loading data using TextMemMapDataset. Defaults to 1. + num_workers (int, optional): The number of worker processes for data loading. Defaults to 8. + pin_memory (bool, optional): Whether to pin memory during data loading for faster GPU training. Defaults to True. + persistent_workers (bool, optional): Whether to keep data loading workers persistent across epochs. Defaults to False. + max_train_steps (int, optional): Maximum number of steps to train. Used to calculate samples mapping for the mmap dataset + """ + + def __init__( + self, + dataset_root: Union[str, Path], + seq_length: int = 512, + seq_length_dec: int = 128, + tokenizer: Optional["TokenizerSpec"] = None, + micro_batch_size: int = 4, + global_batch_size: int = 8, + rampup_batch_size: Optional[List[int]] = None, + seed: int = 1234, + memmap_workers: int = 1, + num_workers: int = 8, + pin_memory: bool = True, + persistent_workers: bool = False, + ): + super().__init__() + self.seq_length = seq_length + self.seq_length_dec = seq_length_dec + self.seed = seed + self.dataset_root = Path(dataset_root) + + # add additional tokens for T5 tokenizer + from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer + + self.tokenizer = tokenizer or get_nmt_tokenizer("megatron", "BertWordPieceCase") + additional_tokens = {'additional_special_tokens': [f'' for i in range(100)]} + self.tokenizer.add_special_tokens(additional_tokens) + + self.memmap_workers = memmap_workers + self.num_workers = num_workers + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers + self.micro_batch_size = micro_batch_size + self.global_batch_size = global_batch_size + self.rampup_batch_size = rampup_batch_size + self.data_sampler = None + self.max_train_samples = None + + def setup(self, stage: str): + self.data_sampler = MegatronDataSampler( + seq_len=self.seq_length, + micro_batch_size=self.micro_batch_size, + global_batch_size=self.global_batch_size, + rampup_batch_size=self.rampup_batch_size, + dataloader_type="batch", + ) + + # Follows the calculation in nemo.collections.nlp.data.language_modeling.megatron. + # base_dataset_utils.get_datasets_weights_and_num_samples + self.max_train_samples = int(math.ceil(self.global_batch_size * self.trainer.max_steps * 1.005)) + + def train_dataloader(self) -> DataLoader: + return self._create_dataloader( + self._create_dataset( + str(self.train_path), + max_num_samples=self.max_train_samples, + ) + ) + + def val_dataloader(self) -> DataLoader: + return self._create_dataloader( + self._create_dataset( + str(self.validation_path), + is_test=True, + ), + ) + + def test_dataloader(self) -> DataLoader: + return self._create_dataloader( + self._create_dataset( + str(self.test_path), + tokens_to_generate=32, + is_test=True, + ) + ) + + @lru_cache + def _create_dataset(self, path, **kwargs): + return create_sft_dataset( + path, + tokenizer=self.tokenizer, + seq_length=self.seq_length, + seq_length_dec=self.seq_length_dec, + memmap_workers=self.memmap_workers, + seed=self.seed, + **kwargs, + ) + + def _create_dataloader(self, dataset, **kwargs) -> DataLoader: + return DataLoader( + dataset, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers, + collate_fn=dataset.collate_fn, + **kwargs, + ) + + @property + def train_path(self) -> Path: + return self.dataset_root / "training.jsonl" + + @property + def validation_path(self) -> Path: + return self.dataset_root / "validation.jsonl" + + @property + def test_path(self) -> Path: + return self.dataset_root / "test.jsonl" diff --git a/nemo/collections/llm/t5/data/squad.py b/nemo/collections/llm/t5/data/squad.py new file mode 100644 index 000000000000..cee0549c80be --- /dev/null +++ b/nemo/collections/llm/t5/data/squad.py @@ -0,0 +1,140 @@ +import json +import shutil +from typing import TYPE_CHECKING, List, Optional + +from datasets import DatasetDict, load_dataset + +from nemo.collections.llm.t5.data.core import get_dataset_root +from nemo.collections.llm.t5.data.fine_tuning import FineTuningDataModule +from nemo.lightning.io.mixin import IOMixin +from nemo.utils import logging + +if TYPE_CHECKING: + from nemo.collections.common.tokenizers import TokenizerSpec + + +class SquadDataModule(FineTuningDataModule, IOMixin): + """A data module for fine-tuning on the Squad dataset. + + This class inherits from the `FineTuningDataModule` class and is specifically designed for fine-tuning models on the + Stanford Question Answering Dataset (SQuAD). It handles data download, preprocessing, splitting, and preparing the data + in a format suitable for training, validation, and testing. + + Args: + force_redownload (bool, optional): Whether to force re-download the dataset even if it exists locally. Defaults to False. + delete_raw (bool, optional): Whether to delete the raw downloaded dataset after preprocessing. Defaults to True. + See FineTuningDataModule for the other args + """ + + def __init__( + self, + seq_length: int = 512, + seq_length_dec: int = 128, + tokenizer: Optional["TokenizerSpec"] = None, + micro_batch_size: int = 4, + global_batch_size: int = 8, + rampup_batch_size: Optional[List[int]] = None, + force_redownload: bool = False, + delete_raw: bool = True, + seed: int = 1234, + memmap_workers: int = 1, + num_workers: int = 8, + pin_memory: bool = True, + persistent_workers: bool = False, + ): + self.force_redownload = force_redownload + self.delete_raw = delete_raw + + super().__init__( + dataset_root=get_dataset_root("squad"), + seq_length=seq_length, + seq_length_dec=seq_length_dec, + tokenizer=tokenizer, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + rampup_batch_size=rampup_batch_size, + seed=seed, + memmap_workers=memmap_workers, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + + def prepare_data(self) -> None: + # if train file is specified, no need to do anything + if self.train_path.exists() and not self.force_redownload: + return + + dset = self._download_data() + self._preprocess_and_split_data(dset, split_val_from_train=False) + + def _download_data(self): + logging.info(f"Downloading {self.__class__.__name__}...") + return load_dataset( + "squad", + cache_dir=str(self.dataset_root), + download_mode="force_redownload" if self.force_redownload else None, + ) + + def _preprocess_and_split_data( + self, dset: DatasetDict, split_val_from_train: bool = True, val_proportion: float = 0.05 + ): + """Preprocesses and splits the downloaded dataset into training, validation, and test sets. + + Args: + dset (DatasetDict): The downloaded dataset object. + split_val_from_train (bool, optional): Whether to split the validation set from the training set. + If False, the validation set is split from the test set. Defaults to True. + val_proportion (float, optional): The proportion of the training or test set to be used for the validation split. + Defaults to 0.05. + """ + logging.info(f"Preprocessing {self.__class__.__name__} to jsonl format and splitting...") + save_splits = {} + train_set = dset.get('train') + val_set = dset.get('validation') + + if split_val_from_train: + split_dataset = train_set.train_test_split(test_size=val_proportion, seed=self.seed) + save_splits['training'] = split_dataset['train'] + save_splits['validation'] = split_dataset['test'] + save_splits['test'] = val_set + else: + split_dataset = val_set.train_test_split(test_size=val_proportion, seed=self.seed) + save_splits['training'] = train_set + save_splits['validation'] = split_dataset['test'] + save_splits['test'] = split_dataset['train'] + + for split_name, dataset in save_splits.items(): + output_file = self.dataset_root / f"{split_name}.jsonl" + + with output_file.open("w", encoding="utf-8") as f: + for example in dataset: + json_line = {} + # Write each example as a JSON line in the output file + # Using similar template as for NeMo 1.0 T5 + json_line["input"] = ( + "Title: " + + example["title"] + + " " + + "Paragraph: " + + example["context"] + + " " + + " Question: " + + example['question'] + ) + json_line["output"] = example["answers"]["text"][0] + if split_name == "test": + json_line["original_answers"] = example["answers"]["text"] + f.write(json.dumps(json_line) + "\n") + + logging.info(f"{split_name} split saved to {output_file}") + + if self.delete_raw: + for p in self.dataset_root.iterdir(): + if p.is_dir(): + shutil.rmtree(p) + elif '.jsonl' not in str(p.name): + p.unlink() + + def reconfigure_limit_batches(self): + return diff --git a/nemo/collections/llm/t5/model/t5.py b/nemo/collections/llm/t5/model/t5.py index 2df5d633e200..dcba70bc8986 100644 --- a/nemo/collections/llm/t5/model/t5.py +++ b/nemo/collections/llm/t5/model/t5.py @@ -11,6 +11,8 @@ from torch import nn from nemo.collections.llm import fn +from nemo.collections.nlp.modules.common.megatron.token_level_encoder_decoder import AttnMaskType +from nemo.collections.nlp.modules.common.megatron.utils import build_attention_mask_3d from nemo.lightning import get_vocab_size, io from nemo.lightning.megatron_parallel import MaskedTokenLossReduction from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule @@ -39,10 +41,21 @@ def t5_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: else: _batch = batch - # convert attention mask values from int to True/False - _batch['enc_mask'] = _batch['enc_mask'] < 0.5 - _batch['dec_mask'] = _batch['dec_mask'] < 0.5 - _batch['enc_dec_mask'] = _batch['enc_dec_mask'] < 0.5 + # if Dataset object is NeMo 1.0's T5SFTDataset (e.g. when finetuning with SQUAD) + if 'enc_dec_mask' not in _batch: + encoder_attn_mask_3d = build_attention_mask_3d(_batch['enc_mask'], _batch['enc_mask'], AttnMaskType.padding) + decoder_attn_mask_3d = build_attention_mask_3d(_batch['dec_mask'], _batch['dec_mask'], AttnMaskType.causal) + enc_dec_attn_mask_3d = build_attention_mask_3d(_batch['dec_mask'], _batch['enc_mask'], AttnMaskType.padding) + _batch['enc_mask'] = encoder_attn_mask_3d + _batch['dec_mask'] = decoder_attn_mask_3d + _batch['enc_dec_mask'] = enc_dec_attn_mask_3d + + # if Dataset object is Mcore T5 dataset (e.g. pretraining) + else: + # convert attention mask values from int to True/False + _batch['enc_mask'] = _batch['enc_mask'] < 0.5 + _batch['dec_mask'] = _batch['dec_mask'] < 0.5 + _batch['enc_dec_mask'] = _batch['enc_dec_mask'] < 0.5 required_keys = set() required_keys.update(["enc_mask", "dec_mask", "enc_dec_mask"]) diff --git a/tests/collections/llm/megatron_t5_finetuning.py b/tests/collections/llm/megatron_t5_finetuning.py new file mode 100644 index 000000000000..76a23d36975b --- /dev/null +++ b/tests/collections/llm/megatron_t5_finetuning.py @@ -0,0 +1,130 @@ +## NOTE: This script is present for github-actions testing only. +## There are no guarantees that this script is up-to-date with latest NeMo. + +import argparse + +import torch +from megatron.core.optimizer import OptimizerConfig +from pytorch_lightning.loggers import WandbLogger + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.collections.llm.api import finetune +from nemo.collections.llm.t5.data import SquadDataModule +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.lightning import NeMoLogger +from nemo.lightning.pytorch.callbacks import ModelCheckpoint +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule + + +def get_args(): + parser = argparse.ArgumentParser(description='Train a small T5 model using NeMo 2.0') + parser.add_argument('--devices', type=int, help="Number of devices to use for training") + parser.add_argument('--max-steps', type=int, help="Number of steps to train for") + parser.add_argument('--experiment-dir', type=str, help="directory to write results and checkpoints to") + parser.add_argument('--experiment-name', type=str, help="name of experiment") + parser.add_argument('--wandb-project', type=str, default=None, help="wandb project name") + parser.add_argument('--checkpoint-path', type=str, help="Path to checkpoint dir") + parser.add_argument('--index-mapping-dir', type=str, help="directory to write index mappings to") + + return parser.parse_args() + + +if __name__ == '__main__': + + args = get_args() + + tokenizer = get_nmt_tokenizer( + "megatron", + "BertWordPieceCase", + ) + + data = SquadDataModule( + seq_length=512, + seq_length_dec=128, + micro_batch_size=16, + global_batch_size=128, + tokenizer=tokenizer, + num_workers=4, + ) + + t5_config = llm.t5.model.t5.T5Config( + num_layers=12, + encoder_num_layers=12, + hidden_size=768, + ffn_hidden_size=3072, + num_attention_heads=12, + kv_channels=64, + init_method_std=0.015, + hidden_dropout=0.1, + attention_dropout=0.1, + layernorm_epsilon=1e-5, + make_vocab_size_divisible_by=128, + max_position_embeddings=512, + ) + model = llm.t5.model.t5.T5Model(t5_config, tokenizer=data.tokenizer) + + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + pipeline_dtype=torch.float32, + ckpt_load_optimizer=False, + # ckpt_load_optimizer=True, + ) + checkpoint_callback = ModelCheckpoint( + every_n_train_steps=5000, + ) + callbacks = [checkpoint_callback] + + resume = nl.AutoResume( + resume_if_exists=True, + resume_ignore_no_checkpoint=True, + resume_from_path=args.checkpoint_path, + ) + + opt_config = OptimizerConfig( + optimizer='adam', + lr=2.0e-5, + use_distributed_optimizer=False, + bf16=False, + weight_decay=0.1, + ) + opt = MegatronOptimizerModule( + config=opt_config, + ) + + trainer = nl.Trainer( + devices=args.devices, + max_steps=args.max_steps, + accelerator="gpu", + strategy=strategy, + callbacks=callbacks, + log_every_n_steps=1, + limit_val_batches=2, + val_check_interval=50, + plugins=nl.MegatronMixedPrecision(precision="32"), + ) + + if args.wandb_project is not None: + wandb_logger = WandbLogger( + name=args.experiment_name, + project=args.wandb_project, + log_model="all", + ) + else: + wandb_logger = None + nemo_logger = NeMoLogger( + name=args.experiment_name, + use_datetime_version=False, + log_dir=args.experiment_dir, + wandb=wandb_logger, + ) + + finetune( + model=model, + resume=resume, + data=data, + trainer=trainer, + log=nemo_logger, + optim=opt, + )