Skip to content

Commit

Permalink
pad to product of tp and sp group size
Browse files Browse the repository at this point in the history
  • Loading branch information
jquesnelle committed Aug 14, 2024
1 parent e34cabb commit 62f824b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
1 change: 1 addition & 0 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def get_dataloader_from_data_stage(
conversation_column_name=data.dataset.conversation_column_name,
dp_rank=trainer.parallel_context.dp_pg.rank(),
dp_ranks_size=trainer.parallel_context.dp_pg.size(),
tp_ranks_size=trainer.parallel_context.tp_pg.size(),
sp_ranks_size=trainer.parallel_context.sp_pg.size(),
seed=data.seed,
)
Expand Down
8 changes: 5 additions & 3 deletions src/nanotron/data/chat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
split: str = "train",
dp_rank: int = 0,
dp_ranks_size: int = 1,
tp_ranks_size: int = 1,
sp_ranks_size: int = 1,
skip_num_samples: int = None, # TODO(tj.solergibert) Delete, check later comment
seed: int = 1234,
Expand All @@ -66,6 +67,7 @@ def __init__(
self.seed = seed
self.pack_samples = pack_samples
self.sp_chunks = sp_ranks_size * 2 if sp_ranks_size > 1 else 1
self.sp_tp_product = self.sp_chunks * tp_ranks_size

# Load, split and shuffle dataset
self.dataset = load_dataset(dataset_path, split=split, streaming=True)
Expand Down Expand Up @@ -143,14 +145,14 @@ def __iter__(self):
is_completition = self.create_labels(tokens, is_completition)
input_mask = ([1] * len(tokens))

rem = len(tokens) % self.sp_chunks
rem = len(tokens) % self.sp_tp_product
if rem != 0:
pad_amount = self.sp_chunks - rem
pad_amount = self.sp_tp_product - rem
tokens.extend([self.chat_tokenizer.tokenizer.pad_token_id] * pad_amount)
is_completition.extend([False] * pad_amount)
input_mask.extend([0] * pad_amount)

if self.sp_chunks > 1:
if self.sp_tp_product > 1:
# sequence needs to be of length (closest multiple of 2 * sp_pg.size()) + 1
# + 1 is so we have (closest multiple of 2 * sp_pg.size()) after shifting by one to get causal prediction
tokens.append(self.chat_tokenizer.tokenizer.pad_token_id)
Expand Down
22 changes: 13 additions & 9 deletions src/nanotron/data/dataloader_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,25 @@ def build_chat_dataloader(
output_pp_rank: int,
dataloader_pin_memory: bool = True,
) -> DataLoader:
pack_samples = dataset.pack_samples

# Case of ranks not requiring data. We give them a dummy dataset, then the collator will do his job
if dist.get_rank(parallel_context.pp_pg) not in [input_pp_rank, output_pp_rank]:
dataset_length = 1_000_000 # len(dataset) TODO find a more elegant way to specify this dummy dataset
dataset = EmptyInfiniteDataset(length=dataset_length)

data_collator = DataCollatorForSFT(
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
parallel_context=parallel_context,
) if dataset.pack_samples else DataCollatorForUnpackedSFT(
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
parallel_context=parallel_context,
)
if pack_samples:
data_collator = DataCollatorForSFT(
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
parallel_context=parallel_context,
)
else:
data_collator = DataCollatorForUnpackedSFT(
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
parallel_context=parallel_context,
)

dp_rank = parallel_context.dp_pg.rank()

Expand Down

0 comments on commit 62f824b

Please sign in to comment.