Skip to content

Commit

Permalink
Adding NeMo 2.0 T5 finetuning (on Squad dataset) (#10716)
Browse files Browse the repository at this point in the history
* huvu/t5_nemo2.0 first commit from local

* runable training

* commit to save

* update nemo/collections/llm/t5/data/pre_training.py, adding cicd test

* updating codes

* reset nemo/collections/nlp/parts/megatron_trainer_builder.py

* reset megatron_lm_encoder_decoder_model.py, remove t5_release_test_config.sh

* update init files

* update Dockerfile.ci

* fix wandb for cicd test

* update training data path

* remove uninstall TE

* update .github/workflows/cicd-main.yml, disable fused/flashAttn

* adjusting val_check_interval for action ci-cd tests

* restore .github/workflows/cicd-main.yml

* update

* update nemologger args

* Apply isort and black reformatting

Signed-off-by: huvunvidia <huvunvidia@users.noreply.github.com>

* a checkpoint for NeMo 2.0 T5 finetune

* moving NeMo 2.0 examples to tests

* reorganizing files, putting examples files into tests

* runs with matched NeMo 1.0 and NeMo 2.0 curve

* fix tiny/irrelevant changes

* update AutoResume resume_from_path

* add pre-trained checkpoint to Azure's TestData, add CI-CD test

* only run t5 finetune test

* fix cicd test

* restore cicd file

* restore cicd file

* restore cicd file

* Apply isort and black reformatting

Signed-off-by: artbataev <artbataev@users.noreply.github.com>

---------

Signed-off-by: huvunvidia <huvunvidia@users.noreply.github.com>
Signed-off-by: Huy Vu <86480512+huvunvidia@users.noreply.github.com>
Signed-off-by: artbataev <artbataev@users.noreply.github.com>
Co-authored-by: Huy Vu2 <huvu@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: huvunvidia <huvunvidia@users.noreply.github.com>
Co-authored-by: Huy Vu2 <huvu@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: artbataev <artbataev@users.noreply.github.com>
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
  • Loading branch information
5 people authored and akoumpa committed Oct 24, 2024
1 parent e4afbe7 commit 8079a72
Show file tree
Hide file tree
Showing 7 changed files with 500 additions and 6 deletions.
18 changes: 17 additions & 1 deletion .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5167,6 +5167,21 @@ jobs:
rm -rf tests/collections/llm/t5_pretrain_results/${{ github.run_id }}
rm -rf tests/collections/llm/t5_index_mappings/${{ github.run_id }}
L2_NeMo_2_T5_Finetuning:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_T5_Finetuning') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python tests/collections/llm/megatron_t5_finetuning.py \
--devices=2 \
--max-steps=250 \
--experiment-dir=tests/collections/llm/t5_finetune_results/${{ github.run_id }} \
--checkpoint-path=/home/TestData/nlp/megatron_t5/220m/nemo2.0_t5_220m_150steps
AFTER_SCRIPT: |
rm -rf tests/collections/llm/t5_finetune_results/${{ github.run_id }}
L2_NeMo_2_Mixtral_Pretraining:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -5316,6 +5331,7 @@ jobs:
- L2_NeMo_2_SSM_Pretraining
- L2_NeMo_2_SSM_Finetuning
- L2_NeMo_2_T5_Pretraining
- L2_NeMo_2_T5_Finetuning
- L2_NeMo_2_Mixtral_Pretraining
- L2_PTQ_Llama2_INT8_SQ
- L2_PTQ_Llama2_FP8
Expand Down Expand Up @@ -5462,4 +5478,4 @@ jobs:
- name: "Pipeline not successful, set exit code to 1"
if: ${{ always() && steps.pipeline-conclusion.outputs.SUCCESS == 'false' }}
run: exit 1
run: exit 1
4 changes: 3 additions & 1 deletion nemo/collections/llm/t5/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from nemo.collections.llm.t5.data.fine_tuning import FineTuningDataModule
from nemo.collections.llm.t5.data.pre_training import PreTrainingDataModule
from nemo.collections.llm.t5.data.squad import SquadDataModule

__all__ = ["PreTrainingDataModule"]
__all__ = ["FineTuningDataModule", "PreTrainingDataModule", "SquadDataModule"]
46 changes: 46 additions & 0 deletions nemo/collections/llm/t5/data/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from pathlib import Path
from typing import TYPE_CHECKING, Optional

from nemo.lightning.base import NEMO_DATASETS_CACHE

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec
from nemo.collections.nlp.data.language_modeling.megatron.t5_sft_dataset import T5SFTDataset


def get_dataset_root(name: str) -> Path:
output = Path(NEMO_DATASETS_CACHE) / name
output.mkdir(parents=True, exist_ok=True)

return output


def create_sft_dataset(
path: Path,
tokenizer: "TokenizerSpec",
seq_length: int = 512,
seq_length_dec: int = 128,
add_bos: bool = True,
add_eos: bool = True,
replace_bos_with_pad: bool = False,
seed: int = 1234,
index_mapping_dir: Optional[str] = None,
memmap_workers: int = 2,
hf_dataset: bool = False,
**kwargs,
) -> "T5SFTDataset":
from nemo.collections.nlp.data.language_modeling.megatron.t5_sft_dataset import T5SFTDataset

return T5SFTDataset(
file_path=str(path),
src_tokenizer=tokenizer,
tgt_tokenizer=tokenizer,
max_src_seq_length=seq_length,
max_tgt_seq_length=seq_length_dec,
memmap_workers=memmap_workers,
hf_dataset=hf_dataset,
add_bos_to_input=add_bos,
add_eos_to_input=add_eos,
replace_bos_with_pad=replace_bos_with_pad,
index_mapping_dir=index_mapping_dir,
)
147 changes: 147 additions & 0 deletions nemo/collections/llm/t5/data/fine_tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import math
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union

import pytorch_lightning as pl
from torch.utils.data import DataLoader

from nemo.collections.llm.t5.data.core import create_sft_dataset
from nemo.lightning.pytorch.plugins import MegatronDataSampler

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec


class FineTuningDataModule(pl.LightningDataModule):
"""Base class for fine-tuning an LLM.
This class provides a foundation for building custom data modules for fine-tuning Nemo NLP models. It inherits from
`pl.LightningDataModule` from the PyTorch Lightning library and handles data loading, preprocessing, and batch creation
for training, validation, and testing.
Args:
dataset_root (Union[str, Path]): The root directory containing the training, validation, and test data.
seq_length (int, optional): The maximum sequence length for the input and output text. Defaults to 2048.
tokenizer (Optional[TokenizerSpec], optional): The tokenizer to use for preprocessing the text. Defaults to None.
If not provided, a BertWordPieceCase tokenizer will be used.
micro_batch_size (int, optional): The micro batch size for training. Defaults to 4.
global_batch_size (int, optional): The global batch size for training. Defaults to 8.
rampup_batch_size (Optional[List[int]], optional): A list of batch sizes for ramping up during training. Defaults to None.
seed (int, optional): The random seed for data shuffling. Defaults to 1234.
memmap_workers (int, optional): The number of worker processes for loading data using TextMemMapDataset. Defaults to 1.
num_workers (int, optional): The number of worker processes for data loading. Defaults to 8.
pin_memory (bool, optional): Whether to pin memory during data loading for faster GPU training. Defaults to True.
persistent_workers (bool, optional): Whether to keep data loading workers persistent across epochs. Defaults to False.
max_train_steps (int, optional): Maximum number of steps to train. Used to calculate samples mapping for the mmap dataset
"""

def __init__(
self,
dataset_root: Union[str, Path],
seq_length: int = 512,
seq_length_dec: int = 128,
tokenizer: Optional["TokenizerSpec"] = None,
micro_batch_size: int = 4,
global_batch_size: int = 8,
rampup_batch_size: Optional[List[int]] = None,
seed: int = 1234,
memmap_workers: int = 1,
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = False,
):
super().__init__()
self.seq_length = seq_length
self.seq_length_dec = seq_length_dec
self.seed = seed
self.dataset_root = Path(dataset_root)

# add additional tokens for T5 tokenizer
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

self.tokenizer = tokenizer or get_nmt_tokenizer("megatron", "BertWordPieceCase")
additional_tokens = {'additional_special_tokens': [f'<extra_id_{i}>' for i in range(100)]}
self.tokenizer.add_special_tokens(additional_tokens)

self.memmap_workers = memmap_workers
self.num_workers = num_workers
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.rampup_batch_size = rampup_batch_size
self.data_sampler = None
self.max_train_samples = None

def setup(self, stage: str):
self.data_sampler = MegatronDataSampler(
seq_len=self.seq_length,
micro_batch_size=self.micro_batch_size,
global_batch_size=self.global_batch_size,
rampup_batch_size=self.rampup_batch_size,
dataloader_type="batch",
)

# Follows the calculation in nemo.collections.nlp.data.language_modeling.megatron.
# base_dataset_utils.get_datasets_weights_and_num_samples
self.max_train_samples = int(math.ceil(self.global_batch_size * self.trainer.max_steps * 1.005))

def train_dataloader(self) -> DataLoader:
return self._create_dataloader(
self._create_dataset(
str(self.train_path),
max_num_samples=self.max_train_samples,
)
)

def val_dataloader(self) -> DataLoader:
return self._create_dataloader(
self._create_dataset(
str(self.validation_path),
is_test=True,
),
)

def test_dataloader(self) -> DataLoader:
return self._create_dataloader(
self._create_dataset(
str(self.test_path),
tokens_to_generate=32,
is_test=True,
)
)

@lru_cache
def _create_dataset(self, path, **kwargs):
return create_sft_dataset(
path,
tokenizer=self.tokenizer,
seq_length=self.seq_length,
seq_length_dec=self.seq_length_dec,
memmap_workers=self.memmap_workers,
seed=self.seed,
**kwargs,
)

def _create_dataloader(self, dataset, **kwargs) -> DataLoader:
return DataLoader(
dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
collate_fn=dataset.collate_fn,
**kwargs,
)

@property
def train_path(self) -> Path:
return self.dataset_root / "training.jsonl"

@property
def validation_path(self) -> Path:
return self.dataset_root / "validation.jsonl"

@property
def test_path(self) -> Path:
return self.dataset_root / "test.jsonl"
140 changes: 140 additions & 0 deletions nemo/collections/llm/t5/data/squad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import json
import shutil
from typing import TYPE_CHECKING, List, Optional

from datasets import DatasetDict, load_dataset

from nemo.collections.llm.t5.data.core import get_dataset_root
from nemo.collections.llm.t5.data.fine_tuning import FineTuningDataModule
from nemo.lightning.io.mixin import IOMixin
from nemo.utils import logging

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec


class SquadDataModule(FineTuningDataModule, IOMixin):
"""A data module for fine-tuning on the Squad dataset.
This class inherits from the `FineTuningDataModule` class and is specifically designed for fine-tuning models on the
Stanford Question Answering Dataset (SQuAD). It handles data download, preprocessing, splitting, and preparing the data
in a format suitable for training, validation, and testing.
Args:
force_redownload (bool, optional): Whether to force re-download the dataset even if it exists locally. Defaults to False.
delete_raw (bool, optional): Whether to delete the raw downloaded dataset after preprocessing. Defaults to True.
See FineTuningDataModule for the other args
"""

def __init__(
self,
seq_length: int = 512,
seq_length_dec: int = 128,
tokenizer: Optional["TokenizerSpec"] = None,
micro_batch_size: int = 4,
global_batch_size: int = 8,
rampup_batch_size: Optional[List[int]] = None,
force_redownload: bool = False,
delete_raw: bool = True,
seed: int = 1234,
memmap_workers: int = 1,
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = False,
):
self.force_redownload = force_redownload
self.delete_raw = delete_raw

super().__init__(
dataset_root=get_dataset_root("squad"),
seq_length=seq_length,
seq_length_dec=seq_length_dec,
tokenizer=tokenizer,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
rampup_batch_size=rampup_batch_size,
seed=seed,
memmap_workers=memmap_workers,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
)

def prepare_data(self) -> None:
# if train file is specified, no need to do anything
if self.train_path.exists() and not self.force_redownload:
return

dset = self._download_data()
self._preprocess_and_split_data(dset, split_val_from_train=False)

def _download_data(self):
logging.info(f"Downloading {self.__class__.__name__}...")
return load_dataset(
"squad",
cache_dir=str(self.dataset_root),
download_mode="force_redownload" if self.force_redownload else None,
)

def _preprocess_and_split_data(
self, dset: DatasetDict, split_val_from_train: bool = True, val_proportion: float = 0.05
):
"""Preprocesses and splits the downloaded dataset into training, validation, and test sets.
Args:
dset (DatasetDict): The downloaded dataset object.
split_val_from_train (bool, optional): Whether to split the validation set from the training set.
If False, the validation set is split from the test set. Defaults to True.
val_proportion (float, optional): The proportion of the training or test set to be used for the validation split.
Defaults to 0.05.
"""
logging.info(f"Preprocessing {self.__class__.__name__} to jsonl format and splitting...")
save_splits = {}
train_set = dset.get('train')
val_set = dset.get('validation')

if split_val_from_train:
split_dataset = train_set.train_test_split(test_size=val_proportion, seed=self.seed)
save_splits['training'] = split_dataset['train']
save_splits['validation'] = split_dataset['test']
save_splits['test'] = val_set
else:
split_dataset = val_set.train_test_split(test_size=val_proportion, seed=self.seed)
save_splits['training'] = train_set
save_splits['validation'] = split_dataset['test']
save_splits['test'] = split_dataset['train']

for split_name, dataset in save_splits.items():
output_file = self.dataset_root / f"{split_name}.jsonl"

with output_file.open("w", encoding="utf-8") as f:
for example in dataset:
json_line = {}
# Write each example as a JSON line in the output file
# Using similar template as for NeMo 1.0 T5
json_line["input"] = (
"Title: "
+ example["title"]
+ " "
+ "Paragraph: "
+ example["context"]
+ " "
+ " Question: "
+ example['question']
)
json_line["output"] = example["answers"]["text"][0]
if split_name == "test":
json_line["original_answers"] = example["answers"]["text"]
f.write(json.dumps(json_line) + "\n")

logging.info(f"{split_name} split saved to {output_file}")

if self.delete_raw:
for p in self.dataset_root.iterdir():
if p.is_dir():
shutil.rmtree(p)
elif '.jsonl' not in str(p.name):
p.unlink()

def reconfigure_limit_batches(self):
return
Loading

0 comments on commit 8079a72

Please sign in to comment.