Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding NeMo 2.0 T5 finetuning (on Squad dataset) #10716

Merged
merged 39 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7ff1737
huvu/t5_nemo2.0 first commit from local
Aug 16, 2024
93a6661
runable training
Aug 26, 2024
ce721ea
commit to save
Aug 28, 2024
62e37c8
Merge remote-tracking branch 'origin/main' into huvu/t5_nemo2.0
Aug 28, 2024
769e8c6
update nemo/collections/llm/t5/data/pre_training.py, adding cicd test
Sep 4, 2024
8c086e8
updating codes
Sep 4, 2024
e6b01e8
reset nemo/collections/nlp/parts/megatron_trainer_builder.py
Sep 4, 2024
3bf5140
reset megatron_lm_encoder_decoder_model.py, remove t5_release_test_co…
Sep 4, 2024
47b1174
update init files
Sep 4, 2024
1f12e6c
keep changes of cicd
Sep 4, 2024
fcaf5bf
update Dockerfile.ci
Sep 4, 2024
575ab86
fix wandb for cicd test
Sep 4, 2024
4deacf4
update training data path
Sep 4, 2024
bfad4b1
remove uninstall TE
Sep 4, 2024
726ca7c
update .github/workflows/cicd-main.yml, disable fused/flashAttn
Sep 5, 2024
f151272
adjusting val_check_interval for action ci-cd tests
Sep 5, 2024
62db039
restore .github/workflows/cicd-main.yml
Sep 5, 2024
e880687
update
Sep 17, 2024
4b6c431
resolve conflict
Sep 17, 2024
910db2f
update nemologger args
Sep 18, 2024
d49e0fe
Apply isort and black reformatting
huvunvidia Sep 18, 2024
e1620b7
Merge branch 'main' into huvu/t5_nemo2.0
huvunvidia Sep 18, 2024
7d0800f
a checkpoint for NeMo 2.0 T5 finetune
Oct 2, 2024
011bcea
solving conflicts
Oct 2, 2024
ff11a07
moving NeMo 2.0 examples to tests
Oct 2, 2024
b9de083
reorganizing files, putting examples files into tests
Oct 2, 2024
a92635f
runs with matched NeMo 1.0 and NeMo 2.0 curve
Oct 4, 2024
1286b4a
resolve merge conflict
Oct 4, 2024
7573fda
fix tiny/irrelevant changes
Oct 4, 2024
4e7a838
update AutoResume resume_from_path
Oct 8, 2024
fa9a2db
Merge remote-tracking branch 'origin/main' into huvu/t5_nemo2.0_finetune
Oct 8, 2024
fb79ee8
add pre-trained checkpoint to Azure's TestData, add CI-CD test
Oct 8, 2024
db2c36b
only run t5 finetune test
Oct 8, 2024
4ab8a82
fix cicd test
Oct 9, 2024
29b9375
resolve conflict
Oct 9, 2024
eeac5aa
restore cicd file
Oct 9, 2024
cf32841
restore cicd file
Oct 9, 2024
87fe4b5
restore cicd file
Oct 9, 2024
c158f7c
Apply isort and black reformatting
artbataev Oct 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5216,6 +5216,21 @@ jobs:
rm -rf tests/collections/llm/t5_pretrain_results/${{ github.run_id }}
rm -rf tests/collections/llm/t5_index_mappings/${{ github.run_id }}

L2_NeMo_2_T5_Finetuning:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_T5_Finetuning') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python tests/collections/llm/megatron_t5_finetuning.py \
--devices=2 \
--max-steps=250 \
--experiment-dir=tests/collections/llm/t5_finetune_results/${{ github.run_id }} \
--checkpoint-path=/home/TestData/nlp/megatron_t5/220m/nemo2.0_t5_220m_150steps
ko3n1g marked this conversation as resolved.
Show resolved Hide resolved
AFTER_SCRIPT: |
rm -rf tests/collections/llm/t5_finetune_results/${{ github.run_id }}

