From eb7941dc5f907d9ed95be0664876b0e82b9ce824 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Tue, 26 Mar 2024 10:12:14 -0400 Subject: [PATCH] Packed sequence data shuffling & without `thd` attention (#8693) * packed seq without thd attention for mlperf Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * enable shuffling for packed dataset Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * set seed Signed-off-by: Chen Cui --------- Signed-off-by: Chen Cui Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../megatron/gpt_sft_dataset.py | 99 +++++++++++++------ .../megatron_gpt_sft_model.py | 5 +- 2 files changed, 74 insertions(+), 30 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index d8314990b5cd..bb7bf07e4ad1 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -104,6 +104,8 @@ def __init__( self.prompt_template = prompt_template self.virtual_tokens = virtual_tokens self.tokens_to_generate = tokens_to_generate + self.memmap_workers = memmap_workers + self.hf_dataset = hf_dataset self.truncation_method = truncation_method self.is_test = is_test self.output_original_text = output_original_text @@ -118,25 +120,32 @@ def __init__( else: self.special_tokens = special_tokens - if hf_dataset: + self._load_dataset() + + # Validate prompt template + self._maybe_validate_prompt_template() + + # Will be None after this call if `max_num_samples` is None + self._build_samples_mapping() + + def _load_dataset(self): + if self.hf_dataset: self.indexed_dataset = load_dataset( - 'json', data_files=file_path, cache_dir=index_mapping_dir, num_proc=memmap_workers, split='train' + 'json', + data_files=self.file_path, + cache_dir=self.index_mapping_dir, + num_proc=self.memmap_workers, + split='train', ) else: self.indexed_dataset = JSONLMemMapDataset( - dataset_paths=[file_path], + dataset_paths=[self.file_path], tokenizer=None, header_lines=0, - index_mapping_dir=index_mapping_dir, - workers=memmap_workers, + index_mapping_dir=self.index_mapping_dir, + workers=self.memmap_workers, ) - # Validate prompt template - self._maybe_validate_prompt_template() - - # Will be None after this call if `max_num_samples` is None - self._build_samples_mapping() - def _maybe_validate_prompt_template(self): assert ( self.prompt_template is not None @@ -476,23 +485,28 @@ def collate_fn(self, batch): class GPTSFTPackedDataset(GPTSFTDataset): - def __init__(self, file_path: str, tokenizer: TokenizerSpec, **kwargs): + def __init__(self, file_path: str, tokenizer: TokenizerSpec, return_cu_seqlen: bool = True, **kwargs): super().__init__(file_path, tokenizer, **kwargs) assert self.virtual_tokens == 0, "P-Tuning with packed sequence is not supported." - self._load_packed_dataset(file_path) + + # Whether to return `cu_seqlen` to pass to model. This should be true for almost all use cases. + self.return_cu_seqlen = return_cu_seqlen + + np.random.seed(self.seed) def __getitem__(self, idx): + if self.samples_mapping is not None: + # assert idx < len(self.samples_mapping) + idx = self.samples_mapping[idx] + input_ids = self.indexed_dataset[idx]['input_ids'] seq_boundaries = self.indexed_dataset[idx]['seq_start_id'] + [len(input_ids)] loss_mask = self.indexed_dataset[idx]['loss_mask'] return {'input_ids': input_ids, 'seq_boundaries': seq_boundaries, 'loss_mask': loss_mask} - def __len__(self): - return len(self.indexed_dataset) - - def _load_packed_dataset(self, file_path): + def _load_dataset(self): try: - self.indexed_dataset = np.load(file_path, allow_pickle=True) + self.indexed_dataset = np.load(self.file_path, allow_pickle=True) except Exception as e: logging.error( f"Failed to load packed dataset. The dataset should be a `.npy` file. " @@ -500,6 +514,19 @@ def _load_packed_dataset(self, file_path): ) exit(1) + def _build_samples_mapping(self): + if self.max_num_samples is not None: + # custom samples mapping logic, following the format for unpacked sft dataset + # Note: this is epoch-level shuffling, i.e. sampling without replacement until end of epoch, then repeat. + # Unpacked dataset shuffles by sampling with replacement indefinitely. + dataset_len = len(self.indexed_dataset) + max_num_epochs = np.ceil(self.max_num_samples / dataset_len) + indices = np.arange(dataset_len)[None, :].repeat(max_num_epochs, axis=0) + [np.random.shuffle(x) for x in indices] + self.samples_mapping = indices.reshape(1, -1).squeeze()[: self.max_num_samples] + else: + self.samples_mapping = None + def _build_loss_mask(self, processed_example): if self.answer_only_loss: seq_boundaries = processed_example['seq_boundaries'] @@ -565,28 +592,42 @@ def collate_fn(self, batch): position_ids[0] ), "Dataset problem: input_ids and position_ids lengths don't match" - cu_seqlens = self._collate_item(cu_seqlens, max_length=max(len(l) for l in cu_seqlens) + 1, pad_id=-1) input_ids = self._collate_item(input_ids, max_length=max_length, pad_id=self.tokenizer.eos_id) labels = self._collate_item(labels, max_length=max_length, pad_id=self.tokenizer.eos_id) loss_mask = self._collate_item(loss_mask, max_length=max_length, pad_id=0) position_ids = self._collate_item(position_ids, max_length=max_length, pad_id=0) - # Pre-generate `cu_seqlens_argmin` and `max_seqlen` as CPU tensor to avoid device-to-host copies. - cu_seqlens = torch.IntTensor(cu_seqlens) - cu_seqlens_argmin = torch.argmin(cu_seqlens, dim=1, keepdim=True) - seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1] - max_seqlen, _ = seqlens.max(dim=1, keepdim=True) - processed_batch = { 'tokens': torch.LongTensor(input_ids), 'labels': torch.LongTensor(labels), - 'attention_mask': torch.LongTensor([1] * len(input_ids)), # no attention mask is needed for packed seq 'loss_mask': torch.LongTensor(loss_mask), 'position_ids': torch.LongTensor(position_ids), - 'cu_seqlens': torch.IntTensor(cu_seqlens), # cu_seqlens_q must be in dtype torch.int32 'token_count': token_count, - 'cu_seqlens_argmin': cu_seqlens_argmin, - 'max_seqlen': max_seqlen, } + if self.return_cu_seqlen: + cu_seqlens = self._collate_item(cu_seqlens, max_length=max(len(l) for l in cu_seqlens) + 1, pad_id=-1) + + # Pre-generate `cu_seqlens_argmin` and `max_seqlen` as CPU tensor to avoid device-to-host copies. + cu_seqlens = torch.IntTensor(cu_seqlens) + cu_seqlens_argmin = torch.argmin(cu_seqlens, dim=1, keepdim=True) + seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1] + max_seqlen, _ = seqlens.max(dim=1, keepdim=True) + + processed_batch.update( + { + 'attention_mask': torch.LongTensor( + [1] * len(input_ids) + ), # no attention mask is needed for packed seq + 'cu_seqlens': torch.IntTensor(cu_seqlens), # cu_seqlens_q must be in dtype torch.int32 + 'cu_seqlens_argmin': cu_seqlens_argmin, # only required for perf + 'max_seqlen': max_seqlen, # only required for perf + } + ) + else: + attention_mask = [self._create_attention_mask(max_length) for _ in batch] + processed_batch.update( + {'attention_mask': torch.stack(attention_mask),} + ) + return processed_batch diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index b6aeeb29ed3f..b2879a9171a7 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -258,11 +258,13 @@ def _build_dataset(self, data_cfg, is_train=True): 8 * self.cfg.get('tensor_model_parallel_size', 1) if self.cfg.get('sequence_parallel', False) else 16 ) + dataset_kwargs = {} for file_path, num_samples in zip(data_cfg.file_names, num_train_samples_per_dataset): if self.cfg.data.get("chat", False): dataset_cls = GPTSFTChatDataset elif packed_sequence: dataset_cls = GPTSFTPackedDataset + dataset_kwargs = {'return_cu_seqlen': data_cfg.get("packed_sequence_return_cu_seqlen", True)} assert data_cfg.micro_batch_size == 1, "Micro batch size must be 1 if using packed sequence" else: dataset_cls = GPTSFTDataset @@ -281,7 +283,7 @@ def _build_dataset(self, data_cfg, is_train=True): add_eos=data_cfg.get('add_eos', True), add_sep=data_cfg.get('add_sep', False), sep_id=self.sep_id, - max_num_samples=num_samples[0] if not packed_sequence else None, + max_num_samples=num_samples[0], seed=data_cfg.get('seed', 1234), label_key=data_cfg.get('label_key', 'answer'), answer_only_loss=self.cfg.get('answer_only_loss', True), @@ -306,6 +308,7 @@ def _build_dataset(self, data_cfg, is_train=True): 'chat_prompt_tokens', None ), # special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '', 'turn_start': '', 'label_start': '', 'end_of_turn': '\n', "end_of_name": "\n"} is_test=not is_train, + **dataset_kwargs, ) datasets.append(dataset) if is_train: