Skip to content

Commit

Permalink
Merge branch 'main' into jiemingz/ckpt_mem_fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
JimmyZhang12 authored Mar 26, 2024
2 parents 55ca157 + eb7941d commit 3d8dd95
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __init__(
self.prompt_template = prompt_template
self.virtual_tokens = virtual_tokens
self.tokens_to_generate = tokens_to_generate
self.memmap_workers = memmap_workers
self.hf_dataset = hf_dataset
self.truncation_method = truncation_method
self.is_test = is_test
self.output_original_text = output_original_text
Expand All @@ -118,25 +120,32 @@ def __init__(
else:
self.special_tokens = special_tokens

if hf_dataset:
self._load_dataset()

# Validate prompt template
self._maybe_validate_prompt_template()

# Will be None after this call if `max_num_samples` is None
self._build_samples_mapping()

def _load_dataset(self):
if self.hf_dataset:
self.indexed_dataset = load_dataset(
'json', data_files=file_path, cache_dir=index_mapping_dir, num_proc=memmap_workers, split='train'
'json',
data_files=self.file_path,
cache_dir=self.index_mapping_dir,
num_proc=self.memmap_workers,
split='train',
)
else:
self.indexed_dataset = JSONLMemMapDataset(
dataset_paths=[file_path],
dataset_paths=[self.file_path],
tokenizer=None,
header_lines=0,
index_mapping_dir=index_mapping_dir,
workers=memmap_workers,
index_mapping_dir=self.index_mapping_dir,
workers=self.memmap_workers,
)

# Validate prompt template
self._maybe_validate_prompt_template()

# Will be None after this call if `max_num_samples` is None
self._build_samples_mapping()

def _maybe_validate_prompt_template(self):
assert (
self.prompt_template is not None
Expand Down Expand Up @@ -476,30 +485,48 @@ def collate_fn(self, batch):


class GPTSFTPackedDataset(GPTSFTDataset):
def __init__(self, file_path: str, tokenizer: TokenizerSpec, **kwargs):
def __init__(self, file_path: str, tokenizer: TokenizerSpec, return_cu_seqlen: bool = True, **kwargs):
super().__init__(file_path, tokenizer, **kwargs)
assert self.virtual_tokens == 0, "P-Tuning with packed sequence is not supported."
self._load_packed_dataset(file_path)

# Whether to return `cu_seqlen` to pass to model. This should be true for almost all use cases.
self.return_cu_seqlen = return_cu_seqlen

np.random.seed(self.seed)

def __getitem__(self, idx):
if self.samples_mapping is not None:
# assert idx < len(self.samples_mapping)
idx = self.samples_mapping[idx]

input_ids = self.indexed_dataset[idx]['input_ids']
seq_boundaries = self.indexed_dataset[idx]['seq_start_id'] + [len(input_ids)]
loss_mask = self.indexed_dataset[idx]['loss_mask']
return {'input_ids': input_ids, 'seq_boundaries': seq_boundaries, 'loss_mask': loss_mask}

def __len__(self):
return len(self.indexed_dataset)

def _load_packed_dataset(self, file_path):
def _load_dataset(self):
try:
self.indexed_dataset = np.load(file_path, allow_pickle=True)
self.indexed_dataset = np.load(self.file_path, allow_pickle=True)
except Exception as e:
logging.error(
f"Failed to load packed dataset. The dataset should be a `.npy` file. "
f"Please check if the packed dataset was prepared correctly. The original error was:\n {e}",
)
exit(1)

def _build_samples_mapping(self):
if self.max_num_samples is not None:
# custom samples mapping logic, following the format for unpacked sft dataset
# Note: this is epoch-level shuffling, i.e. sampling without replacement until end of epoch, then repeat.
# Unpacked dataset shuffles by sampling with replacement indefinitely.
dataset_len = len(self.indexed_dataset)
max_num_epochs = np.ceil(self.max_num_samples / dataset_len)
indices = np.arange(dataset_len)[None, :].repeat(max_num_epochs, axis=0)
[np.random.shuffle(x) for x in indices]
self.samples_mapping = indices.reshape(1, -1).squeeze()[: self.max_num_samples]
else:
self.samples_mapping = None

def _build_loss_mask(self, processed_example):
if self.answer_only_loss:
seq_boundaries = processed_example['seq_boundaries']
Expand Down Expand Up @@ -565,28 +592,42 @@ def collate_fn(self, batch):
position_ids[0]
), "Dataset problem: input_ids and position_ids lengths don't match"

cu_seqlens = self._collate_item(cu_seqlens, max_length=max(len(l) for l in cu_seqlens) + 1, pad_id=-1)
input_ids = self._collate_item(input_ids, max_length=max_length, pad_id=self.tokenizer.eos_id)
labels = self._collate_item(labels, max_length=max_length, pad_id=self.tokenizer.eos_id)
loss_mask = self._collate_item(loss_mask, max_length=max_length, pad_id=0)
position_ids = self._collate_item(position_ids, max_length=max_length, pad_id=0)

# Pre-generate `cu_seqlens_argmin` and `max_seqlen` as CPU tensor to avoid device-to-host copies.
cu_seqlens = torch.IntTensor(cu_seqlens)
cu_seqlens_argmin = torch.argmin(cu_seqlens, dim=1, keepdim=True)
seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1]
max_seqlen, _ = seqlens.max(dim=1, keepdim=True)

processed_batch = {
'tokens': torch.LongTensor(input_ids),
'labels': torch.LongTensor(labels),
'attention_mask': torch.LongTensor([1] * len(input_ids)), # no attention mask is needed for packed seq
'loss_mask': torch.LongTensor(loss_mask),
'position_ids': torch.LongTensor(position_ids),
'cu_seqlens': torch.IntTensor(cu_seqlens), # cu_seqlens_q must be in dtype torch.int32
'token_count': token_count,
'cu_seqlens_argmin': cu_seqlens_argmin,
'max_seqlen': max_seqlen,
}

if self.return_cu_seqlen:
cu_seqlens = self._collate_item(cu_seqlens, max_length=max(len(l) for l in cu_seqlens) + 1, pad_id=-1)

# Pre-generate `cu_seqlens_argmin` and `max_seqlen` as CPU tensor to avoid device-to-host copies.
cu_seqlens = torch.IntTensor(cu_seqlens)
cu_seqlens_argmin = torch.argmin(cu_seqlens, dim=1, keepdim=True)
seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1]
max_seqlen, _ = seqlens.max(dim=1, keepdim=True)

processed_batch.update(
{
'attention_mask': torch.LongTensor(
[1] * len(input_ids)
), # no attention mask is needed for packed seq
'cu_seqlens': torch.IntTensor(cu_seqlens), # cu_seqlens_q must be in dtype torch.int32
'cu_seqlens_argmin': cu_seqlens_argmin, # only required for perf
'max_seqlen': max_seqlen, # only required for perf
}
)
else:
attention_mask = [self._create_attention_mask(max_length) for _ in batch]
processed_batch.update(
{'attention_mask': torch.stack(attention_mask),}
)

return processed_batch
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,13 @@ def _build_dataset(self, data_cfg, is_train=True):
8 * self.cfg.get('tensor_model_parallel_size', 1) if self.cfg.get('sequence_parallel', False) else 16
)

dataset_kwargs = {}
for file_path, num_samples in zip(data_cfg.file_names, num_train_samples_per_dataset):
if self.cfg.data.get("chat", False):
dataset_cls = GPTSFTChatDataset
elif packed_sequence:
dataset_cls = GPTSFTPackedDataset
dataset_kwargs = {'return_cu_seqlen': data_cfg.get("packed_sequence_return_cu_seqlen", True)}
assert data_cfg.micro_batch_size == 1, "Micro batch size must be 1 if using packed sequence"
else:
dataset_cls = GPTSFTDataset
Expand All @@ -281,7 +283,7 @@ def _build_dataset(self, data_cfg, is_train=True):
add_eos=data_cfg.get('add_eos', True),
add_sep=data_cfg.get('add_sep', False),
sep_id=self.sep_id,
max_num_samples=num_samples[0] if not packed_sequence else None,
max_num_samples=num_samples[0],
seed=data_cfg.get('seed', 1234),
label_key=data_cfg.get('label_key', 'answer'),
answer_only_loss=self.cfg.get('answer_only_loss', True),
Expand All @@ -306,6 +308,7 @@ def _build_dataset(self, data_cfg, is_train=True):
'chat_prompt_tokens', None
), # special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '<extra_id_0>', 'turn_start': '<extra_id_1>', 'label_start': '<extra_id_2>', 'end_of_turn': '\n', "end_of_name": "\n"}
is_test=not is_train,
**dataset_kwargs,
)
datasets.append(dataset)
if is_train:
Expand Down

0 comments on commit 3d8dd95

Please sign in to comment.