L2_NeMo_2_Mixtral_Pretraining:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -5365,6 +5380,7 @@ jobs:
- L2_NeMo_2_SSM_Pretraining
- L2_NeMo_2_SSM_Finetuning
- L2_NeMo_2_T5_Pretraining
- L2_NeMo_2_T5_Finetuning
- L2_NeMo_2_Mixtral_Pretraining
- L2_PTQ_Llama2_INT8_SQ
- L2_PTQ_Llama2_FP8
Expand Down Expand Up @@ -5511,4 +5527,4 @@ jobs:

- name: "Pipeline not successful, set exit code to 1"
if: ${{ always() && steps.pipeline-conclusion.outputs.SUCCESS == 'false' }}
run: exit 1
run: exit 1
4 changes: 3 additions & 1 deletion nemo/collections/llm/t5/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from nemo.collections.llm.t5.data.fine_tuning import FineTuningDataModule
from nemo.collections.llm.t5.data.pre_training import PreTrainingDataModule
from nemo.collections.llm.t5.data.squad import SquadDataModule

__all__ = ["PreTrainingDataModule"]
__all__ = ["FineTuningDataModule", "PreTrainingDataModule", "SquadDataModule"]
46 changes: 46 additions & 0 deletions nemo/collections/llm/t5/data/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from pathlib import Path
from typing import TYPE_CHECKING, Optional

from nemo.lightning.base import NEMO_DATASETS_CACHE

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec
from nemo.collections.nlp.data.language_modeling.megatron.t5_sft_dataset import T5SFTDataset


def get_dataset_root(name: str) -> Path:
output = Path(NEMO_DATASETS_CACHE) / name
output.mkdir(parents=True, exist_ok=True)

return output


def create_sft_dataset(
path: Path,
tokenizer: "TokenizerSpec",
seq_length: int = 512,
seq_length_dec: int = 128,
add_bos: bool = True,
add_eos: bool = True,
replace_bos_with_pad: bool = False,
seed: int = 1234,
index_mapping_dir: Optional[str] = None,
memmap_workers: int = 2,
hf_dataset: bool = False,
**kwargs,
) -> "T5SFTDataset":
from nemo.collections.nlp.data.language_modeling.megatron.t5_sft_dataset import T5SFTDataset

return T5SFTDataset(
file_path=str(path),
src_tokenizer=tokenizer,
tgt_tokenizer=tokenizer,
max_src_seq_length=seq_length,
max_tgt_seq_length=seq_length_dec,
memmap_workers=memmap_workers,
hf_dataset=hf_dataset,
add_bos_to_input=add_bos,
add_eos_to_input=add_eos,
replace_bos_with_pad=replace_bos_with_pad,
index_mapping_dir=index_mapping_dir,
)
147 changes: 147 additions & 0 deletions nemo/collections/llm/t5/data/fine_tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import math
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union

import pytorch_lightning as pl
from torch.utils.data import DataLoader

from nemo.collections.llm.t5.data.core import create_sft_dataset
from nemo.lightning.pytorch.plugins import MegatronDataSampler

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec


class FineTuningDataModule(pl.LightningDataModule):
"""Base class for fine-tuning an LLM.

This class provides a foundation for building custom data modules for fine-tuning Nemo NLP models. It inherits from
`pl.LightningDataModule` from the PyTorch Lightning library and handles data loading, preprocessing, and batch creation
for training, validation, and testing.

Args:
dataset_root (Union[str, Path]): The root directory containing the training, validation, and test data.
seq_length (int, optional): The maximum sequence length for the input and output text. Defaults to 2048.
tokenizer (Optional[TokenizerSpec], optional): The tokenizer to use for preprocessing the text. Defaults to None.
If not provided, a BertWordPieceCase tokenizer will be used.
micro_batch_size (int, optional): The micro batch size for training. Defaults to 4.
global_batch_size (int, optional): The global batch size for training. Defaults to 8.
rampup_batch_size (Optional[List[int]], optional): A list of batch sizes for ramping up during training. Defaults to None.
seed (int, optional): The random seed for data shuffling. Defaults to 1234.
memmap_workers (int, optional): The number of worker processes for loading data using TextMemMapDataset. Defaults to 1.
num_workers (int, optional): The number of worker processes for data loading. Defaults to 8.
pin_memory (bool, optional): Whether to pin memory during data loading for faster GPU training. Defaults to True.
persistent_workers (bool, optional): Whether to keep data loading workers persistent across epochs. Defaults to False.
max_train_steps (int, optional): Maximum number of steps to train. Used to calculate samples mapping for the mmap dataset
"""

