diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index 435e439d..5ed8ef66 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -13,15 +13,14 @@ # limitations under the License. """This module contains prompt tuning through PEFT""" # Standard +from datetime import datetime from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import gc import json import os -import tempfile # Third Party -from datasets import Dataset -from datasets import IterableDataset as TransformersIterableDataset +from accelerate import Accelerator from peft import ( MultitaskPromptTuningConfig, PeftConfig, @@ -31,8 +30,12 @@ TaskType, get_peft_model, ) +from torch.optim import AdamW +from torch.utils.data import DataLoader +from tqdm import tqdm from transformers import AutoModelForCausalLM, default_data_collator from transformers.models.auto.tokenization_auto import AutoTokenizer +from transformers.optimization import get_linear_schedule_with_warmup import numpy as np import torch import transformers @@ -56,7 +59,12 @@ PromptOutputModelType, TuningConfig, ) -from ...resources.pretrained_model import HFAutoCausalLM, HFAutoSeq2SeqLM +from ...resources.pretrained_model import ( + HFAutoCausalLM, + HFAutoSeq2SeqLM, + PretrainedModelBase, +) +from ...toolkit.data_stream_wrapper import SimpleIterableStreamWrapper from ...toolkit.data_type_utils import get_torch_dtype, str_to_torch_dtype from ...toolkit.task_specific_utils import convert_to_generation_record from ...toolkit.text_generation.model_run_utils import ( @@ -64,13 +72,6 @@ generate_text_func, generate_text_func_stream, ) -from ...toolkit.text_generation.training_utils import ( - ALLOWED_TRAINING_ARGS, - collect_trainer_arguments, - infer_max_steps, - launch_training, - preprocess_function, -) from ...toolkit.trainer_utils import validate_training_data from ...toolkit.verbalizer_utils import render_verbalizer from .peft_config import TuningType, get_peft_config, resolve_base_model @@ -296,7 +297,7 @@ def train( batch_size: Optional[int] = 8, max_source_length: Optional[int] = 256, max_target_length: Optional[int] = 128, - accumulate_steps: Optional[int] = 1, + accumulate_steps: Optional[int] = 32, torch_dtype: Optional[str] = None, # TODO: Optional[Union[torch.dtype, str]] silence_progress_bars: Optional[bool] = True, seed: int = RANDOM_SEED, @@ -336,7 +337,7 @@ def train( max_target_length: int Max length of target sequences being predicted. Default: 128. accumulate_steps: int - Optional, number of steps to use for gradient accumulation. Default: 1. + Number of steps to use for gradient accumulation. Default: 1. torch_dtype: str TODO: Optional[Union[torch.dtype, str]] Data type to use for training/inference of the underlying text generation model. @@ -361,34 +362,12 @@ def train( # but it can have impact on performance. # transformers.enable_full_determinism(seed) - torch_dtype = get_torch_dtype(torch_dtype) + # HACK - These things can't be passed through the train API currently - # Coerce the passed model into a resource; if we have one, this is a noop - # TODO: When splitting up this mono-module, use the configured resource - # type of the concrete class to bootstrap - base_model = resolve_base_model(base_model, cls, torch_dtype) + metric = kwargs.get("metric") + base_model = resolve_base_model(base_model, cls, torch_dtype) base_model_name = base_model._model_name - - # Enable gradient checkpointing on base model - # PeftModel checks if the base_model has gradient checkpointing - # enabled and then configures the tensors it creates with appropriate - # setting. If we do not enable this, then we will get `tensor 0` requires - # grad error, where `tensor 0` is created by peft - base_model.model.gradient_checkpointing_enable() - - # Get config of the base model - base_model_config = base_model.get_config() - - # Remove _name_or_path field as a model can be - # saved in different location but still same - del base_model_config["_name_or_path"] - error.value_check( - "", - "_name_or_path" not in base_model_config, - "_name_or_path needs to be removed from config!", - ) - task_type, output_model_types, peft_config, tuning_type = get_peft_config( tuning_type, tuning_config, @@ -398,8 +377,6 @@ def train( verbalizer, ) - log.debug("Peft config [%s]", peft_config) - # Check if data is within limit allowed for this module and model validate_training_data( train_stream, @@ -407,8 +384,12 @@ def train( cls.MODULE_ID, ) - train_stream = train_stream.map(convert_to_generation_record) + # Coerce the passed model into a resource; if we have one, this is a noop + # TODO: When splitting up this mono-module, use the configured resource + # type of the concrete class to bootstrap + torch_dtype = get_torch_dtype(torch_dtype) + train_stream = train_stream.map(convert_to_generation_record) if val_stream: error.value_check( "", len(val_stream) > 0, "val_stream cannot be empty" @@ -416,6 +397,19 @@ def train( val_stream = val_stream.map(convert_to_generation_record) + # Convert our datastreams -> data loaders by disguising them as PyTorch iterable datasets + train_dataloader, val_dataloader = cls.create_dataloaders_from_stream( + base_model=base_model, + task_type=task_type, + train_stream=train_stream, + verbalizer=verbalizer, + validation_stream=val_stream or None, + batch_size=batch_size, + max_source_length=max_source_length, + max_target_length=max_target_length, + ) + + log.debug("Peft config [%s]", peft_config) # FIXME: Should only do following line for causal LM (and bloomz?) - check that is the case if isinstance(base_model, HFAutoCausalLM): base_model.model.config.d_model = 1024 @@ -425,88 +419,34 @@ def train( # Convert our Peft model (not just the underlying # transformers model) to the right underlying type. device = cls._get_device(device) - - # cls.convert_peft_model_to_type(device, peft_model, torch_dtype) - - ## Generate data loader from stream - training_dataset: Union[ - Dataset, TransformersIterableDataset - ] = preprocess_function( - base_model=base_model, - train_stream=train_stream, + cls.convert_peft_model_to_type(device, peft_model, torch_dtype) + + training_loss_tracker = cls._execute_train_loop( + peft_model, + num_epochs, + train_dataloader, + device, + eval_dataloader=val_dataloader, + metric=metric, + learning_rate=learning_rate, tokenizer=base_model.tokenizer, - max_source_length=max_source_length, - max_target_length=max_target_length, - shuffle=True, - use_iterable_dataset=False, - random_seed=cls.RANDOM_SEED, - task_ids=0, + accumulate_steps=accumulate_steps, + silence_progress_bars=silence_progress_bars, + torch_dtype=torch_dtype, ) - # Filter **training_arguments to only process allowed ones - filtered_training_arguments = { - k: v for k, v in kwargs.items() if k in ALLOWED_TRAINING_ARGS - } + # Get config of the base model + base_model_config = base_model.get_config() - extra_training_args = set(kwargs.keys()).difference( - filtered_training_arguments.keys() + # Remove _name_or_path field as a model can be + # saved in different location but still same + del base_model_config["_name_or_path"] + error.value_check( + "", + "_name_or_path" not in base_model_config, + "_name_or_path needs to be removed from config!", ) - if extra_training_args: - log.warning( - "", - f"{extra_training_args} parameter(s) not allowed by \ - {cls.__name__} currently and will be ignored!", - ) - - if num_epochs < 1: - log.warning( - "", - f"Number of epochs configured is {num_epochs} which is less than minimum 1. \ - No training will be performed", - ) - - return PeftPromptTuning( - tokenizer=base_model.tokenizer, - model=peft_model, - base_model_config=base_model_config, - base_model_name=base_model_name, - verbalizer=verbalizer, - task_type=task_type, - tuning_type=tuning_type, - output_model_types=output_model_types, - training_metadata={"loss": []}, - ) - - # Open an intermediate checkpoint directory until we've bootstrapped - # our model or we've early exited (if epochs < 1) - with tempfile.TemporaryDirectory() as checkpoint_dir: - - # Formulate training arguments - training_args = collect_trainer_arguments( - torch_dtype, - checkpoint_dir, - batch_size, - num_epochs, - cls.RANDOM_SEED, - learning_rate, - max_steps=infer_max_steps(num_epochs, batch_size, training_dataset), - silence_progress_bars=silence_progress_bars, - accumulate_steps=accumulate_steps, - # NOTE: following can override above arguments in order - **filtered_training_arguments, - ) - - # Use HF Trainer to kick off training on either - # CPU or GPU - training_loss_history = launch_training( - peft_model, - training_dataset, - training_args, - checkpoint_dir, - base_model, - ) - # Wrap up the trained model in a class instance return cls( tokenizer=base_model.tokenizer, @@ -517,7 +457,7 @@ def train( task_type=task_type, tuning_type=tuning_type, output_model_types=output_model_types, - training_metadata={"loss": training_loss_history}, + training_metadata=training_loss_tracker, # TODO: Export other training params to model as well ) @@ -777,6 +717,83 @@ def get_exportable_prompt_vectors( return prompt_dict + @classmethod + def create_dataloaders_from_stream( + cls, + base_model: "caikit_nlp.resources.pretrained_model.base.PretrainedModelBase", + task_type: str, + train_stream: DataStream[GenerationTrainRecord], + verbalizer: str, + batch_size: int, + max_source_length: int, + max_target_length: int, + validation_stream: Union[DataStream[GenerationTrainRecord], None] = None, + collate_fn: Callable = None, + ) -> Tuple[DataLoader]: + """Build PyTorch data loaders around training and (optionally) evaluation DataStreams. + + Args: + base_model: caikit_nlp.resources.pretrained_model.base.PretrainedModelBase + Base resource model used for underlying generation. + task_type: str + Str indicating which task is being accomplished; currently used for determining + tokenization / preprocessing behavior. + train_stream: DataStream[GenerationTrainRecord] + Data to be used for training the prompt vectors of the generation model. + verbalizer: str + Verbalizer template with which we will render text at both train & inference time. + batch_size: int + Batch size to be used for train/eval data loaders. + max_source_length: int + Maximum length to be used for tokenized sequences. + max_target_length: int + Max length of target sequences being predicted. + validation_stream: Union[DataStream[GenerationTrainRecord], None] + Data to be used for validation throughout the train process or None. + collate_fn: Callable + Function to be used for forming batches via lists of dataset inputs. + + Returns: + Tuple[torch.utils.data.DataLoader] + Training & evaluation datastreams for the provided data, respectively. If no + validation_stream is provided, the returned loader for validation_stream will + be None. + """ + if collate_fn is None: + # collate_fn -> pads and maps our inputs to PyTorch vectors + collate_fn = cls._get_collate_fn(base_model.tokenizer, task_type) + + # Grab the data loaders for this task. + # NOTE: Currently we do not expose the buffer size and we + # default to loading the whole dataset into memory + train_dataloader = cls._get_data_loaders_from_stream( + base_model, + train_stream, + base_model.tokenizer, + batch_size, + collate_fn, + verbalizer, + max_source_length, + max_target_length, + shuffle=True, + ) + if validation_stream is not None: + val_dataloader = cls._get_data_loaders_from_stream( + base_model, + validation_stream, + base_model.tokenizer, + batch_size, + collate_fn, + verbalizer, + max_source_length, + max_target_length, + shuffle=False, + ) + else: + val_dataloader = None + + return train_dataloader, val_dataloader + @classmethod def create_hf_tuning_config( cls, @@ -902,6 +919,251 @@ def _get_collate_fn(tokenizer: AutoTokenizer, task_type: str) -> Callable: # want to set labels ourselves. TODO: centralize collator management. return default_data_collator + @staticmethod + def _get_data_loaders_from_stream( + base_model: PretrainedModelBase, + train_stream: DataStream[GenerationTrainRecord], + tokenizer: AutoTokenizer, + batch_size: int, + collate_fn: Callable, + verbalizer: str, + max_source_length: int, + max_target_length: int, + shuffle: bool, + ) -> DataLoader: + """Get the data loaders for train / evaluation. + Args: + base_model: caikit_nlp.resources.pretrained_model.base.PretrainedModelBase + Base resource model used for underlying generation. + train_stream: DataStream[GenerationTrainRecord] + Data to be used for training the prompt vectors of the generation model. + tokenizer: AutoTokenizer + Model tokenizer to be used in preprocessing, i.e., when we iterate over our data. + batch_size: int + Batch sized to be used when building the DataLoader around the stream. + collate_fn: Callable + Function to be used for forming batches via lists of dataset inputs. + verbalizer: str + Verbalizer template to be used for formatting data. This template may use brackets + to indicate where fields from the data model TrainGenerationRecord must be rendered. + max_source_length: int + Max length of sequences being considered. + max_target_length: int + Max length of target sequences being predicted. + shuffle: bool + Indicates whether or not the stream should reshuffle upon reentry. + + Returns: + torch.utils.data.DataLoader + DataLoader to be used for training / evaluating the stream data. + """ + (tokenize_function, _,) = base_model.build_task_tokenize_closure( + tokenizer, max_source_length, max_target_length, verbalizer, task_ids=0 + ) + mapped_stream = train_stream.map(tokenize_function) + # TODO: Deprecate and remove stream wrapper & use trainer + wrapped_stream = SimpleIterableStreamWrapper(mapped_stream, shuffle=shuffle) + dataloader = DataLoader( + wrapped_stream, collate_fn=collate_fn, batch_size=batch_size + ) + + return dataloader + + @classmethod + def _execute_train_loop( + cls, + model: PeftModel, + num_epochs: int, + train_dataloader: DataLoader, + device: str, + eval_dataloader: Union[DataLoader, None] = None, + metric: Optional[Callable] = None, + learning_rate: int = 1e-3, + tokenizer: Union[AutoTokenizer, None] = None, + accumulate_steps: int = 1, + silence_progress_bars: bool = True, + torch_dtype: "torch.dtype" = torch.float32, + ) -> None: + """Execute the core training logic for training the prompt vectors on the frozen model. + Note that this is done by reference. + + Args: + model: PeftModel + Underlying model being leveraged for text generation via prompt tuning. + num_epochs: int + Number of epochs to train. + train_dataloader: torch.utils.data.DataLoader + DataLoader to be used for loading training data. + device: str + Device to be used for training the model. + eval_dataloader: Union[DataLoader, None]. + DataLoader to be used for loading eval data or None. + metric: Union[Callable, None] + Function to be used for evaluating data if an eval data loader is provided. + Default: None. + learning_rate: float + Learning rate to be used while tuning prompt vectors. Default: 1e-3. + tokenizer: Union[AutoTokenizer, None] + Tokenizer for default evaluation; only used if no metric is provided and we have + an eval dataloader. + TODO - remove this can likely be removed. + accumulate_steps: int + Number of steps to use for gradient accumulation. Default: 1. + silence_progress_bars: bool + Silences TQDM progress bars. Default: True + torch_dtype: torch.dtype + Dtype to be used for training. Default: torch.float32 + + Returns: + training_metadata: Dict + Metadata computed during training + """ + optimizer = AdamW(params=model.parameters(), lr=learning_rate) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=(len(train_dataloader) * num_epochs), + ) + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + + if torch_dtype == torch.float16: + mixed_precision = "fp16" + elif ( + torch.cuda.is_available() + and torch.cuda.is_bf16_supported() + and torch_dtype == torch.bfloat16 + ): + mixed_precision = "bf16" + else: + mixed_precision = "no" + + accelerator = Accelerator( + gradient_accumulation_steps=accumulate_steps, + device_placement=True, + mixed_precision=mixed_precision, + ) + + # Disable cache for training + model.config.use_cache = False + + # Below would send all the data and model to + # configured device and convert them to required dtypes + model, optimizer, new_train_dataloader, lr_scheduler = accelerator.prepare( + model, + optimizer, + train_dataloader, + lr_scheduler, + ) + + training_loss_tracker = [] + + step_count = 1 + + for epoch in range(num_epochs): + step_loss_log = {} + model.train() + total_loss = 0 + tqdm_loader = tqdm(new_train_dataloader, disable=silence_progress_bars) + for batch in tqdm_loader: + + tqdm_loader.set_description("Epoch: {}".format(epoch)) + + # TODO Can this dict comprehension always replace "batch.to(device)" for us? + try: + with accelerator.accumulate(model): + outputs = model(**batch) + loss = outputs.loss + # We are converting loss to float explicitely for later use + # keeping it in tensor form can potentially cause memory issues + loss_float = loss.detach().float().item() + total_loss += loss_float + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + step_loss_log[step_count] = loss_float + step_count += 1 + except ( + torch.cuda.OutOfMemoryError # pylint: disable=catching-non-exception + ): + error( + "", + MemoryError("Not enough memory available for training!"), + ) + + log.info("", {"loss": loss_float, "epoch": epoch}) + + for step, loss_val in step_loss_log.items(): + + # Below is added to be propagated and stored as training_metadata + training_loss_tracker.append( + { + "epoch": epoch, + "step": step, + "value": loss_val, + "timestamp": datetime.isoformat(datetime.now()), + } + ) + + if eval_dataloader is not None: + model.eval() + + if metric is not None: + for _, batch in enumerate( + tqdm(eval_dataloader, disable=silence_progress_bars) + ): + batch.to(device) + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + references = batch["labels"] + metric.add_batch( + predictions=predictions, + references=references, + ) + eval_metric = metric.compute() + + log.info("epoch %s: %s", epoch, eval_metric) + else: + eval_loss = 0 + # TODO Can we get away with not maintaining eval_preds? + eval_preds = [] + for _, batch in enumerate( + tqdm(eval_dataloader, disable=silence_progress_bars) + ): + batch = {k: v.to(device) for k, v in batch.items()} + with torch.no_grad(): + outputs = model(**batch) + loss = outputs.loss + eval_loss += loss.detach().float() + + if tokenizer is not None: + eval_preds.extend( + tokenizer.batch_decode( + torch.argmax(outputs.logits, -1) + .detach() + .cpu() + .numpy(), + skip_special_tokens=True, + ) + ) + + eval_epoch_loss = eval_loss / len(train_dataloader) + eval_ppl = torch.exp(eval_epoch_loss) + train_epoch_loss = total_loss / len(eval_dataloader) + train_ppl = torch.exp(train_epoch_loss) + log.debug( + "epoch %s: %s %s %s %s", + epoch, + train_ppl, + train_epoch_loss, + eval_ppl, + eval_epoch_loss, + ) + return {"loss": training_loss_tracker} + @classmethod def _filter_params_for_prompt_config(cls, prompt_config, params): """Utility function to filter out required parameters for prompt_config diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index 65934c77..e0558d4a 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -70,9 +70,6 @@ class TextGeneration(ModuleBase): # Below list is taken from # https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments - # FIXME: Temporarily disable duplicate code check here as - # we will remove below code in next iteration when we consolidate HF Trainer - # pylint: disable=duplicate-code allowed_training_args = { "weight_decay", "adam_beta1", diff --git a/caikit_nlp/resources/pretrained_model/base.py b/caikit_nlp/resources/pretrained_model/base.py index 6adad903..eba74744 100644 --- a/caikit_nlp/resources/pretrained_model/base.py +++ b/caikit_nlp/resources/pretrained_model/base.py @@ -279,7 +279,6 @@ def get_trainer( train_dataset: IterableDataset, eval_dataset: Union[IterableDataset, None] = None, optimizers=(None, None), - model=None, **kwargs, ): """ @@ -304,10 +303,6 @@ def get_trainer( "optimizers": optimizers, "eval_dataset": eval_dataset, } - # If extra model is provided, we will configure trainer - # with that model - if model: - return LoggingTrainer(model, training_args, **trainer_arguments) return LoggingTrainer(self._model, training_args, **trainer_arguments) diff --git a/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py b/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py index 089dcfd9..598b0136 100644 --- a/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py +++ b/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py @@ -98,7 +98,6 @@ def get_trainer( train_dataset: IterableDataset, eval_dataset: Union[IterableDataset, None] = None, optimizers=(None, None), - model=None, **kwargs ): """ @@ -129,11 +128,6 @@ def get_trainer( # "generation_max_length": max_target_length, } - # If extra model is provided, we will configure trainer - # with that model - if model: - return LoggingTrainer(model, training_args, **trainer_arguments) - return LoggingTrainer(self._model, training_args, **trainer_arguments) def _get_data_collator(self, **kwargs): diff --git a/caikit_nlp/toolkit/text_generation/training_utils.py b/caikit_nlp/toolkit/text_generation/training_utils.py deleted file mode 100644 index ac905fad..00000000 --- a/caikit_nlp/toolkit/text_generation/training_utils.py +++ /dev/null @@ -1,249 +0,0 @@ -# Copyright The Caikit Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utility script that contains logic for training""" - -# Standard -from typing import List, Optional, Union - -# Third Party -from datasets import Dataset -from datasets import IterableDataset as TransformersIterableDataset -from transformers import AutoTokenizer -import torch - -# First Party -from caikit.core.data_model import DataStream -from caikit.core.toolkit import error_handler -import alog - -# Local -from ...data_model import GenerationTrainRecord -from ...resources.pretrained_model import PretrainedModelBase - -log = alog.use_channel("TXTGEN_TRN_UTLS") -error = error_handler.get(log) - -# Below list is taken from -# https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments -ALLOWED_TRAINING_ARGS = { - "weight_decay", - "adam_beta1", - "adam_beta2", - "adam_epsilon", - "max_grad_norm", - "lr_scheduler_type", - "warmup_ratio", - "warmup_steps", - "use_ipex", - "disable_tqdm", - "label_names", - "optim", - "optim_args", - "group_by_length", - "dataloader_pin_memory", - "gradient_checkpointing", - "full_determinism", -} - -# Create trainer arguments -def collect_trainer_arguments( - torch_dtype, - output_dir, - batch_size, - num_epochs, - random_seed, - learning_rate, - max_steps, - silence_progress_bars=True, - accumulate_steps=1, - **kwargs -): - """Utility function to return processed HF Trainer argument dictionary""" - - # NOTE: Following is not exhaustive list of all parameters - # for all dtypes - if torch_dtype == torch.float16: - dtype_based_params = { - "fp16": True, - } - elif torch_dtype == torch.bfloat16: - dtype_based_params = { - "bf16": True, - } - else: - # default to float32 - dtype_based_params = {} - - return { - # trainer settings - "output_dir": output_dir, - # NOTE: We have disabled evaluation for now - "do_eval": False, - "do_train": True, - "no_cuda": not torch.cuda.is_available(), - # NOTE: This is explicitly set to false since it will - # negatively impact the performance - "full_determinism": False, - # logging configuration - "logging_strategy": "steps", - "logging_steps": 1, # logging at every step - "disable_tqdm": silence_progress_bars, - # computation configurations - "seed": random_seed, - "per_device_train_batch_size": batch_size, - "per_device_eval_batch_size": batch_size, - "num_train_epochs": num_epochs, - "learning_rate": learning_rate, - "weight_decay": 0.01, - "save_total_limit": 3, - "gradient_checkpointing": True, - "gradient_accumulation_steps": accumulate_steps, - # huggingface configurations - "push_to_hub": False, - # dataset configurations - "remove_unused_columns": True, - "dataloader_pin_memory": False, - # Required for iterable dataset - "max_steps": max_steps, - # others - "auto_find_batch_size": True, - **dtype_based_params, - **kwargs, - } - - -def preprocess_function( - base_model: PretrainedModelBase, - train_stream: DataStream[GenerationTrainRecord], - tokenizer: AutoTokenizer, - max_source_length: int, - max_target_length: int, - shuffle: bool, - use_iterable_dataset: bool, - random_seed: int, - task_ids: Optional[List[int]] = None, -): - """Pre-process each example to get it prepared for training.""" - dataset_type = TransformersIterableDataset if use_iterable_dataset else Dataset - log.debug("Loading dataset class: [%s]", dataset_type.__name__) - fn_kwargs = { - "tokenizer": tokenizer, - "max_source_length": max_source_length, - "max_target_length": max_target_length, - } - if task_ids is not None: - fn_kwargs["task_ids"] = task_ids - - # TODO: Add check for empty training stream - dataset = dataset_type.from_generator( - get_record, gen_kwargs={"train_stream": train_stream} - ) - mapped_dataset = dataset.map( - base_model.tokenize_function, - fn_kwargs=fn_kwargs, - # For now, we hardcode to False, since causal LM chunking is not exposed yet - batched=False, - # batched=base_model.REQUIRES_TOKEN_UNWRAPPING, - # Drop the input / output columns; we need to do this for dimensions to play - # happily when operating on batched inputs for causal language modeling. - remove_columns=["input", "output"], - ) - - if shuffle: - log.debug("Shuffling the dataset") - return mapped_dataset.shuffle(seed=random_seed) - - return mapped_dataset - - -def launch_training( - base_model, - training_dataset, - training_args, - checkpoint_dir, - caikit_resource=None, - tokenizer=None, -) -> None: - """Utility function to wrap trainer and execute training""" - # If we have a caikit resource, grab the trainer through it - if caikit_resource is not None: - trainer = caikit_resource.get_trainer( - train_dataset=training_dataset, model=base_model, **training_args - ) - else: - # If trainer is not provided fetch it from base_model - if hasattr(base_model, "get_trainer"): - trainer = base_model.get_trainer( - train_dataset=training_dataset, **training_args - ) - else: - error("", "could not resolve trainer. Check base model type!") - - # Start training via Trainer.train function - result = trainer.train() - - # Log the output of the training. This will include stats about training - log.info("", "Training completed. Summary: {}".format(result)) - - # save the model temporarily and reload it - # this is done, since otherwise the model might be distributed in different - # devices, in which case its better to use trainer's `prediction_step` - # functions, but then, they don't always give API similar to `generate` - # and thus cause incompatibilities in `run` function - trainer.save_state() - trainer.save_model(checkpoint_dir) - - # save tokenizer explicitly - if hasattr(base_model, "tokenizer"): - base_model.tokenizer.save_pretrained(checkpoint_dir) - elif tokenizer: - tokenizer.save_pretrained(checkpoint_dir) - else: - log.warning( - "", - "Cannot save tokenizer as not available to train function.", - ) - - # Below will return log history but launch will automatically attach rank to it. - # if started in distributed fashion - return trainer.state.log_history - - -def infer_max_steps( - num_epochs: int, - batch_size: int, - training_dataset: Union[Dataset, TransformersIterableDataset], -): - # Calculate the number of samples that we have - if isinstance(training_dataset, Dataset): - data_len = len(training_dataset) - else: - data_len = 0 - for _ in training_dataset: - data_len += 1 - # Figure out how many batches we'll have per epoch - num_batches = data_len // batch_size - # Assume drop_last=False; in general, this doesn't really matter. - # We mostly do this to avoid strange behavior when the dataset - # size is smaller than the batch size. - if num_batches != (data_len * batch_size): - num_batches += 1 - num_steps = num_batches * num_epochs - log.debug("Number of inferred steps: [%s]", num_steps) - return num_steps - - -def get_record(train_stream): - for data in train_stream: - yield {"input": data.input, "output": data.output}