Skip to content

Commit

Permalink
Packed sequence bug fixes (NVIDIA#10898)
Browse files Browse the repository at this point in the history
* save prepared dataset to different folders according to tokenizer name

Signed-off-by: Chen Cui <chcui@nvidia.com>

* fix hang

Signed-off-by: Chen Cui <chcui@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: cuichenx <cuichenx@users.noreply.github.com>

* Apply isort and black reformatting

Signed-off-by: artbataev <artbataev@users.noreply.github.com>

* fix hang

Signed-off-by: Chen Cui <chcui@nvidia.com>

* raise mbs>1 error and provide suggestion to user instead of automatically changing config

Signed-off-by: Chen Cui <chcui@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: cuichenx <cuichenx@users.noreply.github.com>

* add ci for packed seq

Signed-off-by: Chen Cui <chcui@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: cuichenx <cuichenx@users.noreply.github.com>

* fix bug

Signed-off-by: Chen Cui <chcui@nvidia.com>

---------

Signed-off-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: cuichenx <cuichenx@users.noreply.github.com>
Signed-off-by: artbataev <artbataev@users.noreply.github.com>
Co-authored-by: cuichenx <cuichenx@users.noreply.github.com>
Co-authored-by: artbataev <artbataev@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 18, 2024
1 parent a1fdf07 commit 76352fb
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 42 deletions.
74 changes: 58 additions & 16 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5225,8 +5225,6 @@ jobs:
--pp_size 1 \
--mbs 1
AFTER_SCRIPT: |
rm -rf /tmp/nemo2_gpt_finetune/${{ github.run_id }}
L2_NeMo_2_GPT_SFT_TP1PP1_MBS2:
needs: [cicd-test-container-setup]
Expand Down Expand Up @@ -5256,8 +5254,6 @@ jobs:
--pp_size 1 \
--mbs 2
AFTER_SCRIPT: |
rm -rf /tmp/nemo2_gpt_finetune/${{ github.run_id }}
L2_NeMo_2_GPT_SFT_TP1PP2_MBS2:
needs: [cicd-test-container-setup]
Expand Down Expand Up @@ -5287,8 +5283,6 @@ jobs:
--pp_size 2 \
--mbs 2
AFTER_SCRIPT: |
rm -rf /tmp/nemo2_gpt_finetune/${{ github.run_id }}
L2_NeMo_2_GPT_SFT_TP2PP1_MBS2:
needs: [cicd-test-container-setup]
Expand Down Expand Up @@ -5318,8 +5312,35 @@ jobs:
--pp_size 1 \
--mbs 2
AFTER_SCRIPT: |
rm -rf /tmp/nemo2_gpt_finetune/${{ github.run_id }}
L2_NeMo_2_GPT_SFT_TP1PP1_MBS1_PACKED:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_GPT_SFT_TP1PP1_MBS1_PACKED') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--devices 2 \
--max_steps 3 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
--peft none \
--tp_size 1 \
--pp_size 1 \
--mbs 1 --packed
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--devices 2 \
--max_steps 6 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
--peft none \
--tp_size 1 \
--pp_size 1 \
--mbs 1 --packed
L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1:
needs: [cicd-test-container-setup]
Expand Down Expand Up @@ -5349,8 +5370,6 @@ jobs:
--pp_size 1 \
--mbs 1
AFTER_SCRIPT: |
rm -rf /tmp/nemo2_gpt_finetune/${{ github.run_id }}
L2_NeMo_2_GPT_LoRA_TP1PP1_MBS2:
needs: [cicd-test-container-setup]
Expand Down Expand Up @@ -5380,8 +5399,6 @@ jobs:
--pp_size 1 \
--mbs 2
AFTER_SCRIPT: |
rm -rf /tmp/nemo2_gpt_finetune/${{ github.run_id }}
L2_NeMo_2_GPT_LoRA_TP1PP2_MBS2:
needs: [cicd-test-container-setup]
Expand Down Expand Up @@ -5411,8 +5428,6 @@ jobs:
--pp_size 2 \
--mbs 2
AFTER_SCRIPT: |
rm -rf /tmp/nemo2_gpt_finetune/${{ github.run_id }}
L2_NeMo_2_GPT_LoRA_TP2PP1_MBS2:
needs: [cicd-test-container-setup]
Expand Down Expand Up @@ -5442,8 +5457,33 @@ jobs:
--pp_size 1 \
--mbs 2
AFTER_SCRIPT: |
rm -rf /tmp/nemo2_gpt_finetune/${{ github.run_id }}
L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--devices 2 \
--max_steps 3 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
--peft lora \
--tp_size 1 \
--pp_size 1 \
--mbs 1 --packed
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--devices 2 \
--max_steps 6 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
--peft lora \
--tp_size 1 \
--pp_size 1 \
--mbs 1 --packed
L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact:
needs: [cicd-test-container-setup]
Expand Down Expand Up @@ -5597,10 +5637,12 @@ jobs:
- L2_NeMo_2_GPT_SFT_TP1PP1_MBS2
- L2_NeMo_2_GPT_SFT_TP1PP2_MBS2
- L2_NeMo_2_GPT_SFT_TP2PP1_MBS2
- L2_NeMo_2_GPT_SFT_TP1PP1_MBS1_PACKED
- L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1
- L2_NeMo_2_GPT_LoRA_TP1PP1_MBS2
- L2_NeMo_2_GPT_LoRA_TP1PP2_MBS2
- L2_NeMo_2_GPT_LoRA_TP2PP1_MBS2
- L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED
- L2_NeMo_2_Mixtral_Pretraining
- L2_PTQ_Llama2_INT8_SQ
- L2_PTQ_Llama2_FP8
Expand Down
5 changes: 3 additions & 2 deletions nemo/collections/llm/gpt/data/dolly.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec
from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs


class DollyDataModule(FineTuningDataModule, IOMixin):
Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(
pin_memory: bool = True,
persistent_workers: bool = False,
pad_to_max_length: bool = False,
packed_sequence_size: int = -1,
packed_sequence_specs: Optional["PackedSequenceSpecs"] = None,
):
self.force_redownload = force_redownload
self.delete_raw = delete_raw
Expand All @@ -74,7 +75,7 @@ def __init__(
pin_memory=pin_memory,
persistent_workers=persistent_workers,
pad_to_max_length=pad_to_max_length,
packed_sequence_size=packed_sequence_size,
packed_sequence_specs=packed_sequence_specs,
)

def prepare_data(self) -> None:
Expand Down
52 changes: 35 additions & 17 deletions nemo/collections/llm/gpt/data/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import pytorch_lightning as pl
from torch.utils.data import DataLoader

from nemo.collections.common.tokenizers import AutoTokenizer
from nemo.collections.llm.gpt.data.core import create_sft_dataset
from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.utils import logging

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec
from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs


class FineTuningDataModule(pl.LightningDataModule):
Expand All @@ -50,10 +52,7 @@ class FineTuningDataModule(pl.LightningDataModule):
persistent_workers (bool, optional): Whether to keep data loading workers persistent across epochs. Defaults to False.
max_train_steps (int, optional): Maximum number of steps to train. Used to calculate samples mapping for the mmap dataset
pad_to_max_length (bool, optional): Whether to pad the input to the max sequence length. If False, will pad to the max length of the current batch.
packed_sequence_size (int, optional): If a positive integer, this arg enables training with sequence packing and specifies the pack size
If less than or equal to 0, sequence packing is disabled. Defaults to -1.
Note: This arg is distinct from `seq_length` because `seq_length` specifies the maximum length of the original sequence
(i.e. the length to truncate long sequences in the input data).
packed_sequence_specs (PackedSequenceSpecs, optional): See PackedSequenceSpecs for details
"""

def __init__(
Expand All @@ -70,7 +69,7 @@ def __init__(
pin_memory: bool = True,
persistent_workers: bool = False,
pad_to_max_length: bool = False,
packed_sequence_size: int = -1,
packed_sequence_specs: Optional["PackedSequenceSpecs"] = None,
):
super().__init__()
self.seq_length = seq_length
Expand All @@ -87,22 +86,21 @@ def __init__(
self.data_sampler = None
self.max_train_samples = None
self.pad_to_max_length = pad_to_max_length
self.packed_sequence_size = packed_sequence_size
self._adjust_batch_sizes_for_packed_sequence()
self.packed_sequence_specs = packed_sequence_specs
self.packed_sequence_size = -1 if not packed_sequence_specs else packed_sequence_specs.packed_sequence_size
self.validate_batch_size_for_packed_sequence()

def _adjust_batch_sizes_for_packed_sequence(self):
def validate_batch_size_for_packed_sequence(self):
if self.packed_sequence_size > 0 and self.micro_batch_size > 1:
logging.warning(
raise ValueError(
"Micro batch size should be 1 when training with packed sequence, but your micro batch size "
f"is {self.micro_batch_size}. Your config will be automatically updated to the following: "
f"MBS will be set to 1 (from {self.micro_batch_size}), "
f"GBS will be set to {self.global_batch_size // self.micro_batch_size} (from {self.global_batch_size}), "
f"packed sequence length will be set to {self.packed_sequence_size*self.micro_batch_size} (from {self.packed_sequence_size}). "
f"is {self.micro_batch_size}. \nThe following config is equivalent to your current setting for "
f"a packed dataset. Please update your config to the following: \n"
f"Set micro batch size to 1 (currently {self.micro_batch_size})\n"
f"Set global batch size to {self.global_batch_size // self.micro_batch_size} (currently {self.global_batch_size}) \n"
f"Set packed sequence length to {self.packed_sequence_size*self.micro_batch_size} (currently {self.packed_sequence_size}) \n"
f"For details please visit https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/optimizations/sequence_packing.html"
)
self.global_batch_size //= self.micro_batch_size
self.packed_sequence_size *= self.micro_batch_size
self.micro_batch_size = 1

def prepare_data(self) -> None:
if self.packed_sequence_size > 0 and not self.train_path_packed.is_file():
Expand Down Expand Up @@ -187,7 +185,12 @@ def train_path(self) -> Path:
@property
def train_path_packed(self) -> Path:
if self.packed_sequence_size > 0:
return self.dataset_root / f"training_packed{self.packed_sequence_size}.npy"
if self.packed_sequence_specs.packed_data_path is not None:
return self.packed_sequence_specs.packed_data_path
tokenizer_model_name = self._extract_tokenizer_model_name()
folder_name = self.dataset_root / "packed" / tokenizer_model_name
folder_name.mkdir(parents=True, exist_ok=True)
return folder_name / f"training_{self.packed_sequence_size}.npy"
else:
raise ValueError("`train_path_packed` invalid since packed sequence size is not specified.")

Expand All @@ -198,3 +201,18 @@ def validation_path(self) -> Path:
@property
def test_path(self) -> Path:
return self.dataset_root / "test.jsonl"

def _extract_tokenizer_model_name(self) -> str:
if self.packed_sequence_specs.tokenizer_model_name is not None:
tokenizer_model_name = self.packed_sequence_specs.tokenizer_model_name
elif isinstance(self.tokenizer, AutoTokenizer):
name = self.tokenizer.tokenizer.name_or_path
if name.endswith("nemo_tokenizer"):
# NEMO_HOME/hf_org/hf_model/nemo_tokenizer => hf_org--hf_model
tokenizer_model_name = '--'.join(name.split("/")[-3:-1])
else:
# hf_org/hf_model => hf_org--hf_model
tokenizer_model_name = name.replace("/", "--")
else:
tokenizer_model_name = f"unknown_tokenizer_{hash(self.tokenizer)}"
return tokenizer_model_name
31 changes: 30 additions & 1 deletion nemo/collections/llm/gpt/data/packed_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -83,3 +83,32 @@ def prepare_packed_sequence_data(
# save output data
np.save(output_path, output_data)
logging.info(f"Packed sequence is prepared and saved to {output_path}")


@dataclass
class PackedSequenceSpecs:
packed_sequence_size: int = -1
"""
If a positive integer, this arg enables training with sequence packing and specifies the pack size
If less than or equal to 0, sequence packing is disabled. Defaults to -1.
Note: This arg is distinct from `seq_length` because `seq_length` specifies the maximum length of the original sequence
(i.e. the length to truncate long sequences in the input data).
"""

tokenizer_model_name: str = None
"""
Keep track of tokenizer model name, since each tokenizer produces a different packed sequence dataset file.
This field is set by llm.finetune api.
"""

packed_data_path: Path = None
"""
If specified, use the packed dataset from this file instead of the default path.
"""

def __post_init__(self):
if self.packed_data_path is not None:
assert (
self.packed_data_path.suffix == ".npy"
), f"packed data file must be a .npy file: {self.packed_data_path}"
assert self.packed_data_path.exists(), f"packed data file does not exist: {self.packed_data_path}"
5 changes: 3 additions & 2 deletions nemo/collections/llm/gpt/data/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec
from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs


class SquadDataModule(FineTuningDataModule, IOMixin):
Expand Down Expand Up @@ -54,7 +55,7 @@ def __init__(
pin_memory: bool = True,
persistent_workers: bool = False,
pad_to_max_length: bool = False,
packed_sequence_size: int = -1,
packed_sequence_specs: Optional["PackedSequenceSpecs"] = None,
):
self.force_redownload = force_redownload
self.delete_raw = delete_raw
Expand All @@ -72,7 +73,7 @@ def __init__(
pin_memory=pin_memory,
persistent_workers=persistent_workers,
pad_to_max_length=pad_to_max_length,
packed_sequence_size=packed_sequence_size,
packed_sequence_specs=packed_sequence_specs,
)

def prepare_data(self) -> None:
Expand Down
20 changes: 18 additions & 2 deletions nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(
index_mapping_dir=index_mapping_dir,
)

if is_distributed:
if is_distributed and not _lightning_prepare_data():
torch.distributed.barrier()

if is_distributed and AppState().local_rank == 0:
Expand All @@ -152,7 +152,7 @@ def __init__(
index_mapping_dir=index_mapping_dir,
)

if is_distributed:
if is_distributed and not _lightning_prepare_data():
torch.distributed.barrier()

logging.info(f"Loading data files")
Expand Down Expand Up @@ -749,3 +749,19 @@ def get_sample_block(self, block_idx: int) -> np.ndarray:
sample_block = sample_block % self.dataset_size

return sample_block


def _lightning_prepare_data():
"""
This function checks whether it is invoked in lightning's hook "prepare_data", which is run only on rank 0.
TextMemMapDataset contains a torch.distributed.barrier operation, so when run inside the single-process hook
prepare_data, the barrier operation would hang forever.
"""
import inspect

return any(
[
frame.function == 'prepare_data' and 'prepare_packed_sequence_data' in frame.code_context[0]
for frame in inspect.stack()
]
)
Loading

0 comments on commit 76352fb

Please sign in to comment.