def __init__(
self,
dataset_root: Union[str, Path],
seq_length: int = 512,
seq_length_dec: int = 128,
tokenizer: Optional["TokenizerSpec"] = None,
micro_batch_size: int = 4,
global_batch_size: int = 8,
rampup_batch_size: Optional[List[int]] = None,
seed: int = 1234,
memmap_workers: int = 1,
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = False,
):
super().__init__()
self.seq_length = seq_length
self.seq_length_dec = seq_length_dec
self.seed = seed
self.dataset_root = Path(dataset_root)

# add additional tokens for T5 tokenizer
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

self.tokenizer = tokenizer or get_nmt_tokenizer("megatron", "BertWordPieceCase")
additional_tokens = {'additional_special_tokens': [f'<extra_id_{i}>' for i in range(100)]}
self.tokenizer.add_special_tokens(additional_tokens)

self.memmap_workers = memmap_workers
self.num_workers = num_workers
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.rampup_batch_size = rampup_batch_size
self.data_sampler = None
self.max_train_samples = None

def setup(self, stage: str):
self.data_sampler = MegatronDataSampler(
seq_len=self.seq_length,
micro_batch_size=self.micro_batch_size,
global_batch_size=self.global_batch_size,
rampup_batch_size=self.rampup_batch_size,
dataloader_type="batch",
)

# Follows the calculation in nemo.collections.nlp.data.language_modeling.megatron.
# base_dataset_utils.get_datasets_weights_and_num_samples
self.max_train_samples = int(math.ceil(self.global_batch_size * self.trainer.max_steps * 1.005))

def train_dataloader(self) -> DataLoader:
return self._create_dataloader(
self._create_dataset(
str(self.train_path),
max_num_samples=self.max_train_samples,
)
)

def val_dataloader(self) -> DataLoader:
return self._create_dataloader(
self._create_dataset(
str(self.validation_path),
is_test=True,
),
)

def test_dataloader(self) -> DataLoader:
return self._create_dataloader(
self._create_dataset(
str(self.test_path),
tokens_to_generate=32,
is_test=True,
)
)

@lru_cache
def _create_dataset(self, path, **kwargs):
return create_sft_dataset(
path,
tokenizer=self.tokenizer,
seq_length=self.seq_length,
seq_length_dec=self.seq_length_dec,
memmap_workers=self.memmap_workers,
seed=self.seed,
**kwargs,
)

def _create_dataloader(self, dataset, **kwargs) -> DataLoader:
return DataLoader(
dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
collate_fn=dataset.collate_fn,
**kwargs,
)

@property
def train_path(self) -> Path:
return self.dataset_root / "training.jsonl"

@property
def validation_path(self) -> Path:
return self.dataset_root / "validation.jsonl"

@property
def test_path(self) -> Path:
return self.dataset_root / "test.jsonl"
140 changes: 140 additions & 0 deletions nemo/collections/llm/t5/data/squad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import json
import shutil
from typing import TYPE_CHECKING, List, Optional

from datasets import DatasetDict, load_dataset

from nemo.collections.llm.t5.data.core import get_dataset_root
from nemo.collections.llm.t5.data.fine_tuning import FineTuningDataModule
from nemo.lightning.io.mixin import IOMixin
from nemo.utils import logging

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec


