diff --git a/session/translation/end-to-end/nanot5-small.sh b/session/translation/end-to-end/nanot5-small.sh index fbe31949..6af4f458 100644 --- a/session/translation/end-to-end/nanot5-small.sh +++ b/session/translation/end-to-end/nanot5-small.sh @@ -1,21 +1,24 @@ -WANDB_PROJECT="nanot5-small-malaysian-cased-translation-v3" \ +WANDB_PROJECT="nanot5-small-malaysian-cased-translation-v4" \ torchrun \ ---nproc_per_node 4 \ +--nproc_per_node 1 \ -m run_t5_v2 \ --model_name_or_path mesolitica/nanot5-small-malaysian-cased \ --num_train_epochs 2 \ --eval_steps 1000000000 \ --logging_steps 2 \ ---save_steps 1500 \ +--save_steps 200 \ --save_total_limit 3 \ --do_train \ ---train_file malaysian-translation \ ---output_dir nanot5-small-malaysian-cased-translation-v3 \ ---per_device_train_batch_size=12 \ +--train_file mosaic \ +--output_dir nanot5-small-malaysian-cased-translation-v4-v2 \ +--dataloader_num_workers=10 \ +--per_device_train_batch_size=2 \ --per_device_eval_batch_size=3 \ ---gradient_accumulation_steps=2 \ ---max_source_length 4096 \ ---max_target_length 4096 \ +--gradient_accumulation_steps=16 \ +--max_source_length 2048 \ +--max_target_length 2048 \ --learning_rate 2e-4 \ --gradient_checkpointing true \ ---bf16 \ No newline at end of file +--weight_decay 0.01 \ +--bf16 \ +--run_name nanot5-small-malaysian-cased-translation-v4-1 \ No newline at end of file diff --git a/session/translation/end-to-end/run_t5_v2.py b/session/translation/end-to-end/run_t5_v2.py index 2c11e81a..d3b19d9d 100644 --- a/session/translation/end-to-end/run_t5_v2.py +++ b/session/translation/end-to-end/run_t5_v2.py @@ -23,7 +23,7 @@ import os import sys from dataclasses import dataclass, field -from typing import Optional +from typing import Optional, Any, Union import torch import datasets @@ -51,6 +51,9 @@ from transformers.utils import check_min_version, send_example_telemetry from transformers.utils.versions import require_version from tokenizers import AddedToken +from transformers.data.data_collator import pad_without_fast_tokenizer_warning +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from transformers.utils import PaddingStrategy from streaming import LocalDataset # Will error if the minimal version of Transformers is not installed. Remove at your own risks. @@ -70,6 +73,140 @@ MBart50TokenizerFast, M2M100Tokenizer] +@dataclass +class DataCollatorForSeq2Seq: + """ + Data collator that will dynamically pad the inputs received, as well as the labels. + + Args: + tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): + The tokenizer used for encoding the data. + model ([`PreTrainedModel`], *optional*): + The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to + prepare the *decoder_input_ids* + + This is useful when using *label_smoothing* to avoid calculating loss twice. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) + among: + + - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single + sequence is provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + label_pad_token_id (`int`, *optional*, defaults to -100): + The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). + return_tensors (`str`, *optional*, defaults to `"pt"`): + The type of Tensor to return. Allowable values are "np", "pt" and "tf". + """ + + tokenizer: PreTrainedTokenizerBase + model: Optional[Any] = None + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + label_pad_token_id: int = -100 + return_tensors: str = "pt" + + def __call__(self, features, return_tensors=None): + if return_tensors is None: + return_tensors = self.return_tensors + + label_name = 'labels' + labels = [feature[label_name] for feature in features if feature is not None] + # reconvert list[None] to None if necessary + # this might occur when we pass {..., "labels": None} + if labels is not None and all(label is None for label in labels): + labels = None + non_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features if feature is not None] + + # run through tokenizer without labels to ensure no side effects + batch = pad_without_fast_tokenizer_warning( + self.tokenizer, + non_labels_features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=return_tensors, + ) + + # we have to pad the labels manually as we cannot rely on `tokenizer.pad` and we need them to be of the same length to return tensors + no_padding = self.padding is False or self.padding == PaddingStrategy.DO_NOT_PAD + if labels is not None: + if no_padding: + if isinstance(features[0][label_name], list): + batch["labels"] = list(labels) + else: + batch["labels"] = [np.concatenate([label, []]) for label in labels] + else: + max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None + max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length + if self.pad_to_multiple_of is not None: + max_label_length = ( + (max_label_length + self.pad_to_multiple_of - 1) + // self.pad_to_multiple_of + * self.pad_to_multiple_of + ) + + padding_side = self.tokenizer.padding_side + if isinstance(features[0][label_name], list): + batch["labels"] = [ + label + [self.label_pad_token_id] * (max_label_length - len(label)) + if padding_side == "right" + else [self.label_pad_token_id] * (max_label_length - len(label)) + label + for label in labels + ] + else: + batch["labels"] = [ + np.concatenate( + [ + label, + np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64), + ] + ) + if padding_side == "right" + else np.concatenate( + [ + np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64), + label, + ] + ) + for label in labels + ] + + # reintroduce side effects via tokenizer that return respective datatypes for the `return_tensors` argument + if batch.get("labels", None) is not None: + if return_tensors == "pt": + import torch + + batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64) + elif return_tensors == "tf": + import tensorflow as tf + + batch["labels"] = tf.constant(batch["labels"], dtype=tf.int64) + else: + batch["labels"] = np.array(batch["labels"], dtype=np.int64) + else: + batch["labels"] = None + + # prepare decoder_input_ids + if ( + labels is not None + and self.model is not None + and hasattr(self.model, "prepare_decoder_input_ids_from_labels") + ): + decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"]) + batch["decoder_input_ids"] = decoder_input_ids + + return batch @dataclass class ModelArguments: @@ -325,6 +462,17 @@ def __getitem__(self, idx): max_length=max_target_length, truncation=True, ) + left_eos = outputs["input_ids"][-1] == tokenizer.eos_token_id + right_eos = labels["input_ids"][-1] == tokenizer.eos_token_id + + if left_eos and not right_eos: + print(left_eos, right_eos, 'skip') + return None + + if not left_eos and right_eos: + print(left_eos, right_eos, 'skip') + return None + return { "input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"],