class SquadDataModule(FineTuningDataModule, IOMixin):
"""A data module for fine-tuning on the Squad dataset.

This class inherits from the `FineTuningDataModule` class and is specifically designed for fine-tuning models on the
Stanford Question Answering Dataset (SQuAD). It handles data download, preprocessing, splitting, and preparing the data
in a format suitable for training, validation, and testing.

Args:
force_redownload (bool, optional): Whether to force re-download the dataset even if it exists locally. Defaults to False.
delete_raw (bool, optional): Whether to delete the raw downloaded dataset after preprocessing. Defaults to True.
See FineTuningDataModule for the other args
"""

def __init__(
self,
seq_length: int = 512,
seq_length_dec: int = 128,
tokenizer: Optional["TokenizerSpec"] = None,
micro_batch_size: int = 4,
global_batch_size: int = 8,
rampup_batch_size: Optional[List[int]] = None,
force_redownload: bool = False,
delete_raw: bool = True,
seed: int = 1234,
memmap_workers: int = 1,
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = False,
):
self.force_redownload = force_redownload
self.delete_raw = delete_raw

super().__init__(
dataset_root=get_dataset_root("squad"),
seq_length=seq_length,
seq_length_dec=seq_length_dec,
tokenizer=tokenizer,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
rampup_batch_size=rampup_batch_size,
seed=seed,
memmap_workers=memmap_workers,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
)

def prepare_data(self) -> None:
# if train file is specified, no need to do anything
if self.train_path.exists() and not self.force_redownload:
return

dset = self._download_data()
self._preprocess_and_split_data(dset, split_val_from_train=False)

def _download_data(self):
logging.info(f"Downloading {self.__class__.__name__}...")
return load_dataset(
"squad",
cache_dir=str(self.dataset_root),
download_mode="force_redownload" if self.force_redownload else None,
)

def _preprocess_and_split_data(
self, dset: DatasetDict, split_val_from_train: bool = True, val_proportion: float = 0.05
):
"""Preprocesses and splits the downloaded dataset into training, validation, and test sets.

Args:
dset (DatasetDict): The downloaded dataset object.
split_val_from_train (bool, optional): Whether to split the validation set from the training set.
If False, the validation set is split from the test set. Defaults to True.
val_proportion (float, optional): The proportion of the training or test set to be used for the validation split.
Defaults to 0.05.
"""
logging.info(f"Preprocessing {self.__class__.__name__} to jsonl format and splitting...")
save_splits = {}
train_set = dset.get('train')
val_set = dset.get('validation')

if split_val_from_train:
split_dataset = train_set.train_test_split(test_size=val_proportion, seed=self.seed)
save_splits['training'] = split_dataset['train']
save_splits['validation'] = split_dataset['test']
save_splits['test'] = val_set
else:
split_dataset = val_set.train_test_split(test_size=val_proportion, seed=self.seed)
save_splits['training'] = train_set
save_splits['validation'] = split_dataset['test']
save_splits['test'] = split_dataset['train']

for split_name, dataset in save_splits.items():
output_file = self.dataset_root / f"{split_name}.jsonl"

with output_file.open("w", encoding="utf-8") as f:
for example in dataset:
json_line = {}
# Write each example as a JSON line in the output file
# Using similar template as for NeMo 1.0 T5
json_line["input"] = (
"Title: "
+ example["title"]
+ " "
+ "Paragraph: "
+ example["context"]
+ " "
+ " Question: "
+ example['question']
)
json_line["output"] = example["answers"]["text"][0]
if split_name == "test":
json_line["original_answers"] = example["answers"]["text"]
f.write(json.dumps(json_line) + "\n")

logging.info(f"{split_name} split saved to {output_file}")

if self.delete_raw:
for p in self.dataset_root.iterdir():
if p.is_dir():
shutil.rmtree(p)
elif '.jsonl' not in str(p.name):
p.unlink()

def reconfigure_limit_batches(self):
return
Loading
Loading